[tests] adapt events tests to recent changes
[leap_pycommon.git] / src / leap / common / events / zmq_components.py
1 # -*- coding: utf-8 -*-
2 # zmq.py
3 # Copyright (C) 2015, 2016 LEAP
4 #
5 # This program is free software: you can redistribute it and/or modify
6 # it under the terms of the GNU General Public License as published by
7 # the Free Software Foundation, either version 3 of the License, or
8 # (at your option) any later version.
9 #
10 # This program is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 # GNU General Public License for more details.
14 #
15 # You should have received a copy of the GNU General Public License
16 # along with this program. If not, see <http://www.gnu.org/licenses/>.
17 """
18 The server for the events mechanism.
19 """
20 import os
21 import logging
22 import txzmq
23 import re
24
25 from abc import ABCMeta
26
27 try:
28     import zmq.auth
29     from leap.common.events.auth import TxAuthenticator
30     from leap.common.events.auth import TxAuthenticationRequest
31 except ImportError:
32     pass
33
34 from txzmq.connection import ZmqEndpoint, ZmqEndpointType
35
36 from leap.common.config import flags, get_path_prefix
37 from leap.common.zmq_utils import zmq_has_curve
38 from leap.common.zmq_utils import maybe_create_and_get_certificates
39 from leap.common.zmq_utils import PUBLIC_KEYS_PREFIX
40
41 logger = logging.getLogger(__name__)
42
43 ADDRESS_RE = re.compile("^([a-z]+)://([^:]+):?(\d+)?$")
44
45 LOCALHOST_ALLOWED = '127.0.0.1'
46
47
48 class TxZmqComponent(object):
49     """
50     A twisted-powered zmq events component.
51     """
52     _factory = txzmq.ZmqFactory()
53     _factory.registerForShutdown()
54     _auth = None
55
56     __metaclass__ = ABCMeta
57
58     _component_type = None
59
60     def __init__(self, path_prefix=None, enable_curve=True, factory=None):
61         """
62         Initialize the txzmq component.
63         """
64         if path_prefix is None:
65             path_prefix = get_path_prefix(flags.STANDALONE)
66         if factory is not None:
67             self._factory = factory
68         self._config_prefix = os.path.join(path_prefix, "leap", "events")
69         self._connections = []
70         if enable_curve:
71             self.use_curve = zmq_has_curve()
72         else:
73             self.use_curve = False
74
75     @property
76     def component_type(self):
77         if not self._component_type:
78             raise Exception(
79                 "Make sure implementations of TxZmqComponent"
80                 "define a self._component_type!")
81         return self._component_type
82
83     def _zmq_bind(self, connClass, address):
84         """
85         Bind to an address.
86
87         :param connClass: The connection class to be used.
88         :type connClass: txzmq.ZmqConnection
89         :param address: The address to bind to.
90         :type address: str
91
92         :return: The binded connection and port.
93         :rtype: (txzmq.ZmqConnection, int)
94         """
95         proto, addr, port = ADDRESS_RE.search(address).groups()
96
97         endpoint = ZmqEndpoint(ZmqEndpointType.bind, address)
98         connection = connClass(self._factory)
99
100         if self.use_curve:
101             socket = connection.socket
102
103             public, secret = maybe_create_and_get_certificates(
104                 self._config_prefix, self.component_type)
105             socket.curve_publickey = public
106             socket.curve_secretkey = secret
107             self._start_authentication(connection.socket)
108
109         if proto == 'tcp' and int(port) == 0:
110             connection.endpoints.extend([endpoint])
111             port = connection.socket.bind_to_random_port('tcp://%s' % addr)
112         else:
113             connection.addEndpoints([endpoint])
114
115         return connection, int(port)
116
117     def _zmq_connect(self, connClass, address):
118         """
119         Connect to an address.
120
121         :param connClass: The connection class to be used.
122         :type connClass: txzmq.ZmqConnection
123         :param address: The address to connect to.
124         :type address: str
125
126         :return: The binded connection.
127         :rtype: txzmq.ZmqConnection
128         """
129         endpoint = ZmqEndpoint(ZmqEndpointType.connect, address)
130         connection = connClass(self._factory)
131
132         if self.use_curve:
133             socket = connection.socket
134             public, secret = maybe_create_and_get_certificates(
135                 self._config_prefix, self.component_type)
136             server_public_file = os.path.join(
137                 self._config_prefix, PUBLIC_KEYS_PREFIX, "server.key")
138
139             server_public, _ = zmq.auth.load_certificate(server_public_file)
140             socket.curve_publickey = public
141             socket.curve_secretkey = secret
142             socket.curve_serverkey = server_public
143
144         connection.addEndpoints([endpoint])
145         return connection
146
147     def _start_authentication(self, socket):
148
149         if not TxZmqComponent._auth:
150             TxZmqComponent._auth = TxAuthenticator(self._factory)
151             TxZmqComponent._auth.start()
152
153         auth_req = TxAuthenticationRequest(self._factory)
154         auth_req.start()
155         auth_req.allow(LOCALHOST_ALLOWED)
156
157         # tell authenticator to use the certificate in a directory
158         public_keys_dir = os.path.join(self._config_prefix, PUBLIC_KEYS_PREFIX)
159         auth_req.configure_curve(domain="*", location=public_keys_dir)
160         auth_req.shutdown()
161
162         # This has to be set before binding the socket, that's why this method
163         # has to be called before addEndpoints()
164         socket.curve_server = True
165
166
167 class TxZmqServerComponent(TxZmqComponent):
168     """
169     A txZMQ server component.
170     """
171
172     _component_type = "server"
173
174
175 class TxZmqClientComponent(TxZmqComponent):
176     """
177     A txZMQ client component.
178     """
179
180     _component_type = "client"