summaryrefslogtreecommitdiff
path: root/src/leap/common/events/zmq_components.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/leap/common/events/zmq_components.py')
-rw-r--r--src/leap/common/events/zmq_components.py147
1 files changed, 67 insertions, 80 deletions
diff --git a/src/leap/common/events/zmq_components.py b/src/leap/common/events/zmq_components.py
index 51de02c..c533a74 100644
--- a/src/leap/common/events/zmq_components.py
+++ b/src/leap/common/events/zmq_components.py
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
# zmq.py
-# Copyright (C) 2015 LEAP
+# Copyright (C) 2015, 2016 LEAP
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
@@ -14,60 +14,63 @@
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
-
-
"""
The server for the events mechanism.
"""
-
-
import os
import logging
import txzmq
import re
-import time
from abc import ABCMeta
-# XXX some distros don't package libsodium, so we have to be prepared for
-# absence of zmq.auth
try:
import zmq.auth
- from zmq.auth.thread import ThreadAuthenticator
+ from leap.common.events.auth import TxAuthenticator
+ from leap.common.events.auth import TxAuthenticationRequest
except ImportError:
pass
+from txzmq.connection import ZmqEndpoint, ZmqEndpointType
+
from leap.common.config import flags, get_path_prefix
from leap.common.zmq_utils import zmq_has_curve
from leap.common.zmq_utils import maybe_create_and_get_certificates
from leap.common.zmq_utils import PUBLIC_KEYS_PREFIX
-
logger = logging.getLogger(__name__)
-
ADDRESS_RE = re.compile("^([a-z]+)://([^:]+):?(\d+)?$")
+LOCALHOST_ALLOWED = '127.0.0.1'
+
class TxZmqComponent(object):
"""
A twisted-powered zmq events component.
"""
+ _factory = txzmq.ZmqFactory()
+ _factory.registerForShutdown()
+ _auth = None
__metaclass__ = ABCMeta
_component_type = None
- def __init__(self, path_prefix=None):
+ def __init__(self, path_prefix=None, enable_curve=True, factory=None):
"""
Initialize the txzmq component.
"""
- self._factory = txzmq.ZmqFactory()
- self._factory.registerForShutdown()
if path_prefix is None:
path_prefix = get_path_prefix(flags.STANDALONE)
+ if factory is not None:
+ self._factory = factory
self._config_prefix = os.path.join(path_prefix, "leap", "events")
self._connections = []
+ if enable_curve:
+ self.use_curve = zmq_has_curve()
+ else:
+ self.use_curve = False
@property
def component_type(self):
@@ -77,105 +80,89 @@ class TxZmqComponent(object):
"define a self._component_type!")
return self._component_type
- def _zmq_connect(self, connClass, address):
+ def _zmq_bind(self, connClass, address):
"""
- Connect to an address.
+ Bind to an address.
:param connClass: The connection class to be used.
:type connClass: txzmq.ZmqConnection
- :param address: The address to connect to.
+ :param address: The address to bind to.
:type address: str
- :return: The binded connection.
- :rtype: txzmq.ZmqConnection
+ :return: The binded connection and port.
+ :rtype: (txzmq.ZmqConnection, int)
"""
+ proto, addr, port = ADDRESS_RE.search(address).groups()
+
+ endpoint = ZmqEndpoint(ZmqEndpointType.bind, address)
connection = connClass(self._factory)
- # create and configure socket
- socket = connection.socket
- if zmq_has_curve():
+
+ if self.use_curve:
+ socket = connection.socket
+
public, secret = maybe_create_and_get_certificates(
self._config_prefix, self.component_type)
- server_public_file = os.path.join(
- self._config_prefix, PUBLIC_KEYS_PREFIX, "server.key")
- server_public, _ = zmq.auth.load_certificate(server_public_file)
socket.curve_publickey = public
socket.curve_secretkey = secret
- socket.curve_serverkey = server_public
- socket.connect(address)
- logger.debug("Connected %s to %s." % (connClass, address))
- self._connections.append(connection)
- return connection
+ self._start_authentication(connection.socket)
- def _zmq_bind(self, connClass, address):
+ if proto == 'tcp' and int(port) == 0:
+ connection.endpoints.extend([endpoint])
+ port = connection.socket.bind_to_random_port('tcp://%s' % addr)
+ else:
+ connection.addEndpoints([endpoint])
+
+ return connection, int(port)
+
+ def _zmq_connect(self, connClass, address):
"""
- Bind to an address.
+ Connect to an address.
:param connClass: The connection class to be used.
:type connClass: txzmq.ZmqConnection
- :param address: The address to bind to.
+ :param address: The address to connect to.
:type address: str
- :return: The binded connection and port.
- :rtype: (txzmq.ZmqConnection, int)
+ :return: The binded connection.
+ :rtype: txzmq.ZmqConnection
"""
+ endpoint = ZmqEndpoint(ZmqEndpointType.connect, address)
connection = connClass(self._factory)
- socket = connection.socket
- if zmq_has_curve():
+
+ if self.use_curve:
+ socket = connection.socket
public, secret = maybe_create_and_get_certificates(
self._config_prefix, self.component_type)
+ server_public_file = os.path.join(
+ self._config_prefix, PUBLIC_KEYS_PREFIX, "server.key")
+
+ server_public, _ = zmq.auth.load_certificate(server_public_file)
socket.curve_publickey = public
socket.curve_secretkey = secret
- self._start_thread_auth(connection.socket)
+ socket.curve_serverkey = server_public
- proto, addr, port = ADDRESS_RE.search(address).groups()
+ connection.addEndpoints([endpoint])
+ return connection
- if proto == "tcp":
- if port is None or port is '0':
- params = proto, addr
- port = socket.bind_to_random_port("%s://%s" % params)
- logger.debug("Binded %s to %s://%s." % ((connClass,) + params))
- else:
- params = proto, addr, int(port)
- socket.bind("%s://%s:%d" % params)
- logger.debug(
- "Binded %s to %s://%s:%d." % ((connClass,) + params))
- else:
- params = proto, addr
- socket.bind("%s://%s" % params)
- logger.debug(
- "Binded %s to %s://%s" % ((connClass,) + params))
- self._connections.append(connection)
- return connection, port
-
- def _start_thread_auth(self, socket):
- """
- Start the zmq curve thread authenticator.
+ def _start_authentication(self, socket):
- :param socket: The socket in which to configure the authenticator.
- :type socket: zmq.Socket
- """
- authenticator = ThreadAuthenticator(self._factory.context)
+ if not TxZmqComponent._auth:
+ TxZmqComponent._auth = TxAuthenticator(self._factory)
+ TxZmqComponent._auth.start()
- # Temporary fix until we understand what the problem is
- # See https://leap.se/code/issues/7536
- time.sleep(0.5)
+ auth_req = TxAuthenticationRequest(self._factory)
+ auth_req.start()
+ auth_req.allow(LOCALHOST_ALLOWED)
- authenticator.start()
- # XXX do not hardcode this here.
- authenticator.allow('127.0.0.1')
# tell authenticator to use the certificate in a directory
public_keys_dir = os.path.join(self._config_prefix, PUBLIC_KEYS_PREFIX)
- authenticator.configure_curve(domain="*", location=public_keys_dir)
- socket.curve_server = True # must come before bind
+ auth_req.configure_curve(domain="*", location=public_keys_dir)
+ auth_req.shutdown()
+ TxZmqComponent._auth.shutdown()
- def shutdown(self):
- """
- Shutdown the component.
- """
- logger.debug("Shutting down component %s." % str(self))
- for conn in self._connections:
- conn.shutdown()
- self._factory.shutdown()
+ # This has to be set before binding the socket, that's why this method
+ # has to be called before addEndpoints()
+ socket.curve_server = True
class TxZmqServerComponent(TxZmqComponent):