diff options
Diffstat (limited to 'src/leap')
-rw-r--r-- | src/leap/common/events/auth.py | 4 | ||||
-rw-r--r-- | src/leap/common/events/client.py | 23 | ||||
-rw-r--r-- | src/leap/common/events/server.py | 13 | ||||
-rw-r--r-- | src/leap/common/events/tests/test_auth.py | 64 | ||||
-rw-r--r-- | src/leap/common/events/tests/test_events.py (renamed from src/leap/common/tests/test_events.py) | 24 | ||||
-rw-r--r-- | src/leap/common/events/txclient.py | 8 | ||||
-rw-r--r-- | src/leap/common/events/zmq_components.py | 58 | ||||
-rw-r--r-- | src/leap/common/testing/basetest.py | 9 |
8 files changed, 151 insertions, 52 deletions
diff --git a/src/leap/common/events/auth.py b/src/leap/common/events/auth.py index 1a1bcab..5b71f2d 100644 --- a/src/leap/common/events/auth.py +++ b/src/leap/common/events/auth.py @@ -38,8 +38,8 @@ class TxAuthenticator(ZmqConnection): address = 'inproc://zeromq.zap.01' encoding = 'utf-8' - def __init__(self, factory): - super(TxAuthenticator, self).__init__(factory) + def __init__(self, factory, *args, **kw): + super(TxAuthenticator, self).__init__(factory, *args, **kw) self.authenticator = Authenticator(factory.context) self.authenticator._send_zap_reply = self._send_zap_reply diff --git a/src/leap/common/events/client.py b/src/leap/common/events/client.py index 60d24bc..78617de 100644 --- a/src/leap/common/events/client.py +++ b/src/leap/common/events/client.py @@ -63,14 +63,18 @@ logger = logging.getLogger(__name__) _emit_addr = EMIT_ADDR _reg_addr = REG_ADDR +_factory = None +_enable_curve = True -def configure_client(emit_addr, reg_addr): - global _emit_addr, _reg_addr +def configure_client(emit_addr, reg_addr, factory=None, enable_curve=True): + global _emit_addr, _reg_addr, _factory, _enable_curve logger.debug("Configuring client with addresses: (%s, %s)" % (emit_addr, reg_addr)) _emit_addr = emit_addr _reg_addr = reg_addr + _factory = factory + _enable_curve = enable_curve class EventsClient(object): @@ -103,7 +107,9 @@ class EventsClient(object): """ with cls._instance_lock: if cls._instance is None: - cls._instance = cls(_emit_addr, _reg_addr) + cls._instance = cls( + _emit_addr, _reg_addr, factory=_factory, + enable_curve=_enable_curve) return cls._instance def register(self, event, callback, uid=None, replace=False): @@ -270,7 +276,7 @@ class EventsClientThread(threading.Thread, EventsClient): A threaded version of the events client. """ - def __init__(self, emit_addr, reg_addr): + def __init__(self, emit_addr, reg_addr, factory=None, enable_curve=True): """ Initialize the events client. """ @@ -281,15 +287,22 @@ class EventsClientThread(threading.Thread, EventsClient): self._config_prefix = os.path.join( get_path_prefix(flags.STANDALONE), "leap", "events") self._loop = None + self._factory = factory self._context = None self._push = None self._sub = None + if enable_curve: + self.use_curve = zmq_has_curve() + else: + self.use_curve = False + def _init_zmq(self): """ Initialize ZMQ connections. """ self._loop = EventsIOLoop() + # we need a new context for each thread self._context = zmq.Context() # connect SUB first, otherwise we might miss some event sent from this # same client @@ -311,7 +324,7 @@ class EventsClientThread(threading.Thread, EventsClient): logger.debug("Connecting %s to %s." % (socktype, address)) socket = self._context.socket(socktype) # configure curve authentication - if zmq_has_curve(): + if self.use_curve: public, private = maybe_create_and_get_certificates( self._config_prefix, "client") server_public_file = os.path.join( diff --git a/src/leap/common/events/server.py b/src/leap/common/events/server.py index 6252853..ad79abe 100644 --- a/src/leap/common/events/server.py +++ b/src/leap/common/events/server.py @@ -37,7 +37,8 @@ else: logger = logging.getLogger(__name__) -def ensure_server(emit_addr=EMIT_ADDR, reg_addr=REG_ADDR): +def ensure_server(emit_addr=EMIT_ADDR, reg_addr=REG_ADDR, path_prefix=None, + factory=None, enable_curve=True): """ Make sure the server is running in the given addresses. @@ -49,7 +50,8 @@ def ensure_server(emit_addr=EMIT_ADDR, reg_addr=REG_ADDR): :return: an events server instance :rtype: EventsServer """ - _server = EventsServer(emit_addr, reg_addr) + _server = EventsServer(emit_addr, reg_addr, path_prefix, factory=factory, + enable_curve=enable_curve) return _server @@ -59,7 +61,8 @@ class EventsServer(TxZmqServerComponent): events in another address. """ - def __init__(self, emit_addr, reg_addr): + def __init__(self, emit_addr, reg_addr, path_prefix=None, factory=None, + enable_curve=True): """ Initialize the events server. @@ -68,7 +71,9 @@ class EventsServer(TxZmqServerComponent): :param reg_addr: The address to which publish events to clients. :type reg_addr: str """ - TxZmqServerComponent.__init__(self) + TxZmqServerComponent.__init__(self, path_prefix=path_prefix, + factory=factory, + enable_curve=enable_curve) # bind PULL and PUB sockets self._pull, self.pull_port = self._zmq_bind( txzmq.ZmqPullConnection, emit_addr) diff --git a/src/leap/common/events/tests/test_auth.py b/src/leap/common/events/tests/test_auth.py new file mode 100644 index 0000000..78ffd9f --- /dev/null +++ b/src/leap/common/events/tests/test_auth.py @@ -0,0 +1,64 @@ +# -*- coding: utf-8 -*- +# test_zmq_components.py +# Copyright (C) 2014 LEAP +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. +""" +Tests for the auth module. +""" +import os + +from twisted.trial import unittest +from txzmq import ZmqFactory + +from leap.common.events import auth +from leap.common.testing.basetest import BaseLeapTest +from leap.common.zmq_utils import PUBLIC_KEYS_PREFIX +from leap.common.zmq_utils import maybe_create_and_get_certificates + +from txzmq.test import _wait + + +class ZmqAuthTestCase(unittest.TestCase, BaseLeapTest): + + def setUp(self): + self.setUpEnv(launch_events_server=False) + + self.factory = ZmqFactory() + self._config_prefix = os.path.join(self.tempdir, "leap", "events") + + self.public, self.secret = maybe_create_and_get_certificates( + self._config_prefix, 'server') + + self.authenticator = auth.TxAuthenticator(self.factory) + self.authenticator.start() + self.auth_req = auth.TxAuthenticationRequest(self.factory) + + def tearDown(self): + self.factory.shutdown() + self.tearDownEnv() + + def test_curve_auth(self): + self.auth_req.start() + self.auth_req.allow('127.0.0.1') + public_keys_dir = os.path.join(self._config_prefix, PUBLIC_KEYS_PREFIX) + self.auth_req.configure_curve(domain="*", location=public_keys_dir) + + def check(ignored): + authenticator = self.authenticator.authenticator + certs = authenticator.certs['*'] + self.failUnlessEqual(authenticator.whitelist, set([u'127.0.0.1'])) + self.failUnlessEqual(certs[certs.keys()[0]], True) + + return _wait(0.1).addCallback(check) diff --git a/src/leap/common/tests/test_events.py b/src/leap/common/events/tests/test_events.py index 2ad097e..c45601b 100644 --- a/src/leap/common/tests/test_events.py +++ b/src/leap/common/events/tests/test_events.py @@ -14,16 +14,18 @@ # # You should have received a copy of the GNU General Public License # along with this program. If not, see <http://www.gnu.org/licenses/>. - - +""" +Tests for the events framework +""" import os import logging -import time from twisted.internet.reactor import callFromThread from twisted.trial import unittest from twisted.internet import defer +from txzmq import ZmqFactory + from leap.common.events import server from leap.common.events import client from leap.common.events import flags @@ -40,19 +42,22 @@ class EventsGenericClientTestCase(object): def setUp(self): flags.set_events_enabled(True) + self.factory = ZmqFactory() self._server = server.ensure_server( emit_addr="tcp://127.0.0.1:0", - reg_addr="tcp://127.0.0.1:0") + reg_addr="tcp://127.0.0.1:0", + factory=self.factory, + enable_curve=False) + self._client.configure_client( emit_addr="tcp://127.0.0.1:%d" % self._server.pull_port, - reg_addr="tcp://127.0.0.1:%d" % self._server.pub_port) + reg_addr="tcp://127.0.0.1:%d" % self._server.pub_port, + factory=self.factory, enable_curve=False) def tearDown(self): - self._client.shutdown() - self._server.shutdown() flags.set_events_enabled(False) - # wait a bit for sockets to close properly - time.sleep(0.1) + self.factory.shutdown() + self._client.instance().reset() def test_client_register(self): """ @@ -84,6 +89,7 @@ class EventsGenericClientTestCase(object): self.assertTrue(callbacks[event2][uid2] == cbk2, 'Could not register event in local client.') + def test_register_signal_replace(self): """ Make sure clients can replace already registered callbacks. diff --git a/src/leap/common/events/txclient.py b/src/leap/common/events/txclient.py index a2b704d..63f12d7 100644 --- a/src/leap/common/events/txclient.py +++ b/src/leap/common/events/txclient.py @@ -58,11 +58,13 @@ class EventsTxClient(TxZmqClientComponent, EventsClient): """ def __init__(self, emit_addr=EMIT_ADDR, reg_addr=REG_ADDR, - path_prefix=None): + path_prefix=None, factory=None, enable_curve=True): """ - Initialize the events server. + Initialize the events client. """ - TxZmqClientComponent.__init__(self, path_prefix=path_prefix) + TxZmqClientComponent.__init__( + self, path_prefix=path_prefix, factory=factory, + enable_curve=enable_curve) EventsClient.__init__(self, emit_addr, reg_addr) # connect SUB first, otherwise we might miss some event sent from this # same client diff --git a/src/leap/common/events/zmq_components.py b/src/leap/common/events/zmq_components.py index 74abb76..8919cd9 100644 --- a/src/leap/common/events/zmq_components.py +++ b/src/leap/common/events/zmq_components.py @@ -57,12 +57,14 @@ class TxZmqComponent(object): _component_type = None - def __init__(self, path_prefix=None, enable_curve=True): + def __init__(self, path_prefix=None, enable_curve=True, factory=None): """ Initialize the txzmq component. """ if path_prefix is None: path_prefix = get_path_prefix(flags.STANDALONE) + if factory is not None: + self._factory = factory self._config_prefix = os.path.join(path_prefix, "leap", "events") self._connections = [] if enable_curve: @@ -78,64 +80,69 @@ class TxZmqComponent(object): "define a self._component_type!") return self._component_type - def _zmq_connect(self, connClass, address): + def _zmq_bind(self, connClass, address): """ - Connect to an address. + Bind to an address. :param connClass: The connection class to be used. :type connClass: txzmq.ZmqConnection - :param address: The address to connect to. + :param address: The address to bind to. :type address: str - :return: The binded connection. - :rtype: txzmq.ZmqConnection + :return: The binded connection and port. + :rtype: (txzmq.ZmqConnection, int) """ - endpoint = ZmqEndpoint(ZmqEndpointType.connect, address) + proto, addr, port = ADDRESS_RE.search(address).groups() + + endpoint = ZmqEndpoint(ZmqEndpointType.bind, address) connection = connClass(self._factory) if self.use_curve: socket = connection.socket + public, secret = maybe_create_and_get_certificates( self._config_prefix, self.component_type) - server_public_file = os.path.join( - self._config_prefix, PUBLIC_KEYS_PREFIX, "server.key") - - server_public, _ = zmq.auth.load_certificate(server_public_file) socket.curve_publickey = public socket.curve_secretkey = secret - socket.curve_serverkey = server_public + self._start_authentication(connection.socket) - connection.addEndpoints([endpoint]) - return connection + if proto == 'tcp' and int(port) == 0: + connection.endpoints.extend([endpoint]) + port = connection.socket.bind_to_random_port('tcp://%s' % addr) + else: + connection.addEndpoints([endpoint]) - def _zmq_bind(self, connClass, address): + return connection, int(port) + + def _zmq_connect(self, connClass, address): """ - Bind to an address. + Connect to an address. :param connClass: The connection class to be used. :type connClass: txzmq.ZmqConnection - :param address: The address to bind to. + :param address: The address to connect to. :type address: str - :return: The binded connection and port. - :rtype: (txzmq.ZmqConnection, int) + :return: The binded connection. + :rtype: txzmq.ZmqConnection """ - proto, addr, port = ADDRESS_RE.search(address).groups() - - endpoint = ZmqEndpoint(ZmqEndpointType.bind, address) + endpoint = ZmqEndpoint(ZmqEndpointType.connect, address) connection = connClass(self._factory) if self.use_curve: socket = connection.socket - public, secret = maybe_create_and_get_certificates( self._config_prefix, self.component_type) + server_public_file = os.path.join( + self._config_prefix, PUBLIC_KEYS_PREFIX, "server.key") + + server_public, _ = zmq.auth.load_certificate(server_public_file) socket.curve_publickey = public socket.curve_secretkey = secret - self._start_authentication(connection.socket) + socket.curve_serverkey = server_public connection.addEndpoints([endpoint]) - return connection, port + return connection def _start_authentication(self, socket): @@ -150,6 +157,7 @@ class TxZmqComponent(object): # tell authenticator to use the certificate in a directory public_keys_dir = os.path.join(self._config_prefix, PUBLIC_KEYS_PREFIX) auth_req.configure_curve(domain="*", location=public_keys_dir) + auth_req.shutdown() # This has to be set before binding the socket, that's why this method # has to be called before addEndpoints() diff --git a/src/leap/common/testing/basetest.py b/src/leap/common/testing/basetest.py index 3d3cee0..2e84a25 100644 --- a/src/leap/common/testing/basetest.py +++ b/src/leap/common/testing/basetest.py @@ -52,7 +52,7 @@ class BaseLeapTest(unittest.TestCase): cls.tearDownEnv() @classmethod - def setUpEnv(cls): + def setUpEnv(cls, launch_events_server=True): """ Sets up common facilities for testing this TestCase: - custom PATH and HOME environmental variables @@ -72,14 +72,15 @@ class BaseLeapTest(unittest.TestCase): os.environ["PATH"] = bin_tdir os.environ["HOME"] = cls.tempdir os.environ["XDG_CONFIG_HOME"] = os.path.join(cls.tempdir, ".config") - cls._init_events() + if launch_events_server: + cls._init_events() @classmethod def _init_events(cls): if flags.EVENTS_ENABLED: cls._server = events_server.ensure_server( - emit_addr="tcp://127.0.0.1:0", - reg_addr="tcp://127.0.0.1:0") + emit_addr="tcp://127.0.0.1", + reg_addr="tcp://127.0.0.1") events_client.configure_client( emit_addr="tcp://127.0.0.1:%d" % cls._server.pull_port, reg_addr="tcp://127.0.0.1:%d" % cls._server.pub_port) |