From 027ad7eed50947608738ce0009fccf776936e55c Mon Sep 17 00:00:00 2001 From: Kali Kaneko Date: Mon, 29 Feb 2016 19:33:28 -0400 Subject: [tests] adapt events tests to recent changes --- src/leap/common/events/zmq_components.py | 58 ++++++++++++++++++-------------- 1 file changed, 33 insertions(+), 25 deletions(-) (limited to 'src/leap/common/events/zmq_components.py') diff --git a/src/leap/common/events/zmq_components.py b/src/leap/common/events/zmq_components.py index 74abb76..8919cd9 100644 --- a/src/leap/common/events/zmq_components.py +++ b/src/leap/common/events/zmq_components.py @@ -57,12 +57,14 @@ class TxZmqComponent(object): _component_type = None - def __init__(self, path_prefix=None, enable_curve=True): + def __init__(self, path_prefix=None, enable_curve=True, factory=None): """ Initialize the txzmq component. """ if path_prefix is None: path_prefix = get_path_prefix(flags.STANDALONE) + if factory is not None: + self._factory = factory self._config_prefix = os.path.join(path_prefix, "leap", "events") self._connections = [] if enable_curve: @@ -78,64 +80,69 @@ class TxZmqComponent(object): "define a self._component_type!") return self._component_type - def _zmq_connect(self, connClass, address): + def _zmq_bind(self, connClass, address): """ - Connect to an address. + Bind to an address. :param connClass: The connection class to be used. :type connClass: txzmq.ZmqConnection - :param address: The address to connect to. + :param address: The address to bind to. :type address: str - :return: The binded connection. - :rtype: txzmq.ZmqConnection + :return: The binded connection and port. + :rtype: (txzmq.ZmqConnection, int) """ - endpoint = ZmqEndpoint(ZmqEndpointType.connect, address) + proto, addr, port = ADDRESS_RE.search(address).groups() + + endpoint = ZmqEndpoint(ZmqEndpointType.bind, address) connection = connClass(self._factory) if self.use_curve: socket = connection.socket + public, secret = maybe_create_and_get_certificates( self._config_prefix, self.component_type) - server_public_file = os.path.join( - self._config_prefix, PUBLIC_KEYS_PREFIX, "server.key") - - server_public, _ = zmq.auth.load_certificate(server_public_file) socket.curve_publickey = public socket.curve_secretkey = secret - socket.curve_serverkey = server_public + self._start_authentication(connection.socket) - connection.addEndpoints([endpoint]) - return connection + if proto == 'tcp' and int(port) == 0: + connection.endpoints.extend([endpoint]) + port = connection.socket.bind_to_random_port('tcp://%s' % addr) + else: + connection.addEndpoints([endpoint]) - def _zmq_bind(self, connClass, address): + return connection, int(port) + + def _zmq_connect(self, connClass, address): """ - Bind to an address. + Connect to an address. :param connClass: The connection class to be used. :type connClass: txzmq.ZmqConnection - :param address: The address to bind to. + :param address: The address to connect to. :type address: str - :return: The binded connection and port. - :rtype: (txzmq.ZmqConnection, int) + :return: The binded connection. + :rtype: txzmq.ZmqConnection """ - proto, addr, port = ADDRESS_RE.search(address).groups() - - endpoint = ZmqEndpoint(ZmqEndpointType.bind, address) + endpoint = ZmqEndpoint(ZmqEndpointType.connect, address) connection = connClass(self._factory) if self.use_curve: socket = connection.socket - public, secret = maybe_create_and_get_certificates( self._config_prefix, self.component_type) + server_public_file = os.path.join( + self._config_prefix, PUBLIC_KEYS_PREFIX, "server.key") + + server_public, _ = zmq.auth.load_certificate(server_public_file) socket.curve_publickey = public socket.curve_secretkey = secret - self._start_authentication(connection.socket) + socket.curve_serverkey = server_public connection.addEndpoints([endpoint]) - return connection, port + return connection def _start_authentication(self, socket): @@ -150,6 +157,7 @@ class TxZmqComponent(object): # tell authenticator to use the certificate in a directory public_keys_dir = os.path.join(self._config_prefix, PUBLIC_KEYS_PREFIX) auth_req.configure_curve(domain="*", location=public_keys_dir) + auth_req.shutdown() # This has to be set before binding the socket, that's why this method # has to be called before addEndpoints() -- cgit v1.2.3