[tests] adapt events tests to recent changes
[leap_pycommon.git] / src / leap / common / events / client.py
index 60d24bc..78617de 100644 (file)
@@ -63,14 +63,18 @@ logger = logging.getLogger(__name__)
 
 _emit_addr = EMIT_ADDR
 _reg_addr = REG_ADDR
+_factory = None
+_enable_curve = True
 
 
-def configure_client(emit_addr, reg_addr):
-    global _emit_addr, _reg_addr
+def configure_client(emit_addr, reg_addr, factory=None, enable_curve=True):
+    global _emit_addr, _reg_addr, _factory, _enable_curve
     logger.debug("Configuring client with addresses: (%s, %s)" %
                  (emit_addr, reg_addr))
     _emit_addr = emit_addr
     _reg_addr = reg_addr
+    _factory = factory
+    _enable_curve = enable_curve
 
 
 class EventsClient(object):
@@ -103,7 +107,9 @@ class EventsClient(object):
         """
         with cls._instance_lock:
             if cls._instance is None:
-                cls._instance = cls(_emit_addr, _reg_addr)
+                cls._instance = cls(
+                    _emit_addr, _reg_addr, factory=_factory,
+                    enable_curve=_enable_curve)
         return cls._instance
 
     def register(self, event, callback, uid=None, replace=False):
@@ -270,7 +276,7 @@ class EventsClientThread(threading.Thread, EventsClient):
     A threaded version of the events client.
     """
 
-    def __init__(self, emit_addr, reg_addr):
+    def __init__(self, emit_addr, reg_addr, factory=None, enable_curve=True):
         """
         Initialize the events client.
         """
@@ -281,15 +287,22 @@ class EventsClientThread(threading.Thread, EventsClient):
         self._config_prefix = os.path.join(
             get_path_prefix(flags.STANDALONE), "leap", "events")
         self._loop = None
+        self._factory = factory
         self._context = None
         self._push = None
         self._sub = None
 
+        if enable_curve:
+            self.use_curve = zmq_has_curve()
+        else:
+            self.use_curve = False
+
     def _init_zmq(self):
         """
         Initialize ZMQ connections.
         """
         self._loop = EventsIOLoop()
+        # we need a new context for each thread
         self._context = zmq.Context()
         # connect SUB first, otherwise we might miss some event sent from this
         # same client
@@ -311,7 +324,7 @@ class EventsClientThread(threading.Thread, EventsClient):
         logger.debug("Connecting %s to %s." % (socktype, address))
         socket = self._context.socket(socktype)
         # configure curve authentication
-        if zmq_has_curve():
+        if self.use_curve:
             public, private = maybe_create_and_get_certificates(
                 self._config_prefix, "client")
             server_public_file = os.path.join(