Trying to init events server raises when given port is not free.
[leap_pycommon.git] / src / leap / common / tests / test_events.py
index 687195f..90124b4 100644 (file)
 import unittest
 import sets
 import time
+import socket
+import threading
+import random
+
+
+from mock import Mock
 from protobuf.socketrpc import RpcService
 from leap.common import events
 from leap.common.events import (
@@ -28,9 +34,11 @@ from leap.common.events import (
 from leap.common.events.events_pb2 import (
     EventsServerService,
     EventsServerService_Stub,
+    EventsClientService_Stub,
     EventResponse,
     SignalRequest,
     RegisterRequest,
+    PingRequest,
     SOLEDAD_CREATING_KEYS,
     CLIENT_UID,
 )
@@ -39,11 +47,6 @@ from leap.common.events.events_pb2 import (
 port = 8090
 
 received = False
-local_callback_executed = False
-
-
-def callback(request, reponse):
-    return True
 
 
 class EventsTestCase(unittest.TestCase):
@@ -120,17 +123,28 @@ class EventsTestCase(unittest.TestCase):
         response = service.signal(request, timeout=1000)
         self.assertEqual(EventResponse.OK, response.status,
                          'Wrong response status.')
-        # test asynch
 
-        def local_callback(request, response):
-            global local_callback_executed
-            local_callback_executed = True
+    def test_signal_executes_callback(self):
+        """
+        Ensure callback is executed upon receiving signal.
+        """
+        sig = CLIENT_UID
+        request = SignalRequest()
+        request.event = sig
+        request.content = 'my signal contents'
+        request.mac_method = mac_auth.MacMethod.MAC_NONE
+        request.mac = ""
+        service = RpcService(EventsServerService_Stub, port, 'localhost')
 
-        events.register(sig, local_callback)
-        service.signal(request, callback=local_callback)
-        time.sleep(0.1)
-        self.assertTrue(local_callback_executed,
-                        'Local callback did not execute.')
+        # register a callback
+        flag = Mock()
+        events.register(sig, lambda req: flag(req.event))
+        # signal
+        response = service.signal(request)
+        self.assertEqual(EventResponse.OK, response.status,
+                         'Wrong response status.')
+        time.sleep(1)  # wait for signal to arrive
+        flag.assert_called_once_with(sig)
 
     def test_events_server_service_register(self):
         """
@@ -173,12 +187,9 @@ class EventsTestCase(unittest.TestCase):
         Ensure clients can receive signals.
         """
         sig = 7
+        flag = Mock()
 
-        def getsig(param=None):
-            global received
-            received = True
-
-        events.register(sig, getsig)
+        events.register(sig, lambda req: flag(req.event))
         request = SignalRequest()
         request.event = sig
         request.content = ""
@@ -188,7 +199,7 @@ class EventsTestCase(unittest.TestCase):
         response = service.signal(request, timeout=1000)
         self.assertTrue(response is not None, 'Did not receive response.')
         time.sleep(0.5)
-        self.assertTrue(received, 'Did not receive signal back.')
+        flag.assert_called_once_with(sig)
 
     def test_client_send_signal(self):
         """
@@ -231,3 +242,98 @@ class EventsTestCase(unittest.TestCase):
         self.assertTrue(
             client.registered_callbacks[sig].pop()[0] == 'cbkuid2')
         self.assertTrue(port in complist[sig])
+
+    def test_server_replies_ping(self):
+        """
+        Ensure server replies to a ping.
+        """
+        request = PingRequest()
+        service = RpcService(EventsServerService_Stub, port, 'localhost')
+        response = service.ping(request, timeout=1000)
+        self.assertIsNotNone(response)
+        self.assertEqual(EventResponse.OK, response.status,
+                         'Wrong response status.')
+
+    def test_client_replies_ping(self):
+        """
+        Ensure clients reply to a ping.
+        """
+        daemon = client.ensure_client_daemon()
+        port = daemon.get_port()
+        request = PingRequest()
+        service = RpcService(EventsClientService_Stub, port, 'localhost')
+        response = service.ping(request, timeout=1000)
+        self.assertEqual(EventResponse.OK, response.status,
+                         'Wrong response status.')
+
+    def test_server_ping(self):
+        """
+        Ensure the function from server module pings correctly.
+        """
+        response = server.ping()
+        self.assertIsNotNone(response)
+        self.assertEqual(EventResponse.OK, response.status,
+                         'Wrong response status.')
+
+    def test_client_ping(self):
+        """
+        Ensure the function from client module pings correctly.
+        """
+        daemon = client.ensure_client_daemon()
+        response = client.ping(daemon.get_port())
+        self.assertIsNotNone(response)
+        self.assertEqual(EventResponse.OK, response.status,
+                         'Wrong response status.')
+
+    def test_module_ping_server(self):
+        """
+        Ensure the function from main module pings server correctly.
+        """
+        response = events.ping_server()
+        self.assertIsNotNone(response)
+        self.assertEqual(EventResponse.OK, response.status,
+                         'Wrong response status.')
+
+    def test_module_ping_client(self):
+        """
+        Ensure the function from main module pings clients correctly.
+        """
+        daemon = client.ensure_client_daemon()
+        response = events.ping_client(daemon.get_port())
+        self.assertIsNotNone(response)
+        self.assertEqual(EventResponse.OK, response.status,
+                         'Wrong response status.')
+
+    def test_ensure_server_raises_if_port_taken(self):
+        """
+        Verify that server raises an exception if port is already taken.
+        """
+        # get a random free port
+        while True:
+            port = random.randint(1024, 65535)
+            try:
+                s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+                s.connect(('localhost', port))
+                s.close()
+            except:
+                break
+
+        class PortBlocker(threading.Thread):
+
+            def run(self):
+                conns = 0
+                s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+                s.bind(('localhost', port))
+                s.setblocking(1)
+                s.listen(1)
+                while conns < 2:  # blocks until rece
+                    conns += 1
+                    s.accept()
+                s.close()
+
+        # block the port
+        taker = PortBlocker()
+        taker.start()
+        time.sleep(1)  # wait for thread to start.
+        self.assertRaises(
+            server.PortAlreadyTaken, server.ensure_server, port)