Workaround for deadlock problem in zmq auth
[leap_pycommon.git] / src / leap / common / events / zmq_components.py
1 # -*- coding: utf-8 -*-
2 # zmq.py
3 # Copyright (C) 2015 LEAP
4 #
5 # This program is free software: you can redistribute it and/or modify
6 # it under the terms of the GNU General Public License as published by
7 # the Free Software Foundation, either version 3 of the License, or
8 # (at your option) any later version.
9 #
10 # This program is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 # GNU General Public License for more details.
14 #
15 # You should have received a copy of the GNU General Public License
16 # along with this program. If not, see <http://www.gnu.org/licenses/>.
17
18
19 """
20 The server for the events mechanism.
21 """
22
23
24 import os
25 import logging
26 import txzmq
27 import re
28 import time
29
30 from abc import ABCMeta
31
32 # XXX some distros don't package libsodium, so we have to be prepared for
33 #     absence of zmq.auth
34 try:
35     import zmq.auth
36     from zmq.auth.thread import ThreadAuthenticator
37 except ImportError:
38     pass
39
40 from leap.common.config import flags, get_path_prefix
41 from leap.common.zmq_utils import zmq_has_curve
42 from leap.common.zmq_utils import maybe_create_and_get_certificates
43 from leap.common.zmq_utils import PUBLIC_KEYS_PREFIX
44
45
46 logger = logging.getLogger(__name__)
47
48
49 ADDRESS_RE = re.compile("^([a-z]+)://([^:]+):?(\d+)?$")
50
51
52 class TxZmqComponent(object):
53     """
54     A twisted-powered zmq events component.
55     """
56
57     __metaclass__ = ABCMeta
58
59     _component_type = None
60
61     def __init__(self, path_prefix=None):
62         """
63         Initialize the txzmq component.
64         """
65         self._factory = txzmq.ZmqFactory()
66         self._factory.registerForShutdown()
67         if path_prefix is None:
68             path_prefix = get_path_prefix(flags.STANDALONE)
69         self._config_prefix = os.path.join(path_prefix, "leap", "events")
70         self._connections = []
71
72     @property
73     def component_type(self):
74         if not self._component_type:
75             raise Exception(
76                 "Make sure implementations of TxZmqComponent"
77                 "define a self._component_type!")
78         return self._component_type
79
80     def _zmq_connect(self, connClass, address):
81         """
82         Connect to an address.
83
84         :param connClass: The connection class to be used.
85         :type connClass: txzmq.ZmqConnection
86         :param address: The address to connect to.
87         :type address: str
88
89         :return: The binded connection.
90         :rtype: txzmq.ZmqConnection
91         """
92         connection = connClass(self._factory)
93         # create and configure socket
94         socket = connection.socket
95         if zmq_has_curve():
96             public, secret = maybe_create_and_get_certificates(
97                 self._config_prefix, self.component_type)
98             server_public_file = os.path.join(
99                 self._config_prefix, PUBLIC_KEYS_PREFIX, "server.key")
100             server_public, _ = zmq.auth.load_certificate(server_public_file)
101             socket.curve_publickey = public
102             socket.curve_secretkey = secret
103             socket.curve_serverkey = server_public
104         socket.connect(address)
105         logger.debug("Connected %s to %s." % (connClass, address))
106         self._connections.append(connection)
107         return connection
108
109     def _zmq_bind(self, connClass, address):
110         """
111         Bind to an address.
112
113         :param connClass: The connection class to be used.
114         :type connClass: txzmq.ZmqConnection
115         :param address: The address to bind to.
116         :type address: str
117
118         :return: The binded connection and port.
119         :rtype: (txzmq.ZmqConnection, int)
120         """
121         connection = connClass(self._factory)
122         socket = connection.socket
123         if zmq_has_curve():
124             public, secret = maybe_create_and_get_certificates(
125                 self._config_prefix, self.component_type)
126             socket.curve_publickey = public
127             socket.curve_secretkey = secret
128             self._start_thread_auth(connection.socket)
129
130         proto, addr, port = ADDRESS_RE.search(address).groups()
131
132         if proto == "tcp":
133             if port is None or port is '0':
134                 params = proto, addr
135                 port = socket.bind_to_random_port("%s://%s" % params)
136                 logger.debug("Binded %s to %s://%s." % ((connClass,) + params))
137             else:
138                 params = proto, addr, int(port)
139                 socket.bind("%s://%s:%d" % params)
140                 logger.debug(
141                     "Binded %s to %s://%s:%d." % ((connClass,) + params))
142         else:
143             params = proto, addr
144             socket.bind("%s://%s" % params)
145             logger.debug(
146                 "Binded %s to %s://%s" % ((connClass,) + params))
147         self._connections.append(connection)
148         return connection, port
149
150     def _start_thread_auth(self, socket):
151         """
152         Start the zmq curve thread authenticator.
153
154         :param socket: The socket in which to configure the authenticator.
155         :type socket: zmq.Socket
156         """
157         authenticator = ThreadAuthenticator(self._factory.context)
158
159         # Temporary fix until we understand what the problem is
160         # See https://leap.se/code/issues/7536
161         time.sleep(0.5)
162
163         authenticator.start()
164         # XXX do not hardcode this here.
165         authenticator.allow('127.0.0.1')
166         # tell authenticator to use the certificate in a directory
167         public_keys_dir = os.path.join(self._config_prefix, PUBLIC_KEYS_PREFIX)
168         authenticator.configure_curve(domain="*", location=public_keys_dir)
169         socket.curve_server = True  # must come before bind
170
171     def shutdown(self):
172         """
173         Shutdown the component.
174         """
175         logger.debug("Shutting down component %s." % str(self))
176         for conn in self._connections:
177             conn.shutdown()
178         self._factory.shutdown()
179
180
181 class TxZmqServerComponent(TxZmqComponent):
182     """
183     A txZMQ server component.
184     """
185
186     _component_type = "server"
187
188
189 class TxZmqClientComponent(TxZmqComponent):
190     """
191     A txZMQ client component.
192     """
193
194     _component_type = "client"