[tests] adapt events tests to recent changes
[leap_pycommon.git] / src / leap / common / events / client.py
1 # -*- coding: utf-8 -*-
2 # client.py
3 # Copyright (C) 2013, 2014, 2015 LEAP
4 #
5 # This program is free software: you can redistribute it and/or modify
6 # it under the terms of the GNU General Public License as published by
7 # the Free Software Foundation, either version 3 of the License, or
8 # (at your option) any later version.
9 #
10 # This program is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 # GNU General Public License for more details.
14 #
15 # You should have received a copy of the GNU General Public License
16 # along with this program. If not, see <http://www.gnu.org/licenses/>.
17 """
18 The client end point of the events mechanism.
19
20 Clients are the communicating parties of the events mechanism. They
21 communicate by sending messages to a server, which in turn redistributes
22 messages to other clients.
23
24 When a client registers a callback for a given event, it also tells the
25 server that it wants to be notified whenever events of that type are sent by
26 some other client.
27 """
28 import logging
29 import collections
30 import uuid
31 import threading
32 import time
33 import pickle
34 import os
35
36 from abc import ABCMeta
37 from abc import abstractmethod
38
39 import zmq
40 from zmq.eventloop import zmqstream
41 from zmq.eventloop import ioloop
42
43 # XXX some distros don't package libsodium, so we have to be prepared for
44 #     absence of zmq.auth
45 try:
46     import zmq.auth
47 except ImportError:
48     pass
49
50 from leap.common.config import flags, get_path_prefix
51 from leap.common.zmq_utils import zmq_has_curve
52 from leap.common.zmq_utils import maybe_create_and_get_certificates
53 from leap.common.zmq_utils import PUBLIC_KEYS_PREFIX
54
55 from leap.common.events.errors import CallbackAlreadyRegisteredError
56 from leap.common.events.server import EMIT_ADDR
57 from leap.common.events.server import REG_ADDR
58 from leap.common.events import catalog
59
60
61 logger = logging.getLogger(__name__)
62
63
64 _emit_addr = EMIT_ADDR
65 _reg_addr = REG_ADDR
66 _factory = None
67 _enable_curve = True
68
69
70 def configure_client(emit_addr, reg_addr, factory=None, enable_curve=True):
71     global _emit_addr, _reg_addr, _factory, _enable_curve
72     logger.debug("Configuring client with addresses: (%s, %s)" %
73                  (emit_addr, reg_addr))
74     _emit_addr = emit_addr
75     _reg_addr = reg_addr
76     _factory = factory
77     _enable_curve = enable_curve
78
79
80 class EventsClient(object):
81     """
82     A singleton client for the events mechanism.
83     """
84
85     __metaclass__ = ABCMeta
86
87     _instance = None
88     _instance_lock = threading.Lock()
89
90     def __init__(self, emit_addr, reg_addr):
91         """
92         Initialize the events client.
93         """
94         logger.debug("Creating client instance.")
95         self._callbacks = collections.defaultdict(dict)
96         self._emit_addr = emit_addr
97         self._reg_addr = reg_addr
98
99     @property
100     def callbacks(self):
101         return self._callbacks
102
103     @classmethod
104     def instance(cls):
105         """
106         Return a singleton EventsClient instance.
107         """
108         with cls._instance_lock:
109             if cls._instance is None:
110                 cls._instance = cls(
111                     _emit_addr, _reg_addr, factory=_factory,
112                     enable_curve=_enable_curve)
113         return cls._instance
114
115     def register(self, event, callback, uid=None, replace=False):
116         """
117         Register a callback to be executed when an event is received.
118
119         :param event: The event that triggers the callback.
120         :type event: Event
121         :param callback: The callback to be executed.
122         :type callback: callable(event, *content)
123         :param uid: The callback uid.
124         :type uid: str
125         :param replace: Wether an eventual callback with same ID should be
126                         replaced.
127         :type replace: bool
128
129         :return: The callback uid.
130         :rtype: str
131
132         :raises CallbackAlreadyRegisteredError: when there's already a callback
133                 identified by the given uid and replace is False.
134         """
135         logger.debug("Subscribing to event: %s" % event)
136         if not uid:
137             uid = uuid.uuid4()
138         elif uid in self._callbacks[event] and not replace:
139             raise CallbackAlreadyRegisteredError()
140         self._callbacks[event][uid] = callback
141         self._subscribe(str(event))
142         return uid
143
144     def unregister(self, event, uid=None):
145         """
146         Unregister callbacks for an event.
147
148         If uid is not None, then only the callback identified by the given uid
149         is removed. Otherwise, all callbacks for the event are removed.
150
151         :param event: The event that triggers the callback.
152         :type event: Event
153         :param uid: The callback uid.
154         :type uid: str
155         """
156         if not uid:
157             logger.debug(
158                 "Unregistering all callbacks from event %s." % event)
159             self._callbacks[event] = {}
160         else:
161             logger.debug(
162                 "Unregistering callback %s from event %s." % (uid, event))
163             if uid in self._callbacks[event]:
164                 del self._callbacks[event][uid]
165         if not self._callbacks[event]:
166             del self._callbacks[event]
167             self._unsubscribe(str(event))
168
169     def emit(self, event, *content):
170         """
171         Send an event.
172
173         :param event: The event to be sent.
174         :type event: Event
175         :param content: The content of the event.
176         :type content: list
177         """
178         logger.debug("Emitting event: (%s, %s)" % (event, content))
179         payload = str(event) + b'\0' + pickle.dumps(content)
180         self._send(payload)
181
182     def _handle_event(self, event, content):
183         """
184         Handle an incoming event.
185
186         :param event: The event to be sent.
187         :type event: Event
188         :param content: The content of the event.
189         :type content: list
190         """
191         logger.debug("Handling event %s..." % event)
192         for uid in self._callbacks[event]:
193             callback = self._callbacks[event][uid]
194             logger.debug("Executing callback %s." % uid)
195             self._run_callback(callback, event, content)
196
197     @abstractmethod
198     def _run_callback(self, callback, event, content):
199         """
200         Run a callback.
201
202         :param callback: The callback to be run.
203         :type callback: callable(event, *content)
204         :param event: The event to be sent.
205         :type event: Event
206         :param content: The content of the event.
207         :type content: list
208         """
209         pass
210
211     @abstractmethod
212     def _subscribe(self, tag):
213         """
214         Subscribe to a tag on the zmq SUB socket.
215
216         :param tag: The tag to be subscribed.
217         :type tag: str
218         """
219         pass
220
221     @abstractmethod
222     def _unsubscribe(self, tag):
223         """
224         Unsubscribe from a tag on the zmq SUB socket.
225
226         :param tag: The tag to be unsubscribed.
227         :type tag: str
228         """
229         pass
230
231     @abstractmethod
232     def _send(self, data):
233         """
234         Send data through PUSH socket.
235
236         :param data: The data to be sent.
237         :type event: str
238         """
239         pass
240
241     def shutdown(self):
242         self.__class__.reset()
243
244     @classmethod
245     def reset(cls):
246         with cls._instance_lock:
247             cls._instance = None
248
249
250 class EventsIOLoop(ioloop.ZMQIOLoop):
251     """
252     An extension of zmq's ioloop that can wait until there are no callbacks
253     in the queue before stopping.
254     """
255
256     def stop(self, wait=False):
257         """
258         Stop the I/O loop.
259
260         :param wait: Whether we should wait for callbacks in queue to finish
261                      before stopping.
262         :type wait: bool
263         """
264         if wait:
265             # prevent new callbacks from being added
266             with self._callback_lock:
267                 self._closing = True
268             # wait until all callbacks have been executed
269             while self._callbacks:
270                 time.sleep(0.1)
271         ioloop.ZMQIOLoop.stop(self)
272
273
274 class EventsClientThread(threading.Thread, EventsClient):
275     """
276     A threaded version of the events client.
277     """
278
279     def __init__(self, emit_addr, reg_addr, factory=None, enable_curve=True):
280         """
281         Initialize the events client.
282         """
283         threading.Thread.__init__(self)
284         EventsClient.__init__(self, emit_addr, reg_addr)
285         self._lock = threading.Lock()
286         self._initialized = threading.Event()
287         self._config_prefix = os.path.join(
288             get_path_prefix(flags.STANDALONE), "leap", "events")
289         self._loop = None
290         self._factory = factory
291         self._context = None
292         self._push = None
293         self._sub = None
294
295         if enable_curve:
296             self.use_curve = zmq_has_curve()
297         else:
298             self.use_curve = False
299
300     def _init_zmq(self):
301         """
302         Initialize ZMQ connections.
303         """
304         self._loop = EventsIOLoop()
305         # we need a new context for each thread
306         self._context = zmq.Context()
307         # connect SUB first, otherwise we might miss some event sent from this
308         # same client
309         self._sub = self._zmq_connect_sub()
310         self._push = self._zmq_connect_push()
311
312     def _zmq_connect(self, socktype, address):
313         """
314         Connect to an address using with a zmq socktype.
315
316         :param socktype: The ZMQ socket type.
317         :type socktype: int
318         :param address: The address to connect to.
319         :type address: str
320
321         :return: A ZMQ connection stream.
322         :rtype: ZMQStream
323         """
324         logger.debug("Connecting %s to %s." % (socktype, address))
325         socket = self._context.socket(socktype)
326         # configure curve authentication
327         if self.use_curve:
328             public, private = maybe_create_and_get_certificates(
329                 self._config_prefix, "client")
330             server_public_file = os.path.join(
331                 self._config_prefix, PUBLIC_KEYS_PREFIX, "server.key")
332             server_public, _ = zmq.auth.load_certificate(server_public_file)
333             socket.curve_publickey = public
334             socket.curve_secretkey = private
335             socket.curve_serverkey = server_public
336         stream = zmqstream.ZMQStream(socket, self._loop)
337         socket.connect(address)
338         return stream
339
340     def _zmq_connect_push(self):
341         """
342         Initialize the client's PUSH connection.
343
344         :return: A ZMQ connection stream.
345         :rtype: ZMQStream
346         """
347         return self._zmq_connect(zmq.PUSH, self._emit_addr)
348
349     def _zmq_connect_sub(self):
350         """
351         Initialize the client's SUB connection.
352
353         :return: A ZMQ connection stream.
354         :rtype: ZMQStream
355         """
356         stream = self._zmq_connect(zmq.SUB, self._reg_addr)
357         stream.on_recv(self._on_recv)
358         return stream
359
360     def _on_recv(self, msg):
361         """
362         Handle an incoming message in the SUB socket.
363
364         :param msg: The received message.
365         :type msg: str
366         """
367         ev_str, content_pickle = msg[0].split(b'\0', 1)  # undo txzmq tagging
368         event = getattr(catalog, ev_str)
369         content = pickle.loads(content_pickle)
370         self._handle_event(event, content)
371
372     def _subscribe(self, tag):
373         """
374         Subscribe from a tag on the zmq SUB socket.
375
376         :param tag: The tag to be subscribed.
377         :type tag: str
378         """
379         self._sub.socket.setsockopt(zmq.SUBSCRIBE, tag)
380
381     def _unsubscribe(self, tag):
382         """
383         Unsubscribe from a tag on the zmq SUB socket.
384
385         :param tag: The tag to be unsubscribed.
386         :type tag: str
387         """
388         self._sub.socket.setsockopt(zmq.UNSUBSCRIBE, tag)
389
390     def _send(self, data):
391         """
392         Send data through PUSH socket.
393
394         :param data: The data to be sent.
395         :type event: str
396         """
397         # add send() as a callback for ioloop so it works between threads
398         self._loop.add_callback(lambda: self._push.send(data))
399
400     def _run_callback(self, callback, event, content):
401         """
402         Run a callback.
403
404         :param callback: The callback to be run.
405         :type callback: callable(event, *content)
406         :param event: The event to be sent.
407         :type event: Event
408         :param content: The content of the event.
409         :type content: list
410         """
411         self._loop.add_callback(lambda: callback(event, *content))
412
413     def register(self, event, callback, uid=None, replace=False):
414         """
415         Register a callback to be executed when an event is received.
416
417         :param event: The event that triggers the callback.
418         :type event: Event
419         :param callback: The callback to be executed.
420         :type callback: callable(event, *content)
421         :param uid: The callback uid.
422         :type uid: str
423         :param replace: Wether an eventual callback with same ID should be
424                         replaced.
425         :type replace: bool
426
427         :return: The callback uid.
428         :rtype: str
429
430         :raises CallbackAlreadyRegisteredError: when there's already a
431                 callback identified by the given uid and replace is False.
432         """
433         self.ensure_client()
434         return EventsClient.register(
435             self, event, callback, uid=uid, replace=replace)
436
437     def unregister(self, event, uid=None):
438         """
439         Unregister callbacks for an event.
440
441         If uid is not None, then only the callback identified by the given uid
442         is removed. Otherwise, all callbacks for the event are removed.
443
444         :param event: The event that triggers the callback.
445         :type event: Event
446         :param uid: The callback uid.
447         :type uid: str
448         """
449         self.ensure_client()
450         EventsClient.unregister(self, event, uid=uid)
451
452     def emit(self, event, *content):
453         """
454         Send an event.
455
456         :param event: The event to be sent.
457         :type event: Event
458         :param content: The content of the event.
459         :type content: list
460         """
461         self.ensure_client()
462         EventsClient.emit(self, event, *content)
463
464     def run(self):
465         """
466         Run the events client.
467         """
468         logger.debug("Starting ioloop.")
469         self._init_zmq()
470         self._initialized.set()
471         self._loop.start()
472         self._loop.close()
473         logger.debug("Ioloop finished.")
474
475     def ensure_client(self):
476         """
477         Make sure the events client thread is started.
478         """
479         with self._lock:
480             if not self.is_alive():
481                 self.daemon = True
482                 self.start()
483                 self._initialized.wait()
484
485     def shutdown(self):
486         """
487         Shutdown the events client thread.
488         """
489         logger.debug("Shutting down client...")
490         with self._lock:
491             if self.is_alive():
492                 self._loop.stop(wait=True)
493         EventsClient.shutdown(self)
494
495
496 def shutdown():
497     """
498     Shutdown the events client thread.
499     """
500     EventsClientThread.instance().shutdown()
501
502
503 def register(event, callback, uid=None, replace=False):
504     """
505     Register a callback to be executed when an event is received.
506
507     :param event: The event that triggers the callback.
508     :type event: str
509     :param callback: The callback to be executed.
510     :type callback: callable(event, content)
511     :param uid: The callback uid.
512     :type uid: str
513     :param replace: Wether an eventual callback with same ID should be
514                     replaced.
515     :type replace: bool
516
517     :return: The callback uid.
518     :rtype: str
519
520     :raises CallbackAlreadyRegisteredError: when there's already a callback
521             identified by the given uid and replace is False.
522     """
523     return EventsClientThread.instance().register(
524         event, callback, uid=uid, replace=replace)
525
526
527 def unregister(event, uid=None):
528     """
529     Unregister callbacks for an event.
530
531     If uid is not None, then only the callback identified by the given uid is
532     removed. Otherwise, all callbacks for the event are removed.
533
534     :param event: The event that triggers the callback.
535     :type event: str
536     :param uid: The callback uid.
537     :type uid: str
538     """
539     return EventsClientThread.instance().unregister(event, uid=uid)
540
541
542 def emit(event, *content):
543     """
544     Send an event.
545
546     :param event: The event to be sent.
547     :type event: str
548     :param content: The content of the event.
549     :type content: list
550     """
551     return EventsClientThread.instance().emit(event, *content)
552
553
554 def instance():
555     """
556     Return an instance of the events client.
557
558     :return: An instance of the events client.
559     :rtype: EventsClientThread
560     """
561     return EventsClientThread.instance()