diff options
Diffstat (limited to 'src/leap/common/events')
-rw-r--r-- | src/leap/common/events/auth.py | 100 | ||||
-rw-r--r-- | src/leap/common/events/catalog.py | 85 | ||||
-rw-r--r-- | src/leap/common/events/client.py | 23 | ||||
-rw-r--r-- | src/leap/common/events/examples/README.txt | 49 | ||||
-rw-r--r-- | src/leap/common/events/examples/client.py | 2 | ||||
-rw-r--r-- | src/leap/common/events/examples/server.py | 4 | ||||
-rw-r--r-- | src/leap/common/events/server.py | 24 | ||||
-rw-r--r-- | src/leap/common/events/tests/test_auth.py | 64 | ||||
-rw-r--r-- | src/leap/common/events/tests/test_events.py | 203 | ||||
-rw-r--r-- | src/leap/common/events/txclient.py | 10 | ||||
-rw-r--r-- | src/leap/common/events/zmq_components.py | 147 |
11 files changed, 571 insertions, 140 deletions
diff --git a/src/leap/common/events/auth.py b/src/leap/common/events/auth.py new file mode 100644 index 0000000..db217ca --- /dev/null +++ b/src/leap/common/events/auth.py @@ -0,0 +1,100 @@ +# -*- coding: utf-8 -*- +# auth.py +# Copyright (C) 2016 LEAP +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. +""" +ZAP authentication, twisted style. +""" +from zmq import PAIR +from zmq.auth.base import Authenticator, VERSION +from txzmq.connection import ZmqConnection +from zmq.utils.strtypes import b, u + +from twisted.python import log + +from txzmq.connection import ZmqEndpoint, ZmqEndpointType + + +class TxAuthenticator(ZmqConnection): + + """ + This does not implement the whole ZAP protocol, but the bare minimum that + we need. + """ + + socketType = PAIR + address = 'inproc://zeromq.zap.01' + encoding = 'utf-8' + + def __init__(self, factory, *args, **kw): + super(TxAuthenticator, self).__init__(factory, *args, **kw) + self.authenticator = Authenticator(factory.context) + self.authenticator._send_zap_reply = self._send_zap_reply + + def start(self): + endpoint = ZmqEndpoint(ZmqEndpointType.bind, self.address) + self.addEndpoints([endpoint]) + + def messageReceived(self, msg): + + command = msg[0] + + if command == b'ALLOW': + addresses = [u(m, self.encoding) for m in msg[1:]] + try: + self.authenticator.allow(*addresses) + except Exception as e: + log.err("Failed to allow %s", addresses) + + elif command == b'CURVE': + domain = u(msg[1], self.encoding) + location = u(msg[2], self.encoding) + self.authenticator.configure_curve(domain, location) + + def _send_zap_reply(self, request_id, status_code, status_text, + user_id='user'): + """ + Send a ZAP reply to finish the authentication. + """ + user_id = user_id if status_code == b'200' else b'' + if isinstance(user_id, unicode): + user_id = user_id.encode(self.encoding, 'replace') + metadata = b'' # not currently used + reply = [VERSION, request_id, status_code, status_text, + user_id, metadata] + self.send(reply) + + def shutdown(self): + if self.factory: + super(TxAuthenticator, self).shutdown() + + +class TxAuthenticationRequest(ZmqConnection): + + socketType = PAIR + address = 'inproc://zeromq.zap.01' + encoding = 'utf-8' + + def start(self): + endpoint = ZmqEndpoint(ZmqEndpointType.connect, self.address) + self.addEndpoints([endpoint]) + + def allow(self, *addresses): + self.send([b'ALLOW'] + [b(a, self.encoding) for a in addresses]) + + def configure_curve(self, domain='*', location=''): + domain = b(domain, self.encoding) + location = b(location, self.encoding) + self.send([b'CURVE', domain, location]) diff --git a/src/leap/common/events/catalog.py b/src/leap/common/events/catalog.py index 8bddd2c..9a834b2 100644 --- a/src/leap/common/events/catalog.py +++ b/src/leap/common/events/catalog.py @@ -24,49 +24,54 @@ Events catalog. EVENTS = [ "CLIENT_SESSION_ID", "CLIENT_UID", - "IMAP_CLIENT_LOGIN", - "IMAP_SERVICE_FAILED_TO_START", - "IMAP_SERVICE_STARTED", - "IMAP_UNHANDLED_ERROR", - "KEYMANAGER_DONE_UPLOADING_KEYS", - "KEYMANAGER_FINISHED_KEY_GENERATION", - "KEYMANAGER_KEY_FOUND", - "KEYMANAGER_KEY_NOT_FOUND", - "KEYMANAGER_LOOKING_FOR_KEY", - "KEYMANAGER_STARTED_KEY_GENERATION", - "MAIL_FETCHED_INCOMING", - "MAIL_MSG_DECRYPTED", - "MAIL_MSG_DELETED_INCOMING", - "MAIL_MSG_PROCESSING", - "MAIL_MSG_SAVED_LOCALLY", - "MAIL_UNREAD_MESSAGES", "RAISE_WINDOW", - "SMTP_CONNECTION_LOST", - "SMTP_END_ENCRYPT_AND_SIGN", - "SMTP_END_SIGN", - "SMTP_RECIPIENT_ACCEPTED_ENCRYPTED", - "SMTP_RECIPIENT_ACCEPTED_UNENCRYPTED", - "SMTP_RECIPIENT_REJECTED", - "SMTP_SEND_MESSAGE_ERROR", - "SMTP_SEND_MESSAGE_START", - "SMTP_SEND_MESSAGE_SUCCESS", - "SMTP_SERVICE_FAILED_TO_START", - "SMTP_SERVICE_STARTED", - "SMTP_START_ENCRYPT_AND_SIGN", - "SMTP_START_SIGN", - "SOLEDAD_CREATING_KEYS", - "SOLEDAD_DONE_CREATING_KEYS", - "SOLEDAD_DONE_DATA_SYNC", - "SOLEDAD_DONE_DOWNLOADING_KEYS", - "SOLEDAD_DONE_UPLOADING_KEYS", - "SOLEDAD_DOWNLOADING_KEYS", - "SOLEDAD_INVALID_AUTH_TOKEN", - "SOLEDAD_NEW_DATA_TO_SYNC", - "SOLEDAD_SYNC_RECEIVE_STATUS", - "SOLEDAD_SYNC_SEND_STATUS", - "SOLEDAD_UPLOADING_KEYS", "UPDATER_DONE_UPDATING", "UPDATER_NEW_UPDATES", + + "KEYMANAGER_DONE_UPLOADING_KEYS", # (address) + "KEYMANAGER_FINISHED_KEY_GENERATION", # (address) + "KEYMANAGER_KEY_FOUND", # (address) + "KEYMANAGER_KEY_NOT_FOUND", # (address) + "KEYMANAGER_LOOKING_FOR_KEY", # (address) + "KEYMANAGER_STARTED_KEY_GENERATION", # (address) + + "SOLEDAD_CREATING_KEYS", # {uuid, userid} + "SOLEDAD_DONE_CREATING_KEYS", # {uuid, userid} + "SOLEDAD_DONE_DATA_SYNC", # {uuid, userid} + "SOLEDAD_DONE_DOWNLOADING_KEYS", # {uuid, userid} + "SOLEDAD_DONE_UPLOADING_KEYS", # {uuid, userid} + "SOLEDAD_DOWNLOADING_KEYS", # {uuid, userid} + "SOLEDAD_INVALID_AUTH_TOKEN", # {uuid, userid} + "SOLEDAD_SYNC_RECEIVE_STATUS", # {uuid, userid} + "SOLEDAD_SYNC_SEND_STATUS", # {uuid, userid} + "SOLEDAD_UPLOADING_KEYS", # {uuid, userid} + "SOLEDAD_NEW_DATA_TO_SYNC", + + "MAIL_FETCHED_INCOMING", # (userid) + "MAIL_MSG_DECRYPTED", # (userid) + "MAIL_MSG_DELETED_INCOMING", # (userid) + "MAIL_MSG_PROCESSING", # (userid) + "MAIL_MSG_SAVED_LOCALLY", # (userid) + "MAIL_UNREAD_MESSAGES", # (userid, number) + + "IMAP_SERVICE_STARTED", + "IMAP_SERVICE_FAILED_TO_START", + "IMAP_UNHANDLED_ERROR", + "IMAP_CLIENT_LOGIN", # (username) + + "SMTP_SERVICE_STARTED", + "SMTP_SERVICE_FAILED_TO_START", + "SMTP_START_ENCRYPT_AND_SIGN", # (from_addr) + "SMTP_END_ENCRYPT_AND_SIGN", # (from_addr) + "SMTP_START_SIGN", # (from_addr) + "SMTP_END_SIGN", # (from_addr) + "SMTP_SEND_MESSAGE_START", # (from_addr) + "SMTP_SEND_MESSAGE_SUCCESS", # (from_addr) + "SMTP_RECIPIENT_ACCEPTED_ENCRYPTED", # (userid, dest) + "SMTP_RECIPIENT_ACCEPTED_UNENCRYPTED", # (userid, dest) + "SMTP_CONNECTION_LOST", # (userid, dest) + "SMTP_RECIPIENT_REJECTED", # (userid, dest) + "SMTP_SEND_MESSAGE_ERROR", # (userid, dest) ] diff --git a/src/leap/common/events/client.py b/src/leap/common/events/client.py index 60d24bc..78617de 100644 --- a/src/leap/common/events/client.py +++ b/src/leap/common/events/client.py @@ -63,14 +63,18 @@ logger = logging.getLogger(__name__) _emit_addr = EMIT_ADDR _reg_addr = REG_ADDR +_factory = None +_enable_curve = True -def configure_client(emit_addr, reg_addr): - global _emit_addr, _reg_addr +def configure_client(emit_addr, reg_addr, factory=None, enable_curve=True): + global _emit_addr, _reg_addr, _factory, _enable_curve logger.debug("Configuring client with addresses: (%s, %s)" % (emit_addr, reg_addr)) _emit_addr = emit_addr _reg_addr = reg_addr + _factory = factory + _enable_curve = enable_curve class EventsClient(object): @@ -103,7 +107,9 @@ class EventsClient(object): """ with cls._instance_lock: if cls._instance is None: - cls._instance = cls(_emit_addr, _reg_addr) + cls._instance = cls( + _emit_addr, _reg_addr, factory=_factory, + enable_curve=_enable_curve) return cls._instance def register(self, event, callback, uid=None, replace=False): @@ -270,7 +276,7 @@ class EventsClientThread(threading.Thread, EventsClient): A threaded version of the events client. """ - def __init__(self, emit_addr, reg_addr): + def __init__(self, emit_addr, reg_addr, factory=None, enable_curve=True): """ Initialize the events client. """ @@ -281,15 +287,22 @@ class EventsClientThread(threading.Thread, EventsClient): self._config_prefix = os.path.join( get_path_prefix(flags.STANDALONE), "leap", "events") self._loop = None + self._factory = factory self._context = None self._push = None self._sub = None + if enable_curve: + self.use_curve = zmq_has_curve() + else: + self.use_curve = False + def _init_zmq(self): """ Initialize ZMQ connections. """ self._loop = EventsIOLoop() + # we need a new context for each thread self._context = zmq.Context() # connect SUB first, otherwise we might miss some event sent from this # same client @@ -311,7 +324,7 @@ class EventsClientThread(threading.Thread, EventsClient): logger.debug("Connecting %s to %s." % (socktype, address)) socket = self._context.socket(socktype) # configure curve authentication - if zmq_has_curve(): + if self.use_curve: public, private = maybe_create_and_get_certificates( self._config_prefix, "client") server_public_file = os.path.join( diff --git a/src/leap/common/events/examples/README.txt b/src/leap/common/events/examples/README.txt new file mode 100644 index 0000000..0bb0df6 --- /dev/null +++ b/src/leap/common/events/examples/README.txt @@ -0,0 +1,49 @@ +How to debug +----------------------------------------- +monitor the events socket: + sudo ngrep -W byline -d any port 9000 + +launch the server: + python server.py + +launch the client: + python client.py + +if zmq is available and enabled, you should see encrypted messages passing by +the socket. + +You should see something like the following: + +#### +T 127.0.0.1:9000 -> 127.0.0.1:33122 [AP] +.......... +## +T 127.0.0.1:33122 -> 127.0.0.1:9000 [AP] +........... +## +T 127.0.0.1:9000 -> 127.0.0.1:33122 [AP] +..CURVE............................................... +# +T 127.0.0.1:33122 -> 127.0.0.1:9000 [AP] +.CURVE............................................... +# +T 127.0.0.1:33122 -> 127.0.0.1:9000 [AP] +...HELLO.............................................................................:....^...".....'.S...n......Y...................O.7.+.D.q".*..R...j.....8..qu..~......Ck.G\....:...m....Tg.s..M..x<.. +## +T 127.0.0.1:9000 -> 127.0.0.1:33122 [AP] +...WELCOME..%.'.,Td... I..}...........`..Nm......./_.Je...4.....-.....f<v.|.".jJ...^.D...$lJ..U......g..../w.......\..W.....!........i.v....0...........3..a.5}.@F..v./..$ +# +T 127.0.0.1:33122 -> 127.0.0.1:9000 [AP] +..........INITIATE......!.*.=0.-......D..]{...A\.tz...!2.....A./ +6.......Y.h.N....cb.U.|..f..)....W..3..X.2U.3PGl.........m..95.(......NJ....5.'..W.GQ..B/.....\%.,Q..r.'L5.......{.W<=._.$.(6j.G... +...37.H..Th...'.........0 ........,..q....U..G..M.`!_..w....f.".......... +.d.K.Y.>f.n.kV. +# +T 127.0.0.1:9000 -> 127.0.0.1:33122 [AP] +.2.READY............A...e.)......*.8y....k.<.N1Z.4.. +# +T 127.0.0.1:33122 -> 127.0.0.1:9000 [AP] +.+.MESSAGE........o...*M..,.... +.r..w..[.GwcU +### + diff --git a/src/leap/common/events/examples/client.py b/src/leap/common/events/examples/client.py new file mode 100644 index 0000000..d6d8985 --- /dev/null +++ b/src/leap/common/events/examples/client.py @@ -0,0 +1,2 @@ +from leap.common.events.txclient import emit +emit('stuff!') diff --git a/src/leap/common/events/examples/server.py b/src/leap/common/events/examples/server.py new file mode 100644 index 0000000..f40f8dc --- /dev/null +++ b/src/leap/common/events/examples/server.py @@ -0,0 +1,4 @@ +from twisted.internet import reactor +from leap.common.events.server import ensure_server +reactor.callWhenRunning(ensure_server) +reactor.run() diff --git a/src/leap/common/events/server.py b/src/leap/common/events/server.py index a69202e..05fc23e 100644 --- a/src/leap/common/events/server.py +++ b/src/leap/common/events/server.py @@ -14,33 +14,31 @@ # # You should have received a copy of the GNU General Public License # along with this program. If not, see <http://www.gnu.org/licenses/>. - - """ The server for the events mechanism. """ - - import logging +import platform + import txzmq from leap.common.zmq_utils import zmq_has_curve - from leap.common.events.zmq_components import TxZmqServerComponent -if zmq_has_curve(): +if zmq_has_curve() or platform.system() == "Windows": + # Windows doesn't have ipc sockets, we need to use always tcp EMIT_ADDR = "tcp://127.0.0.1:9000" REG_ADDR = "tcp://127.0.0.1:9001" else: EMIT_ADDR = "ipc:///tmp/leap.common.events.socket.0" REG_ADDR = "ipc:///tmp/leap.common.events.socket.1" - logger = logging.getLogger(__name__) -def ensure_server(emit_addr=EMIT_ADDR, reg_addr=REG_ADDR): +def ensure_server(emit_addr=EMIT_ADDR, reg_addr=REG_ADDR, path_prefix=None, + factory=None, enable_curve=True): """ Make sure the server is running in the given addresses. @@ -52,7 +50,8 @@ def ensure_server(emit_addr=EMIT_ADDR, reg_addr=REG_ADDR): :return: an events server instance :rtype: EventsServer """ - _server = EventsServer(emit_addr, reg_addr) + _server = EventsServer(emit_addr, reg_addr, path_prefix, factory=factory, + enable_curve=enable_curve) return _server @@ -62,7 +61,8 @@ class EventsServer(TxZmqServerComponent): events in another address. """ - def __init__(self, emit_addr, reg_addr): + def __init__(self, emit_addr, reg_addr, path_prefix=None, factory=None, + enable_curve=True): """ Initialize the events server. @@ -71,7 +71,9 @@ class EventsServer(TxZmqServerComponent): :param reg_addr: The address to which publish events to clients. :type reg_addr: str """ - TxZmqServerComponent.__init__(self) + TxZmqServerComponent.__init__(self, path_prefix=path_prefix, + factory=factory, + enable_curve=enable_curve) # bind PULL and PUB sockets self._pull, self.pull_port = self._zmq_bind( txzmq.ZmqPullConnection, emit_addr) diff --git a/src/leap/common/events/tests/test_auth.py b/src/leap/common/events/tests/test_auth.py new file mode 100644 index 0000000..78ffd9f --- /dev/null +++ b/src/leap/common/events/tests/test_auth.py @@ -0,0 +1,64 @@ +# -*- coding: utf-8 -*- +# test_zmq_components.py +# Copyright (C) 2014 LEAP +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. +""" +Tests for the auth module. +""" +import os + +from twisted.trial import unittest +from txzmq import ZmqFactory + +from leap.common.events import auth +from leap.common.testing.basetest import BaseLeapTest +from leap.common.zmq_utils import PUBLIC_KEYS_PREFIX +from leap.common.zmq_utils import maybe_create_and_get_certificates + +from txzmq.test import _wait + + +class ZmqAuthTestCase(unittest.TestCase, BaseLeapTest): + + def setUp(self): + self.setUpEnv(launch_events_server=False) + + self.factory = ZmqFactory() + self._config_prefix = os.path.join(self.tempdir, "leap", "events") + + self.public, self.secret = maybe_create_and_get_certificates( + self._config_prefix, 'server') + + self.authenticator = auth.TxAuthenticator(self.factory) + self.authenticator.start() + self.auth_req = auth.TxAuthenticationRequest(self.factory) + + def tearDown(self): + self.factory.shutdown() + self.tearDownEnv() + + def test_curve_auth(self): + self.auth_req.start() + self.auth_req.allow('127.0.0.1') + public_keys_dir = os.path.join(self._config_prefix, PUBLIC_KEYS_PREFIX) + self.auth_req.configure_curve(domain="*", location=public_keys_dir) + + def check(ignored): + authenticator = self.authenticator.authenticator + certs = authenticator.certs['*'] + self.failUnlessEqual(authenticator.whitelist, set([u'127.0.0.1'])) + self.failUnlessEqual(certs[certs.keys()[0]], True) + + return _wait(0.1).addCallback(check) diff --git a/src/leap/common/events/tests/test_events.py b/src/leap/common/events/tests/test_events.py new file mode 100644 index 0000000..d8435c6 --- /dev/null +++ b/src/leap/common/events/tests/test_events.py @@ -0,0 +1,203 @@ +# -*- coding: utf-8 -*- +# test_events.py +# Copyright (C) 2013 LEAP +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. +""" +Tests for the events framework +""" +import os +import logging + +from twisted.internet.reactor import callFromThread +from twisted.trial import unittest +from twisted.internet import defer + +from txzmq import ZmqFactory + +from leap.common.events import server +from leap.common.events import client +from leap.common.events import flags +from leap.common.events import txclient +from leap.common.events import catalog +from leap.common.events.errors import CallbackAlreadyRegisteredError + + +if 'DEBUG' in os.environ: + logging.basicConfig(level=logging.DEBUG) + + +class EventsGenericClientTestCase(object): + + def setUp(self): + flags.set_events_enabled(True) + self.factory = ZmqFactory() + self._server = server.ensure_server( + emit_addr="tcp://127.0.0.1:0", + reg_addr="tcp://127.0.0.1:0", + factory=self.factory, + enable_curve=False) + + self._client.configure_client( + emit_addr="tcp://127.0.0.1:%d" % self._server.pull_port, + reg_addr="tcp://127.0.0.1:%d" % self._server.pub_port, + factory=self.factory, enable_curve=False) + + def tearDown(self): + flags.set_events_enabled(False) + self.factory.shutdown() + self._client.instance().reset() + + def test_client_register(self): + """ + Ensure clients can register callbacks. + """ + callbacks = self._client.instance().callbacks + self.assertTrue(len(callbacks) == 0, + 'There should be no callback for this event.') + # register one event + event1 = catalog.CLIENT_UID + + def cbk1(event, _): + return True + + uid1 = self._client.register(event1, cbk1) + # assert for correct registration + self.assertTrue(len(callbacks) == 1) + self.assertTrue(callbacks[event1][uid1] == cbk1, + 'Could not register event in local client.') + # register another event + event2 = catalog.CLIENT_SESSION_ID + + def cbk2(event, _): + return True + + uid2 = self._client.register(event2, cbk2) + # assert for correct registration + self.assertTrue(len(callbacks) == 2) + self.assertTrue(callbacks[event2][uid2] == cbk2, + 'Could not register event in local client.') + + def test_register_signal_replace(self): + """ + Make sure clients can replace already registered callbacks. + """ + event = catalog.CLIENT_UID + d = defer.Deferred() + + def cbk_fail(event, _): + return callFromThread(d.errback, event) + + def cbk_succeed(event, _): + return callFromThread(d.callback, event) + + self._client.register(event, cbk_fail, uid=1) + self._client.register(event, cbk_succeed, uid=1, replace=True) + self._client.emit(event, None) + return d + + def test_register_signal_replace_fails_when_replace_is_false(self): + """ + Make sure clients trying to replace already registered callbacks fail + when replace=False + """ + event = catalog.CLIENT_UID + self._client.register(event, lambda event, _: None, uid=1) + self.assertRaises( + CallbackAlreadyRegisteredError, + self._client.register, + event, lambda event, _: None, uid=1, replace=False) + + def test_register_more_than_one_callback_works(self): + """ + Make sure clients can replace already registered callbacks. + """ + event = catalog.CLIENT_UID + d1 = defer.Deferred() + + def cbk1(event, _): + return callFromThread(d1.callback, event) + + d2 = defer.Deferred() + + def cbk2(event, _): + return d2.callback(event) + + self._client.register(event, cbk1) + self._client.register(event, cbk2) + self._client.emit(event, None) + d = defer.gatherResults([d1, d2]) + return d + + def test_client_receives_signal(self): + """ + Ensure clients can receive signals. + """ + event = catalog.CLIENT_UID + d = defer.Deferred() + + def cbk(events, _): + callFromThread(d.callback, event) + + self._client.register(event, cbk) + self._client.emit(event, None) + return d + + def test_client_unregister_all(self): + """ + Test that the client can unregister all events for one signal. + """ + event1 = catalog.CLIENT_UID + d = defer.Deferred() + # register more than one callback for the same event + self._client.register( + event1, lambda ev, _: callFromThread(d.errback, None)) + self._client.register( + event1, lambda ev, _: callFromThread(d.errback, None)) + # unregister and emit the event + self._client.unregister(event1) + self._client.emit(event1, None) + # register and emit another event so the deferred can succeed + event2 = catalog.CLIENT_SESSION_ID + self._client.register( + event2, lambda ev, _: callFromThread(d.callback, None)) + self._client.emit(event2, None) + return d + + def test_client_unregister_by_uid(self): + """ + Test that the client can unregister an event by uid. + """ + event = catalog.CLIENT_UID + d = defer.Deferred() + # register one callback that would fail + uid = self._client.register( + event, lambda ev, _: callFromThread(d.errback, None)) + # register one callback that will succeed + self._client.register( + event, lambda ev, _: callFromThread(d.callback, None)) + # unregister by uid and emit the event + self._client.unregister(event, uid=uid) + self._client.emit(event, None) + return d + + +class EventsTxClientTestCase(EventsGenericClientTestCase, unittest.TestCase): + + _client = txclient + + +class EventsClientTestCase(EventsGenericClientTestCase, unittest.TestCase): + + _client = client diff --git a/src/leap/common/events/txclient.py b/src/leap/common/events/txclient.py index dfd0533..63f12d7 100644 --- a/src/leap/common/events/txclient.py +++ b/src/leap/common/events/txclient.py @@ -58,16 +58,19 @@ class EventsTxClient(TxZmqClientComponent, EventsClient): """ def __init__(self, emit_addr=EMIT_ADDR, reg_addr=REG_ADDR, - path_prefix=None): + path_prefix=None, factory=None, enable_curve=True): """ - Initialize the events server. + Initialize the events client. """ - TxZmqClientComponent.__init__(self, path_prefix=path_prefix) + TxZmqClientComponent.__init__( + self, path_prefix=path_prefix, factory=factory, + enable_curve=enable_curve) EventsClient.__init__(self, emit_addr, reg_addr) # connect SUB first, otherwise we might miss some event sent from this # same client self._sub = self._zmq_connect(txzmq.ZmqSubConnection, reg_addr) self._sub.gotMessage = self._gotMessage + self._push = self._zmq_connect(txzmq.ZmqPushConnection, emit_addr) def _gotMessage(self, msg, tag): @@ -122,7 +125,6 @@ class EventsTxClient(TxZmqClientComponent, EventsClient): callback(event, *content) def shutdown(self): - TxZmqClientComponent.shutdown(self) EventsClient.shutdown(self) diff --git a/src/leap/common/events/zmq_components.py b/src/leap/common/events/zmq_components.py index 51de02c..c533a74 100644 --- a/src/leap/common/events/zmq_components.py +++ b/src/leap/common/events/zmq_components.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- # zmq.py -# Copyright (C) 2015 LEAP +# Copyright (C) 2015, 2016 LEAP # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by @@ -14,60 +14,63 @@ # # You should have received a copy of the GNU General Public License # along with this program. If not, see <http://www.gnu.org/licenses/>. - - """ The server for the events mechanism. """ - - import os import logging import txzmq import re -import time from abc import ABCMeta -# XXX some distros don't package libsodium, so we have to be prepared for -# absence of zmq.auth try: import zmq.auth - from zmq.auth.thread import ThreadAuthenticator + from leap.common.events.auth import TxAuthenticator + from leap.common.events.auth import TxAuthenticationRequest except ImportError: pass +from txzmq.connection import ZmqEndpoint, ZmqEndpointType + from leap.common.config import flags, get_path_prefix from leap.common.zmq_utils import zmq_has_curve from leap.common.zmq_utils import maybe_create_and_get_certificates from leap.common.zmq_utils import PUBLIC_KEYS_PREFIX - logger = logging.getLogger(__name__) - ADDRESS_RE = re.compile("^([a-z]+)://([^:]+):?(\d+)?$") +LOCALHOST_ALLOWED = '127.0.0.1' + class TxZmqComponent(object): """ A twisted-powered zmq events component. """ + _factory = txzmq.ZmqFactory() + _factory.registerForShutdown() + _auth = None __metaclass__ = ABCMeta _component_type = None - def __init__(self, path_prefix=None): + def __init__(self, path_prefix=None, enable_curve=True, factory=None): """ Initialize the txzmq component. """ - self._factory = txzmq.ZmqFactory() - self._factory.registerForShutdown() if path_prefix is None: path_prefix = get_path_prefix(flags.STANDALONE) + if factory is not None: + self._factory = factory self._config_prefix = os.path.join(path_prefix, "leap", "events") self._connections = [] + if enable_curve: + self.use_curve = zmq_has_curve() + else: + self.use_curve = False @property def component_type(self): @@ -77,105 +80,89 @@ class TxZmqComponent(object): "define a self._component_type!") return self._component_type - def _zmq_connect(self, connClass, address): + def _zmq_bind(self, connClass, address): """ - Connect to an address. + Bind to an address. :param connClass: The connection class to be used. :type connClass: txzmq.ZmqConnection - :param address: The address to connect to. + :param address: The address to bind to. :type address: str - :return: The binded connection. - :rtype: txzmq.ZmqConnection + :return: The binded connection and port. + :rtype: (txzmq.ZmqConnection, int) """ + proto, addr, port = ADDRESS_RE.search(address).groups() + + endpoint = ZmqEndpoint(ZmqEndpointType.bind, address) connection = connClass(self._factory) - # create and configure socket - socket = connection.socket - if zmq_has_curve(): + + if self.use_curve: + socket = connection.socket + public, secret = maybe_create_and_get_certificates( self._config_prefix, self.component_type) - server_public_file = os.path.join( - self._config_prefix, PUBLIC_KEYS_PREFIX, "server.key") - server_public, _ = zmq.auth.load_certificate(server_public_file) socket.curve_publickey = public socket.curve_secretkey = secret - socket.curve_serverkey = server_public - socket.connect(address) - logger.debug("Connected %s to %s." % (connClass, address)) - self._connections.append(connection) - return connection + self._start_authentication(connection.socket) - def _zmq_bind(self, connClass, address): + if proto == 'tcp' and int(port) == 0: + connection.endpoints.extend([endpoint]) + port = connection.socket.bind_to_random_port('tcp://%s' % addr) + else: + connection.addEndpoints([endpoint]) + + return connection, int(port) + + def _zmq_connect(self, connClass, address): """ - Bind to an address. + Connect to an address. :param connClass: The connection class to be used. :type connClass: txzmq.ZmqConnection - :param address: The address to bind to. + :param address: The address to connect to. :type address: str - :return: The binded connection and port. - :rtype: (txzmq.ZmqConnection, int) + :return: The binded connection. + :rtype: txzmq.ZmqConnection """ + endpoint = ZmqEndpoint(ZmqEndpointType.connect, address) connection = connClass(self._factory) - socket = connection.socket - if zmq_has_curve(): + + if self.use_curve: + socket = connection.socket public, secret = maybe_create_and_get_certificates( self._config_prefix, self.component_type) + server_public_file = os.path.join( + self._config_prefix, PUBLIC_KEYS_PREFIX, "server.key") + + server_public, _ = zmq.auth.load_certificate(server_public_file) socket.curve_publickey = public socket.curve_secretkey = secret - self._start_thread_auth(connection.socket) + socket.curve_serverkey = server_public - proto, addr, port = ADDRESS_RE.search(address).groups() + connection.addEndpoints([endpoint]) + return connection - if proto == "tcp": - if port is None or port is '0': - params = proto, addr - port = socket.bind_to_random_port("%s://%s" % params) - logger.debug("Binded %s to %s://%s." % ((connClass,) + params)) - else: - params = proto, addr, int(port) - socket.bind("%s://%s:%d" % params) - logger.debug( - "Binded %s to %s://%s:%d." % ((connClass,) + params)) - else: - params = proto, addr - socket.bind("%s://%s" % params) - logger.debug( - "Binded %s to %s://%s" % ((connClass,) + params)) - self._connections.append(connection) - return connection, port - - def _start_thread_auth(self, socket): - """ - Start the zmq curve thread authenticator. + def _start_authentication(self, socket): - :param socket: The socket in which to configure the authenticator. - :type socket: zmq.Socket - """ - authenticator = ThreadAuthenticator(self._factory.context) + if not TxZmqComponent._auth: + TxZmqComponent._auth = TxAuthenticator(self._factory) + TxZmqComponent._auth.start() - # Temporary fix until we understand what the problem is - # See https://leap.se/code/issues/7536 - time.sleep(0.5) + auth_req = TxAuthenticationRequest(self._factory) + auth_req.start() + auth_req.allow(LOCALHOST_ALLOWED) - authenticator.start() - # XXX do not hardcode this here. - authenticator.allow('127.0.0.1') # tell authenticator to use the certificate in a directory public_keys_dir = os.path.join(self._config_prefix, PUBLIC_KEYS_PREFIX) - authenticator.configure_curve(domain="*", location=public_keys_dir) - socket.curve_server = True # must come before bind + auth_req.configure_curve(domain="*", location=public_keys_dir) + auth_req.shutdown() + TxZmqComponent._auth.shutdown() - def shutdown(self): - """ - Shutdown the component. - """ - logger.debug("Shutting down component %s." % str(self)) - for conn in self._connections: - conn.shutdown() - self._factory.shutdown() + # This has to be set before binding the socket, that's why this method + # has to be called before addEndpoints() + socket.curve_server = True class TxZmqServerComponent(TxZmqComponent): |