From e80d3d3d93fd3aeeffba7a5b6a05569695fc0f6e Mon Sep 17 00:00:00 2001 From: Ruben Pollan Date: Thu, 12 Nov 2015 19:34:38 +0100 Subject: [style] fix pep8 --- src/leap/common/certs.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'src') diff --git a/src/leap/common/certs.py b/src/leap/common/certs.py index c49015a..95704a6 100644 --- a/src/leap/common/certs.py +++ b/src/leap/common/certs.py @@ -192,8 +192,8 @@ def get_compatible_ssl_context_factory(cert_path=None): class WebClientContextFactory(ssl.ClientContextFactory): """ - A web context factory which ignores the hostname and port and does no - certificate verification. + A web context factory which ignores the hostname and port and does + no certificate verification. """ def getContext(self, hostname, port): return ssl.ClientContextFactory.getContext(self) -- cgit v1.2.3 From a42d67584fa70abb59b932471b4df3606b8294d0 Mon Sep 17 00:00:00 2001 From: Kali Kaneko Date: Mon, 30 Nov 2015 16:17:44 -0400 Subject: [refactor] reorder and comment events Reorder blocks of events, and comment about which user-specific info it's being emitted with them. --- src/leap/common/events/catalog.py | 85 +++++++++++++++++++++------------------ 1 file changed, 45 insertions(+), 40 deletions(-) (limited to 'src') diff --git a/src/leap/common/events/catalog.py b/src/leap/common/events/catalog.py index 8bddd2c..6a09495 100644 --- a/src/leap/common/events/catalog.py +++ b/src/leap/common/events/catalog.py @@ -24,49 +24,54 @@ Events catalog. EVENTS = [ "CLIENT_SESSION_ID", "CLIENT_UID", - "IMAP_CLIENT_LOGIN", - "IMAP_SERVICE_FAILED_TO_START", - "IMAP_SERVICE_STARTED", - "IMAP_UNHANDLED_ERROR", - "KEYMANAGER_DONE_UPLOADING_KEYS", - "KEYMANAGER_FINISHED_KEY_GENERATION", - "KEYMANAGER_KEY_FOUND", - "KEYMANAGER_KEY_NOT_FOUND", - "KEYMANAGER_LOOKING_FOR_KEY", - "KEYMANAGER_STARTED_KEY_GENERATION", - "MAIL_FETCHED_INCOMING", - "MAIL_MSG_DECRYPTED", - "MAIL_MSG_DELETED_INCOMING", - "MAIL_MSG_PROCESSING", - "MAIL_MSG_SAVED_LOCALLY", - "MAIL_UNREAD_MESSAGES", "RAISE_WINDOW", - "SMTP_CONNECTION_LOST", - "SMTP_END_ENCRYPT_AND_SIGN", - "SMTP_END_SIGN", - "SMTP_RECIPIENT_ACCEPTED_ENCRYPTED", - "SMTP_RECIPIENT_ACCEPTED_UNENCRYPTED", - "SMTP_RECIPIENT_REJECTED", - "SMTP_SEND_MESSAGE_ERROR", - "SMTP_SEND_MESSAGE_START", - "SMTP_SEND_MESSAGE_SUCCESS", - "SMTP_SERVICE_FAILED_TO_START", - "SMTP_SERVICE_STARTED", - "SMTP_START_ENCRYPT_AND_SIGN", - "SMTP_START_SIGN", - "SOLEDAD_CREATING_KEYS", - "SOLEDAD_DONE_CREATING_KEYS", - "SOLEDAD_DONE_DATA_SYNC", - "SOLEDAD_DONE_DOWNLOADING_KEYS", - "SOLEDAD_DONE_UPLOADING_KEYS", - "SOLEDAD_DOWNLOADING_KEYS", - "SOLEDAD_INVALID_AUTH_TOKEN", - "SOLEDAD_NEW_DATA_TO_SYNC", - "SOLEDAD_SYNC_RECEIVE_STATUS", - "SOLEDAD_SYNC_SEND_STATUS", - "SOLEDAD_UPLOADING_KEYS", "UPDATER_DONE_UPDATING", "UPDATER_NEW_UPDATES", + + "KEYMANAGER_DONE_UPLOADING_KEYS", # (address) + "KEYMANAGER_FINISHED_KEY_GENERATION", # (address) + "KEYMANAGER_KEY_FOUND", # (address) + "KEYMANAGER_KEY_NOT_FOUND", # (address) + "KEYMANAGER_LOOKING_FOR_KEY", # (address) + "KEYMANAGER_STARTED_KEY_GENERATION", # (address) + + "SOLEDAD_CREATING_KEYS", # {uuid, userid} + "SOLEDAD_DONE_CREATING_KEYS", # {uuid, userid} + "SOLEDAD_DONE_DATA_SYNC", # {uuid, userid} + "SOLEDAD_DONE_DOWNLOADING_KEYS", # {uuid, userid} + "SOLEDAD_DONE_UPLOADING_KEYS", # {uuid, userid} + "SOLEDAD_DOWNLOADING_KEYS", # {uuid, userid} + "SOLEDAD_INVALID_AUTH_TOKEN", # {uuid, userid} + "SOLEDAD_SYNC_RECEIVE_STATUS", # {uuid, userid} + "SOLEDAD_SYNC_SEND_STATUS", # {uuid, userid} + "SOLEDAD_UPLOADING_KEYS", # {uuid, userid} + "SOLEDAD_NEW_DATA_TO_SYNC", + + "MAIL_FETCHED_INCOMING", # (userid) + "MAIL_MSG_DECRYPTED", # (userid) + "MAIL_MSG_DELETED_INCOMING", # (userid) + "MAIL_MSG_PROCESSING", # (userid) + "MAIL_MSG_SAVED_LOCALLY", # (userid) + "MAIL_UNREAD_MESSAGES", # (userid) + + "IMAP_SERVICE_STARTED", + "IMAP_SERVICE_FAILED_TO_START", + "IMAP_UNHANDLED_ERROR", + "IMAP_CLIENT_LOGIN", # (username) + + "SMTP_SERVICE_STARTED", + "SMTP_SERVICE_FAILED_TO_START", + "SMTP_START_ENCRYPT_AND_SIGN", # (from_addr) + "SMTP_END_ENCRYPT_AND_SIGN", # (from_addr) + "SMTP_START_SIGN", # (from_addr) + "SMTP_END_SIGN", # (from_addr) + "SMTP_SEND_MESSAGE_START", # (from_addr) + "SMTP_SEND_MESSAGE_SUCCESS", # (from_addr) + "SMTP_RECIPIENT_ACCEPTED_ENCRYPTED", # (userid, dest) + "SMTP_RECIPIENT_ACCEPTED_UNENCRYPTED", # (userid, dest) + "SMTP_CONNECTION_LOST", # (userid, dest) + "SMTP_RECIPIENT_REJECTED", # (userid, dest) + "SMTP_SEND_MESSAGE_ERROR", # (userid, dest) ] -- cgit v1.2.3 From a36fdc8ebec9c42c61a2c733ea280a3fa9103598 Mon Sep 17 00:00:00 2001 From: meskio on windows Date: Thu, 11 Feb 2016 11:26:06 +0100 Subject: [feat] Get events working on windows Always use tcp channels and disable curve encryption on the zmq connections. - Closes: #7899, #7239 - Related: #7919 --- src/leap/common/events/server.py | 4 +++- src/leap/common/zmq_utils.py | 5 +++++ 2 files changed, 8 insertions(+), 1 deletion(-) (limited to 'src') diff --git a/src/leap/common/events/server.py b/src/leap/common/events/server.py index a69202e..7126723 100644 --- a/src/leap/common/events/server.py +++ b/src/leap/common/events/server.py @@ -22,6 +22,7 @@ The server for the events mechanism. import logging +import platform import txzmq from leap.common.zmq_utils import zmq_has_curve @@ -29,7 +30,8 @@ from leap.common.zmq_utils import zmq_has_curve from leap.common.events.zmq_components import TxZmqServerComponent -if zmq_has_curve(): +if zmq_has_curve() or platform.system() == "Windows": + # Windows doesn't have icp sockets, we need to use always tcp EMIT_ADDR = "tcp://127.0.0.1:9000" REG_ADDR = "tcp://127.0.0.1:9001" else: diff --git a/src/leap/common/zmq_utils.py b/src/leap/common/zmq_utils.py index 0a781de..39a49c7 100644 --- a/src/leap/common/zmq_utils.py +++ b/src/leap/common/zmq_utils.py @@ -19,6 +19,7 @@ Utilities to handle ZMQ certificates. """ import os import logging +import platform import stat import shutil @@ -52,6 +53,10 @@ def zmq_has_curve(): `zmq.auth` module is new in version 14.1 `zmq.has()` is new in version 14.1, new in version libzmq-4.1. """ + if platform.system() == "Windows": + # TODO: curve is not working on windows #7919 + return False + zmq_version = zmq.zmq_version_info() pyzmq_version = zmq.pyzmq_version_info() -- cgit v1.2.3 From 88941164243ce1ac6f30c790120165c04ea4a041 Mon Sep 17 00:00:00 2001 From: Kali Kaneko Date: Mon, 22 Feb 2016 19:25:21 -0400 Subject: [feature] optional flag to disable curve authentication --- src/leap/common/events/zmq_components.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) (limited to 'src') diff --git a/src/leap/common/events/zmq_components.py b/src/leap/common/events/zmq_components.py index 51de02c..2c40f62 100644 --- a/src/leap/common/events/zmq_components.py +++ b/src/leap/common/events/zmq_components.py @@ -58,7 +58,7 @@ class TxZmqComponent(object): _component_type = None - def __init__(self, path_prefix=None): + def __init__(self, path_prefix=None, enable_curve=True): """ Initialize the txzmq component. """ @@ -68,6 +68,10 @@ class TxZmqComponent(object): path_prefix = get_path_prefix(flags.STANDALONE) self._config_prefix = os.path.join(path_prefix, "leap", "events") self._connections = [] + if enable_curve: + self.use_curve = zmq_has_curve() + else: + self.use_curve = False @property def component_type(self): -- cgit v1.2.3 From b940cfc29b88374ce57b101a39bc012bb903f6e8 Mon Sep 17 00:00:00 2001 From: Kali Kaneko Date: Mon, 22 Feb 2016 19:26:45 -0400 Subject: [bug] avoid the events server to block twistd daemon 1. refactor the zmq_connect/bind methods to use the txzmq addEndpoints mechanism, which cleans up the code a bit. it uses the underlying bindOrConnect method. 2. wrap the addEndpoints call in a helper function that ensures that doRead is called afterward. I'm not fully comfortable with us still using the AuthenticatorThread, I believe we could go witha txzmq-based authenticator for curve. --- src/leap/common/events/server.py | 10 ++--- src/leap/common/events/txclient.py | 5 ++- src/leap/common/events/zmq_components.py | 66 ++++++++++++-------------------- 3 files changed, 30 insertions(+), 51 deletions(-) (limited to 'src') diff --git a/src/leap/common/events/server.py b/src/leap/common/events/server.py index 7126723..30a0c44 100644 --- a/src/leap/common/events/server.py +++ b/src/leap/common/events/server.py @@ -14,31 +14,27 @@ # # You should have received a copy of the GNU General Public License # along with this program. If not, see . - - """ The server for the events mechanism. """ - - import logging import platform -import txzmq from leap.common.zmq_utils import zmq_has_curve from leap.common.events.zmq_components import TxZmqServerComponent +import txzmq + if zmq_has_curve() or platform.system() == "Windows": - # Windows doesn't have icp sockets, we need to use always tcp + # Windows doesn't have ipc sockets, we need to use always tcp EMIT_ADDR = "tcp://127.0.0.1:9000" REG_ADDR = "tcp://127.0.0.1:9001" else: EMIT_ADDR = "ipc:///tmp/leap.common.events.socket.0" REG_ADDR = "ipc:///tmp/leap.common.events.socket.1" - logger = logging.getLogger(__name__) diff --git a/src/leap/common/events/txclient.py b/src/leap/common/events/txclient.py index dfd0533..ca247ca 100644 --- a/src/leap/common/events/txclient.py +++ b/src/leap/common/events/txclient.py @@ -28,9 +28,10 @@ some other client. import logging import pickle +from leap.common.events.zmq_components import TxZmqClientComponent + import txzmq -from leap.common.events.zmq_components import TxZmqClientComponent from leap.common.events.client import EventsClient from leap.common.events.client import configure_client from leap.common.events.server import EMIT_ADDR @@ -68,6 +69,7 @@ class EventsTxClient(TxZmqClientComponent, EventsClient): # same client self._sub = self._zmq_connect(txzmq.ZmqSubConnection, reg_addr) self._sub.gotMessage = self._gotMessage + self._push = self._zmq_connect(txzmq.ZmqPushConnection, emit_addr) def _gotMessage(self, msg, tag): @@ -122,7 +124,6 @@ class EventsTxClient(TxZmqClientComponent, EventsClient): callback(event, *content) def shutdown(self): - TxZmqClientComponent.shutdown(self) EventsClient.shutdown(self) diff --git a/src/leap/common/events/zmq_components.py b/src/leap/common/events/zmq_components.py index 2c40f62..1e0d52a 100644 --- a/src/leap/common/events/zmq_components.py +++ b/src/leap/common/events/zmq_components.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- # zmq.py -# Copyright (C) 2015 LEAP +# Copyright (C) 2015, 2016 LEAP # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by @@ -14,19 +14,16 @@ # # You should have received a copy of the GNU General Public License # along with this program. If not, see . - - """ The server for the events mechanism. """ - - import os import logging import txzmq import re import time + from abc import ABCMeta # XXX some distros don't package libsodium, so we have to be prepared for @@ -37,8 +34,11 @@ try: except ImportError: pass +from txzmq.connection import ZmqEndpoint, ZmqEndpointType + from leap.common.config import flags, get_path_prefix from leap.common.zmq_utils import zmq_has_curve + from leap.common.zmq_utils import maybe_create_and_get_certificates from leap.common.zmq_utils import PUBLIC_KEYS_PREFIX @@ -53,6 +53,8 @@ class TxZmqComponent(object): """ A twisted-powered zmq events component. """ + _factory = txzmq.ZmqFactory() + _factory.registerForShutdown() __metaclass__ = ABCMeta @@ -62,8 +64,6 @@ class TxZmqComponent(object): """ Initialize the txzmq component. """ - self._factory = txzmq.ZmqFactory() - self._factory.registerForShutdown() if path_prefix is None: path_prefix = get_path_prefix(flags.STANDALONE) self._config_prefix = os.path.join(path_prefix, "leap", "events") @@ -93,21 +93,22 @@ class TxZmqComponent(object): :return: The binded connection. :rtype: txzmq.ZmqConnection """ + endpoint = ZmqEndpoint(ZmqEndpointType.connect, address) connection = connClass(self._factory) - # create and configure socket - socket = connection.socket - if zmq_has_curve(): + + if self.use_curve: + socket = connection.socket public, secret = maybe_create_and_get_certificates( self._config_prefix, self.component_type) server_public_file = os.path.join( self._config_prefix, PUBLIC_KEYS_PREFIX, "server.key") + server_public, _ = zmq.auth.load_certificate(server_public_file) socket.curve_publickey = public socket.curve_secretkey = secret socket.curve_serverkey = server_public - socket.connect(address) - logger.debug("Connected %s to %s." % (connClass, address)) - self._connections.append(connection) + + connection.addEndpoints([endpoint]) return connection def _zmq_bind(self, connClass, address): @@ -122,33 +123,21 @@ class TxZmqComponent(object): :return: The binded connection and port. :rtype: (txzmq.ZmqConnection, int) """ + proto, addr, port = ADDRESS_RE.search(address).groups() + + endpoint = ZmqEndpoint(ZmqEndpointType.bind, address) connection = connClass(self._factory) - socket = connection.socket - if zmq_has_curve(): + + if self.use_curve: + socket = connection.socket + public, secret = maybe_create_and_get_certificates( self._config_prefix, self.component_type) socket.curve_publickey = public socket.curve_secretkey = secret self._start_thread_auth(connection.socket) - proto, addr, port = ADDRESS_RE.search(address).groups() - - if proto == "tcp": - if port is None or port is '0': - params = proto, addr - port = socket.bind_to_random_port("%s://%s" % params) - logger.debug("Binded %s to %s://%s." % ((connClass,) + params)) - else: - params = proto, addr, int(port) - socket.bind("%s://%s:%d" % params) - logger.debug( - "Binded %s to %s://%s:%d." % ((connClass,) + params)) - else: - params = proto, addr - socket.bind("%s://%s" % params) - logger.debug( - "Binded %s to %s://%s" % ((connClass,) + params)) - self._connections.append(connection) + connection.addEndpoints([endpoint]) return connection, port def _start_thread_auth(self, socket): @@ -158,6 +147,8 @@ class TxZmqComponent(object): :param socket: The socket in which to configure the authenticator. :type socket: zmq.Socket """ + # TODO re-implement without threads. + logger.debug("Starting thread authenticator...") authenticator = ThreadAuthenticator(self._factory.context) # Temporary fix until we understand what the problem is @@ -172,15 +163,6 @@ class TxZmqComponent(object): authenticator.configure_curve(domain="*", location=public_keys_dir) socket.curve_server = True # must come before bind - def shutdown(self): - """ - Shutdown the component. - """ - logger.debug("Shutting down component %s." % str(self)) - for conn in self._connections: - conn.shutdown() - self._factory.shutdown() - class TxZmqServerComponent(TxZmqComponent): """ -- cgit v1.2.3 From 24977b744b42df912a23a2861453e7d4d5202310 Mon Sep 17 00:00:00 2001 From: Kali Kaneko Date: Tue, 23 Feb 2016 19:28:05 -0400 Subject: [feature] reactor-based authenticator We don't really need a thread to make use of the ZAP authenticator. Document bug fix after authenticator thread is gone --- src/leap/common/events/auth.py | 96 ++++++++++++++++++++++++++++++ src/leap/common/events/examples/README.txt | 49 +++++++++++++++ src/leap/common/events/examples/client.py | 2 + src/leap/common/events/examples/server.py | 4 ++ src/leap/common/events/server.py | 5 +- src/leap/common/events/txclient.py | 3 +- src/leap/common/events/zmq_components.py | 44 ++++++-------- 7 files changed, 172 insertions(+), 31 deletions(-) create mode 100644 src/leap/common/events/auth.py create mode 100644 src/leap/common/events/examples/README.txt create mode 100644 src/leap/common/events/examples/client.py create mode 100644 src/leap/common/events/examples/server.py (limited to 'src') diff --git a/src/leap/common/events/auth.py b/src/leap/common/events/auth.py new file mode 100644 index 0000000..1a1bcab --- /dev/null +++ b/src/leap/common/events/auth.py @@ -0,0 +1,96 @@ +# -*- coding: utf-8 -*- +# auth.py +# Copyright (C) 2016 LEAP +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . +""" +ZAP authentication, twisted style. +""" +from zmq import PAIR +from zmq.auth.base import Authenticator, VERSION +from txzmq.connection import ZmqConnection +from zmq.utils.strtypes import b, u + +from twisted.python import log + +from txzmq.connection import ZmqEndpoint, ZmqEndpointType + + +class TxAuthenticator(ZmqConnection): + + """ + This does not implement the whole ZAP protocol, but the bare minimum that + we need. + """ + + socketType = PAIR + address = 'inproc://zeromq.zap.01' + encoding = 'utf-8' + + def __init__(self, factory): + super(TxAuthenticator, self).__init__(factory) + self.authenticator = Authenticator(factory.context) + self.authenticator._send_zap_reply = self._send_zap_reply + + def start(self): + endpoint = ZmqEndpoint(ZmqEndpointType.bind, self.address) + self.addEndpoints([endpoint]) + + def messageReceived(self, msg): + + command = msg[0] + + if command == b'ALLOW': + addresses = [u(m, self.encoding) for m in msg[1:]] + try: + self.authenticator.allow(*addresses) + except Exception as e: + log.err("Failed to allow %s", addresses) + + elif command == b'CURVE': + domain = u(msg[1], self.encoding) + location = u(msg[2], self.encoding) + self.authenticator.configure_curve(domain, location) + + def _send_zap_reply(self, request_id, status_code, status_text, + user_id='user'): + """ + Send a ZAP reply to finish the authentication. + """ + user_id = user_id if status_code == b'200' else b'' + if isinstance(user_id, unicode): + user_id = user_id.encode(self.encoding, 'replace') + metadata = b'' # not currently used + reply = [VERSION, request_id, status_code, status_text, + user_id, metadata] + self.send(reply) + + +class TxAuthenticationRequest(ZmqConnection): + + socketType = PAIR + address = 'inproc://zeromq.zap.01' + encoding = 'utf-8' + + def start(self): + endpoint = ZmqEndpoint(ZmqEndpointType.connect, self.address) + self.addEndpoints([endpoint]) + + def allow(self, *addresses): + self.send([b'ALLOW'] + [b(a, self.encoding) for a in addresses]) + + def configure_curve(self, domain='*', location=''): + domain = b(domain, self.encoding) + location = b(location, self.encoding) + self.send([b'CURVE', domain, location]) diff --git a/src/leap/common/events/examples/README.txt b/src/leap/common/events/examples/README.txt new file mode 100644 index 0000000..0bb0df6 --- /dev/null +++ b/src/leap/common/events/examples/README.txt @@ -0,0 +1,49 @@ +How to debug +----------------------------------------- +monitor the events socket: + sudo ngrep -W byline -d any port 9000 + +launch the server: + python server.py + +launch the client: + python client.py + +if zmq is available and enabled, you should see encrypted messages passing by +the socket. + +You should see something like the following: + +#### +T 127.0.0.1:9000 -> 127.0.0.1:33122 [AP] +.......... +## +T 127.0.0.1:33122 -> 127.0.0.1:9000 [AP] +........... +## +T 127.0.0.1:9000 -> 127.0.0.1:33122 [AP] +..CURVE............................................... +# +T 127.0.0.1:33122 -> 127.0.0.1:9000 [AP] +.CURVE............................................... +# +T 127.0.0.1:33122 -> 127.0.0.1:9000 [AP] +...HELLO.............................................................................:....^...".....'.S...n......Y...................O.7.+.D.q".*..R...j.....8..qu..~......Ck.G\....:...m....Tg.s..M..x<.. +## +T 127.0.0.1:9000 -> 127.0.0.1:33122 [AP] +...WELCOME..%.'.,Td... I..}...........`..Nm......./_.Je...4.....-.....f 127.0.0.1:9000 [AP] +..........INITIATE......!.*.=0.-......D..]{...A\.tz...!2.....A./ +6.......Y.h.N....cb.U.|..f..)....W..3..X.2U.3PGl.........m..95.(......NJ....5.'..W.GQ..B/.....\%.,Q..r.'L5.......{.W<=._.$.(6j.G... +...37.H..Th...'.........0 ........,..q....U..G..M.`!_..w....f.".......... +.d.K.Y.>f.n.kV. +# +T 127.0.0.1:9000 -> 127.0.0.1:33122 [AP] +.2.READY............A...e.)......*.8y....k.<.N1Z.4.. +# +T 127.0.0.1:33122 -> 127.0.0.1:9000 [AP] +.+.MESSAGE........o...*M..,.... +.r..w..[.GwcU +### + diff --git a/src/leap/common/events/examples/client.py b/src/leap/common/events/examples/client.py new file mode 100644 index 0000000..d6d8985 --- /dev/null +++ b/src/leap/common/events/examples/client.py @@ -0,0 +1,2 @@ +from leap.common.events.txclient import emit +emit('stuff!') diff --git a/src/leap/common/events/examples/server.py b/src/leap/common/events/examples/server.py new file mode 100644 index 0000000..f40f8dc --- /dev/null +++ b/src/leap/common/events/examples/server.py @@ -0,0 +1,4 @@ +from twisted.internet import reactor +from leap.common.events.server import ensure_server +reactor.callWhenRunning(ensure_server) +reactor.run() diff --git a/src/leap/common/events/server.py b/src/leap/common/events/server.py index 30a0c44..6252853 100644 --- a/src/leap/common/events/server.py +++ b/src/leap/common/events/server.py @@ -20,12 +20,11 @@ The server for the events mechanism. import logging import platform -from leap.common.zmq_utils import zmq_has_curve +import txzmq +from leap.common.zmq_utils import zmq_has_curve from leap.common.events.zmq_components import TxZmqServerComponent -import txzmq - if zmq_has_curve() or platform.system() == "Windows": # Windows doesn't have ipc sockets, we need to use always tcp diff --git a/src/leap/common/events/txclient.py b/src/leap/common/events/txclient.py index ca247ca..a2b704d 100644 --- a/src/leap/common/events/txclient.py +++ b/src/leap/common/events/txclient.py @@ -28,10 +28,9 @@ some other client. import logging import pickle -from leap.common.events.zmq_components import TxZmqClientComponent - import txzmq +from leap.common.events.zmq_components import TxZmqClientComponent from leap.common.events.client import EventsClient from leap.common.events.client import configure_client from leap.common.events.server import EMIT_ADDR diff --git a/src/leap/common/events/zmq_components.py b/src/leap/common/events/zmq_components.py index 1e0d52a..74abb76 100644 --- a/src/leap/common/events/zmq_components.py +++ b/src/leap/common/events/zmq_components.py @@ -21,16 +21,13 @@ import os import logging import txzmq import re -import time - from abc import ABCMeta -# XXX some distros don't package libsodium, so we have to be prepared for -# absence of zmq.auth try: import zmq.auth - from zmq.auth.thread import ThreadAuthenticator + from leap.common.events.auth import TxAuthenticator + from leap.common.events.auth import TxAuthenticationRequest except ImportError: pass @@ -38,16 +35,15 @@ from txzmq.connection import ZmqEndpoint, ZmqEndpointType from leap.common.config import flags, get_path_prefix from leap.common.zmq_utils import zmq_has_curve - from leap.common.zmq_utils import maybe_create_and_get_certificates from leap.common.zmq_utils import PUBLIC_KEYS_PREFIX - logger = logging.getLogger(__name__) - ADDRESS_RE = re.compile("^([a-z]+)://([^:]+):?(\d+)?$") +LOCALHOST_ALLOWED = '127.0.0.1' + class TxZmqComponent(object): """ @@ -55,6 +51,7 @@ class TxZmqComponent(object): """ _factory = txzmq.ZmqFactory() _factory.registerForShutdown() + _auth = None __metaclass__ = ABCMeta @@ -135,33 +132,28 @@ class TxZmqComponent(object): self._config_prefix, self.component_type) socket.curve_publickey = public socket.curve_secretkey = secret - self._start_thread_auth(connection.socket) + self._start_authentication(connection.socket) connection.addEndpoints([endpoint]) return connection, port - def _start_thread_auth(self, socket): - """ - Start the zmq curve thread authenticator. + def _start_authentication(self, socket): - :param socket: The socket in which to configure the authenticator. - :type socket: zmq.Socket - """ - # TODO re-implement without threads. - logger.debug("Starting thread authenticator...") - authenticator = ThreadAuthenticator(self._factory.context) + if not TxZmqComponent._auth: + TxZmqComponent._auth = TxAuthenticator(self._factory) + TxZmqComponent._auth.start() - # Temporary fix until we understand what the problem is - # See https://leap.se/code/issues/7536 - time.sleep(0.5) + auth_req = TxAuthenticationRequest(self._factory) + auth_req.start() + auth_req.allow(LOCALHOST_ALLOWED) - authenticator.start() - # XXX do not hardcode this here. - authenticator.allow('127.0.0.1') # tell authenticator to use the certificate in a directory public_keys_dir = os.path.join(self._config_prefix, PUBLIC_KEYS_PREFIX) - authenticator.configure_curve(domain="*", location=public_keys_dir) - socket.curve_server = True # must come before bind + auth_req.configure_curve(domain="*", location=public_keys_dir) + + # This has to be set before binding the socket, that's why this method + # has to be called before addEndpoints() + socket.curve_server = True class TxZmqServerComponent(TxZmqComponent): -- cgit v1.2.3 From 027ad7eed50947608738ce0009fccf776936e55c Mon Sep 17 00:00:00 2001 From: Kali Kaneko Date: Mon, 29 Feb 2016 19:33:28 -0400 Subject: [tests] adapt events tests to recent changes --- src/leap/common/events/auth.py | 4 +- src/leap/common/events/client.py | 23 +++- src/leap/common/events/server.py | 13 +- src/leap/common/events/tests/test_auth.py | 64 +++++++++ src/leap/common/events/tests/test_events.py | 204 ++++++++++++++++++++++++++++ src/leap/common/events/txclient.py | 8 +- src/leap/common/events/zmq_components.py | 58 ++++---- src/leap/common/testing/basetest.py | 9 +- src/leap/common/tests/test_events.py | 198 --------------------------- 9 files changed, 340 insertions(+), 241 deletions(-) create mode 100644 src/leap/common/events/tests/test_auth.py create mode 100644 src/leap/common/events/tests/test_events.py delete mode 100644 src/leap/common/tests/test_events.py (limited to 'src') diff --git a/src/leap/common/events/auth.py b/src/leap/common/events/auth.py index 1a1bcab..5b71f2d 100644 --- a/src/leap/common/events/auth.py +++ b/src/leap/common/events/auth.py @@ -38,8 +38,8 @@ class TxAuthenticator(ZmqConnection): address = 'inproc://zeromq.zap.01' encoding = 'utf-8' - def __init__(self, factory): - super(TxAuthenticator, self).__init__(factory) + def __init__(self, factory, *args, **kw): + super(TxAuthenticator, self).__init__(factory, *args, **kw) self.authenticator = Authenticator(factory.context) self.authenticator._send_zap_reply = self._send_zap_reply diff --git a/src/leap/common/events/client.py b/src/leap/common/events/client.py index 60d24bc..78617de 100644 --- a/src/leap/common/events/client.py +++ b/src/leap/common/events/client.py @@ -63,14 +63,18 @@ logger = logging.getLogger(__name__) _emit_addr = EMIT_ADDR _reg_addr = REG_ADDR +_factory = None +_enable_curve = True -def configure_client(emit_addr, reg_addr): - global _emit_addr, _reg_addr +def configure_client(emit_addr, reg_addr, factory=None, enable_curve=True): + global _emit_addr, _reg_addr, _factory, _enable_curve logger.debug("Configuring client with addresses: (%s, %s)" % (emit_addr, reg_addr)) _emit_addr = emit_addr _reg_addr = reg_addr + _factory = factory + _enable_curve = enable_curve class EventsClient(object): @@ -103,7 +107,9 @@ class EventsClient(object): """ with cls._instance_lock: if cls._instance is None: - cls._instance = cls(_emit_addr, _reg_addr) + cls._instance = cls( + _emit_addr, _reg_addr, factory=_factory, + enable_curve=_enable_curve) return cls._instance def register(self, event, callback, uid=None, replace=False): @@ -270,7 +276,7 @@ class EventsClientThread(threading.Thread, EventsClient): A threaded version of the events client. """ - def __init__(self, emit_addr, reg_addr): + def __init__(self, emit_addr, reg_addr, factory=None, enable_curve=True): """ Initialize the events client. """ @@ -281,15 +287,22 @@ class EventsClientThread(threading.Thread, EventsClient): self._config_prefix = os.path.join( get_path_prefix(flags.STANDALONE), "leap", "events") self._loop = None + self._factory = factory self._context = None self._push = None self._sub = None + if enable_curve: + self.use_curve = zmq_has_curve() + else: + self.use_curve = False + def _init_zmq(self): """ Initialize ZMQ connections. """ self._loop = EventsIOLoop() + # we need a new context for each thread self._context = zmq.Context() # connect SUB first, otherwise we might miss some event sent from this # same client @@ -311,7 +324,7 @@ class EventsClientThread(threading.Thread, EventsClient): logger.debug("Connecting %s to %s." % (socktype, address)) socket = self._context.socket(socktype) # configure curve authentication - if zmq_has_curve(): + if self.use_curve: public, private = maybe_create_and_get_certificates( self._config_prefix, "client") server_public_file = os.path.join( diff --git a/src/leap/common/events/server.py b/src/leap/common/events/server.py index 6252853..ad79abe 100644 --- a/src/leap/common/events/server.py +++ b/src/leap/common/events/server.py @@ -37,7 +37,8 @@ else: logger = logging.getLogger(__name__) -def ensure_server(emit_addr=EMIT_ADDR, reg_addr=REG_ADDR): +def ensure_server(emit_addr=EMIT_ADDR, reg_addr=REG_ADDR, path_prefix=None, + factory=None, enable_curve=True): """ Make sure the server is running in the given addresses. @@ -49,7 +50,8 @@ def ensure_server(emit_addr=EMIT_ADDR, reg_addr=REG_ADDR): :return: an events server instance :rtype: EventsServer """ - _server = EventsServer(emit_addr, reg_addr) + _server = EventsServer(emit_addr, reg_addr, path_prefix, factory=factory, + enable_curve=enable_curve) return _server @@ -59,7 +61,8 @@ class EventsServer(TxZmqServerComponent): events in another address. """ - def __init__(self, emit_addr, reg_addr): + def __init__(self, emit_addr, reg_addr, path_prefix=None, factory=None, + enable_curve=True): """ Initialize the events server. @@ -68,7 +71,9 @@ class EventsServer(TxZmqServerComponent): :param reg_addr: The address to which publish events to clients. :type reg_addr: str """ - TxZmqServerComponent.__init__(self) + TxZmqServerComponent.__init__(self, path_prefix=path_prefix, + factory=factory, + enable_curve=enable_curve) # bind PULL and PUB sockets self._pull, self.pull_port = self._zmq_bind( txzmq.ZmqPullConnection, emit_addr) diff --git a/src/leap/common/events/tests/test_auth.py b/src/leap/common/events/tests/test_auth.py new file mode 100644 index 0000000..78ffd9f --- /dev/null +++ b/src/leap/common/events/tests/test_auth.py @@ -0,0 +1,64 @@ +# -*- coding: utf-8 -*- +# test_zmq_components.py +# Copyright (C) 2014 LEAP +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . +""" +Tests for the auth module. +""" +import os + +from twisted.trial import unittest +from txzmq import ZmqFactory + +from leap.common.events import auth +from leap.common.testing.basetest import BaseLeapTest +from leap.common.zmq_utils import PUBLIC_KEYS_PREFIX +from leap.common.zmq_utils import maybe_create_and_get_certificates + +from txzmq.test import _wait + + +class ZmqAuthTestCase(unittest.TestCase, BaseLeapTest): + + def setUp(self): + self.setUpEnv(launch_events_server=False) + + self.factory = ZmqFactory() + self._config_prefix = os.path.join(self.tempdir, "leap", "events") + + self.public, self.secret = maybe_create_and_get_certificates( + self._config_prefix, 'server') + + self.authenticator = auth.TxAuthenticator(self.factory) + self.authenticator.start() + self.auth_req = auth.TxAuthenticationRequest(self.factory) + + def tearDown(self): + self.factory.shutdown() + self.tearDownEnv() + + def test_curve_auth(self): + self.auth_req.start() + self.auth_req.allow('127.0.0.1') + public_keys_dir = os.path.join(self._config_prefix, PUBLIC_KEYS_PREFIX) + self.auth_req.configure_curve(domain="*", location=public_keys_dir) + + def check(ignored): + authenticator = self.authenticator.authenticator + certs = authenticator.certs['*'] + self.failUnlessEqual(authenticator.whitelist, set([u'127.0.0.1'])) + self.failUnlessEqual(certs[certs.keys()[0]], True) + + return _wait(0.1).addCallback(check) diff --git a/src/leap/common/events/tests/test_events.py b/src/leap/common/events/tests/test_events.py new file mode 100644 index 0000000..c45601b --- /dev/null +++ b/src/leap/common/events/tests/test_events.py @@ -0,0 +1,204 @@ +# -*- coding: utf-8 -*- +# test_events.py +# Copyright (C) 2013 LEAP +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . +""" +Tests for the events framework +""" +import os +import logging + +from twisted.internet.reactor import callFromThread +from twisted.trial import unittest +from twisted.internet import defer + +from txzmq import ZmqFactory + +from leap.common.events import server +from leap.common.events import client +from leap.common.events import flags +from leap.common.events import txclient +from leap.common.events import catalog +from leap.common.events.errors import CallbackAlreadyRegisteredError + + +if 'DEBUG' in os.environ: + logging.basicConfig(level=logging.DEBUG) + + +class EventsGenericClientTestCase(object): + + def setUp(self): + flags.set_events_enabled(True) + self.factory = ZmqFactory() + self._server = server.ensure_server( + emit_addr="tcp://127.0.0.1:0", + reg_addr="tcp://127.0.0.1:0", + factory=self.factory, + enable_curve=False) + + self._client.configure_client( + emit_addr="tcp://127.0.0.1:%d" % self._server.pull_port, + reg_addr="tcp://127.0.0.1:%d" % self._server.pub_port, + factory=self.factory, enable_curve=False) + + def tearDown(self): + flags.set_events_enabled(False) + self.factory.shutdown() + self._client.instance().reset() + + def test_client_register(self): + """ + Ensure clients can register callbacks. + """ + callbacks = self._client.instance().callbacks + self.assertTrue(len(callbacks) == 0, + 'There should be no callback for this event.') + # register one event + event1 = catalog.CLIENT_UID + + def cbk1(event, _): + return True + + uid1 = self._client.register(event1, cbk1) + # assert for correct registration + self.assertTrue(len(callbacks) == 1) + self.assertTrue(callbacks[event1][uid1] == cbk1, + 'Could not register event in local client.') + # register another event + event2 = catalog.CLIENT_SESSION_ID + + def cbk2(event, _): + return True + + uid2 = self._client.register(event2, cbk2) + # assert for correct registration + self.assertTrue(len(callbacks) == 2) + self.assertTrue(callbacks[event2][uid2] == cbk2, + 'Could not register event in local client.') + + + def test_register_signal_replace(self): + """ + Make sure clients can replace already registered callbacks. + """ + event = catalog.CLIENT_UID + d = defer.Deferred() + + def cbk_fail(event, _): + return callFromThread(d.errback, event) + + def cbk_succeed(event, _): + return callFromThread(d.callback, event) + + self._client.register(event, cbk_fail, uid=1) + self._client.register(event, cbk_succeed, uid=1, replace=True) + self._client.emit(event, None) + return d + + def test_register_signal_replace_fails_when_replace_is_false(self): + """ + Make sure clients trying to replace already registered callbacks fail + when replace=False + """ + event = catalog.CLIENT_UID + self._client.register(event, lambda event, _: None, uid=1) + self.assertRaises( + CallbackAlreadyRegisteredError, + self._client.register, + event, lambda event, _: None, uid=1, replace=False) + + def test_register_more_than_one_callback_works(self): + """ + Make sure clients can replace already registered callbacks. + """ + event = catalog.CLIENT_UID + d1 = defer.Deferred() + + def cbk1(event, _): + return callFromThread(d1.callback, event) + + d2 = defer.Deferred() + + def cbk2(event, _): + return d2.callback(event) + + self._client.register(event, cbk1) + self._client.register(event, cbk2) + self._client.emit(event, None) + d = defer.gatherResults([d1, d2]) + return d + + def test_client_receives_signal(self): + """ + Ensure clients can receive signals. + """ + event = catalog.CLIENT_UID + d = defer.Deferred() + + def cbk(events, _): + callFromThread(d.callback, event) + + self._client.register(event, cbk) + self._client.emit(event, None) + return d + + def test_client_unregister_all(self): + """ + Test that the client can unregister all events for one signal. + """ + event1 = catalog.CLIENT_UID + d = defer.Deferred() + # register more than one callback for the same event + self._client.register( + event1, lambda ev, _: callFromThread(d.errback, None)) + self._client.register( + event1, lambda ev, _: callFromThread(d.errback, None)) + # unregister and emit the event + self._client.unregister(event1) + self._client.emit(event1, None) + # register and emit another event so the deferred can succeed + event2 = catalog.CLIENT_SESSION_ID + self._client.register( + event2, lambda ev, _: callFromThread(d.callback, None)) + self._client.emit(event2, None) + return d + + def test_client_unregister_by_uid(self): + """ + Test that the client can unregister an event by uid. + """ + event = catalog.CLIENT_UID + d = defer.Deferred() + # register one callback that would fail + uid = self._client.register( + event, lambda ev, _: callFromThread(d.errback, None)) + # register one callback that will succeed + self._client.register( + event, lambda ev, _: callFromThread(d.callback, None)) + # unregister by uid and emit the event + self._client.unregister(event, uid=uid) + self._client.emit(event, None) + return d + + +class EventsTxClientTestCase(EventsGenericClientTestCase, unittest.TestCase): + + _client = txclient + + +class EventsClientTestCase(EventsGenericClientTestCase, unittest.TestCase): + + _client = client diff --git a/src/leap/common/events/txclient.py b/src/leap/common/events/txclient.py index a2b704d..63f12d7 100644 --- a/src/leap/common/events/txclient.py +++ b/src/leap/common/events/txclient.py @@ -58,11 +58,13 @@ class EventsTxClient(TxZmqClientComponent, EventsClient): """ def __init__(self, emit_addr=EMIT_ADDR, reg_addr=REG_ADDR, - path_prefix=None): + path_prefix=None, factory=None, enable_curve=True): """ - Initialize the events server. + Initialize the events client. """ - TxZmqClientComponent.__init__(self, path_prefix=path_prefix) + TxZmqClientComponent.__init__( + self, path_prefix=path_prefix, factory=factory, + enable_curve=enable_curve) EventsClient.__init__(self, emit_addr, reg_addr) # connect SUB first, otherwise we might miss some event sent from this # same client diff --git a/src/leap/common/events/zmq_components.py b/src/leap/common/events/zmq_components.py index 74abb76..8919cd9 100644 --- a/src/leap/common/events/zmq_components.py +++ b/src/leap/common/events/zmq_components.py @@ -57,12 +57,14 @@ class TxZmqComponent(object): _component_type = None - def __init__(self, path_prefix=None, enable_curve=True): + def __init__(self, path_prefix=None, enable_curve=True, factory=None): """ Initialize the txzmq component. """ if path_prefix is None: path_prefix = get_path_prefix(flags.STANDALONE) + if factory is not None: + self._factory = factory self._config_prefix = os.path.join(path_prefix, "leap", "events") self._connections = [] if enable_curve: @@ -78,64 +80,69 @@ class TxZmqComponent(object): "define a self._component_type!") return self._component_type - def _zmq_connect(self, connClass, address): + def _zmq_bind(self, connClass, address): """ - Connect to an address. + Bind to an address. :param connClass: The connection class to be used. :type connClass: txzmq.ZmqConnection - :param address: The address to connect to. + :param address: The address to bind to. :type address: str - :return: The binded connection. - :rtype: txzmq.ZmqConnection + :return: The binded connection and port. + :rtype: (txzmq.ZmqConnection, int) """ - endpoint = ZmqEndpoint(ZmqEndpointType.connect, address) + proto, addr, port = ADDRESS_RE.search(address).groups() + + endpoint = ZmqEndpoint(ZmqEndpointType.bind, address) connection = connClass(self._factory) if self.use_curve: socket = connection.socket + public, secret = maybe_create_and_get_certificates( self._config_prefix, self.component_type) - server_public_file = os.path.join( - self._config_prefix, PUBLIC_KEYS_PREFIX, "server.key") - - server_public, _ = zmq.auth.load_certificate(server_public_file) socket.curve_publickey = public socket.curve_secretkey = secret - socket.curve_serverkey = server_public + self._start_authentication(connection.socket) - connection.addEndpoints([endpoint]) - return connection + if proto == 'tcp' and int(port) == 0: + connection.endpoints.extend([endpoint]) + port = connection.socket.bind_to_random_port('tcp://%s' % addr) + else: + connection.addEndpoints([endpoint]) - def _zmq_bind(self, connClass, address): + return connection, int(port) + + def _zmq_connect(self, connClass, address): """ - Bind to an address. + Connect to an address. :param connClass: The connection class to be used. :type connClass: txzmq.ZmqConnection - :param address: The address to bind to. + :param address: The address to connect to. :type address: str - :return: The binded connection and port. - :rtype: (txzmq.ZmqConnection, int) + :return: The binded connection. + :rtype: txzmq.ZmqConnection """ - proto, addr, port = ADDRESS_RE.search(address).groups() - - endpoint = ZmqEndpoint(ZmqEndpointType.bind, address) + endpoint = ZmqEndpoint(ZmqEndpointType.connect, address) connection = connClass(self._factory) if self.use_curve: socket = connection.socket - public, secret = maybe_create_and_get_certificates( self._config_prefix, self.component_type) + server_public_file = os.path.join( + self._config_prefix, PUBLIC_KEYS_PREFIX, "server.key") + + server_public, _ = zmq.auth.load_certificate(server_public_file) socket.curve_publickey = public socket.curve_secretkey = secret - self._start_authentication(connection.socket) + socket.curve_serverkey = server_public connection.addEndpoints([endpoint]) - return connection, port + return connection def _start_authentication(self, socket): @@ -150,6 +157,7 @@ class TxZmqComponent(object): # tell authenticator to use the certificate in a directory public_keys_dir = os.path.join(self._config_prefix, PUBLIC_KEYS_PREFIX) auth_req.configure_curve(domain="*", location=public_keys_dir) + auth_req.shutdown() # This has to be set before binding the socket, that's why this method # has to be called before addEndpoints() diff --git a/src/leap/common/testing/basetest.py b/src/leap/common/testing/basetest.py index 3d3cee0..2e84a25 100644 --- a/src/leap/common/testing/basetest.py +++ b/src/leap/common/testing/basetest.py @@ -52,7 +52,7 @@ class BaseLeapTest(unittest.TestCase): cls.tearDownEnv() @classmethod - def setUpEnv(cls): + def setUpEnv(cls, launch_events_server=True): """ Sets up common facilities for testing this TestCase: - custom PATH and HOME environmental variables @@ -72,14 +72,15 @@ class BaseLeapTest(unittest.TestCase): os.environ["PATH"] = bin_tdir os.environ["HOME"] = cls.tempdir os.environ["XDG_CONFIG_HOME"] = os.path.join(cls.tempdir, ".config") - cls._init_events() + if launch_events_server: + cls._init_events() @classmethod def _init_events(cls): if flags.EVENTS_ENABLED: cls._server = events_server.ensure_server( - emit_addr="tcp://127.0.0.1:0", - reg_addr="tcp://127.0.0.1:0") + emit_addr="tcp://127.0.0.1", + reg_addr="tcp://127.0.0.1") events_client.configure_client( emit_addr="tcp://127.0.0.1:%d" % cls._server.pull_port, reg_addr="tcp://127.0.0.1:%d" % cls._server.pub_port) diff --git a/src/leap/common/tests/test_events.py b/src/leap/common/tests/test_events.py deleted file mode 100644 index 2ad097e..0000000 --- a/src/leap/common/tests/test_events.py +++ /dev/null @@ -1,198 +0,0 @@ -# -*- coding: utf-8 -*- -# test_events.py -# Copyright (C) 2013 LEAP -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU General Public License -# along with this program. If not, see . - - -import os -import logging -import time - -from twisted.internet.reactor import callFromThread -from twisted.trial import unittest -from twisted.internet import defer - -from leap.common.events import server -from leap.common.events import client -from leap.common.events import flags -from leap.common.events import txclient -from leap.common.events import catalog -from leap.common.events.errors import CallbackAlreadyRegisteredError - - -if 'DEBUG' in os.environ: - logging.basicConfig(level=logging.DEBUG) - - -class EventsGenericClientTestCase(object): - - def setUp(self): - flags.set_events_enabled(True) - self._server = server.ensure_server( - emit_addr="tcp://127.0.0.1:0", - reg_addr="tcp://127.0.0.1:0") - self._client.configure_client( - emit_addr="tcp://127.0.0.1:%d" % self._server.pull_port, - reg_addr="tcp://127.0.0.1:%d" % self._server.pub_port) - - def tearDown(self): - self._client.shutdown() - self._server.shutdown() - flags.set_events_enabled(False) - # wait a bit for sockets to close properly - time.sleep(0.1) - - def test_client_register(self): - """ - Ensure clients can register callbacks. - """ - callbacks = self._client.instance().callbacks - self.assertTrue(len(callbacks) == 0, - 'There should be no callback for this event.') - # register one event - event1 = catalog.CLIENT_UID - - def cbk1(event, _): - return True - - uid1 = self._client.register(event1, cbk1) - # assert for correct registration - self.assertTrue(len(callbacks) == 1) - self.assertTrue(callbacks[event1][uid1] == cbk1, - 'Could not register event in local client.') - # register another event - event2 = catalog.CLIENT_SESSION_ID - - def cbk2(event, _): - return True - - uid2 = self._client.register(event2, cbk2) - # assert for correct registration - self.assertTrue(len(callbacks) == 2) - self.assertTrue(callbacks[event2][uid2] == cbk2, - 'Could not register event in local client.') - - def test_register_signal_replace(self): - """ - Make sure clients can replace already registered callbacks. - """ - event = catalog.CLIENT_UID - d = defer.Deferred() - - def cbk_fail(event, _): - return callFromThread(d.errback, event) - - def cbk_succeed(event, _): - return callFromThread(d.callback, event) - - self._client.register(event, cbk_fail, uid=1) - self._client.register(event, cbk_succeed, uid=1, replace=True) - self._client.emit(event, None) - return d - - def test_register_signal_replace_fails_when_replace_is_false(self): - """ - Make sure clients trying to replace already registered callbacks fail - when replace=False - """ - event = catalog.CLIENT_UID - self._client.register(event, lambda event, _: None, uid=1) - self.assertRaises( - CallbackAlreadyRegisteredError, - self._client.register, - event, lambda event, _: None, uid=1, replace=False) - - def test_register_more_than_one_callback_works(self): - """ - Make sure clients can replace already registered callbacks. - """ - event = catalog.CLIENT_UID - d1 = defer.Deferred() - - def cbk1(event, _): - return callFromThread(d1.callback, event) - - d2 = defer.Deferred() - - def cbk2(event, _): - return d2.callback(event) - - self._client.register(event, cbk1) - self._client.register(event, cbk2) - self._client.emit(event, None) - d = defer.gatherResults([d1, d2]) - return d - - def test_client_receives_signal(self): - """ - Ensure clients can receive signals. - """ - event = catalog.CLIENT_UID - d = defer.Deferred() - - def cbk(events, _): - callFromThread(d.callback, event) - - self._client.register(event, cbk) - self._client.emit(event, None) - return d - - def test_client_unregister_all(self): - """ - Test that the client can unregister all events for one signal. - """ - event1 = catalog.CLIENT_UID - d = defer.Deferred() - # register more than one callback for the same event - self._client.register( - event1, lambda ev, _: callFromThread(d.errback, None)) - self._client.register( - event1, lambda ev, _: callFromThread(d.errback, None)) - # unregister and emit the event - self._client.unregister(event1) - self._client.emit(event1, None) - # register and emit another event so the deferred can succeed - event2 = catalog.CLIENT_SESSION_ID - self._client.register( - event2, lambda ev, _: callFromThread(d.callback, None)) - self._client.emit(event2, None) - return d - - def test_client_unregister_by_uid(self): - """ - Test that the client can unregister an event by uid. - """ - event = catalog.CLIENT_UID - d = defer.Deferred() - # register one callback that would fail - uid = self._client.register( - event, lambda ev, _: callFromThread(d.errback, None)) - # register one callback that will succeed - self._client.register( - event, lambda ev, _: callFromThread(d.callback, None)) - # unregister by uid and emit the event - self._client.unregister(event, uid=uid) - self._client.emit(event, None) - return d - - -class EventsTxClientTestCase(EventsGenericClientTestCase, unittest.TestCase): - - _client = txclient - - -class EventsClientTestCase(EventsGenericClientTestCase, unittest.TestCase): - - _client = client -- cgit v1.2.3 From 07dff4d010b284d8d46eb3b8a859083013c7441f Mon Sep 17 00:00:00 2001 From: Kali Kaneko Date: Wed, 9 Mar 2016 10:48:23 -0400 Subject: [style] pep8 --- src/leap/common/events/server.py | 2 +- src/leap/common/events/tests/test_events.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) (limited to 'src') diff --git a/src/leap/common/events/server.py b/src/leap/common/events/server.py index ad79abe..05fc23e 100644 --- a/src/leap/common/events/server.py +++ b/src/leap/common/events/server.py @@ -51,7 +51,7 @@ def ensure_server(emit_addr=EMIT_ADDR, reg_addr=REG_ADDR, path_prefix=None, :rtype: EventsServer """ _server = EventsServer(emit_addr, reg_addr, path_prefix, factory=factory, - enable_curve=enable_curve) + enable_curve=enable_curve) return _server diff --git a/src/leap/common/events/tests/test_events.py b/src/leap/common/events/tests/test_events.py index c45601b..d8435c6 100644 --- a/src/leap/common/events/tests/test_events.py +++ b/src/leap/common/events/tests/test_events.py @@ -89,7 +89,6 @@ class EventsGenericClientTestCase(object): self.assertTrue(callbacks[event2][uid2] == cbk2, 'Could not register event in local client.') - def test_register_signal_replace(self): """ Make sure clients can replace already registered callbacks. -- cgit v1.2.3 From ecf025e3d6065c9729ac72489efcdc0218fdffe1 Mon Sep 17 00:00:00 2001 From: Kali Kaneko Date: Wed, 2 Mar 2016 11:53:50 -0400 Subject: [feature] HookableService ad-hoc register/trigger mechanism used for service composition. to be used in bitmask.core and bitmask.bonafide in the first place. --- src/leap/common/service_hooks.py | 75 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 75 insertions(+) create mode 100644 src/leap/common/service_hooks.py (limited to 'src') diff --git a/src/leap/common/service_hooks.py b/src/leap/common/service_hooks.py new file mode 100644 index 0000000..96e95cc --- /dev/null +++ b/src/leap/common/service_hooks.py @@ -0,0 +1,75 @@ +# -*- coding: utf-8 -*- +# service_hooks.py +# Copyright (C) 2016 LEAP +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . +""" +Hooks for service composition. +""" +from collections import defaultdict + +from twisted.application.service import IService, Service +from twisted.python import log + +from zope.interface import implementer + + +@implementer(IService) +class HookableService(Service): + + """ + This service allows for other services in a Twisted Service tree to be + notified whenever a certain kind of hook is triggered. + + During the service composition, one is expected to register + a hook name with the name of the service that wants to react to the + triggering of the hook. All the services, both hooked and listeners, should + be registered against the same parent service. + + Upon the hook being triggered, the method "hook_" will be called with + the passed data in the listener service. + """ + + def register_hook(self, name, listener): + if not hasattr(self, 'event_listeners'): + self.event_listeners = defaultdict(list) + log.msg("Registering hook %s->%s" % (name, listener)) + self.event_listeners[name].append(listener) + + def trigger_hook(self, name, **data): + + def react_to_hook(listener, name, **kw): + try: + getattr(listener, 'hook_' + name)(**kw) + except AttributeError: + raise RuntimeError( + "Tried to notify a hook, but the listener service class %s" + "has not defined the proper method" % listener.__class__) + + if not hasattr(self, 'event_listeners'): + self.event_listeners = defaultdict(list) + listeners = self._get_listener_services(name) + + for listener in listeners: + react_to_hook(listener, name, **data) + + def _get_sibling_service(self, name): + return self.parent.getServiceNamed(name) + + def _get_listener_services(self, hook): + if hook in self.event_listeners: + service_names = self.event_listeners[hook] + services = [ + self._get_sibling_service(name) for name in service_names] + return services -- cgit v1.2.3 From 3a317f04bfa55698a7064ea3d5c5a1b4cc5ead36 Mon Sep 17 00:00:00 2001 From: Christoph Kluenter Date: Wed, 16 Mar 2016 17:03:52 +0100 Subject: [bug] close TxAuthenticator properly otherwise the context.term() does not return --- src/leap/common/events/auth.py | 4 ++++ src/leap/common/events/zmq_components.py | 1 + 2 files changed, 5 insertions(+) (limited to 'src') diff --git a/src/leap/common/events/auth.py b/src/leap/common/events/auth.py index 5b71f2d..db217ca 100644 --- a/src/leap/common/events/auth.py +++ b/src/leap/common/events/auth.py @@ -76,6 +76,10 @@ class TxAuthenticator(ZmqConnection): user_id, metadata] self.send(reply) + def shutdown(self): + if self.factory: + super(TxAuthenticator, self).shutdown() + class TxAuthenticationRequest(ZmqConnection): diff --git a/src/leap/common/events/zmq_components.py b/src/leap/common/events/zmq_components.py index 8919cd9..c533a74 100644 --- a/src/leap/common/events/zmq_components.py +++ b/src/leap/common/events/zmq_components.py @@ -158,6 +158,7 @@ class TxZmqComponent(object): public_keys_dir = os.path.join(self._config_prefix, PUBLIC_KEYS_PREFIX) auth_req.configure_curve(domain="*", location=public_keys_dir) auth_req.shutdown() + TxZmqComponent._auth.shutdown() # This has to be set before binding the socket, that's why this method # has to be called before addEndpoints() -- cgit v1.2.3 From 9365e03b6490e0b86ddde0fe35854431bf17d94c Mon Sep 17 00:00:00 2001 From: Kali Kaneko Date: Thu, 24 Mar 2016 09:24:37 -0400 Subject: [doc] update event annotation --- src/leap/common/events/catalog.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'src') diff --git a/src/leap/common/events/catalog.py b/src/leap/common/events/catalog.py index 6a09495..9a834b2 100644 --- a/src/leap/common/events/catalog.py +++ b/src/leap/common/events/catalog.py @@ -52,7 +52,7 @@ EVENTS = [ "MAIL_MSG_DELETED_INCOMING", # (userid) "MAIL_MSG_PROCESSING", # (userid) "MAIL_MSG_SAVED_LOCALLY", # (userid) - "MAIL_UNREAD_MESSAGES", # (userid) + "MAIL_UNREAD_MESSAGES", # (userid, number) "IMAP_SERVICE_STARTED", "IMAP_SERVICE_FAILED_TO_START", -- cgit v1.2.3 From 334fe8d2d38466ad309e1214d003f977f603dfb9 Mon Sep 17 00:00:00 2001 From: Kali Kaneko Date: Fri, 1 Apr 2016 17:33:26 -0400 Subject: [pkg] update to versioneer 0.16 --- src/leap/common/__init__.py | 2 +- src/leap/common/_version.py | 541 +++++++++++++++++++++++++++++++++----------- 2 files changed, 412 insertions(+), 131 deletions(-) (limited to 'src') diff --git a/src/leap/common/__init__.py b/src/leap/common/__init__.py index 383e198..3b07cf8 100644 --- a/src/leap/common/__init__.py +++ b/src/leap/common/__init__.py @@ -4,7 +4,6 @@ from leap.common import certs from leap.common import check from leap.common import files from leap.common import events -from ._version import get_versions logger = logging.getLogger(__name__) @@ -17,5 +16,6 @@ except ImportError: __all__ = ["certs", "check", "files", "events"] +from ._version import get_versions __version__ = get_versions()['version'] del get_versions diff --git a/src/leap/common/_version.py b/src/leap/common/_version.py index de94ba8..e29d969 100644 --- a/src/leap/common/_version.py +++ b/src/leap/common/_version.py @@ -1,73 +1,157 @@ + # This file helps to compute a version number in source trees obtained from # git-archive tarball (such as those provided by githubs download-from-tag -# feature). Distribution tarballs (build by setup.py sdist) and build +# feature). Distribution tarballs (built by setup.py sdist) and build # directories (produced by setup.py build) will contain a much shorter file # that just contains the computed version number. # This file is released into the public domain. Generated by -# versioneer-0.7+ (https://github.com/warner/python-versioneer) +# versioneer-0.16 (https://github.com/warner/python-versioneer) -# these strings will be replaced by git during git-archive +"""Git implementation of _version.py.""" +import errno +import os +import re import subprocess import sys -import re -import os.path -IN_LONG_VERSION_PY = True -git_refnames = "$Format:%d$" -git_full = "$Format:%H$" +def get_keywords(): + """Get the keywords needed to look up the version information.""" + # these strings will be replaced by git during git-archive. + # setup.py/versioneer.py will grep for the variable names, so they must + # each be defined on a line of their own. _version.py will just call + # get_keywords(). + git_refnames = "$Format:%d$" + git_full = "$Format:%H$" + keywords = {"refnames": git_refnames, "full": git_full} + return keywords -def run_command(args, cwd=None, verbose=False): - try: - # remember shell=False, so use git.cmd on windows, not just git - p = subprocess.Popen(args, stdout=subprocess.PIPE, cwd=cwd) - except EnvironmentError: - e = sys.exc_info()[1] + +class VersioneerConfig: + """Container for Versioneer configuration parameters.""" + + +def get_config(): + """Create, populate and return the VersioneerConfig() object.""" + # these strings are filled in when 'setup.py versioneer' creates + # _version.py + cfg = VersioneerConfig() + cfg.VCS = "git" + cfg.style = "pep440" + cfg.tag_prefix = "" + cfg.parentdir_prefix = "None" + cfg.versionfile_source = "src/leap/common/_version.py" + cfg.verbose = False + return cfg + + +class NotThisMethod(Exception): + """Exception raised if a method is not valid for the current scenario.""" + + +LONG_VERSION_PY = {} +HANDLERS = {} + + +def register_vcs_handler(vcs, method): # decorator + """Decorator to mark a method as the handler for a particular VCS.""" + def decorate(f): + """Store f in HANDLERS[vcs][method].""" + if vcs not in HANDLERS: + HANDLERS[vcs] = {} + HANDLERS[vcs][method] = f + return f + return decorate + + +def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False): + """Call the given command(s).""" + assert isinstance(commands, list) + p = None + for c in commands: + try: + dispcmd = str([c] + args) + # remember shell=False, so use git.cmd on windows, not just git + p = subprocess.Popen([c] + args, cwd=cwd, stdout=subprocess.PIPE, + stderr=(subprocess.PIPE if hide_stderr + else None)) + break + except EnvironmentError: + e = sys.exc_info()[1] + if e.errno == errno.ENOENT: + continue + if verbose: + print("unable to run %s" % dispcmd) + print(e) + return None + else: if verbose: - print("unable to run %s" % args[0]) - print(e) + print("unable to find command, tried %s" % (commands,)) return None stdout = p.communicate()[0].strip() - if sys.version >= '3': + if sys.version_info[0] >= 3: stdout = stdout.decode() if p.returncode != 0: if verbose: - print("unable to run %s (error)" % args[0]) + print("unable to run %s (error)" % dispcmd) return None return stdout -def get_expanded_variables(versionfile_source): +def versions_from_parentdir(parentdir_prefix, root, verbose): + """Try to determine the version from the parent directory name. + + Source tarballs conventionally unpack into a directory that includes + both the project name and a version string. + """ + dirname = os.path.basename(root) + if not dirname.startswith(parentdir_prefix): + if verbose: + print("guessing rootdir is '%s', but '%s' doesn't start with " + "prefix '%s'" % (root, dirname, parentdir_prefix)) + raise NotThisMethod("rootdir doesn't start with parentdir_prefix") + return {"version": dirname[len(parentdir_prefix):], + "full-revisionid": None, + "dirty": False, "error": None} + + +@register_vcs_handler("git", "get_keywords") +def git_get_keywords(versionfile_abs): + """Extract version information from the given file.""" # the code embedded in _version.py can just fetch the value of these - # variables. When used from setup.py, we don't want to import - # _version.py, so we do it with a regexp instead. This function is not - # used from _version.py. - variables = {} + # keywords. When used from setup.py, we don't want to import _version.py, + # so we do it with a regexp instead. This function is not used from + # _version.py. + keywords = {} try: - f = open(versionfile_source, "r") + f = open(versionfile_abs, "r") for line in f.readlines(): if line.strip().startswith("git_refnames ="): mo = re.search(r'=\s*"(.*)"', line) if mo: - variables["refnames"] = mo.group(1) + keywords["refnames"] = mo.group(1) if line.strip().startswith("git_full ="): mo = re.search(r'=\s*"(.*)"', line) if mo: - variables["full"] = mo.group(1) + keywords["full"] = mo.group(1) f.close() except EnvironmentError: pass - return variables + return keywords -def versions_from_expanded_variables(variables, tag_prefix, verbose=False): - refnames = variables["refnames"].strip() +@register_vcs_handler("git", "keywords") +def git_versions_from_keywords(keywords, tag_prefix, verbose): + """Get version information from git keywords.""" + if not keywords: + raise NotThisMethod("no keywords at all, weird") + refnames = keywords["refnames"].strip() if refnames.startswith("$Format"): if verbose: - print("variables are unexpanded, not using") - return {} # unexpanded, so not in an unpacked git-archive tarball + print("keywords are unexpanded, not using") + raise NotThisMethod("unexpanded keywords, not a git-archive tarball") refs = set([r.strip() for r in refnames.strip("()").split(",")]) # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of # just "foo-1.0". If we see a "tag: " prefix, prefer those. @@ -83,7 +167,7 @@ def versions_from_expanded_variables(variables, tag_prefix, verbose=False): # "stabilization", as well as "HEAD" and "master". tags = set([r for r in refs if re.search(r'\d', r)]) if verbose: - print("discarding '%s', no digits" % ",".join(refs - tags)) + print("discarding '%s', no digits" % ",".join(refs-tags)) if verbose: print("likely tags: %s" % ",".join(sorted(tags))) for ref in sorted(tags): @@ -93,111 +177,308 @@ def versions_from_expanded_variables(variables, tag_prefix, verbose=False): if verbose: print("picking %s" % r) return {"version": r, - "full": variables["full"].strip()} - # no suitable tags, so we use the full revision id + "full-revisionid": keywords["full"].strip(), + "dirty": False, "error": None + } + # no suitable tags, so version is "0+unknown", but full hex is still there if verbose: - print("no suitable tags, using full revision id") - return {"version": variables["full"].strip(), - "full": variables["full"].strip()} - - -def versions_from_vcs(tag_prefix, versionfile_source, verbose=False): - # this runs 'git' from the root of the source tree. That either means - # someone ran a setup.py command (and this code is in versioneer.py, so - # IN_LONG_VERSION_PY=False, thus the containing directory is the root of - # the source tree), or someone ran a project-specific entry point (and - # this code is in _version.py, so IN_LONG_VERSION_PY=True, thus the - # containing directory is somewhere deeper in the source tree). This only - # gets called if the git-archive 'subst' variables were *not* expanded, - # and _version.py hasn't already been rewritten with a short version - # string, meaning we're inside a checked out source tree. + print("no suitable tags, using unknown + full revision id") + return {"version": "0+unknown", + "full-revisionid": keywords["full"].strip(), + "dirty": False, "error": "no suitable tags"} - try: - here = os.path.abspath(__file__) - except NameError: - # some py2exe/bbfreeze/non-CPython implementations don't do __file__ - return {} # not always correct - - # versionfile_source is the relative path from the top of the source tree - # (where the .git directory might live) to this file. Invert this to find - # the root from __file__. - root = here - if IN_LONG_VERSION_PY: - for i in range(len(versionfile_source.split("/"))): - root = os.path.dirname(root) - else: - root = os.path.dirname(here) + +@register_vcs_handler("git", "pieces_from_vcs") +def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): + """Get version from 'git describe' in the root of the source tree. + + This only gets called if the git-archive 'subst' keywords were *not* + expanded, and _version.py hasn't already been rewritten with a short + version string, meaning we're inside a checked out source tree. + """ if not os.path.exists(os.path.join(root, ".git")): if verbose: print("no .git in %s" % root) - return {} + raise NotThisMethod("no .git directory") - GIT = "git" + GITS = ["git"] if sys.platform == "win32": - GIT = "git.cmd" - stdout = run_command([GIT, "describe", "--tags", "--dirty", "--always"], - cwd=root) - if stdout is None: - return {} - if not stdout.startswith(tag_prefix): - if verbose: - print("tag '%s' doesn't start with prefix '%s'" % (stdout, tag_prefix)) - return {} - tag = stdout[len(tag_prefix):] - stdout = run_command([GIT, "rev-parse", "HEAD"], cwd=root) - if stdout is None: - return {} - full = stdout.strip() - if tag.endswith("-dirty"): - full += "-dirty" - return {"version": tag, "full": full} - - -def versions_from_parentdir(parentdir_prefix, versionfile_source, verbose=False): - if IN_LONG_VERSION_PY: - # We're running from _version.py. If it's from a source tree - # (execute-in-place), we can work upwards to find the root of the - # tree, and then check the parent directory for a version string. If - # it's in an installed application, there's no hope. - try: - here = os.path.abspath(__file__) - except NameError: - # py2exe/bbfreeze/non-CPython don't have __file__ - return {} # without __file__, we have no hope + GITS = ["git.cmd", "git.exe"] + # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] + # if there isn't one, this yields HEX[-dirty] (no NUM) + describe_out = run_command(GITS, ["describe", "--tags", "--dirty", + "--always", "--long", + "--match", "%s*" % tag_prefix], + cwd=root) + # --long was added in git-1.5.5 + if describe_out is None: + raise NotThisMethod("'git describe' failed") + describe_out = describe_out.strip() + full_out = run_command(GITS, ["rev-parse", "HEAD"], cwd=root) + if full_out is None: + raise NotThisMethod("'git rev-parse' failed") + full_out = full_out.strip() + + pieces = {} + pieces["long"] = full_out + pieces["short"] = full_out[:7] # maybe improved later + pieces["error"] = None + + # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty] + # TAG might have hyphens. + git_describe = describe_out + + # look for -dirty suffix + dirty = git_describe.endswith("-dirty") + pieces["dirty"] = dirty + if dirty: + git_describe = git_describe[:git_describe.rindex("-dirty")] + + # now we have TAG-NUM-gHEX or HEX + + if "-" in git_describe: + # TAG-NUM-gHEX + mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe) + if not mo: + # unparseable. Maybe git-describe is misbehaving? + pieces["error"] = ("unable to parse git-describe output: '%s'" + % describe_out) + return pieces + + # tag + full_tag = mo.group(1) + if not full_tag.startswith(tag_prefix): + if verbose: + fmt = "tag '%s' doesn't start with prefix '%s'" + print(fmt % (full_tag, tag_prefix)) + pieces["error"] = ("tag '%s' doesn't start with prefix '%s'" + % (full_tag, tag_prefix)) + return pieces + pieces["closest-tag"] = full_tag[len(tag_prefix):] + + # distance: number of commits since tag + pieces["distance"] = int(mo.group(2)) + + # commit: short hex revision ID + pieces["short"] = mo.group(3) + + else: + # HEX: no tags + pieces["closest-tag"] = None + count_out = run_command(GITS, ["rev-list", "HEAD", "--count"], + cwd=root) + pieces["distance"] = int(count_out) # total number of commits + + return pieces + + +def plus_or_dot(pieces): + """Return a + if we don't already have one, else return a .""" + if "+" in pieces.get("closest-tag", ""): + return "." + return "+" + + +def render_pep440(pieces): + """Build up version string, with post-release "local version identifier". + + Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you + get a tagged build and then dirty it, you'll get TAG+0.gHEX.dirty + + Exceptions: + 1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty] + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + rendered += plus_or_dot(pieces) + rendered += "%d.g%s" % (pieces["distance"], pieces["short"]) + if pieces["dirty"]: + rendered += ".dirty" + else: + # exception #1 + rendered = "0+untagged.%d.g%s" % (pieces["distance"], + pieces["short"]) + if pieces["dirty"]: + rendered += ".dirty" + return rendered + + +def render_pep440_pre(pieces): + """TAG[.post.devDISTANCE] -- No -dirty. + + Exceptions: + 1: no tags. 0.post.devDISTANCE + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"]: + rendered += ".post.dev%d" % pieces["distance"] + else: + # exception #1 + rendered = "0.post.dev%d" % pieces["distance"] + return rendered + + +def render_pep440_post(pieces): + """TAG[.postDISTANCE[.dev0]+gHEX] . + + The ".dev0" means dirty. Note that .dev0 sorts backwards + (a dirty tree will appear "older" than the corresponding clean one), + but you shouldn't be releasing software with -dirty anyways. + + Exceptions: + 1: no tags. 0.postDISTANCE[.dev0] + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + rendered += ".post%d" % pieces["distance"] + if pieces["dirty"]: + rendered += ".dev0" + rendered += plus_or_dot(pieces) + rendered += "g%s" % pieces["short"] + else: + # exception #1 + rendered = "0.post%d" % pieces["distance"] + if pieces["dirty"]: + rendered += ".dev0" + rendered += "+g%s" % pieces["short"] + return rendered + + +def render_pep440_old(pieces): + """TAG[.postDISTANCE[.dev0]] . + + The ".dev0" means dirty. + + Eexceptions: + 1: no tags. 0.postDISTANCE[.dev0] + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + rendered += ".post%d" % pieces["distance"] + if pieces["dirty"]: + rendered += ".dev0" + else: + # exception #1 + rendered = "0.post%d" % pieces["distance"] + if pieces["dirty"]: + rendered += ".dev0" + return rendered + + +def render_git_describe(pieces): + """TAG[-DISTANCE-gHEX][-dirty]. + + Like 'git describe --tags --dirty --always'. + + Exceptions: + 1: no tags. HEX[-dirty] (note: no 'g' prefix) + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"]: + rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) + else: + # exception #1 + rendered = pieces["short"] + if pieces["dirty"]: + rendered += "-dirty" + return rendered + + +def render_git_describe_long(pieces): + """TAG-DISTANCE-gHEX[-dirty]. + + Like 'git describe --tags --dirty --always -long'. + The distance/hash is unconditional. + + Exceptions: + 1: no tags. HEX[-dirty] (note: no 'g' prefix) + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) + else: + # exception #1 + rendered = pieces["short"] + if pieces["dirty"]: + rendered += "-dirty" + return rendered + + +def render(pieces, style): + """Render the given version pieces into the requested style.""" + if pieces["error"]: + return {"version": "unknown", + "full-revisionid": pieces.get("long"), + "dirty": None, + "error": pieces["error"]} + + if not style or style == "default": + style = "pep440" # the default + + if style == "pep440": + rendered = render_pep440(pieces) + elif style == "pep440-pre": + rendered = render_pep440_pre(pieces) + elif style == "pep440-post": + rendered = render_pep440_post(pieces) + elif style == "pep440-old": + rendered = render_pep440_old(pieces) + elif style == "git-describe": + rendered = render_git_describe(pieces) + elif style == "git-describe-long": + rendered = render_git_describe_long(pieces) + else: + raise ValueError("unknown style '%s'" % style) + + return {"version": rendered, "full-revisionid": pieces["long"], + "dirty": pieces["dirty"], "error": None} + + +def get_versions(): + """Get version information or return default if unable to do so.""" + # I am in _version.py, which lives at ROOT/VERSIONFILE_SOURCE. If we have + # __file__, we can work backwards from there to the root. Some + # py2exe/bbfreeze/non-CPython implementations don't do __file__, in which + # case we can only use expanded keywords. + + cfg = get_config() + verbose = cfg.verbose + + try: + return git_versions_from_keywords(get_keywords(), cfg.tag_prefix, + verbose) + except NotThisMethod: + pass + + try: + root = os.path.realpath(__file__) # versionfile_source is the relative path from the top of the source - # tree to _version.py. Invert this to find the root from __file__. - root = here - for i in range(len(versionfile_source.split("/"))): + # tree (where the .git directory might live) to this file. Invert + # this to find the root from __file__. + for i in cfg.versionfile_source.split('/'): root = os.path.dirname(root) - else: - # we're running from versioneer.py, which means we're running from - # the setup.py in a source tree. sys.argv[0] is setup.py in the root. - here = os.path.abspath(sys.argv[0]) - root = os.path.dirname(here) + except NameError: + return {"version": "0+unknown", "full-revisionid": None, + "dirty": None, + "error": "unable to find root of source tree"} - # Source tarballs conventionally unpack into a directory that includes - # both the project name and a version string. - dirname = os.path.basename(root) - if not dirname.startswith(parentdir_prefix): - if verbose: - print("guessing rootdir is '%s', but '%s' doesn't start with prefix '%s'" % - (root, dirname, parentdir_prefix)) - return None - return {"version": dirname[len(parentdir_prefix):], "full": ""} - -tag_prefix = "" -parentdir_prefix = "leap.common-" -versionfile_source = "src/leap/common/_version.py" - - -def get_versions(default={"version": "unknown", "full": ""}, verbose=False): - variables = {"refnames": git_refnames, "full": git_full} - ver = versions_from_expanded_variables(variables, tag_prefix, verbose) - if not ver: - ver = versions_from_vcs(tag_prefix, versionfile_source, verbose) - if not ver: - ver = versions_from_parentdir(parentdir_prefix, versionfile_source, - verbose) - if not ver: - ver = default - return ver + try: + pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose) + return render(pieces, cfg.style) + except NotThisMethod: + pass + + try: + if cfg.parentdir_prefix: + return versions_from_parentdir(cfg.parentdir_prefix, root, verbose) + except NotThisMethod: + pass + + return {"version": "0+unknown", "full-revisionid": None, + "dirty": None, + "error": "unable to compute version"} -- cgit v1.2.3