summaryrefslogtreecommitdiff
path: root/src/leap/common/events/zmq_components.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/leap/common/events/zmq_components.py')
-rw-r--r--src/leap/common/events/zmq_components.py58
1 files changed, 33 insertions, 25 deletions
diff --git a/src/leap/common/events/zmq_components.py b/src/leap/common/events/zmq_components.py
index 74abb76..8919cd9 100644
--- a/src/leap/common/events/zmq_components.py
+++ b/src/leap/common/events/zmq_components.py
@@ -57,12 +57,14 @@ class TxZmqComponent(object):
_component_type = None
- def __init__(self, path_prefix=None, enable_curve=True):
+ def __init__(self, path_prefix=None, enable_curve=True, factory=None):
"""
Initialize the txzmq component.
"""
if path_prefix is None:
path_prefix = get_path_prefix(flags.STANDALONE)
+ if factory is not None:
+ self._factory = factory
self._config_prefix = os.path.join(path_prefix, "leap", "events")
self._connections = []
if enable_curve:
@@ -78,64 +80,69 @@ class TxZmqComponent(object):
"define a self._component_type!")
return self._component_type
- def _zmq_connect(self, connClass, address):
+ def _zmq_bind(self, connClass, address):
"""
- Connect to an address.
+ Bind to an address.
:param connClass: The connection class to be used.
:type connClass: txzmq.ZmqConnection
- :param address: The address to connect to.
+ :param address: The address to bind to.
:type address: str
- :return: The binded connection.
- :rtype: txzmq.ZmqConnection
+ :return: The binded connection and port.
+ :rtype: (txzmq.ZmqConnection, int)
"""
- endpoint = ZmqEndpoint(ZmqEndpointType.connect, address)
+ proto, addr, port = ADDRESS_RE.search(address).groups()
+
+ endpoint = ZmqEndpoint(ZmqEndpointType.bind, address)
connection = connClass(self._factory)
if self.use_curve:
socket = connection.socket
+
public, secret = maybe_create_and_get_certificates(
self._config_prefix, self.component_type)
- server_public_file = os.path.join(
- self._config_prefix, PUBLIC_KEYS_PREFIX, "server.key")
-
- server_public, _ = zmq.auth.load_certificate(server_public_file)
socket.curve_publickey = public
socket.curve_secretkey = secret
- socket.curve_serverkey = server_public
+ self._start_authentication(connection.socket)
- connection.addEndpoints([endpoint])
- return connection
+ if proto == 'tcp' and int(port) == 0:
+ connection.endpoints.extend([endpoint])
+ port = connection.socket.bind_to_random_port('tcp://%s' % addr)
+ else:
+ connection.addEndpoints([endpoint])
- def _zmq_bind(self, connClass, address):
+ return connection, int(port)
+
+ def _zmq_connect(self, connClass, address):
"""
- Bind to an address.
+ Connect to an address.
:param connClass: The connection class to be used.
:type connClass: txzmq.ZmqConnection
- :param address: The address to bind to.
+ :param address: The address to connect to.
:type address: str
- :return: The binded connection and port.
- :rtype: (txzmq.ZmqConnection, int)
+ :return: The binded connection.
+ :rtype: txzmq.ZmqConnection
"""
- proto, addr, port = ADDRESS_RE.search(address).groups()
-
- endpoint = ZmqEndpoint(ZmqEndpointType.bind, address)
+ endpoint = ZmqEndpoint(ZmqEndpointType.connect, address)
connection = connClass(self._factory)
if self.use_curve:
socket = connection.socket
-
public, secret = maybe_create_and_get_certificates(
self._config_prefix, self.component_type)
+ server_public_file = os.path.join(
+ self._config_prefix, PUBLIC_KEYS_PREFIX, "server.key")
+
+ server_public, _ = zmq.auth.load_certificate(server_public_file)
socket.curve_publickey = public
socket.curve_secretkey = secret
- self._start_authentication(connection.socket)
+ socket.curve_serverkey = server_public
connection.addEndpoints([endpoint])
- return connection, port
+ return connection
def _start_authentication(self, socket):
@@ -150,6 +157,7 @@ class TxZmqComponent(object):
# tell authenticator to use the certificate in a directory
public_keys_dir = os.path.join(self._config_prefix, PUBLIC_KEYS_PREFIX)
auth_req.configure_curve(domain="*", location=public_keys_dir)
+ auth_req.shutdown()
# This has to be set before binding the socket, that's why this method
# has to be called before addEndpoints()