[tests] adapt events tests to recent changes
[leap_pycommon.git] / src / leap / common / events / zmq_components.py
index 74abb76..8919cd9 100644 (file)
@@ -57,12 +57,14 @@ class TxZmqComponent(object):
 
     _component_type = None
 
-    def __init__(self, path_prefix=None, enable_curve=True):
+    def __init__(self, path_prefix=None, enable_curve=True, factory=None):
         """
         Initialize the txzmq component.
         """
         if path_prefix is None:
             path_prefix = get_path_prefix(flags.STANDALONE)
+        if factory is not None:
+            self._factory = factory
         self._config_prefix = os.path.join(path_prefix, "leap", "events")
         self._connections = []
         if enable_curve:
@@ -78,64 +80,69 @@ class TxZmqComponent(object):
                 "define a self._component_type!")
         return self._component_type
 
-    def _zmq_connect(self, connClass, address):
+    def _zmq_bind(self, connClass, address):
         """
-        Connect to an address.
+        Bind to an address.
 
         :param connClass: The connection class to be used.
         :type connClass: txzmq.ZmqConnection
-        :param address: The address to connect to.
+        :param address: The address to bind to.
         :type address: str
 
-        :return: The binded connection.
-        :rtype: txzmq.ZmqConnection
+        :return: The binded connection and port.
+        :rtype: (txzmq.ZmqConnection, int)
         """
-        endpoint = ZmqEndpoint(ZmqEndpointType.connect, address)
+        proto, addr, port = ADDRESS_RE.search(address).groups()
+
+        endpoint = ZmqEndpoint(ZmqEndpointType.bind, address)
         connection = connClass(self._factory)
 
         if self.use_curve:
             socket = connection.socket
+
             public, secret = maybe_create_and_get_certificates(
                 self._config_prefix, self.component_type)
-            server_public_file = os.path.join(
-                self._config_prefix, PUBLIC_KEYS_PREFIX, "server.key")
-
-            server_public, _ = zmq.auth.load_certificate(server_public_file)
             socket.curve_publickey = public
             socket.curve_secretkey = secret
-            socket.curve_serverkey = server_public
+            self._start_authentication(connection.socket)
 
-        connection.addEndpoints([endpoint])
-        return connection
+        if proto == 'tcp' and int(port) == 0:
+            connection.endpoints.extend([endpoint])
+            port = connection.socket.bind_to_random_port('tcp://%s' % addr)
+        else:
+            connection.addEndpoints([endpoint])
 
-    def _zmq_bind(self, connClass, address):
+        return connection, int(port)
+
+    def _zmq_connect(self, connClass, address):
         """
-        Bind to an address.
+        Connect to an address.
 
         :param connClass: The connection class to be used.
         :type connClass: txzmq.ZmqConnection
-        :param address: The address to bind to.
+        :param address: The address to connect to.
         :type address: str
 
-        :return: The binded connection and port.
-        :rtype: (txzmq.ZmqConnection, int)
+        :return: The binded connection.
+        :rtype: txzmq.ZmqConnection
         """
-        proto, addr, port = ADDRESS_RE.search(address).groups()
-
-        endpoint = ZmqEndpoint(ZmqEndpointType.bind, address)
+        endpoint = ZmqEndpoint(ZmqEndpointType.connect, address)
         connection = connClass(self._factory)
 
         if self.use_curve:
             socket = connection.socket
-
             public, secret = maybe_create_and_get_certificates(
                 self._config_prefix, self.component_type)
+            server_public_file = os.path.join(
+                self._config_prefix, PUBLIC_KEYS_PREFIX, "server.key")
+
+            server_public, _ = zmq.auth.load_certificate(server_public_file)
             socket.curve_publickey = public
             socket.curve_secretkey = secret
-            self._start_authentication(connection.socket)
+            socket.curve_serverkey = server_public
 
         connection.addEndpoints([endpoint])
-        return connection, port
+        return connection
 
     def _start_authentication(self, socket):
 
@@ -150,6 +157,7 @@ class TxZmqComponent(object):
         # tell authenticator to use the certificate in a directory
         public_keys_dir = os.path.join(self._config_prefix, PUBLIC_KEYS_PREFIX)
         auth_req.configure_curve(domain="*", location=public_keys_dir)
+        auth_req.shutdown()
 
         # This has to be set before binding the socket, that's why this method
         # has to be called before addEndpoints()