[tests] adapt events tests to recent changes
authorKali Kaneko <kali@leap.se>
Mon, 29 Feb 2016 23:33:28 +0000 (19:33 -0400)
committerKali Kaneko <kali@leap.se>
Mon, 29 Feb 2016 23:39:43 +0000 (19:39 -0400)
src/leap/common/events/auth.py
src/leap/common/events/client.py
src/leap/common/events/server.py
src/leap/common/events/tests/test_auth.py [new file with mode: 0644]
src/leap/common/events/tests/test_events.py [moved from src/leap/common/tests/test_events.py with 94% similarity]
src/leap/common/events/txclient.py
src/leap/common/events/zmq_components.py
src/leap/common/testing/basetest.py

index 1a1bcab..5b71f2d 100644 (file)
@@ -38,8 +38,8 @@ class TxAuthenticator(ZmqConnection):
     address = 'inproc://zeromq.zap.01'
     encoding = 'utf-8'
 
-    def __init__(self, factory):
-        super(TxAuthenticator, self).__init__(factory)
+    def __init__(self, factory, *args, **kw):
+        super(TxAuthenticator, self).__init__(factory, *args, **kw)
         self.authenticator = Authenticator(factory.context)
         self.authenticator._send_zap_reply = self._send_zap_reply
 
index 60d24bc..78617de 100644 (file)
@@ -63,14 +63,18 @@ logger = logging.getLogger(__name__)
 
 _emit_addr = EMIT_ADDR
 _reg_addr = REG_ADDR
+_factory = None
+_enable_curve = True
 
 
-def configure_client(emit_addr, reg_addr):
-    global _emit_addr, _reg_addr
+def configure_client(emit_addr, reg_addr, factory=None, enable_curve=True):
+    global _emit_addr, _reg_addr, _factory, _enable_curve
     logger.debug("Configuring client with addresses: (%s, %s)" %
                  (emit_addr, reg_addr))
     _emit_addr = emit_addr
     _reg_addr = reg_addr
+    _factory = factory
+    _enable_curve = enable_curve
 
 
 class EventsClient(object):
@@ -103,7 +107,9 @@ class EventsClient(object):
         """
         with cls._instance_lock:
             if cls._instance is None:
-                cls._instance = cls(_emit_addr, _reg_addr)
+                cls._instance = cls(
+                    _emit_addr, _reg_addr, factory=_factory,
+                    enable_curve=_enable_curve)
         return cls._instance
 
     def register(self, event, callback, uid=None, replace=False):
@@ -270,7 +276,7 @@ class EventsClientThread(threading.Thread, EventsClient):
     A threaded version of the events client.
     """
 
-    def __init__(self, emit_addr, reg_addr):
+    def __init__(self, emit_addr, reg_addr, factory=None, enable_curve=True):
         """
         Initialize the events client.
         """
@@ -281,15 +287,22 @@ class EventsClientThread(threading.Thread, EventsClient):
         self._config_prefix = os.path.join(
             get_path_prefix(flags.STANDALONE), "leap", "events")
         self._loop = None
+        self._factory = factory
         self._context = None
         self._push = None
         self._sub = None
 
+        if enable_curve:
+            self.use_curve = zmq_has_curve()
+        else:
+            self.use_curve = False
+
     def _init_zmq(self):
         """
         Initialize ZMQ connections.
         """
         self._loop = EventsIOLoop()
+        # we need a new context for each thread
         self._context = zmq.Context()
         # connect SUB first, otherwise we might miss some event sent from this
         # same client
@@ -311,7 +324,7 @@ class EventsClientThread(threading.Thread, EventsClient):
         logger.debug("Connecting %s to %s." % (socktype, address))
         socket = self._context.socket(socktype)
         # configure curve authentication
-        if zmq_has_curve():
+        if self.use_curve:
             public, private = maybe_create_and_get_certificates(
                 self._config_prefix, "client")
             server_public_file = os.path.join(
index 6252853..ad79abe 100644 (file)
@@ -37,7 +37,8 @@ else:
 logger = logging.getLogger(__name__)
 
 
-def ensure_server(emit_addr=EMIT_ADDR, reg_addr=REG_ADDR):
+def ensure_server(emit_addr=EMIT_ADDR, reg_addr=REG_ADDR, path_prefix=None,
+                  factory=None, enable_curve=True):
     """
     Make sure the server is running in the given addresses.
 
@@ -49,7 +50,8 @@ def ensure_server(emit_addr=EMIT_ADDR, reg_addr=REG_ADDR):
     :return: an events server instance
     :rtype: EventsServer
     """
-    _server = EventsServer(emit_addr, reg_addr)
+    _server = EventsServer(emit_addr, reg_addr, path_prefix, factory=factory,
+            enable_curve=enable_curve)
     return _server
 
 
@@ -59,7 +61,8 @@ class EventsServer(TxZmqServerComponent):
     events in another address.
     """
 
-    def __init__(self, emit_addr, reg_addr):
+    def __init__(self, emit_addr, reg_addr, path_prefix=None, factory=None,
+                 enable_curve=True):
         """
         Initialize the events server.
 
@@ -68,7 +71,9 @@ class EventsServer(TxZmqServerComponent):
         :param reg_addr: The address to which publish events to clients.
         :type reg_addr: str
         """
-        TxZmqServerComponent.__init__(self)
+        TxZmqServerComponent.__init__(self, path_prefix=path_prefix,
+                                      factory=factory,
+                                      enable_curve=enable_curve)
         # bind PULL and PUB sockets
         self._pull, self.pull_port = self._zmq_bind(
             txzmq.ZmqPullConnection, emit_addr)
diff --git a/src/leap/common/events/tests/test_auth.py b/src/leap/common/events/tests/test_auth.py
new file mode 100644 (file)
index 0000000..78ffd9f
--- /dev/null
@@ -0,0 +1,64 @@
+# -*- coding: utf-8 -*-
+# test_zmq_components.py
+# Copyright (C) 2014 LEAP
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program.  If not, see <http://www.gnu.org/licenses/>.
+"""
+Tests for the auth module.
+"""
+import os
+
+from twisted.trial import unittest
+from txzmq import ZmqFactory
+
+from leap.common.events import auth
+from leap.common.testing.basetest import BaseLeapTest
+from leap.common.zmq_utils import PUBLIC_KEYS_PREFIX
+from leap.common.zmq_utils import maybe_create_and_get_certificates
+
+from txzmq.test import _wait
+
+
+class ZmqAuthTestCase(unittest.TestCase, BaseLeapTest):
+
+    def setUp(self):
+        self.setUpEnv(launch_events_server=False)
+
+        self.factory = ZmqFactory()
+        self._config_prefix = os.path.join(self.tempdir, "leap", "events")
+
+        self.public, self.secret = maybe_create_and_get_certificates(
+            self._config_prefix, 'server')
+
+        self.authenticator = auth.TxAuthenticator(self.factory)
+        self.authenticator.start()
+        self.auth_req = auth.TxAuthenticationRequest(self.factory)
+
+    def tearDown(self):
+        self.factory.shutdown()
+        self.tearDownEnv()
+
+    def test_curve_auth(self):
+        self.auth_req.start()
+        self.auth_req.allow('127.0.0.1')
+        public_keys_dir = os.path.join(self._config_prefix, PUBLIC_KEYS_PREFIX)
+        self.auth_req.configure_curve(domain="*", location=public_keys_dir)
+
+        def check(ignored):
+            authenticator = self.authenticator.authenticator
+            certs = authenticator.certs['*']
+            self.failUnlessEqual(authenticator.whitelist, set([u'127.0.0.1']))
+            self.failUnlessEqual(certs[certs.keys()[0]], True)
+
+        return _wait(0.1).addCallback(check)
similarity index 94%
rename from src/leap/common/tests/test_events.py
rename to src/leap/common/events/tests/test_events.py
index 2ad097e..c45601b 100644 (file)
 #
 # You should have received a copy of the GNU General Public License
 # along with this program. If not, see <http://www.gnu.org/licenses/>.
-
-
+"""
+Tests for the events framework
+"""
 import os
 import logging
-import time
 
 from twisted.internet.reactor import callFromThread
 from twisted.trial import unittest
 from twisted.internet import defer
 
+from txzmq import ZmqFactory
+
 from leap.common.events import server
 from leap.common.events import client
 from leap.common.events import flags
@@ -40,19 +42,22 @@ class EventsGenericClientTestCase(object):
 
     def setUp(self):
         flags.set_events_enabled(True)
+        self.factory = ZmqFactory()
         self._server = server.ensure_server(
             emit_addr="tcp://127.0.0.1:0",
-            reg_addr="tcp://127.0.0.1:0")
+            reg_addr="tcp://127.0.0.1:0",
+            factory=self.factory,
+            enable_curve=False)
+
         self._client.configure_client(
             emit_addr="tcp://127.0.0.1:%d" % self._server.pull_port,
-            reg_addr="tcp://127.0.0.1:%d" % self._server.pub_port)
+            reg_addr="tcp://127.0.0.1:%d" % self._server.pub_port,
+            factory=self.factory, enable_curve=False)
 
     def tearDown(self):
-        self._client.shutdown()
-        self._server.shutdown()
         flags.set_events_enabled(False)
-        # wait a bit for sockets to close properly
-        time.sleep(0.1)
+        self.factory.shutdown()
+        self._client.instance().reset()
 
     def test_client_register(self):
         """
@@ -84,6 +89,7 @@ class EventsGenericClientTestCase(object):
         self.assertTrue(callbacks[event2][uid2] == cbk2,
                         'Could not register event in local client.')
 
+
     def test_register_signal_replace(self):
         """
         Make sure clients can replace already registered callbacks.
index a2b704d..63f12d7 100644 (file)
@@ -58,11 +58,13 @@ class EventsTxClient(TxZmqClientComponent, EventsClient):
     """
 
     def __init__(self, emit_addr=EMIT_ADDR, reg_addr=REG_ADDR,
-                 path_prefix=None):
+                 path_prefix=None, factory=None, enable_curve=True):
         """
-        Initialize the events server.
+        Initialize the events client.
         """
-        TxZmqClientComponent.__init__(self, path_prefix=path_prefix)
+        TxZmqClientComponent.__init__(
+            self, path_prefix=path_prefix, factory=factory,
+            enable_curve=enable_curve)
         EventsClient.__init__(self, emit_addr, reg_addr)
         # connect SUB first, otherwise we might miss some event sent from this
         # same client
index 74abb76..8919cd9 100644 (file)
@@ -57,12 +57,14 @@ class TxZmqComponent(object):
 
     _component_type = None
 
-    def __init__(self, path_prefix=None, enable_curve=True):
+    def __init__(self, path_prefix=None, enable_curve=True, factory=None):
         """
         Initialize the txzmq component.
         """
         if path_prefix is None:
             path_prefix = get_path_prefix(flags.STANDALONE)
+        if factory is not None:
+            self._factory = factory
         self._config_prefix = os.path.join(path_prefix, "leap", "events")
         self._connections = []
         if enable_curve:
@@ -78,64 +80,69 @@ class TxZmqComponent(object):
                 "define a self._component_type!")
         return self._component_type
 
-    def _zmq_connect(self, connClass, address):
+    def _zmq_bind(self, connClass, address):
         """
-        Connect to an address.
+        Bind to an address.
 
         :param connClass: The connection class to be used.
         :type connClass: txzmq.ZmqConnection
-        :param address: The address to connect to.
+        :param address: The address to bind to.
         :type address: str
 
-        :return: The binded connection.
-        :rtype: txzmq.ZmqConnection
+        :return: The binded connection and port.
+        :rtype: (txzmq.ZmqConnection, int)
         """
-        endpoint = ZmqEndpoint(ZmqEndpointType.connect, address)
+        proto, addr, port = ADDRESS_RE.search(address).groups()
+
+        endpoint = ZmqEndpoint(ZmqEndpointType.bind, address)
         connection = connClass(self._factory)
 
         if self.use_curve:
             socket = connection.socket
+
             public, secret = maybe_create_and_get_certificates(
                 self._config_prefix, self.component_type)
-            server_public_file = os.path.join(
-                self._config_prefix, PUBLIC_KEYS_PREFIX, "server.key")
-
-            server_public, _ = zmq.auth.load_certificate(server_public_file)
             socket.curve_publickey = public
             socket.curve_secretkey = secret
-            socket.curve_serverkey = server_public
+            self._start_authentication(connection.socket)
 
-        connection.addEndpoints([endpoint])
-        return connection
+        if proto == 'tcp' and int(port) == 0:
+            connection.endpoints.extend([endpoint])
+            port = connection.socket.bind_to_random_port('tcp://%s' % addr)
+        else:
+            connection.addEndpoints([endpoint])
 
-    def _zmq_bind(self, connClass, address):
+        return connection, int(port)
+
+    def _zmq_connect(self, connClass, address):
         """
-        Bind to an address.
+        Connect to an address.
 
         :param connClass: The connection class to be used.
         :type connClass: txzmq.ZmqConnection
-        :param address: The address to bind to.
+        :param address: The address to connect to.
         :type address: str
 
-        :return: The binded connection and port.
-        :rtype: (txzmq.ZmqConnection, int)
+        :return: The binded connection.
+        :rtype: txzmq.ZmqConnection
         """
-        proto, addr, port = ADDRESS_RE.search(address).groups()
-
-        endpoint = ZmqEndpoint(ZmqEndpointType.bind, address)
+        endpoint = ZmqEndpoint(ZmqEndpointType.connect, address)
         connection = connClass(self._factory)
 
         if self.use_curve:
             socket = connection.socket
-
             public, secret = maybe_create_and_get_certificates(
                 self._config_prefix, self.component_type)
+            server_public_file = os.path.join(
+                self._config_prefix, PUBLIC_KEYS_PREFIX, "server.key")
+
+            server_public, _ = zmq.auth.load_certificate(server_public_file)
             socket.curve_publickey = public
             socket.curve_secretkey = secret
-            self._start_authentication(connection.socket)
+            socket.curve_serverkey = server_public
 
         connection.addEndpoints([endpoint])
-        return connection, port
+        return connection
 
     def _start_authentication(self, socket):
 
@@ -150,6 +157,7 @@ class TxZmqComponent(object):
         # tell authenticator to use the certificate in a directory
         public_keys_dir = os.path.join(self._config_prefix, PUBLIC_KEYS_PREFIX)
         auth_req.configure_curve(domain="*", location=public_keys_dir)
+        auth_req.shutdown()
 
         # This has to be set before binding the socket, that's why this method
         # has to be called before addEndpoints()
index 3d3cee0..2e84a25 100644 (file)
@@ -52,7 +52,7 @@ class BaseLeapTest(unittest.TestCase):
         cls.tearDownEnv()
 
     @classmethod
-    def setUpEnv(cls):
+    def setUpEnv(cls, launch_events_server=True):
         """
         Sets up common facilities for testing this TestCase:
         - custom PATH and HOME environmental variables
@@ -72,14 +72,15 @@ class BaseLeapTest(unittest.TestCase):
         os.environ["PATH"] = bin_tdir
         os.environ["HOME"] = cls.tempdir
         os.environ["XDG_CONFIG_HOME"] = os.path.join(cls.tempdir, ".config")
-        cls._init_events()
+        if launch_events_server:
+            cls._init_events()
 
     @classmethod
     def _init_events(cls):
         if flags.EVENTS_ENABLED:
             cls._server = events_server.ensure_server(
-                emit_addr="tcp://127.0.0.1:0",
-                reg_addr="tcp://127.0.0.1:0")
+                emit_addr="tcp://127.0.0.1",
+                reg_addr="tcp://127.0.0.1")
             events_client.configure_client(
                 emit_addr="tcp://127.0.0.1:%d" % cls._server.pull_port,
                 reg_addr="tcp://127.0.0.1:%d" % cls._server.pub_port)