diff options
author | Micah Anderson <micah@riseup.net> | 2015-11-10 17:34:54 -0500 |
---|---|---|
committer | Micah Anderson <micah@riseup.net> | 2015-11-10 17:34:54 -0500 |
commit | 93ac9288e301643a4b9c31e2750a231f4e7bc8d8 (patch) | |
tree | 92d7ea2d18f288cd7dfd852ee6af8787c722c87c /src/leap | |
parent | 42814b5bf836a83724d2c74d6bb32bc168b7a81c (diff) | |
parent | e074eac10c6e08757857c770cb190cdb9d3a4583 (diff) |
Merge branch 'debian/experimental' into debian/platform-0.8
Diffstat (limited to 'src/leap')
22 files changed, 752 insertions, 182 deletions
diff --git a/src/leap/common/__init__.py b/src/leap/common/__init__.py index 5619900..383e198 100644 --- a/src/leap/common/__init__.py +++ b/src/leap/common/__init__.py @@ -4,6 +4,7 @@ from leap.common import certs from leap.common import check from leap.common import files from leap.common import events +from ._version import get_versions logger = logging.getLogger(__name__) @@ -11,11 +12,10 @@ try: import pygeoip HAS_GEOIP = True except ImportError: - #logger.debug('PyGeoIP not found. Disabled Geo support.') + # logger.debug('PyGeoIP not found. Disabled Geo support.') HAS_GEOIP = False __all__ = ["certs", "check", "files", "events"] -from ._version import get_versions __version__ = get_versions()['version'] del get_versions diff --git a/src/leap/common/_version.py b/src/leap/common/_version.py index 410c404..f5738ea 100644 --- a/src/leap/common/_version.py +++ b/src/leap/common/_version.py @@ -5,8 +5,8 @@ # unpacked source archive. Distribution tarballs contain a pre-generated copy # of this file. -version_version = '0.4.0' -version_full = 'ba00824758e1d37620ab605e87899c2b6650263e' +version_version = '0.4.4' +version_full = 'ee0e9cadccd00cb62032d8fc4b322bb6fe3dc7ed' def get_versions(default={}, verbose=False): diff --git a/src/leap/common/ca_bundle.py b/src/leap/common/ca_bundle.py index d8c72a6..e2a624d 100644 --- a/src/leap/common/ca_bundle.py +++ b/src/leap/common/ca_bundle.py @@ -21,23 +21,24 @@ If you are packaging Requests, e.g., for a Linux distribution or a managed environment, you can change the definition of where() to return a separately packaged CA bundle. """ -import platform import os.path +import platform +import sys _system = platform.system() IS_MAC = _system == "Darwin" + def where(): """ Return the preferred certificate bundle. :rtype: str """ - # vendored bundle inside Requests, plus some additions of ours - if IS_MAC: - return os.path.join("/Applications", "Bitmask.app", - "Contents", "Resources", - "cacert.pem") + if getattr(sys, 'frozen', False): + # we are running in a |PyInstaller| bundle + path = sys._MEIPASS + return os.path.join(path, 'cacert.pem') return os.path.join(os.path.dirname(__file__), 'cacert.pem') if __name__ == '__main__': diff --git a/src/leap/common/certs.py b/src/leap/common/certs.py index db513f6..37ede8e 100644 --- a/src/leap/common/certs.py +++ b/src/leap/common/certs.py @@ -178,3 +178,21 @@ def should_redownload(certfile, now=time.gmtime): return True return False + + +def get_compatible_ssl_context_factory(cert_path=None): + import twisted + cert = None + if twisted.version.base() > '14.0.1': + from twisted.web.client import BrowserLikePolicyForHTTPS + from twisted.internet import ssl + if cert_path: + cert = ssl.Certificate.loadPEM(open(cert_path).read()) + policy = BrowserLikePolicyForHTTPS(cert) + return policy + else: + raise Exception((""" + Twisted 14.0.2 is needed in order to have secure + Client Web SSL Contexts, not %s + See: http://twistedmatrix.com/trac/ticket/7647 + """) % (twisted.version.base())) diff --git a/src/leap/common/config/flags.py b/src/leap/common/config/flags.py new file mode 100644 index 0000000..6fd43f6 --- /dev/null +++ b/src/leap/common/config/flags.py @@ -0,0 +1,28 @@ +# -*- coding: utf-8 -*- +# flags.py +# Copyright (C) 2015 LEAP +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. +""" +This file is meant to be used to store global flags that affect the +application. + +WARNING: You should NOT use this kind of flags unless you're sure of what + you're doing, and someone else tells you that you're right. + Most of the times there is a better and safer alternative. +""" + +# The STANDALONE flag is used to: +# - use a relative or system wide path to find the configuration files. +STANDALONE = False diff --git a/src/leap/common/config/pluggableconfig.py b/src/leap/common/config/pluggableconfig.py index 8535fa6..1a98427 100644 --- a/src/leap/common/config/pluggableconfig.py +++ b/src/leap/common/config/pluggableconfig.py @@ -27,7 +27,7 @@ import urlparse import jsonschema -#from leap.base.util.translations import LEAPTranslatable +# from leap.base.util.translations import LEAPTranslatable from leap.common.check import leap_assert @@ -163,8 +163,8 @@ class TranslatableType(object): return data # LEAPTranslatable(data) # needed? we already have an extended dict... - #def get_prep_value(self, data): - #return dict(data) + # def get_prep_value(self, data): + # return dict(data) class URIType(object): @@ -283,9 +283,13 @@ class PluggableConfig(object): except BaseException, e: raise TypeCastException( "Could not coerce %s, %s, " - "to format %s: %s" % (key, value, - _ftype.__class__.__name__, - e)) + "to format %s: %s" % ( + key, + value, + _ftype.__class__.__name__, + e + ) + ) return config @@ -303,9 +307,12 @@ class PluggableConfig(object): except BaseException, e: raise TypeCastException( "Could not serialize %s, %s, " - "by format %s: %s" % (key, value, - _ftype.__class__.__name__, - e)) + "by format %s: %s" % ( + key, + value, + _ftype.__class__.__name__, + e) + ) else: config[key] = value return config @@ -435,7 +442,7 @@ class PluggableConfig(object): content = self.deserialize(string) if not string and fromfile is not None: - #import ipdb;ipdb.set_trace() + # import ipdb;ipdb.set_trace() content = self.deserialize(fromfile=fromfile) if not content: diff --git a/src/leap/common/config/tests/test_baseconfig.py b/src/leap/common/config/tests/test_baseconfig.py index 8bdf4d0..e17e82d 100644 --- a/src/leap/common/config/tests/test_baseconfig.py +++ b/src/leap/common/config/tests/test_baseconfig.py @@ -29,21 +29,21 @@ from mock import Mock # reduced eipconfig sample config sample_config = { "gateways": [ - { - "capabilities": { - "adblock": False, - "transport": ["openvpn"], - "user_ips": False - }, - "host": "host.dev.example.org", - }, { - "capabilities": { - "adblock": False, - "transport": ["openvpn"], - "user_ips": False - }, - "host": "host2.dev.example.org", - } + { + "capabilities": { + "adblock": False, + "transport": ["openvpn"], + "user_ips": False + }, + "host": "host.dev.example.org", + }, { + "capabilities": { + "adblock": False, + "transport": ["openvpn"], + "user_ips": False + }, + "host": "host2.dev.example.org", + } ], "default_language": "en", "languages": [ diff --git a/src/leap/common/events/__init__.py b/src/leap/common/events/__init__.py index 9269b9a..f9ad5fa 100644 --- a/src/leap/common/events/__init__.py +++ b/src/leap/common/events/__init__.py @@ -14,8 +14,6 @@ # # You should have received a copy of the GNU General Public License # along with this program. If not, see <http://www.gnu.org/licenses/>. - - """ This is an events mechanism that uses a server to allow for emitting events between clients. @@ -37,13 +35,15 @@ To emit an event, use leap.common.events.emit(): >>> from leap.common.events import catalog >>> emit(catalog.CLIENT_UID) """ - - import logging import argparse from leap.common.events import client +from leap.common.events import txclient from leap.common.events import server +from leap.common.events import flags +from leap.common.events.flags import set_events_enabled + from leap.common.events import catalog @@ -52,6 +52,7 @@ __all__ = [ "unregister", "emit", "catalog", + "set_events_enabled" ] @@ -78,7 +79,13 @@ def register(event, callback, uid=None, replace=False): :raises CallbackAlreadyRegistered: when there's already a callback identified by the given uid and replace is False. """ - return client.register(event, callback, uid, replace) + if flags.EVENTS_ENABLED: + return client.register(event, callback, uid, replace) + + +def register_async(event, callback, uid=None, replace=False): + if flags.EVENTS_ENABLED: + return txclient.register(event, callback, uid, replace) def unregister(event, uid=None): @@ -93,7 +100,13 @@ def unregister(event, uid=None): :param uid: The callback uid. :type uid: str """ - return client.unregister(event, uid) + if flags.EVENTS_ENABLED: + return client.unregister(event, uid) + + +def unregister_async(event, uid=None): + if flags.EVENTS_ENABLED: + return txclient.unregister(event, uid) def emit(event, *content): @@ -105,7 +118,13 @@ def emit(event, *content): :param content: The content of the event. :type content: list """ - return client.emit(event, *content) + if flags.EVENTS_ENABLED: + return client.emit(event, *content) + + +def emit_async(event, *content): + if flags.EVENTS_ENABLED: + return txclient.emit(event, *content) if __name__ == "__main__": diff --git a/src/leap/common/events/client.py b/src/leap/common/events/client.py index 0706fe3..60d24bc 100644 --- a/src/leap/common/events/client.py +++ b/src/leap/common/events/client.py @@ -14,8 +14,6 @@ # # You should have received a copy of the GNU General Public License # along with this program. If not, see <http://www.gnu.org/licenses/>. - - """ The client end point of the events mechanism. @@ -27,8 +25,6 @@ When a client registers a callback for a given event, it also tells the server that it wants to be notified whenever events of that type are sent by some other client. """ - - import logging import collections import uuid @@ -51,7 +47,7 @@ try: except ImportError: pass -from leap.common.config import get_path_prefix +from leap.common.config import flags, get_path_prefix from leap.common.zmq_utils import zmq_has_curve from leap.common.zmq_utils import maybe_create_and_get_certificates from leap.common.zmq_utils import PUBLIC_KEYS_PREFIX @@ -173,21 +169,38 @@ class EventsClient(object): :param content: The content of the event. :type content: list """ - logger.debug("Sending event: (%s, %s)" % (event, content)) - self._send(str(event) + b'\0' + pickle.dumps(content)) + logger.debug("Emitting event: (%s, %s)" % (event, content)) + payload = str(event) + b'\0' + pickle.dumps(content) + self._send(payload) def _handle_event(self, event, content): """ Handle an incoming event. - :param msg: The incoming message. - :type msg: list(str) + :param event: The event to be sent. + :type event: Event + :param content: The content of the event. + :type content: list """ logger.debug("Handling event %s..." % event) - for uid in self._callbacks[event].keys(): + for uid in self._callbacks[event]: callback = self._callbacks[event][uid] logger.debug("Executing callback %s." % uid) - callback(event, *content) + self._run_callback(callback, event, content) + + @abstractmethod + def _run_callback(self, callback, event, content): + """ + Run a callback. + + :param callback: The callback to be run. + :type callback: callable(event, *content) + :param event: The event to be sent. + :type event: Event + :param content: The content of the event. + :type content: list + """ + pass @abstractmethod def _subscribe(self, tag): @@ -266,7 +279,7 @@ class EventsClientThread(threading.Thread, EventsClient): self._lock = threading.Lock() self._initialized = threading.Event() self._config_prefix = os.path.join( - get_path_prefix(), "leap", "events") + get_path_prefix(flags.STANDALONE), "leap", "events") self._loop = None self._context = None self._push = None @@ -368,10 +381,22 @@ class EventsClientThread(threading.Thread, EventsClient): :param data: The data to be sent. :type event: str """ - logger.debug("Sending data: %s" % data) # add send() as a callback for ioloop so it works between threads self._loop.add_callback(lambda: self._push.send(data)) + def _run_callback(self, callback, event, content): + """ + Run a callback. + + :param callback: The callback to be run. + :type callback: callable(event, *content) + :param event: The event to be sent. + :type event: Event + :param content: The content of the event. + :type content: list + """ + self._loop.add_callback(lambda: callback(event, *content)) + def register(self, event, callback, uid=None, replace=False): """ Register a callback to be executed when an event is received. @@ -393,7 +418,8 @@ class EventsClientThread(threading.Thread, EventsClient): callback identified by the given uid and replace is False. """ self.ensure_client() - return EventsClient.register(self, event, callback, uid=uid, replace=replace) + return EventsClient.register( + self, event, callback, uid=uid, replace=replace) def unregister(self, event, uid=None): """ diff --git a/src/leap/common/events/flags.py b/src/leap/common/events/flags.py new file mode 100644 index 0000000..137f663 --- /dev/null +++ b/src/leap/common/events/flags.py @@ -0,0 +1,28 @@ +# -*- coding: utf-8 -*- +# __init__.py +# Copyright (C) 2015 LEAP +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. +""" +Flags for the events framework. +""" +from leap.common.check import leap_assert + +EVENTS_ENABLED = True + + +def set_events_enabled(flag): + leap_assert(isinstance(flag, bool)) + global EVENTS_ENABLED + EVENTS_ENABLED = flag diff --git a/src/leap/common/events/tests/__init__.py b/src/leap/common/events/tests/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/src/leap/common/events/tests/__init__.py diff --git a/src/leap/common/events/tests/test_zmq_components.py b/src/leap/common/events/tests/test_zmq_components.py new file mode 100644 index 0000000..c51e37e --- /dev/null +++ b/src/leap/common/events/tests/test_zmq_components.py @@ -0,0 +1,51 @@ +# -*- coding: utf-8 -*- +# test_zmq_components.py +# Copyright (C) 2014 LEAP +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. +""" +Tests for the zmq_components module. +""" +try: + import unittest2 as unittest +except ImportError: + import unittest + +from leap.common.events import zmq_components + + +class AddrParseTestCase(unittest.TestCase): + + def setUp(self): + pass + + def tearDown(self): + pass + + def test_addr_parsing(self): + addr_re = zmq_components.ADDRESS_RE + + self.assertEqual( + addr_re.search("ipc:///tmp/foo/bar/baaz-2/foo.0").groups(), + ("ipc", "/tmp/foo/bar/baaz-2/foo.0", None)) + self.assertEqual( + addr_re.search("tcp://localhost:9000").groups(), + ("tcp", "localhost", "9000")) + self.assertEqual( + addr_re.search("tcp://127.0.0.1:9000").groups(), + ("tcp", "127.0.0.1", "9000")) + + +if __name__ == "__main__": + unittest.main() diff --git a/src/leap/common/events/txclient.py b/src/leap/common/events/txclient.py index 8206ed5..dfd0533 100644 --- a/src/leap/common/events/txclient.py +++ b/src/leap/common/events/txclient.py @@ -14,8 +14,6 @@ # # You should have received a copy of the GNU General Public License # along with this program. If not, see <http://www.gnu.org/licenses/>. - - """ The client end point of the events mechanism, implemented using txzmq. @@ -27,8 +25,6 @@ When a client registers a callback for a given event, it also tells the server that it wants to be notified whenever events of that type are sent by some other client. """ - - import logging import pickle @@ -62,7 +58,7 @@ class EventsTxClient(TxZmqClientComponent, EventsClient): """ def __init__(self, emit_addr=EMIT_ADDR, reg_addr=REG_ADDR, - path_prefix=None): + path_prefix=None): """ Initialize the events server. """ @@ -112,6 +108,19 @@ class EventsTxClient(TxZmqClientComponent, EventsClient): """ self._push.send(data) + def _run_callback(self, callback, event, content): + """ + Run a callback. + + :param callback: The callback to be run. + :type callback: callable(event, *content) + :param event: The event to be sent. + :type event: Event + :param content: The content of the event. + :type content: list + """ + callback(event, *content) + def shutdown(self): TxZmqClientComponent.shutdown(self) EventsClient.shutdown(self) diff --git a/src/leap/common/events/zmq_components.py b/src/leap/common/events/zmq_components.py index 4fb95d3..51de02c 100644 --- a/src/leap/common/events/zmq_components.py +++ b/src/leap/common/events/zmq_components.py @@ -25,6 +25,7 @@ import os import logging import txzmq import re +import time from abc import ABCMeta @@ -36,7 +37,7 @@ try: except ImportError: pass -from leap.common.config import get_path_prefix +from leap.common.config import flags, get_path_prefix from leap.common.zmq_utils import zmq_has_curve from leap.common.zmq_utils import maybe_create_and_get_certificates from leap.common.zmq_utils import PUBLIC_KEYS_PREFIX @@ -45,7 +46,7 @@ from leap.common.zmq_utils import PUBLIC_KEYS_PREFIX logger = logging.getLogger(__name__) -ADDRESS_RE = re.compile("(.+)://(.+):([0-9]+)") +ADDRESS_RE = re.compile("^([a-z]+)://([^:]+):?(\d+)?$") class TxZmqComponent(object): @@ -63,8 +64,8 @@ class TxZmqComponent(object): """ self._factory = txzmq.ZmqFactory() self._factory.registerForShutdown() - if path_prefix == None: - path_prefix = get_path_prefix() + if path_prefix is None: + path_prefix = get_path_prefix(flags.STANDALONE) self._config_prefix = os.path.join(path_prefix, "leap", "events") self._connections = [] @@ -125,15 +126,24 @@ class TxZmqComponent(object): socket.curve_publickey = public socket.curve_secretkey = secret self._start_thread_auth(connection.socket) - # check if port was given - protocol, addr, port = ADDRESS_RE.match(address).groups() - if port == "0": - port = socket.bind_to_random_port("%s://%s" % (protocol, addr)) + + proto, addr, port = ADDRESS_RE.search(address).groups() + + if proto == "tcp": + if port is None or port is '0': + params = proto, addr + port = socket.bind_to_random_port("%s://%s" % params) + logger.debug("Binded %s to %s://%s." % ((connClass,) + params)) + else: + params = proto, addr, int(port) + socket.bind("%s://%s:%d" % params) + logger.debug( + "Binded %s to %s://%s:%d." % ((connClass,) + params)) else: - socket.bind(address) - port = int(port) - logger.debug("Binded %s to %s://%s:%d." - % (connClass, protocol, addr, port)) + params = proto, addr + socket.bind("%s://%s" % params) + logger.debug( + "Binded %s to %s://%s" % ((connClass,) + params)) self._connections.append(connection) return connection, port @@ -145,6 +155,11 @@ class TxZmqComponent(object): :type socket: zmq.Socket """ authenticator = ThreadAuthenticator(self._factory.context) + + # Temporary fix until we understand what the problem is + # See https://leap.se/code/issues/7536 + time.sleep(0.5) + authenticator.start() # XXX do not hardcode this here. authenticator.allow('127.0.0.1') diff --git a/src/leap/common/http.py b/src/leap/common/http.py index 39f01ba..0dee3a2 100644 --- a/src/leap/common/http.py +++ b/src/leap/common/http.py @@ -18,72 +18,141 @@ Twisted HTTP/HTTPS client. """ -import os +try: + import twisted + assert twisted +except ImportError: + print "*******" + print "Twisted is needed to use leap.common.http module" + print "" + print "Install the extra requirement of the package:" + print "$ pip install leap.common[Twisted]" + import sys + sys.exit(1) -from zope.interface import implements -from OpenSSL.crypto import load_certificate -from OpenSSL.crypto import FILETYPE_PEM +from leap.common.certs import get_compatible_ssl_context_factory +from leap.common.check import leap_assert + +from zope.interface import implements from twisted.internet import reactor -from twisted.internet.ssl import ClientContextFactory -from twisted.internet.ssl import CertificateOptions -from twisted.internet.defer import succeed +from twisted.internet import defer +from twisted.python import failure from twisted.web.client import Agent from twisted.web.client import HTTPConnectionPool +from twisted.web.client import _HTTP11ClientFactory as HTTP11ClientFactory from twisted.web.client import readBody -from twisted.web.client import BrowserLikePolicyForHTTPS from twisted.web.http_headers import Headers from twisted.web.iweb import IBodyProducer +from twisted.web._newclient import HTTP11ClientProtocol + + +__all__ = ["HTTPClient"] + + +# A default HTTP timeout is used for 2 distinct purposes: +# 1. as HTTP connection timeout, prior to connection estabilshment. +# 2. as data reception timeout, after the connection has been established. +DEFAULT_HTTP_TIMEOUT = 30 # seconds + + +class _HTTP11ClientFactory(HTTP11ClientFactory): + """ + A timeout-able HTTP 1.1 client protocol factory. + """ + + def __init__(self, quiescentCallback, timeout): + """ + :param quiescentCallback: The quiescent callback to be passed to + protocol instances, used to return them to + the connection pool. + :type quiescentCallback: callable(Protocol) + :param timeout: The timeout, in seconds, for requests made by + protocols created by this factory. + :type timeout: float + """ + HTTP11ClientFactory.__init__(self, quiescentCallback) + self._timeout = timeout + + def buildProtocol(self, _): + """ + Build the HTTP 1.1 client protocol. + """ + return _HTTP11ClientProtocol(self._quiescentCallback, self._timeout) + + +class _HTTPConnectionPool(HTTPConnectionPool): + """ + A timeout-able HTTP connection pool. + """ + + _factory = _HTTP11ClientFactory + + def __init__(self, reactor, persistent, timeout, maxPersistentPerHost=10): + HTTPConnectionPool.__init__(self, reactor, persistent=persistent) + self.maxPersistentPerHost = maxPersistentPerHost + self._timeout = timeout + + def _newConnection(self, key, endpoint): + def quiescentCallback(protocol): + self._putConnection(key, protocol) + factory = self._factory(quiescentCallback, timeout=self._timeout) + return endpoint.connect(factory) class HTTPClient(object): """ HTTP client done the twisted way, with a main focus on pinning the SSL certificate. + + By default, it uses a shared connection pool. If you want a dedicated + one, create and pass on __init__ pool parameter. + Please note that this client will limit the maximum amount of connections + by using a DeferredSemaphore. + This limit is equal to the maxPersistentPerHost used on pool and is needed + in order to avoid resource abuse on huge requests batches. """ - def __init__(self, cert_file=None): + _pool = _HTTPConnectionPool( + reactor, + persistent=True, + timeout=DEFAULT_HTTP_TIMEOUT, + maxPersistentPerHost=10 + ) + + def __init__(self, cert_file=None, + timeout=DEFAULT_HTTP_TIMEOUT, pool=None): """ Init the HTTP client :param cert_file: The path to the certificate file, if None given the system's CAs will be used. :type cert_file: str + :param timeout: The amount of time that this Agent will wait for the + peer to accept a connection and for each request to be + finished. If a pool is passed, then this argument is + ignored. + :type timeout: float """ - self._pool = HTTPConnectionPool(reactor, persistent=True) - self._pool.maxPersistentPerHost = 10 - if cert_file: - cert = self._load_cert(cert_file) - self._agent = Agent( - reactor, - HTTPClient.ClientContextFactory(cert), - pool=self._pool) - else: - # trust the system's CAs - self._agent = Agent( - reactor, - BrowserLikePolicyForHTTPS(), - pool=self._pool) + self._timeout = timeout + self._pool = pool if pool is not None else self._pool + self._agent = Agent( + reactor, + get_compatible_ssl_context_factory(cert_file), + pool=self._pool, + connectTimeout=self._timeout) + self._semaphore = defer.DeferredSemaphore( + self._pool.maxPersistentPerHost) - def _load_cert(self, cert_file): - """ - Load a X509 certificate from a file. + def _createPool(self, maxPersistentPerHost=10, persistent=True): + pool = _HTTPConnectionPool(reactor, persistent, self._timeout) + pool.maxPersistentPerHost = maxPersistentPerHost + return pool - :param cert_file: The path to the certificate file. - :type cert_file: str - - :return: The X509 certificate. - :rtype: OpenSSL.crypto.X509 - """ - if os.path.exists(cert_file): - with open(cert_file) as f: - data = f.read() - return load_certificate(FILETYPE_PEM, data) - - def request(self, url, method='GET', body=None, headers={}): + def _request(self, url, method, body, headers, callback): """ Perform an HTTP request. @@ -95,68 +164,185 @@ class HTTPClient(object): :type body: str :param headers: The headers of the request. :type headers: dict + :param callback: A callback to be added to the request's deferred + callback chain. + :type callback: callable :return: A deferred that fires with the body of the request. :rtype: twisted.internet.defer.Deferred """ if body: - body = HTTPClient.StringBodyProducer(body) + body = _StringBodyProducer(body) d = self._agent.request( method, url, headers=Headers(headers), bodyProducer=body) - d.addCallback(readBody) + d.addCallback(callback) return d - class ClientContextFactory(ClientContextFactory): + def request(self, url, method='GET', body=None, headers={}, + callback=readBody): + """ + Perform an HTTP request, but limit the maximum amount of concurrent + connections. + + May be passed a callback to be added to the request's deferred + callback chain. The callback is expected to receive the response of + the request and may do whatever it wants with the response. By + default, if no callback is passed, we will use a simple body reader + which returns a deferred that is fired with the body of the response. + + :param url: The URL for the request. + :type url: str + :param method: The HTTP method of the request. + :type method: str + :param body: The body of the request, if any. + :type body: str + :param headers: The headers of the request. + :type headers: dict + :param callback: A callback to be added to the request's deferred + callback chain. + :type callback: callable + + :return: A deferred that fires with the body of the request. + :rtype: twisted.internet.defer.Deferred + """ + leap_assert( + callable(callback), + message="The callback parameter should be a callable!") + return self._semaphore.run(self._request, url, method, body, headers, + callback) + + def close(self): """ - A context factory that will verify the server's certificate against a - given CA certificate. + Close any cached connections. """ + self._pool.closeCachedConnections() + +# +# An IBodyProducer to write the body of an HTTP request as a string. +# + + +class _StringBodyProducer(object): + """ + A producer that writes the body of a request to a consumer. + """ - def __init__(self, cacert): - """ - Initialize the context factory. + implements(IBodyProducer) - :param cacert: The CA certificate. - :type cacert: OpenSSL.crypto.X509 - """ - self._cacert = cacert + def __init__(self, body): + """ + Initialize the string produer. - def getContext(self, hostname, port): - opts = CertificateOptions(verify=True, caCerts=[self._cacert]) - return opts.getContext() + :param body: The body of the request. + :type body: str + """ + self.body = body + self.length = len(body) - class StringBodyProducer(object): + def startProducing(self, consumer): """ - A producer that writes the body of a request to a consumer. + Write the body to the consumer. + + :param consumer: Any IConsumer provider. + :type consumer: twisted.internet.interfaces.IConsumer + + :return: A successful deferred. + :rtype: twisted.internet.defer.Deferred """ + consumer.write(self.body) + return defer.succeed(None) - implements(IBodyProducer) + def pauseProducing(self): + pass - def __init__(self, body): - """ - Initialize the string produer. + def stopProducing(self): + pass - :param body: The body of the request. - :type body: str - """ - self.body = body - self.length = len(body) - def startProducing(self, consumer): - """ - Write the body to the consumer. +# +# Patched twisted.web classes +# - :param consumer: Any IConsumer provider. - :type consumer: twisted.internet.interfaces.IConsumer +class _HTTP11ClientProtocol(HTTP11ClientProtocol): + """ + A timeout-able HTTP 1.1 client protocol, that is instantiated by the + _HTTP11ClientFactory below. + """ - :return: A successful deferred. - :rtype: twisted.internet.defer.Deferred - """ - consumer.write(self.body) - return succeed(None) + def __init__(self, quiescentCallback, timeout): + """ + Initialize the protocol. - def pauseProducing(self): - pass + :param quiescentCallback: + :type quiescentCallback: callable + :param timeout: A timeout, in seconds, for requests made by this + protocol. + :type timeout: float + """ + HTTP11ClientProtocol.__init__(self, quiescentCallback) + self._timeout = timeout + self._timeoutCall = None + + def request(self, request): + """ + Issue request over self.transport and return a Deferred which + will fire with a Response instance or an error. + + :param request: The object defining the parameters of the request to + issue. + :type request: twisted.web._newclient.Request + + :return: A deferred which fires after the request has finished. + :rtype: Deferred + """ + d = HTTP11ClientProtocol.request(self, request) + if self._timeout: + self._last_buffer_len = 0 + timeoutCall = reactor.callLater( + self._timeout, self._doTimeout, request) + self._timeoutCall = timeoutCall + return d + + def _doTimeout(self, request): + """ + Give up the request because of a timeout. + + :param request: The object defining the parameters of the request to + issue. + :type request: twisted.web._newclient.Request + """ + self._giveUp( + failure.Failure( + defer.TimeoutError( + "Getting %s took longer than %s seconds." + % (request.absoluteURI, self._timeout)))) + + def _cancelTimeout(self): + """ + Cancel the request timeout, when it's finished. + """ + if self._timeoutCall and self._timeoutCall.active(): + self._timeoutCall.cancel() + self._timeoutCall = None + + def _finishResponse(self, rest): + """ + Cancel the timeout when finished receiving the response. + """ + self._cancelTimeout() + HTTP11ClientProtocol._finishResponse(self, rest) + + def dataReceived(self, bytes): + """ + Receive some data and extend the timeout period of this request. + + :param bytes: A string of indeterminate length. + :type bytes: str + """ + HTTP11ClientProtocol.dataReceived(self, bytes) + if self._timeoutCall and self._timeoutCall.active(): + self._timeoutCall.reset(self._timeout) - def stopProducing(self): - pass + def connectionLost(self, reason): + self._cancelTimeout() + return HTTP11ClientProtocol.connectionLost(self, reason) diff --git a/src/leap/common/plugins.py b/src/leap/common/plugins.py new file mode 100644 index 0000000..04152f9 --- /dev/null +++ b/src/leap/common/plugins.py @@ -0,0 +1,76 @@ +# -*- coding: utf-8 -*- +# plugins.py +# Copyright (C) 2015 LEAP +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. +""" +Twisted plugins leap utilities. +""" +import os.path + +from twisted.plugin import getPlugins + +from leap.common.config import get_path_prefix + +# A whitelist of modules from where to collect plugins dynamically. +# For the moment restricted to leap namespace, but the idea is that we can pass +# other "trusted" modules as options to the initialization of soledad. + +# TODO discover all the namespace automagically + +PLUGGABLE_LEAP_MODULES = ('mail', 'keymanager') + +_preffix = get_path_prefix() +rc_file = os.path.join(_preffix, "leap", "leap.cfg") + + +def _get_extra_pluggable_modules(): + import ConfigParser + config = ConfigParser.RawConfigParser() + config.read(rc_file) + try: + modules = eval( + config.get('plugins', 'extra_pluggable_modules'), {}, {}) + except (ConfigParser.NoSectionError, ConfigParser.NoOptionError, + ConfigParser.MissingSectionHeaderError): + modules = [] + return modules + +if os.path.isfile(rc_file): + # TODO in the case of being called from the standalone client, + # we should pass the flag in some other way. + EXTRA_PLUGGABLE_MODULES = _get_extra_pluggable_modules() +else: + EXTRA_PLUGGABLE_MODULES = [] + + +def collect_plugins(interface): + """ + Traverse a whitelist of modules and collect all the plugins that implement + the passed interface. + """ + plugins = [] + for namespace in PLUGGABLE_LEAP_MODULES: + try: + module = __import__('leap.%s.plugins' % namespace, fromlist='.') + plugins = plugins + list(getPlugins(interface, module)) + except ImportError: + pass + for namespace in EXTRA_PLUGGABLE_MODULES: + try: + module = __import__('%s.plugins' % namespace, fromlist='.') + plugins = plugins + list(getPlugins(interface, module)) + except ImportError: + pass + return plugins diff --git a/src/leap/common/testing/basetest.py b/src/leap/common/testing/basetest.py index 3fbcf33..3d3cee0 100644 --- a/src/leap/common/testing/basetest.py +++ b/src/leap/common/testing/basetest.py @@ -30,8 +30,11 @@ except ImportError: from leap.common.check import leap_assert from leap.common.events import server as events_server from leap.common.events import client as events_client +from leap.common.events import flags, set_events_enabled from leap.common.files import mkdir_p, check_and_fix_urw_only +set_events_enabled(False) + class BaseLeapTest(unittest.TestCase): """ @@ -73,12 +76,13 @@ class BaseLeapTest(unittest.TestCase): @classmethod def _init_events(cls): - cls._server = events_server.ensure_server( - emit_addr="tcp://127.0.0.1:0", - reg_addr="tcp://127.0.0.1:0") - events_client.configure_client( - emit_addr="tcp://127.0.0.1:%d" % cls._server.pull_port, - reg_addr="tcp://127.0.0.1:%d" % cls._server.pub_port) + if flags.EVENTS_ENABLED: + cls._server = events_server.ensure_server( + emit_addr="tcp://127.0.0.1:0", + reg_addr="tcp://127.0.0.1:0") + events_client.configure_client( + emit_addr="tcp://127.0.0.1:%d" % cls._server.pull_port, + reg_addr="tcp://127.0.0.1:%d" % cls._server.pub_port) @classmethod def tearDownEnv(cls): @@ -87,8 +91,9 @@ class BaseLeapTest(unittest.TestCase): - restores the default PATH and HOME variables - removes the temporal folder """ - events_client.shutdown() - cls._server.shutdown() + if flags.EVENTS_ENABLED: + events_client.shutdown() + cls._server.shutdown() os.environ["PATH"] = cls.old_path os.environ["HOME"] = cls.old_home diff --git a/src/leap/common/testing/test_basetest.py b/src/leap/common/testing/test_basetest.py index cf0962d..ec42a62 100644 --- a/src/leap/common/testing/test_basetest.py +++ b/src/leap/common/testing/test_basetest.py @@ -83,12 +83,10 @@ class TestInitBaseLeapTest(BaseLeapTest): """ def setUp(self): - """nuke it""" - pass + self.setUpEnv() def tearDown(self): - """nuke it""" - pass + self.tearDownEnv() def test_path_is_changed(self): """tests whether we have changed the PATH env var""" diff --git a/src/leap/common/tests/test_certs.py b/src/leap/common/tests/test_certs.py index 999071f..8ebc0f4 100644 --- a/src/leap/common/tests/test_certs.py +++ b/src/leap/common/tests/test_certs.py @@ -43,10 +43,10 @@ CERT_NOT_AFTER = (2023, 9, 1, 17, 52, 16, 4, 244, 0) class CertsTest(BaseLeapTest): def setUp(self): - pass + self.setUpEnv() def tearDown(self): - pass + self.tearDownEnv() def test_should_redownload_if_no_cert(self): self.assertTrue(certs.should_redownload(certfile="")) @@ -60,11 +60,13 @@ class CertsTest(BaseLeapTest): self.assertTrue(certs.should_redownload(cert_path)) def test_should_redownload_if_before(self): - new_now = lambda: time.struct_time(CERT_NOT_BEFORE) + def new_now(): + time.struct_time(CERT_NOT_BEFORE) self.assertTrue(certs.should_redownload(TEST_CERT_PEM, now=new_now)) def test_should_redownload_if_after(self): - new_now = lambda: time.struct_time(CERT_NOT_AFTER) + def new_now(): + time.struct_time(CERT_NOT_AFTER) self.assertTrue(certs.should_redownload(TEST_CERT_PEM, now=new_now)) def test_not_should_redownload(self): diff --git a/src/leap/common/tests/test_events.py b/src/leap/common/tests/test_events.py index 7ef3e1b..2ad097e 100644 --- a/src/leap/common/tests/test_events.py +++ b/src/leap/common/tests/test_events.py @@ -1,4 +1,4 @@ -## -*- coding: utf-8 -*- +# -*- coding: utf-8 -*- # test_events.py # Copyright (C) 2013 LEAP # @@ -20,11 +20,13 @@ import os import logging import time +from twisted.internet.reactor import callFromThread from twisted.trial import unittest from twisted.internet import defer from leap.common.events import server from leap.common.events import client +from leap.common.events import flags from leap.common.events import txclient from leap.common.events import catalog from leap.common.events.errors import CallbackAlreadyRegisteredError @@ -37,6 +39,7 @@ if 'DEBUG' in os.environ: class EventsGenericClientTestCase(object): def setUp(self): + flags.set_events_enabled(True) self._server = server.ensure_server( emit_addr="tcp://127.0.0.1:0", reg_addr="tcp://127.0.0.1:0") @@ -47,6 +50,7 @@ class EventsGenericClientTestCase(object): def tearDown(self): self._client.shutdown() self._server.shutdown() + flags.set_events_enabled(False) # wait a bit for sockets to close properly time.sleep(0.1) @@ -59,7 +63,10 @@ class EventsGenericClientTestCase(object): 'There should be no callback for this event.') # register one event event1 = catalog.CLIENT_UID - cbk1 = lambda event, _: True + + def cbk1(event, _): + return True + uid1 = self._client.register(event1, cbk1) # assert for correct registration self.assertTrue(len(callbacks) == 1) @@ -67,7 +74,10 @@ class EventsGenericClientTestCase(object): 'Could not register event in local client.') # register another event event2 = catalog.CLIENT_SESSION_ID - cbk2 = lambda event, _: True + + def cbk2(event, _): + return True + uid2 = self._client.register(event2, cbk2) # assert for correct registration self.assertTrue(len(callbacks) == 2) @@ -80,8 +90,13 @@ class EventsGenericClientTestCase(object): """ event = catalog.CLIENT_UID d = defer.Deferred() - cbk_fail = lambda event, _: d.errback(event) - cbk_succeed = lambda event, _: d.callback(event) + + def cbk_fail(event, _): + return callFromThread(d.errback, event) + + def cbk_succeed(event, _): + return callFromThread(d.callback, event) + self._client.register(event, cbk_fail, uid=1) self._client.register(event, cbk_succeed, uid=1, replace=True) self._client.emit(event, None) @@ -105,9 +120,15 @@ class EventsGenericClientTestCase(object): """ event = catalog.CLIENT_UID d1 = defer.Deferred() - cbk1 = lambda event, _: d1.callback(event) + + def cbk1(event, _): + return callFromThread(d1.callback, event) + d2 = defer.Deferred() - cbk2 = lambda event, _: d2.callback(event) + + def cbk2(event, _): + return d2.callback(event) + self._client.register(event, cbk1) self._client.register(event, cbk2) self._client.emit(event, None) @@ -120,8 +141,10 @@ class EventsGenericClientTestCase(object): """ event = catalog.CLIENT_UID d = defer.Deferred() + def cbk(events, _): - d.callback(event) + callFromThread(d.callback, event) + self._client.register(event, cbk) self._client.emit(event, None) return d @@ -133,14 +156,17 @@ class EventsGenericClientTestCase(object): event1 = catalog.CLIENT_UID d = defer.Deferred() # register more than one callback for the same event - self._client.register(event1, lambda ev, _: d.errback(None)) - self._client.register(event1, lambda ev, _: d.errback(None)) + self._client.register( + event1, lambda ev, _: callFromThread(d.errback, None)) + self._client.register( + event1, lambda ev, _: callFromThread(d.errback, None)) # unregister and emit the event self._client.unregister(event1) self._client.emit(event1, None) # register and emit another event so the deferred can succeed event2 = catalog.CLIENT_SESSION_ID - self._client.register(event2, lambda ev, _: d.callback(None)) + self._client.register( + event2, lambda ev, _: callFromThread(d.callback, None)) self._client.emit(event2, None) return d @@ -151,9 +177,11 @@ class EventsGenericClientTestCase(object): event = catalog.CLIENT_UID d = defer.Deferred() # register one callback that would fail - uid = self._client.register(event, lambda ev, _: d.errback(None)) + uid = self._client.register( + event, lambda ev, _: callFromThread(d.errback, None)) # register one callback that will succeed - self._client.register(event, lambda ev, _: d.callback(None)) + self._client.register( + event, lambda ev, _: callFromThread(d.callback, None)) # unregister by uid and emit the event self._client.unregister(event, uid=uid) self._client.emit(event, None) diff --git a/src/leap/common/tests/test_http.py b/src/leap/common/tests/test_http.py new file mode 100644 index 0000000..f44550f --- /dev/null +++ b/src/leap/common/tests/test_http.py @@ -0,0 +1,75 @@ +# -*- coding: utf-8 -*- +# test_http.py +# Copyright (C) 2013 LEAP +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. +""" +Tests for: + * leap/common/http.py +""" +import os +try: + import unittest2 as unittest +except ImportError: + import unittest + +from leap.common import http +from leap.common.testing.basetest import BaseLeapTest + +TEST_CERT_PEM = os.path.join( + os.path.split(__file__)[0], + '..', 'testing', "leaptest_combined_keycert.pem") + + +class HTTPClientTest(BaseLeapTest): + + def setUp(self): + pass + + def tearDown(self): + pass + + def test_agents_sharing_pool_by_default(self): + client = http.HTTPClient() + client2 = http.HTTPClient(TEST_CERT_PEM) + self.assertNotEquals( + client._agent, client2._agent, "Expected dedicated agents") + self.assertEquals( + client._agent._pool, client2._agent._pool, + "Pool was not reused by default") + + def test_agent_can_have_dedicated_custom_pool(self): + custom_pool = http._HTTPConnectionPool( + None, + timeout=10, + maxPersistentPerHost=42, + persistent=False + ) + self.assertEquals(custom_pool.maxPersistentPerHost, 42, + "Custom persistent connections " + "limit is not being respected") + self.assertFalse(custom_pool.persistent, + "Custom persistence is not being respected") + default_client = http.HTTPClient() + custom_client = http.HTTPClient(pool=custom_pool) + + self.assertNotEquals( + default_client._agent, custom_client._agent, + "No agent reuse is expected") + self.assertEquals( + custom_pool, custom_client._agent._pool, + "Custom pool usage was not respected") + +if __name__ == "__main__": + unittest.main() diff --git a/src/leap/common/zmq_utils.py b/src/leap/common/zmq_utils.py index 19625b9..0a781de 100644 --- a/src/leap/common/zmq_utils.py +++ b/src/leap/common/zmq_utils.py @@ -101,5 +101,3 @@ def maybe_create_and_get_certificates(basedir, name): mkdir_p(public_keys_dir) shutil.move(old_public_key, new_public_key) return zmq.auth.load_certificate(private_key) - - |