2c40f623398275ab576ff397416c8aeb3ced0713
[leap_pycommon.git] / src / leap / common / events / zmq_components.py
1 # -*- coding: utf-8 -*-
2 # zmq.py
3 # Copyright (C) 2015 LEAP
4 #
5 # This program is free software: you can redistribute it and/or modify
6 # it under the terms of the GNU General Public License as published by
7 # the Free Software Foundation, either version 3 of the License, or
8 # (at your option) any later version.
9 #
10 # This program is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 # GNU General Public License for more details.
14 #
15 # You should have received a copy of the GNU General Public License
16 # along with this program. If not, see <http://www.gnu.org/licenses/>.
17
18
19 """
20 The server for the events mechanism.
21 """
22
23
24 import os
25 import logging
26 import txzmq
27 import re
28 import time
29
30 from abc import ABCMeta
31
32 # XXX some distros don't package libsodium, so we have to be prepared for
33 #     absence of zmq.auth
34 try:
35     import zmq.auth
36     from zmq.auth.thread import ThreadAuthenticator
37 except ImportError:
38     pass
39
40 from leap.common.config import flags, get_path_prefix
41 from leap.common.zmq_utils import zmq_has_curve
42 from leap.common.zmq_utils import maybe_create_and_get_certificates
43 from leap.common.zmq_utils import PUBLIC_KEYS_PREFIX
44
45
46 logger = logging.getLogger(__name__)
47
48
49 ADDRESS_RE = re.compile("^([a-z]+)://([^:]+):?(\d+)?$")
50
51
52 class TxZmqComponent(object):
53     """
54     A twisted-powered zmq events component.
55     """
56
57     __metaclass__ = ABCMeta
58
59     _component_type = None
60
61     def __init__(self, path_prefix=None, enable_curve=True):
62         """
63         Initialize the txzmq component.
64         """
65         self._factory = txzmq.ZmqFactory()
66         self._factory.registerForShutdown()
67         if path_prefix is None:
68             path_prefix = get_path_prefix(flags.STANDALONE)
69         self._config_prefix = os.path.join(path_prefix, "leap", "events")
70         self._connections = []
71         if enable_curve:
72             self.use_curve = zmq_has_curve()
73         else:
74             self.use_curve = False
75
76     @property
77     def component_type(self):
78         if not self._component_type:
79             raise Exception(
80                 "Make sure implementations of TxZmqComponent"
81                 "define a self._component_type!")
82         return self._component_type
83
84     def _zmq_connect(self, connClass, address):
85         """
86         Connect to an address.
87
88         :param connClass: The connection class to be used.
89         :type connClass: txzmq.ZmqConnection
90         :param address: The address to connect to.
91         :type address: str
92
93         :return: The binded connection.
94         :rtype: txzmq.ZmqConnection
95         """
96         connection = connClass(self._factory)
97         # create and configure socket
98         socket = connection.socket
99         if zmq_has_curve():
100             public, secret = maybe_create_and_get_certificates(
101                 self._config_prefix, self.component_type)
102             server_public_file = os.path.join(
103                 self._config_prefix, PUBLIC_KEYS_PREFIX, "server.key")
104             server_public, _ = zmq.auth.load_certificate(server_public_file)
105             socket.curve_publickey = public
106             socket.curve_secretkey = secret
107             socket.curve_serverkey = server_public
108         socket.connect(address)
109         logger.debug("Connected %s to %s." % (connClass, address))
110         self._connections.append(connection)
111         return connection
112
113     def _zmq_bind(self, connClass, address):
114         """
115         Bind to an address.
116
117         :param connClass: The connection class to be used.
118         :type connClass: txzmq.ZmqConnection
119         :param address: The address to bind to.
120         :type address: str
121
122         :return: The binded connection and port.
123         :rtype: (txzmq.ZmqConnection, int)
124         """
125         connection = connClass(self._factory)
126         socket = connection.socket
127         if zmq_has_curve():
128             public, secret = maybe_create_and_get_certificates(
129                 self._config_prefix, self.component_type)
130             socket.curve_publickey = public
131             socket.curve_secretkey = secret
132             self._start_thread_auth(connection.socket)
133
134         proto, addr, port = ADDRESS_RE.search(address).groups()
135
136         if proto == "tcp":
137             if port is None or port is '0':
138                 params = proto, addr
139                 port = socket.bind_to_random_port("%s://%s" % params)
140                 logger.debug("Binded %s to %s://%s." % ((connClass,) + params))
141             else:
142                 params = proto, addr, int(port)
143                 socket.bind("%s://%s:%d" % params)
144                 logger.debug(
145                     "Binded %s to %s://%s:%d." % ((connClass,) + params))
146         else:
147             params = proto, addr
148             socket.bind("%s://%s" % params)
149             logger.debug(
150                 "Binded %s to %s://%s" % ((connClass,) + params))
151         self._connections.append(connection)
152         return connection, port
153
154     def _start_thread_auth(self, socket):
155         """
156         Start the zmq curve thread authenticator.
157
158         :param socket: The socket in which to configure the authenticator.
159         :type socket: zmq.Socket
160         """
161         authenticator = ThreadAuthenticator(self._factory.context)
162
163         # Temporary fix until we understand what the problem is
164         # See https://leap.se/code/issues/7536
165         time.sleep(0.5)
166
167         authenticator.start()
168         # XXX do not hardcode this here.
169         authenticator.allow('127.0.0.1')
170         # tell authenticator to use the certificate in a directory
171         public_keys_dir = os.path.join(self._config_prefix, PUBLIC_KEYS_PREFIX)
172         authenticator.configure_curve(domain="*", location=public_keys_dir)
173         socket.curve_server = True  # must come before bind
174
175     def shutdown(self):
176         """
177         Shutdown the component.
178         """
179         logger.debug("Shutting down component %s." % str(self))
180         for conn in self._connections:
181             conn.shutdown()
182         self._factory.shutdown()
183
184
185 class TxZmqServerComponent(TxZmqComponent):
186     """
187     A txZMQ server component.
188     """
189
190     _component_type = "server"
191
192
193 class TxZmqClientComponent(TxZmqComponent):
194     """
195     A txZMQ client component.
196     """
197
198     _component_type = "client"