# -*- coding: utf-8 -*- # zmq.py # Copyright (C) 2015 LEAP # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program. If not, see . """ The server for the events mechanism. """ import os import logging import txzmq import re from abc import ABCMeta # XXX some distros don't package libsodium, so we have to be prepared for # absence of zmq.auth try: import zmq.auth from zmq.auth.thread import ThreadAuthenticator except ImportError: pass from leap.common.config import get_path_prefix from leap.common.zmq_utils import zmq_has_curve from leap.common.zmq_utils import maybe_create_and_get_certificates from leap.common.zmq_utils import PUBLIC_KEYS_PREFIX logger = logging.getLogger(__name__) ADDRESS_RE = re.compile("^([a-z]+)://([^:]+):?(\d+)?$") class TxZmqComponent(object): """ A twisted-powered zmq events component. """ __metaclass__ = ABCMeta _component_type = None def __init__(self, path_prefix=None): """ Initialize the txzmq component. """ self._factory = txzmq.ZmqFactory() self._factory.registerForShutdown() if path_prefix is None: path_prefix = get_path_prefix() self._config_prefix = os.path.join(path_prefix, "leap", "events") self._connections = [] @property def component_type(self): if not self._component_type: raise Exception( "Make sure implementations of TxZmqComponent" "define a self._component_type!") return self._component_type def _zmq_connect(self, connClass, address): """ Connect to an address. :param connClass: The connection class to be used. :type connClass: txzmq.ZmqConnection :param address: The address to connect to. :type address: str :return: The binded connection. :rtype: txzmq.ZmqConnection """ connection = connClass(self._factory) # create and configure socket socket = connection.socket if zmq_has_curve(): public, secret = maybe_create_and_get_certificates( self._config_prefix, self.component_type) server_public_file = os.path.join( self._config_prefix, PUBLIC_KEYS_PREFIX, "server.key") server_public, _ = zmq.auth.load_certificate(server_public_file) socket.curve_publickey = public socket.curve_secretkey = secret socket.curve_serverkey = server_public socket.connect(address) logger.debug("Connected %s to %s." % (connClass, address)) self._connections.append(connection) return connection def _zmq_bind(self, connClass, address): """ Bind to an address. :param connClass: The connection class to be used. :type connClass: txzmq.ZmqConnection :param address: The address to bind to. :type address: str :return: The binded connection and port. :rtype: (txzmq.ZmqConnection, int) """ connection = connClass(self._factory) socket = connection.socket if zmq_has_curve(): public, secret = maybe_create_and_get_certificates( self._config_prefix, self.component_type) socket.curve_publickey = public socket.curve_secretkey = secret self._start_thread_auth(connection.socket) proto, addr, port = ADDRESS_RE.search(address).groups() if proto == "tcp": if port is None or port is '0': params = proto, addr port = socket.bind_to_random_port("%s://%s" % params) logger.debug("Binded %s to %s://%s." % ((connClass,) + params)) else: params = proto, addr, int(port) socket.bind("%s://%s:%d" % params) logger.debug( "Binded %s to %s://%s:%d." % ((connClass,) + params)) else: params = proto, addr socket.bind("%s://%s" % params) logger.debug( "Binded %s to %s://%s" % ((connClass,) + params)) self._connections.append(connection) return connection, port def _start_thread_auth(self, socket): """ Start the zmq curve thread authenticator. :param socket: The socket in which to configure the authenticator. :type socket: zmq.Socket """ authenticator = ThreadAuthenticator(self._factory.context) authenticator.start() # XXX do not hardcode this here. authenticator.allow('127.0.0.1') # tell authenticator to use the certificate in a directory public_keys_dir = os.path.join(self._config_prefix, PUBLIC_KEYS_PREFIX) authenticator.configure_curve(domain="*", location=public_keys_dir) socket.curve_server = True # must come before bind def shutdown(self): """ Shutdown the component. """ logger.debug("Shutting down component %s." % str(self)) for conn in self._connections: conn.shutdown() self._factory.shutdown() class TxZmqServerComponent(TxZmqComponent): """ A txZMQ server component. """ _component_type = "server" class TxZmqClientComponent(TxZmqComponent): """ A txZMQ client component. """ _component_type = "client"