04f71e02120aa9971e8f1e49ad5c4fd282d69c9f
[leap_pycommon.git] / src / leap / common / events / zmq_components.py
1 # -*- coding: utf-8 -*-
2 # zmq.py
3 # Copyright (C) 2015 LEAP
4 #
5 # This program is free software: you can redistribute it and/or modify
6 # it under the terms of the GNU General Public License as published by
7 # the Free Software Foundation, either version 3 of the License, or
8 # (at your option) any later version.
9 #
10 # This program is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 # GNU General Public License for more details.
14 #
15 # You should have received a copy of the GNU General Public License
16 # along with this program. If not, see <http://www.gnu.org/licenses/>.
17
18
19 """
20 The server for the events mechanism.
21 """
22
23
24 import os
25 import logging
26 import txzmq
27 import re
28
29 from abc import ABCMeta
30
31 # XXX some distros don't package libsodium, so we have to be prepared for
32 #     absence of zmq.auth
33 try:
34     import zmq.auth
35     from zmq.auth.thread import ThreadAuthenticator
36 except ImportError:
37     pass
38
39 from leap.common.config import get_path_prefix
40 from leap.common.zmq_utils import zmq_has_curve
41 from leap.common.zmq_utils import maybe_create_and_get_certificates
42 from leap.common.zmq_utils import PUBLIC_KEYS_PREFIX
43
44
45 logger = logging.getLogger(__name__)
46
47
48 ADDRESS_RE = re.compile("^([a-z]+)://([^:]+):?(\d+)?$")
49
50
51 class TxZmqComponent(object):
52     """
53     A twisted-powered zmq events component.
54     """
55
56     __metaclass__ = ABCMeta
57
58     _component_type = None
59
60     def __init__(self, path_prefix=None):
61         """
62         Initialize the txzmq component.
63         """
64         self._factory = txzmq.ZmqFactory()
65         self._factory.registerForShutdown()
66         if path_prefix is None:
67             path_prefix = get_path_prefix()
68         self._config_prefix = os.path.join(path_prefix, "leap", "events")
69         self._connections = []
70
71     @property
72     def component_type(self):
73         if not self._component_type:
74             raise Exception(
75                 "Make sure implementations of TxZmqComponent"
76                 "define a self._component_type!")
77         return self._component_type
78
79     def _zmq_connect(self, connClass, address):
80         """
81         Connect to an address.
82
83         :param connClass: The connection class to be used.
84         :type connClass: txzmq.ZmqConnection
85         :param address: The address to connect to.
86         :type address: str
87
88         :return: The binded connection.
89         :rtype: txzmq.ZmqConnection
90         """
91         connection = connClass(self._factory)
92         # create and configure socket
93         socket = connection.socket
94         if zmq_has_curve():
95             public, secret = maybe_create_and_get_certificates(
96                 self._config_prefix, self.component_type)
97             server_public_file = os.path.join(
98                 self._config_prefix, PUBLIC_KEYS_PREFIX, "server.key")
99             server_public, _ = zmq.auth.load_certificate(server_public_file)
100             socket.curve_publickey = public
101             socket.curve_secretkey = secret
102             socket.curve_serverkey = server_public
103         socket.connect(address)
104         logger.debug("Connected %s to %s." % (connClass, address))
105         self._connections.append(connection)
106         return connection
107
108     def _zmq_bind(self, connClass, address):
109         """
110         Bind to an address.
111
112         :param connClass: The connection class to be used.
113         :type connClass: txzmq.ZmqConnection
114         :param address: The address to bind to.
115         :type address: str
116
117         :return: The binded connection and port.
118         :rtype: (txzmq.ZmqConnection, int)
119         """
120         connection = connClass(self._factory)
121         socket = connection.socket
122         if zmq_has_curve():
123             public, secret = maybe_create_and_get_certificates(
124                 self._config_prefix, self.component_type)
125             socket.curve_publickey = public
126             socket.curve_secretkey = secret
127             self._start_thread_auth(connection.socket)
128
129         proto, addr, port = ADDRESS_RE.search(address).groups()
130
131         if port is None or port is '0':
132             params = proto, addr
133             port = socket.bind_to_random_port("%s://%s" % params)
134             # XXX this log doesn't appear
135             logger.debug("Binded %s to %s://%s." % ((connClass,) + params))
136         else:
137             params = proto, addr, int(port)
138             socket.bind("%s://%s:%d" % params)
139             # XXX this log doesn't appear
140             logger.debug("Binded %s to %s://%s:%d." % ((connClass,) + params))
141         self._connections.append(connection)
142         return connection, port
143
144     def _start_thread_auth(self, socket):
145         """
146         Start the zmq curve thread authenticator.
147
148         :param socket: The socket in which to configure the authenticator.
149         :type socket: zmq.Socket
150         """
151         authenticator = ThreadAuthenticator(self._factory.context)
152         authenticator.start()
153         # XXX do not hardcode this here.
154         authenticator.allow('127.0.0.1')
155         # tell authenticator to use the certificate in a directory
156         public_keys_dir = os.path.join(self._config_prefix, PUBLIC_KEYS_PREFIX)
157         authenticator.configure_curve(domain="*", location=public_keys_dir)
158         socket.curve_server = True  # must come before bind
159
160     def shutdown(self):
161         """
162         Shutdown the component.
163         """
164         logger.debug("Shutting down component %s." % str(self))
165         for conn in self._connections:
166             conn.shutdown()
167         self._factory.shutdown()
168
169
170 class TxZmqServerComponent(TxZmqComponent):
171     """
172     A txZMQ server component.
173     """
174
175     _component_type = "server"
176
177
178 class TxZmqClientComponent(TxZmqComponent):
179     """
180     A txZMQ client component.
181     """
182
183     _component_type = "client"