diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/leap/common/events/auth.py | 4 | ||||
| -rw-r--r-- | src/leap/common/events/client.py | 23 | ||||
| -rw-r--r-- | src/leap/common/events/server.py | 13 | ||||
| -rw-r--r-- | src/leap/common/events/tests/test_auth.py | 64 | ||||
| -rw-r--r-- | src/leap/common/events/tests/test_events.py (renamed from src/leap/common/tests/test_events.py) | 24 | ||||
| -rw-r--r-- | src/leap/common/events/txclient.py | 8 | ||||
| -rw-r--r-- | src/leap/common/events/zmq_components.py | 58 | ||||
| -rw-r--r-- | src/leap/common/testing/basetest.py | 9 | 
8 files changed, 151 insertions, 52 deletions
| diff --git a/src/leap/common/events/auth.py b/src/leap/common/events/auth.py index 1a1bcab..5b71f2d 100644 --- a/src/leap/common/events/auth.py +++ b/src/leap/common/events/auth.py @@ -38,8 +38,8 @@ class TxAuthenticator(ZmqConnection):      address = 'inproc://zeromq.zap.01'      encoding = 'utf-8' -    def __init__(self, factory): -        super(TxAuthenticator, self).__init__(factory) +    def __init__(self, factory, *args, **kw): +        super(TxAuthenticator, self).__init__(factory, *args, **kw)          self.authenticator = Authenticator(factory.context)          self.authenticator._send_zap_reply = self._send_zap_reply diff --git a/src/leap/common/events/client.py b/src/leap/common/events/client.py index 60d24bc..78617de 100644 --- a/src/leap/common/events/client.py +++ b/src/leap/common/events/client.py @@ -63,14 +63,18 @@ logger = logging.getLogger(__name__)  _emit_addr = EMIT_ADDR  _reg_addr = REG_ADDR +_factory = None +_enable_curve = True -def configure_client(emit_addr, reg_addr): -    global _emit_addr, _reg_addr +def configure_client(emit_addr, reg_addr, factory=None, enable_curve=True): +    global _emit_addr, _reg_addr, _factory, _enable_curve      logger.debug("Configuring client with addresses: (%s, %s)" %                   (emit_addr, reg_addr))      _emit_addr = emit_addr      _reg_addr = reg_addr +    _factory = factory +    _enable_curve = enable_curve  class EventsClient(object): @@ -103,7 +107,9 @@ class EventsClient(object):          """          with cls._instance_lock:              if cls._instance is None: -                cls._instance = cls(_emit_addr, _reg_addr) +                cls._instance = cls( +                    _emit_addr, _reg_addr, factory=_factory, +                    enable_curve=_enable_curve)          return cls._instance      def register(self, event, callback, uid=None, replace=False): @@ -270,7 +276,7 @@ class EventsClientThread(threading.Thread, EventsClient):      A threaded version of the events client.      """ -    def __init__(self, emit_addr, reg_addr): +    def __init__(self, emit_addr, reg_addr, factory=None, enable_curve=True):          """          Initialize the events client.          """ @@ -281,15 +287,22 @@ class EventsClientThread(threading.Thread, EventsClient):          self._config_prefix = os.path.join(              get_path_prefix(flags.STANDALONE), "leap", "events")          self._loop = None +        self._factory = factory          self._context = None          self._push = None          self._sub = None +        if enable_curve: +            self.use_curve = zmq_has_curve() +        else: +            self.use_curve = False +      def _init_zmq(self):          """          Initialize ZMQ connections.          """          self._loop = EventsIOLoop() +        # we need a new context for each thread          self._context = zmq.Context()          # connect SUB first, otherwise we might miss some event sent from this          # same client @@ -311,7 +324,7 @@ class EventsClientThread(threading.Thread, EventsClient):          logger.debug("Connecting %s to %s." % (socktype, address))          socket = self._context.socket(socktype)          # configure curve authentication -        if zmq_has_curve(): +        if self.use_curve:              public, private = maybe_create_and_get_certificates(                  self._config_prefix, "client")              server_public_file = os.path.join( diff --git a/src/leap/common/events/server.py b/src/leap/common/events/server.py index 6252853..ad79abe 100644 --- a/src/leap/common/events/server.py +++ b/src/leap/common/events/server.py @@ -37,7 +37,8 @@ else:  logger = logging.getLogger(__name__) -def ensure_server(emit_addr=EMIT_ADDR, reg_addr=REG_ADDR): +def ensure_server(emit_addr=EMIT_ADDR, reg_addr=REG_ADDR, path_prefix=None, +                  factory=None, enable_curve=True):      """      Make sure the server is running in the given addresses. @@ -49,7 +50,8 @@ def ensure_server(emit_addr=EMIT_ADDR, reg_addr=REG_ADDR):      :return: an events server instance      :rtype: EventsServer      """ -    _server = EventsServer(emit_addr, reg_addr) +    _server = EventsServer(emit_addr, reg_addr, path_prefix, factory=factory, +            enable_curve=enable_curve)      return _server @@ -59,7 +61,8 @@ class EventsServer(TxZmqServerComponent):      events in another address.      """ -    def __init__(self, emit_addr, reg_addr): +    def __init__(self, emit_addr, reg_addr, path_prefix=None, factory=None, +                 enable_curve=True):          """          Initialize the events server. @@ -68,7 +71,9 @@ class EventsServer(TxZmqServerComponent):          :param reg_addr: The address to which publish events to clients.          :type reg_addr: str          """ -        TxZmqServerComponent.__init__(self) +        TxZmqServerComponent.__init__(self, path_prefix=path_prefix, +                                      factory=factory, +                                      enable_curve=enable_curve)          # bind PULL and PUB sockets          self._pull, self.pull_port = self._zmq_bind(              txzmq.ZmqPullConnection, emit_addr) diff --git a/src/leap/common/events/tests/test_auth.py b/src/leap/common/events/tests/test_auth.py new file mode 100644 index 0000000..78ffd9f --- /dev/null +++ b/src/leap/common/events/tests/test_auth.py @@ -0,0 +1,64 @@ +# -*- coding: utf-8 -*- +# test_zmq_components.py +# Copyright (C) 2014 LEAP +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program.  If not, see <http://www.gnu.org/licenses/>. +""" +Tests for the auth module. +""" +import os + +from twisted.trial import unittest +from txzmq import ZmqFactory + +from leap.common.events import auth +from leap.common.testing.basetest import BaseLeapTest +from leap.common.zmq_utils import PUBLIC_KEYS_PREFIX +from leap.common.zmq_utils import maybe_create_and_get_certificates + +from txzmq.test import _wait + + +class ZmqAuthTestCase(unittest.TestCase, BaseLeapTest): + +    def setUp(self): +        self.setUpEnv(launch_events_server=False) + +        self.factory = ZmqFactory() +        self._config_prefix = os.path.join(self.tempdir, "leap", "events") + +        self.public, self.secret = maybe_create_and_get_certificates( +            self._config_prefix, 'server') + +        self.authenticator = auth.TxAuthenticator(self.factory) +        self.authenticator.start() +        self.auth_req = auth.TxAuthenticationRequest(self.factory) + +    def tearDown(self): +        self.factory.shutdown() +        self.tearDownEnv() + +    def test_curve_auth(self): +        self.auth_req.start() +        self.auth_req.allow('127.0.0.1') +        public_keys_dir = os.path.join(self._config_prefix, PUBLIC_KEYS_PREFIX) +        self.auth_req.configure_curve(domain="*", location=public_keys_dir) + +        def check(ignored): +            authenticator = self.authenticator.authenticator +            certs = authenticator.certs['*'] +            self.failUnlessEqual(authenticator.whitelist, set([u'127.0.0.1'])) +            self.failUnlessEqual(certs[certs.keys()[0]], True) + +        return _wait(0.1).addCallback(check) diff --git a/src/leap/common/tests/test_events.py b/src/leap/common/events/tests/test_events.py index 2ad097e..c45601b 100644 --- a/src/leap/common/tests/test_events.py +++ b/src/leap/common/events/tests/test_events.py @@ -14,16 +14,18 @@  #  # You should have received a copy of the GNU General Public License  # along with this program. If not, see <http://www.gnu.org/licenses/>. - - +""" +Tests for the events framework +"""  import os  import logging -import time  from twisted.internet.reactor import callFromThread  from twisted.trial import unittest  from twisted.internet import defer +from txzmq import ZmqFactory +  from leap.common.events import server  from leap.common.events import client  from leap.common.events import flags @@ -40,19 +42,22 @@ class EventsGenericClientTestCase(object):      def setUp(self):          flags.set_events_enabled(True) +        self.factory = ZmqFactory()          self._server = server.ensure_server(              emit_addr="tcp://127.0.0.1:0", -            reg_addr="tcp://127.0.0.1:0") +            reg_addr="tcp://127.0.0.1:0", +            factory=self.factory, +            enable_curve=False) +          self._client.configure_client(              emit_addr="tcp://127.0.0.1:%d" % self._server.pull_port, -            reg_addr="tcp://127.0.0.1:%d" % self._server.pub_port) +            reg_addr="tcp://127.0.0.1:%d" % self._server.pub_port, +            factory=self.factory, enable_curve=False)      def tearDown(self): -        self._client.shutdown() -        self._server.shutdown()          flags.set_events_enabled(False) -        # wait a bit for sockets to close properly -        time.sleep(0.1) +        self.factory.shutdown() +        self._client.instance().reset()      def test_client_register(self):          """ @@ -84,6 +89,7 @@ class EventsGenericClientTestCase(object):          self.assertTrue(callbacks[event2][uid2] == cbk2,                          'Could not register event in local client.') +      def test_register_signal_replace(self):          """          Make sure clients can replace already registered callbacks. diff --git a/src/leap/common/events/txclient.py b/src/leap/common/events/txclient.py index a2b704d..63f12d7 100644 --- a/src/leap/common/events/txclient.py +++ b/src/leap/common/events/txclient.py @@ -58,11 +58,13 @@ class EventsTxClient(TxZmqClientComponent, EventsClient):      """      def __init__(self, emit_addr=EMIT_ADDR, reg_addr=REG_ADDR, -                 path_prefix=None): +                 path_prefix=None, factory=None, enable_curve=True):          """ -        Initialize the events server. +        Initialize the events client.          """ -        TxZmqClientComponent.__init__(self, path_prefix=path_prefix) +        TxZmqClientComponent.__init__( +            self, path_prefix=path_prefix, factory=factory, +            enable_curve=enable_curve)          EventsClient.__init__(self, emit_addr, reg_addr)          # connect SUB first, otherwise we might miss some event sent from this          # same client diff --git a/src/leap/common/events/zmq_components.py b/src/leap/common/events/zmq_components.py index 74abb76..8919cd9 100644 --- a/src/leap/common/events/zmq_components.py +++ b/src/leap/common/events/zmq_components.py @@ -57,12 +57,14 @@ class TxZmqComponent(object):      _component_type = None -    def __init__(self, path_prefix=None, enable_curve=True): +    def __init__(self, path_prefix=None, enable_curve=True, factory=None):          """          Initialize the txzmq component.          """          if path_prefix is None:              path_prefix = get_path_prefix(flags.STANDALONE) +        if factory is not None: +            self._factory = factory          self._config_prefix = os.path.join(path_prefix, "leap", "events")          self._connections = []          if enable_curve: @@ -78,64 +80,69 @@ class TxZmqComponent(object):                  "define a self._component_type!")          return self._component_type -    def _zmq_connect(self, connClass, address): +    def _zmq_bind(self, connClass, address):          """ -        Connect to an address. +        Bind to an address.          :param connClass: The connection class to be used.          :type connClass: txzmq.ZmqConnection -        :param address: The address to connect to. +        :param address: The address to bind to.          :type address: str -        :return: The binded connection. -        :rtype: txzmq.ZmqConnection +        :return: The binded connection and port. +        :rtype: (txzmq.ZmqConnection, int)          """ -        endpoint = ZmqEndpoint(ZmqEndpointType.connect, address) +        proto, addr, port = ADDRESS_RE.search(address).groups() + +        endpoint = ZmqEndpoint(ZmqEndpointType.bind, address)          connection = connClass(self._factory)          if self.use_curve:              socket = connection.socket +              public, secret = maybe_create_and_get_certificates(                  self._config_prefix, self.component_type) -            server_public_file = os.path.join( -                self._config_prefix, PUBLIC_KEYS_PREFIX, "server.key") - -            server_public, _ = zmq.auth.load_certificate(server_public_file)              socket.curve_publickey = public              socket.curve_secretkey = secret -            socket.curve_serverkey = server_public +            self._start_authentication(connection.socket) -        connection.addEndpoints([endpoint]) -        return connection +        if proto == 'tcp' and int(port) == 0: +            connection.endpoints.extend([endpoint]) +            port = connection.socket.bind_to_random_port('tcp://%s' % addr) +        else: +            connection.addEndpoints([endpoint]) -    def _zmq_bind(self, connClass, address): +        return connection, int(port) + +    def _zmq_connect(self, connClass, address):          """ -        Bind to an address. +        Connect to an address.          :param connClass: The connection class to be used.          :type connClass: txzmq.ZmqConnection -        :param address: The address to bind to. +        :param address: The address to connect to.          :type address: str -        :return: The binded connection and port. -        :rtype: (txzmq.ZmqConnection, int) +        :return: The binded connection. +        :rtype: txzmq.ZmqConnection          """ -        proto, addr, port = ADDRESS_RE.search(address).groups() - -        endpoint = ZmqEndpoint(ZmqEndpointType.bind, address) +        endpoint = ZmqEndpoint(ZmqEndpointType.connect, address)          connection = connClass(self._factory)          if self.use_curve:              socket = connection.socket -              public, secret = maybe_create_and_get_certificates(                  self._config_prefix, self.component_type) +            server_public_file = os.path.join( +                self._config_prefix, PUBLIC_KEYS_PREFIX, "server.key") + +            server_public, _ = zmq.auth.load_certificate(server_public_file)              socket.curve_publickey = public              socket.curve_secretkey = secret -            self._start_authentication(connection.socket) +            socket.curve_serverkey = server_public          connection.addEndpoints([endpoint]) -        return connection, port +        return connection      def _start_authentication(self, socket): @@ -150,6 +157,7 @@ class TxZmqComponent(object):          # tell authenticator to use the certificate in a directory          public_keys_dir = os.path.join(self._config_prefix, PUBLIC_KEYS_PREFIX)          auth_req.configure_curve(domain="*", location=public_keys_dir) +        auth_req.shutdown()          # This has to be set before binding the socket, that's why this method          # has to be called before addEndpoints() diff --git a/src/leap/common/testing/basetest.py b/src/leap/common/testing/basetest.py index 3d3cee0..2e84a25 100644 --- a/src/leap/common/testing/basetest.py +++ b/src/leap/common/testing/basetest.py @@ -52,7 +52,7 @@ class BaseLeapTest(unittest.TestCase):          cls.tearDownEnv()      @classmethod -    def setUpEnv(cls): +    def setUpEnv(cls, launch_events_server=True):          """          Sets up common facilities for testing this TestCase:          - custom PATH and HOME environmental variables @@ -72,14 +72,15 @@ class BaseLeapTest(unittest.TestCase):          os.environ["PATH"] = bin_tdir          os.environ["HOME"] = cls.tempdir          os.environ["XDG_CONFIG_HOME"] = os.path.join(cls.tempdir, ".config") -        cls._init_events() +        if launch_events_server: +            cls._init_events()      @classmethod      def _init_events(cls):          if flags.EVENTS_ENABLED:              cls._server = events_server.ensure_server( -                emit_addr="tcp://127.0.0.1:0", -                reg_addr="tcp://127.0.0.1:0") +                emit_addr="tcp://127.0.0.1", +                reg_addr="tcp://127.0.0.1")              events_client.configure_client(                  emit_addr="tcp://127.0.0.1:%d" % cls._server.pull_port,                  reg_addr="tcp://127.0.0.1:%d" % cls._server.pub_port) | 
