diff options
author | Kali Kaneko <kali@leap.se> | 2016-02-29 19:33:28 -0400 |
---|---|---|
committer | Kali Kaneko <kali@leap.se> | 2016-02-29 19:39:43 -0400 |
commit | 027ad7eed50947608738ce0009fccf776936e55c (patch) | |
tree | 6777d80f097a06bd9c560b24fc886504a693fae8 /src/leap/common/events | |
parent | 24977b744b42df912a23a2861453e7d4d5202310 (diff) |
[tests] adapt events tests to recent changes
Diffstat (limited to 'src/leap/common/events')
-rw-r--r-- | src/leap/common/events/auth.py | 4 | ||||
-rw-r--r-- | src/leap/common/events/client.py | 23 | ||||
-rw-r--r-- | src/leap/common/events/server.py | 13 | ||||
-rw-r--r-- | src/leap/common/events/tests/test_auth.py | 64 | ||||
-rw-r--r-- | src/leap/common/events/tests/test_events.py | 204 | ||||
-rw-r--r-- | src/leap/common/events/txclient.py | 8 | ||||
-rw-r--r-- | src/leap/common/events/zmq_components.py | 58 |
7 files changed, 335 insertions, 39 deletions
diff --git a/src/leap/common/events/auth.py b/src/leap/common/events/auth.py index 1a1bcab..5b71f2d 100644 --- a/src/leap/common/events/auth.py +++ b/src/leap/common/events/auth.py @@ -38,8 +38,8 @@ class TxAuthenticator(ZmqConnection): address = 'inproc://zeromq.zap.01' encoding = 'utf-8' - def __init__(self, factory): - super(TxAuthenticator, self).__init__(factory) + def __init__(self, factory, *args, **kw): + super(TxAuthenticator, self).__init__(factory, *args, **kw) self.authenticator = Authenticator(factory.context) self.authenticator._send_zap_reply = self._send_zap_reply diff --git a/src/leap/common/events/client.py b/src/leap/common/events/client.py index 60d24bc..78617de 100644 --- a/src/leap/common/events/client.py +++ b/src/leap/common/events/client.py @@ -63,14 +63,18 @@ logger = logging.getLogger(__name__) _emit_addr = EMIT_ADDR _reg_addr = REG_ADDR +_factory = None +_enable_curve = True -def configure_client(emit_addr, reg_addr): - global _emit_addr, _reg_addr +def configure_client(emit_addr, reg_addr, factory=None, enable_curve=True): + global _emit_addr, _reg_addr, _factory, _enable_curve logger.debug("Configuring client with addresses: (%s, %s)" % (emit_addr, reg_addr)) _emit_addr = emit_addr _reg_addr = reg_addr + _factory = factory + _enable_curve = enable_curve class EventsClient(object): @@ -103,7 +107,9 @@ class EventsClient(object): """ with cls._instance_lock: if cls._instance is None: - cls._instance = cls(_emit_addr, _reg_addr) + cls._instance = cls( + _emit_addr, _reg_addr, factory=_factory, + enable_curve=_enable_curve) return cls._instance def register(self, event, callback, uid=None, replace=False): @@ -270,7 +276,7 @@ class EventsClientThread(threading.Thread, EventsClient): A threaded version of the events client. """ - def __init__(self, emit_addr, reg_addr): + def __init__(self, emit_addr, reg_addr, factory=None, enable_curve=True): """ Initialize the events client. """ @@ -281,15 +287,22 @@ class EventsClientThread(threading.Thread, EventsClient): self._config_prefix = os.path.join( get_path_prefix(flags.STANDALONE), "leap", "events") self._loop = None + self._factory = factory self._context = None self._push = None self._sub = None + if enable_curve: + self.use_curve = zmq_has_curve() + else: + self.use_curve = False + def _init_zmq(self): """ Initialize ZMQ connections. """ self._loop = EventsIOLoop() + # we need a new context for each thread self._context = zmq.Context() # connect SUB first, otherwise we might miss some event sent from this # same client @@ -311,7 +324,7 @@ class EventsClientThread(threading.Thread, EventsClient): logger.debug("Connecting %s to %s." % (socktype, address)) socket = self._context.socket(socktype) # configure curve authentication - if zmq_has_curve(): + if self.use_curve: public, private = maybe_create_and_get_certificates( self._config_prefix, "client") server_public_file = os.path.join( diff --git a/src/leap/common/events/server.py b/src/leap/common/events/server.py index 6252853..ad79abe 100644 --- a/src/leap/common/events/server.py +++ b/src/leap/common/events/server.py @@ -37,7 +37,8 @@ else: logger = logging.getLogger(__name__) -def ensure_server(emit_addr=EMIT_ADDR, reg_addr=REG_ADDR): +def ensure_server(emit_addr=EMIT_ADDR, reg_addr=REG_ADDR, path_prefix=None, + factory=None, enable_curve=True): """ Make sure the server is running in the given addresses. @@ -49,7 +50,8 @@ def ensure_server(emit_addr=EMIT_ADDR, reg_addr=REG_ADDR): :return: an events server instance :rtype: EventsServer """ - _server = EventsServer(emit_addr, reg_addr) + _server = EventsServer(emit_addr, reg_addr, path_prefix, factory=factory, + enable_curve=enable_curve) return _server @@ -59,7 +61,8 @@ class EventsServer(TxZmqServerComponent): events in another address. """ - def __init__(self, emit_addr, reg_addr): + def __init__(self, emit_addr, reg_addr, path_prefix=None, factory=None, + enable_curve=True): """ Initialize the events server. @@ -68,7 +71,9 @@ class EventsServer(TxZmqServerComponent): :param reg_addr: The address to which publish events to clients. :type reg_addr: str """ - TxZmqServerComponent.__init__(self) + TxZmqServerComponent.__init__(self, path_prefix=path_prefix, + factory=factory, + enable_curve=enable_curve) # bind PULL and PUB sockets self._pull, self.pull_port = self._zmq_bind( txzmq.ZmqPullConnection, emit_addr) diff --git a/src/leap/common/events/tests/test_auth.py b/src/leap/common/events/tests/test_auth.py new file mode 100644 index 0000000..78ffd9f --- /dev/null +++ b/src/leap/common/events/tests/test_auth.py @@ -0,0 +1,64 @@ +# -*- coding: utf-8 -*- +# test_zmq_components.py +# Copyright (C) 2014 LEAP +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. +""" +Tests for the auth module. +""" +import os + +from twisted.trial import unittest +from txzmq import ZmqFactory + +from leap.common.events import auth +from leap.common.testing.basetest import BaseLeapTest +from leap.common.zmq_utils import PUBLIC_KEYS_PREFIX +from leap.common.zmq_utils import maybe_create_and_get_certificates + +from txzmq.test import _wait + + +class ZmqAuthTestCase(unittest.TestCase, BaseLeapTest): + + def setUp(self): + self.setUpEnv(launch_events_server=False) + + self.factory = ZmqFactory() + self._config_prefix = os.path.join(self.tempdir, "leap", "events") + + self.public, self.secret = maybe_create_and_get_certificates( + self._config_prefix, 'server') + + self.authenticator = auth.TxAuthenticator(self.factory) + self.authenticator.start() + self.auth_req = auth.TxAuthenticationRequest(self.factory) + + def tearDown(self): + self.factory.shutdown() + self.tearDownEnv() + + def test_curve_auth(self): + self.auth_req.start() + self.auth_req.allow('127.0.0.1') + public_keys_dir = os.path.join(self._config_prefix, PUBLIC_KEYS_PREFIX) + self.auth_req.configure_curve(domain="*", location=public_keys_dir) + + def check(ignored): + authenticator = self.authenticator.authenticator + certs = authenticator.certs['*'] + self.failUnlessEqual(authenticator.whitelist, set([u'127.0.0.1'])) + self.failUnlessEqual(certs[certs.keys()[0]], True) + + return _wait(0.1).addCallback(check) diff --git a/src/leap/common/events/tests/test_events.py b/src/leap/common/events/tests/test_events.py new file mode 100644 index 0000000..c45601b --- /dev/null +++ b/src/leap/common/events/tests/test_events.py @@ -0,0 +1,204 @@ +# -*- coding: utf-8 -*- +# test_events.py +# Copyright (C) 2013 LEAP +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. +""" +Tests for the events framework +""" +import os +import logging + +from twisted.internet.reactor import callFromThread +from twisted.trial import unittest +from twisted.internet import defer + +from txzmq import ZmqFactory + +from leap.common.events import server +from leap.common.events import client +from leap.common.events import flags +from leap.common.events import txclient +from leap.common.events import catalog +from leap.common.events.errors import CallbackAlreadyRegisteredError + + +if 'DEBUG' in os.environ: + logging.basicConfig(level=logging.DEBUG) + + +class EventsGenericClientTestCase(object): + + def setUp(self): + flags.set_events_enabled(True) + self.factory = ZmqFactory() + self._server = server.ensure_server( + emit_addr="tcp://127.0.0.1:0", + reg_addr="tcp://127.0.0.1:0", + factory=self.factory, + enable_curve=False) + + self._client.configure_client( + emit_addr="tcp://127.0.0.1:%d" % self._server.pull_port, + reg_addr="tcp://127.0.0.1:%d" % self._server.pub_port, + factory=self.factory, enable_curve=False) + + def tearDown(self): + flags.set_events_enabled(False) + self.factory.shutdown() + self._client.instance().reset() + + def test_client_register(self): + """ + Ensure clients can register callbacks. + """ + callbacks = self._client.instance().callbacks + self.assertTrue(len(callbacks) == 0, + 'There should be no callback for this event.') + # register one event + event1 = catalog.CLIENT_UID + + def cbk1(event, _): + return True + + uid1 = self._client.register(event1, cbk1) + # assert for correct registration + self.assertTrue(len(callbacks) == 1) + self.assertTrue(callbacks[event1][uid1] == cbk1, + 'Could not register event in local client.') + # register another event + event2 = catalog.CLIENT_SESSION_ID + + def cbk2(event, _): + return True + + uid2 = self._client.register(event2, cbk2) + # assert for correct registration + self.assertTrue(len(callbacks) == 2) + self.assertTrue(callbacks[event2][uid2] == cbk2, + 'Could not register event in local client.') + + + def test_register_signal_replace(self): + """ + Make sure clients can replace already registered callbacks. + """ + event = catalog.CLIENT_UID + d = defer.Deferred() + + def cbk_fail(event, _): + return callFromThread(d.errback, event) + + def cbk_succeed(event, _): + return callFromThread(d.callback, event) + + self._client.register(event, cbk_fail, uid=1) + self._client.register(event, cbk_succeed, uid=1, replace=True) + self._client.emit(event, None) + return d + + def test_register_signal_replace_fails_when_replace_is_false(self): + """ + Make sure clients trying to replace already registered callbacks fail + when replace=False + """ + event = catalog.CLIENT_UID + self._client.register(event, lambda event, _: None, uid=1) + self.assertRaises( + CallbackAlreadyRegisteredError, + self._client.register, + event, lambda event, _: None, uid=1, replace=False) + + def test_register_more_than_one_callback_works(self): + """ + Make sure clients can replace already registered callbacks. + """ + event = catalog.CLIENT_UID + d1 = defer.Deferred() + + def cbk1(event, _): + return callFromThread(d1.callback, event) + + d2 = defer.Deferred() + + def cbk2(event, _): + return d2.callback(event) + + self._client.register(event, cbk1) + self._client.register(event, cbk2) + self._client.emit(event, None) + d = defer.gatherResults([d1, d2]) + return d + + def test_client_receives_signal(self): + """ + Ensure clients can receive signals. + """ + event = catalog.CLIENT_UID + d = defer.Deferred() + + def cbk(events, _): + callFromThread(d.callback, event) + + self._client.register(event, cbk) + self._client.emit(event, None) + return d + + def test_client_unregister_all(self): + """ + Test that the client can unregister all events for one signal. + """ + event1 = catalog.CLIENT_UID + d = defer.Deferred() + # register more than one callback for the same event + self._client.register( + event1, lambda ev, _: callFromThread(d.errback, None)) + self._client.register( + event1, lambda ev, _: callFromThread(d.errback, None)) + # unregister and emit the event + self._client.unregister(event1) + self._client.emit(event1, None) + # register and emit another event so the deferred can succeed + event2 = catalog.CLIENT_SESSION_ID + self._client.register( + event2, lambda ev, _: callFromThread(d.callback, None)) + self._client.emit(event2, None) + return d + + def test_client_unregister_by_uid(self): + """ + Test that the client can unregister an event by uid. + """ + event = catalog.CLIENT_UID + d = defer.Deferred() + # register one callback that would fail + uid = self._client.register( + event, lambda ev, _: callFromThread(d.errback, None)) + # register one callback that will succeed + self._client.register( + event, lambda ev, _: callFromThread(d.callback, None)) + # unregister by uid and emit the event + self._client.unregister(event, uid=uid) + self._client.emit(event, None) + return d + + +class EventsTxClientTestCase(EventsGenericClientTestCase, unittest.TestCase): + + _client = txclient + + +class EventsClientTestCase(EventsGenericClientTestCase, unittest.TestCase): + + _client = client diff --git a/src/leap/common/events/txclient.py b/src/leap/common/events/txclient.py index a2b704d..63f12d7 100644 --- a/src/leap/common/events/txclient.py +++ b/src/leap/common/events/txclient.py @@ -58,11 +58,13 @@ class EventsTxClient(TxZmqClientComponent, EventsClient): """ def __init__(self, emit_addr=EMIT_ADDR, reg_addr=REG_ADDR, - path_prefix=None): + path_prefix=None, factory=None, enable_curve=True): """ - Initialize the events server. + Initialize the events client. """ - TxZmqClientComponent.__init__(self, path_prefix=path_prefix) + TxZmqClientComponent.__init__( + self, path_prefix=path_prefix, factory=factory, + enable_curve=enable_curve) EventsClient.__init__(self, emit_addr, reg_addr) # connect SUB first, otherwise we might miss some event sent from this # same client diff --git a/src/leap/common/events/zmq_components.py b/src/leap/common/events/zmq_components.py index 74abb76..8919cd9 100644 --- a/src/leap/common/events/zmq_components.py +++ b/src/leap/common/events/zmq_components.py @@ -57,12 +57,14 @@ class TxZmqComponent(object): _component_type = None - def __init__(self, path_prefix=None, enable_curve=True): + def __init__(self, path_prefix=None, enable_curve=True, factory=None): """ Initialize the txzmq component. """ if path_prefix is None: path_prefix = get_path_prefix(flags.STANDALONE) + if factory is not None: + self._factory = factory self._config_prefix = os.path.join(path_prefix, "leap", "events") self._connections = [] if enable_curve: @@ -78,64 +80,69 @@ class TxZmqComponent(object): "define a self._component_type!") return self._component_type - def _zmq_connect(self, connClass, address): + def _zmq_bind(self, connClass, address): """ - Connect to an address. + Bind to an address. :param connClass: The connection class to be used. :type connClass: txzmq.ZmqConnection - :param address: The address to connect to. + :param address: The address to bind to. :type address: str - :return: The binded connection. - :rtype: txzmq.ZmqConnection + :return: The binded connection and port. + :rtype: (txzmq.ZmqConnection, int) """ - endpoint = ZmqEndpoint(ZmqEndpointType.connect, address) + proto, addr, port = ADDRESS_RE.search(address).groups() + + endpoint = ZmqEndpoint(ZmqEndpointType.bind, address) connection = connClass(self._factory) if self.use_curve: socket = connection.socket + public, secret = maybe_create_and_get_certificates( self._config_prefix, self.component_type) - server_public_file = os.path.join( - self._config_prefix, PUBLIC_KEYS_PREFIX, "server.key") - - server_public, _ = zmq.auth.load_certificate(server_public_file) socket.curve_publickey = public socket.curve_secretkey = secret - socket.curve_serverkey = server_public + self._start_authentication(connection.socket) - connection.addEndpoints([endpoint]) - return connection + if proto == 'tcp' and int(port) == 0: + connection.endpoints.extend([endpoint]) + port = connection.socket.bind_to_random_port('tcp://%s' % addr) + else: + connection.addEndpoints([endpoint]) - def _zmq_bind(self, connClass, address): + return connection, int(port) + + def _zmq_connect(self, connClass, address): """ - Bind to an address. + Connect to an address. :param connClass: The connection class to be used. :type connClass: txzmq.ZmqConnection - :param address: The address to bind to. + :param address: The address to connect to. :type address: str - :return: The binded connection and port. - :rtype: (txzmq.ZmqConnection, int) + :return: The binded connection. + :rtype: txzmq.ZmqConnection """ - proto, addr, port = ADDRESS_RE.search(address).groups() - - endpoint = ZmqEndpoint(ZmqEndpointType.bind, address) + endpoint = ZmqEndpoint(ZmqEndpointType.connect, address) connection = connClass(self._factory) if self.use_curve: socket = connection.socket - public, secret = maybe_create_and_get_certificates( self._config_prefix, self.component_type) + server_public_file = os.path.join( + self._config_prefix, PUBLIC_KEYS_PREFIX, "server.key") + + server_public, _ = zmq.auth.load_certificate(server_public_file) socket.curve_publickey = public socket.curve_secretkey = secret - self._start_authentication(connection.socket) + socket.curve_serverkey = server_public connection.addEndpoints([endpoint]) - return connection, port + return connection def _start_authentication(self, socket): @@ -150,6 +157,7 @@ class TxZmqComponent(object): # tell authenticator to use the certificate in a directory public_keys_dir = os.path.join(self._config_prefix, PUBLIC_KEYS_PREFIX) auth_req.configure_curve(domain="*", location=public_keys_dir) + auth_req.shutdown() # This has to be set before binding the socket, that's why this method # has to be called before addEndpoints() |