summaryrefslogtreecommitdiff
path: root/src/leap/common/events
diff options
context:
space:
mode:
authorKali Kaneko <kali@leap.se>2016-02-29 19:33:28 -0400
committerKali Kaneko <kali@leap.se>2016-02-29 19:39:43 -0400
commit027ad7eed50947608738ce0009fccf776936e55c (patch)
tree6777d80f097a06bd9c560b24fc886504a693fae8 /src/leap/common/events
parent24977b744b42df912a23a2861453e7d4d5202310 (diff)
[tests] adapt events tests to recent changes
Diffstat (limited to 'src/leap/common/events')
-rw-r--r--src/leap/common/events/auth.py4
-rw-r--r--src/leap/common/events/client.py23
-rw-r--r--src/leap/common/events/server.py13
-rw-r--r--src/leap/common/events/tests/test_auth.py64
-rw-r--r--src/leap/common/events/tests/test_events.py204
-rw-r--r--src/leap/common/events/txclient.py8
-rw-r--r--src/leap/common/events/zmq_components.py58
7 files changed, 335 insertions, 39 deletions
diff --git a/src/leap/common/events/auth.py b/src/leap/common/events/auth.py
index 1a1bcab..5b71f2d 100644
--- a/src/leap/common/events/auth.py
+++ b/src/leap/common/events/auth.py
@@ -38,8 +38,8 @@ class TxAuthenticator(ZmqConnection):
address = 'inproc://zeromq.zap.01'
encoding = 'utf-8'
- def __init__(self, factory):
- super(TxAuthenticator, self).__init__(factory)
+ def __init__(self, factory, *args, **kw):
+ super(TxAuthenticator, self).__init__(factory, *args, **kw)
self.authenticator = Authenticator(factory.context)
self.authenticator._send_zap_reply = self._send_zap_reply
diff --git a/src/leap/common/events/client.py b/src/leap/common/events/client.py
index 60d24bc..78617de 100644
--- a/src/leap/common/events/client.py
+++ b/src/leap/common/events/client.py
@@ -63,14 +63,18 @@ logger = logging.getLogger(__name__)
_emit_addr = EMIT_ADDR
_reg_addr = REG_ADDR
+_factory = None
+_enable_curve = True
-def configure_client(emit_addr, reg_addr):
- global _emit_addr, _reg_addr
+def configure_client(emit_addr, reg_addr, factory=None, enable_curve=True):
+ global _emit_addr, _reg_addr, _factory, _enable_curve
logger.debug("Configuring client with addresses: (%s, %s)" %
(emit_addr, reg_addr))
_emit_addr = emit_addr
_reg_addr = reg_addr
+ _factory = factory
+ _enable_curve = enable_curve
class EventsClient(object):
@@ -103,7 +107,9 @@ class EventsClient(object):
"""
with cls._instance_lock:
if cls._instance is None:
- cls._instance = cls(_emit_addr, _reg_addr)
+ cls._instance = cls(
+ _emit_addr, _reg_addr, factory=_factory,
+ enable_curve=_enable_curve)
return cls._instance
def register(self, event, callback, uid=None, replace=False):
@@ -270,7 +276,7 @@ class EventsClientThread(threading.Thread, EventsClient):
A threaded version of the events client.
"""
- def __init__(self, emit_addr, reg_addr):
+ def __init__(self, emit_addr, reg_addr, factory=None, enable_curve=True):
"""
Initialize the events client.
"""
@@ -281,15 +287,22 @@ class EventsClientThread(threading.Thread, EventsClient):
self._config_prefix = os.path.join(
get_path_prefix(flags.STANDALONE), "leap", "events")
self._loop = None
+ self._factory = factory
self._context = None
self._push = None
self._sub = None
+ if enable_curve:
+ self.use_curve = zmq_has_curve()
+ else:
+ self.use_curve = False
+
def _init_zmq(self):
"""
Initialize ZMQ connections.
"""
self._loop = EventsIOLoop()
+ # we need a new context for each thread
self._context = zmq.Context()
# connect SUB first, otherwise we might miss some event sent from this
# same client
@@ -311,7 +324,7 @@ class EventsClientThread(threading.Thread, EventsClient):
logger.debug("Connecting %s to %s." % (socktype, address))
socket = self._context.socket(socktype)
# configure curve authentication
- if zmq_has_curve():
+ if self.use_curve:
public, private = maybe_create_and_get_certificates(
self._config_prefix, "client")
server_public_file = os.path.join(
diff --git a/src/leap/common/events/server.py b/src/leap/common/events/server.py
index 6252853..ad79abe 100644
--- a/src/leap/common/events/server.py
+++ b/src/leap/common/events/server.py
@@ -37,7 +37,8 @@ else:
logger = logging.getLogger(__name__)
-def ensure_server(emit_addr=EMIT_ADDR, reg_addr=REG_ADDR):
+def ensure_server(emit_addr=EMIT_ADDR, reg_addr=REG_ADDR, path_prefix=None,
+ factory=None, enable_curve=True):
"""
Make sure the server is running in the given addresses.
@@ -49,7 +50,8 @@ def ensure_server(emit_addr=EMIT_ADDR, reg_addr=REG_ADDR):
:return: an events server instance
:rtype: EventsServer
"""
- _server = EventsServer(emit_addr, reg_addr)
+ _server = EventsServer(emit_addr, reg_addr, path_prefix, factory=factory,
+ enable_curve=enable_curve)
return _server
@@ -59,7 +61,8 @@ class EventsServer(TxZmqServerComponent):
events in another address.
"""
- def __init__(self, emit_addr, reg_addr):
+ def __init__(self, emit_addr, reg_addr, path_prefix=None, factory=None,
+ enable_curve=True):
"""
Initialize the events server.
@@ -68,7 +71,9 @@ class EventsServer(TxZmqServerComponent):
:param reg_addr: The address to which publish events to clients.
:type reg_addr: str
"""
- TxZmqServerComponent.__init__(self)
+ TxZmqServerComponent.__init__(self, path_prefix=path_prefix,
+ factory=factory,
+ enable_curve=enable_curve)
# bind PULL and PUB sockets
self._pull, self.pull_port = self._zmq_bind(
txzmq.ZmqPullConnection, emit_addr)
diff --git a/src/leap/common/events/tests/test_auth.py b/src/leap/common/events/tests/test_auth.py
new file mode 100644
index 0000000..78ffd9f
--- /dev/null
+++ b/src/leap/common/events/tests/test_auth.py
@@ -0,0 +1,64 @@
+# -*- coding: utf-8 -*-
+# test_zmq_components.py
+# Copyright (C) 2014 LEAP
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see <http://www.gnu.org/licenses/>.
+"""
+Tests for the auth module.
+"""
+import os
+
+from twisted.trial import unittest
+from txzmq import ZmqFactory
+
+from leap.common.events import auth
+from leap.common.testing.basetest import BaseLeapTest
+from leap.common.zmq_utils import PUBLIC_KEYS_PREFIX
+from leap.common.zmq_utils import maybe_create_and_get_certificates
+
+from txzmq.test import _wait
+
+
+class ZmqAuthTestCase(unittest.TestCase, BaseLeapTest):
+
+ def setUp(self):
+ self.setUpEnv(launch_events_server=False)
+
+ self.factory = ZmqFactory()
+ self._config_prefix = os.path.join(self.tempdir, "leap", "events")
+
+ self.public, self.secret = maybe_create_and_get_certificates(
+ self._config_prefix, 'server')
+
+ self.authenticator = auth.TxAuthenticator(self.factory)
+ self.authenticator.start()
+ self.auth_req = auth.TxAuthenticationRequest(self.factory)
+
+ def tearDown(self):
+ self.factory.shutdown()
+ self.tearDownEnv()
+
+ def test_curve_auth(self):
+ self.auth_req.start()
+ self.auth_req.allow('127.0.0.1')
+ public_keys_dir = os.path.join(self._config_prefix, PUBLIC_KEYS_PREFIX)
+ self.auth_req.configure_curve(domain="*", location=public_keys_dir)
+
+ def check(ignored):
+ authenticator = self.authenticator.authenticator
+ certs = authenticator.certs['*']
+ self.failUnlessEqual(authenticator.whitelist, set([u'127.0.0.1']))
+ self.failUnlessEqual(certs[certs.keys()[0]], True)
+
+ return _wait(0.1).addCallback(check)
diff --git a/src/leap/common/events/tests/test_events.py b/src/leap/common/events/tests/test_events.py
new file mode 100644
index 0000000..c45601b
--- /dev/null
+++ b/src/leap/common/events/tests/test_events.py
@@ -0,0 +1,204 @@
+# -*- coding: utf-8 -*-
+# test_events.py
+# Copyright (C) 2013 LEAP
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see <http://www.gnu.org/licenses/>.
+"""
+Tests for the events framework
+"""
+import os
+import logging
+
+from twisted.internet.reactor import callFromThread
+from twisted.trial import unittest
+from twisted.internet import defer
+
+from txzmq import ZmqFactory
+
+from leap.common.events import server
+from leap.common.events import client
+from leap.common.events import flags
+from leap.common.events import txclient
+from leap.common.events import catalog
+from leap.common.events.errors import CallbackAlreadyRegisteredError
+
+
+if 'DEBUG' in os.environ:
+ logging.basicConfig(level=logging.DEBUG)
+
+
+class EventsGenericClientTestCase(object):
+
+ def setUp(self):
+ flags.set_events_enabled(True)
+ self.factory = ZmqFactory()
+ self._server = server.ensure_server(
+ emit_addr="tcp://127.0.0.1:0",
+ reg_addr="tcp://127.0.0.1:0",
+ factory=self.factory,
+ enable_curve=False)
+
+ self._client.configure_client(
+ emit_addr="tcp://127.0.0.1:%d" % self._server.pull_port,
+ reg_addr="tcp://127.0.0.1:%d" % self._server.pub_port,
+ factory=self.factory, enable_curve=False)
+
+ def tearDown(self):
+ flags.set_events_enabled(False)
+ self.factory.shutdown()
+ self._client.instance().reset()
+
+ def test_client_register(self):
+ """
+ Ensure clients can register callbacks.
+ """
+ callbacks = self._client.instance().callbacks
+ self.assertTrue(len(callbacks) == 0,
+ 'There should be no callback for this event.')
+ # register one event
+ event1 = catalog.CLIENT_UID
+
+ def cbk1(event, _):
+ return True
+
+ uid1 = self._client.register(event1, cbk1)
+ # assert for correct registration
+ self.assertTrue(len(callbacks) == 1)
+ self.assertTrue(callbacks[event1][uid1] == cbk1,
+ 'Could not register event in local client.')
+ # register another event
+ event2 = catalog.CLIENT_SESSION_ID
+
+ def cbk2(event, _):
+ return True
+
+ uid2 = self._client.register(event2, cbk2)
+ # assert for correct registration
+ self.assertTrue(len(callbacks) == 2)
+ self.assertTrue(callbacks[event2][uid2] == cbk2,
+ 'Could not register event in local client.')
+
+
+ def test_register_signal_replace(self):
+ """
+ Make sure clients can replace already registered callbacks.
+ """
+ event = catalog.CLIENT_UID
+ d = defer.Deferred()
+
+ def cbk_fail(event, _):
+ return callFromThread(d.errback, event)
+
+ def cbk_succeed(event, _):
+ return callFromThread(d.callback, event)
+
+ self._client.register(event, cbk_fail, uid=1)
+ self._client.register(event, cbk_succeed, uid=1, replace=True)
+ self._client.emit(event, None)
+ return d
+
+ def test_register_signal_replace_fails_when_replace_is_false(self):
+ """
+ Make sure clients trying to replace already registered callbacks fail
+ when replace=False
+ """
+ event = catalog.CLIENT_UID
+ self._client.register(event, lambda event, _: None, uid=1)
+ self.assertRaises(
+ CallbackAlreadyRegisteredError,
+ self._client.register,
+ event, lambda event, _: None, uid=1, replace=False)
+
+ def test_register_more_than_one_callback_works(self):
+ """
+ Make sure clients can replace already registered callbacks.
+ """
+ event = catalog.CLIENT_UID
+ d1 = defer.Deferred()
+
+ def cbk1(event, _):
+ return callFromThread(d1.callback, event)
+
+ d2 = defer.Deferred()
+
+ def cbk2(event, _):
+ return d2.callback(event)
+
+ self._client.register(event, cbk1)
+ self._client.register(event, cbk2)
+ self._client.emit(event, None)
+ d = defer.gatherResults([d1, d2])
+ return d
+
+ def test_client_receives_signal(self):
+ """
+ Ensure clients can receive signals.
+ """
+ event = catalog.CLIENT_UID
+ d = defer.Deferred()
+
+ def cbk(events, _):
+ callFromThread(d.callback, event)
+
+ self._client.register(event, cbk)
+ self._client.emit(event, None)
+ return d
+
+ def test_client_unregister_all(self):
+ """
+ Test that the client can unregister all events for one signal.
+ """
+ event1 = catalog.CLIENT_UID
+ d = defer.Deferred()
+ # register more than one callback for the same event
+ self._client.register(
+ event1, lambda ev, _: callFromThread(d.errback, None))
+ self._client.register(
+ event1, lambda ev, _: callFromThread(d.errback, None))
+ # unregister and emit the event
+ self._client.unregister(event1)
+ self._client.emit(event1, None)
+ # register and emit another event so the deferred can succeed
+ event2 = catalog.CLIENT_SESSION_ID
+ self._client.register(
+ event2, lambda ev, _: callFromThread(d.callback, None))
+ self._client.emit(event2, None)
+ return d
+
+ def test_client_unregister_by_uid(self):
+ """
+ Test that the client can unregister an event by uid.
+ """
+ event = catalog.CLIENT_UID
+ d = defer.Deferred()
+ # register one callback that would fail
+ uid = self._client.register(
+ event, lambda ev, _: callFromThread(d.errback, None))
+ # register one callback that will succeed
+ self._client.register(
+ event, lambda ev, _: callFromThread(d.callback, None))
+ # unregister by uid and emit the event
+ self._client.unregister(event, uid=uid)
+ self._client.emit(event, None)
+ return d
+
+
+class EventsTxClientTestCase(EventsGenericClientTestCase, unittest.TestCase):
+
+ _client = txclient
+
+
+class EventsClientTestCase(EventsGenericClientTestCase, unittest.TestCase):
+
+ _client = client
diff --git a/src/leap/common/events/txclient.py b/src/leap/common/events/txclient.py
index a2b704d..63f12d7 100644
--- a/src/leap/common/events/txclient.py
+++ b/src/leap/common/events/txclient.py
@@ -58,11 +58,13 @@ class EventsTxClient(TxZmqClientComponent, EventsClient):
"""
def __init__(self, emit_addr=EMIT_ADDR, reg_addr=REG_ADDR,
- path_prefix=None):
+ path_prefix=None, factory=None, enable_curve=True):
"""
- Initialize the events server.
+ Initialize the events client.
"""
- TxZmqClientComponent.__init__(self, path_prefix=path_prefix)
+ TxZmqClientComponent.__init__(
+ self, path_prefix=path_prefix, factory=factory,
+ enable_curve=enable_curve)
EventsClient.__init__(self, emit_addr, reg_addr)
# connect SUB first, otherwise we might miss some event sent from this
# same client
diff --git a/src/leap/common/events/zmq_components.py b/src/leap/common/events/zmq_components.py
index 74abb76..8919cd9 100644
--- a/src/leap/common/events/zmq_components.py
+++ b/src/leap/common/events/zmq_components.py
@@ -57,12 +57,14 @@ class TxZmqComponent(object):
_component_type = None
- def __init__(self, path_prefix=None, enable_curve=True):
+ def __init__(self, path_prefix=None, enable_curve=True, factory=None):
"""
Initialize the txzmq component.
"""
if path_prefix is None:
path_prefix = get_path_prefix(flags.STANDALONE)
+ if factory is not None:
+ self._factory = factory
self._config_prefix = os.path.join(path_prefix, "leap", "events")
self._connections = []
if enable_curve:
@@ -78,64 +80,69 @@ class TxZmqComponent(object):
"define a self._component_type!")
return self._component_type
- def _zmq_connect(self, connClass, address):
+ def _zmq_bind(self, connClass, address):
"""
- Connect to an address.
+ Bind to an address.
:param connClass: The connection class to be used.
:type connClass: txzmq.ZmqConnection
- :param address: The address to connect to.
+ :param address: The address to bind to.
:type address: str
- :return: The binded connection.
- :rtype: txzmq.ZmqConnection
+ :return: The binded connection and port.
+ :rtype: (txzmq.ZmqConnection, int)
"""
- endpoint = ZmqEndpoint(ZmqEndpointType.connect, address)
+ proto, addr, port = ADDRESS_RE.search(address).groups()
+
+ endpoint = ZmqEndpoint(ZmqEndpointType.bind, address)
connection = connClass(self._factory)
if self.use_curve:
socket = connection.socket
+
public, secret = maybe_create_and_get_certificates(
self._config_prefix, self.component_type)
- server_public_file = os.path.join(
- self._config_prefix, PUBLIC_KEYS_PREFIX, "server.key")
-
- server_public, _ = zmq.auth.load_certificate(server_public_file)
socket.curve_publickey = public
socket.curve_secretkey = secret
- socket.curve_serverkey = server_public
+ self._start_authentication(connection.socket)
- connection.addEndpoints([endpoint])
- return connection
+ if proto == 'tcp' and int(port) == 0:
+ connection.endpoints.extend([endpoint])
+ port = connection.socket.bind_to_random_port('tcp://%s' % addr)
+ else:
+ connection.addEndpoints([endpoint])
- def _zmq_bind(self, connClass, address):
+ return connection, int(port)
+
+ def _zmq_connect(self, connClass, address):
"""
- Bind to an address.
+ Connect to an address.
:param connClass: The connection class to be used.
:type connClass: txzmq.ZmqConnection
- :param address: The address to bind to.
+ :param address: The address to connect to.
:type address: str
- :return: The binded connection and port.
- :rtype: (txzmq.ZmqConnection, int)
+ :return: The binded connection.
+ :rtype: txzmq.ZmqConnection
"""
- proto, addr, port = ADDRESS_RE.search(address).groups()
-
- endpoint = ZmqEndpoint(ZmqEndpointType.bind, address)
+ endpoint = ZmqEndpoint(ZmqEndpointType.connect, address)
connection = connClass(self._factory)
if self.use_curve:
socket = connection.socket
-
public, secret = maybe_create_and_get_certificates(
self._config_prefix, self.component_type)
+ server_public_file = os.path.join(
+ self._config_prefix, PUBLIC_KEYS_PREFIX, "server.key")
+
+ server_public, _ = zmq.auth.load_certificate(server_public_file)
socket.curve_publickey = public
socket.curve_secretkey = secret
- self._start_authentication(connection.socket)
+ socket.curve_serverkey = server_public
connection.addEndpoints([endpoint])
- return connection, port
+ return connection
def _start_authentication(self, socket):
@@ -150,6 +157,7 @@ class TxZmqComponent(object):
# tell authenticator to use the certificate in a directory
public_keys_dir = os.path.join(self._config_prefix, PUBLIC_KEYS_PREFIX)
auth_req.configure_curve(domain="*", location=public_keys_dir)
+ auth_req.shutdown()
# This has to be set before binding the socket, that's why this method
# has to be called before addEndpoints()