diff options
Diffstat (limited to 'src/leap/common/http.py')
-rw-r--r-- | src/leap/common/http.py | 342 |
1 files changed, 264 insertions, 78 deletions
diff --git a/src/leap/common/http.py b/src/leap/common/http.py index 39f01ba..0dee3a2 100644 --- a/src/leap/common/http.py +++ b/src/leap/common/http.py @@ -18,72 +18,141 @@ Twisted HTTP/HTTPS client. """ -import os +try: + import twisted + assert twisted +except ImportError: + print "*******" + print "Twisted is needed to use leap.common.http module" + print "" + print "Install the extra requirement of the package:" + print "$ pip install leap.common[Twisted]" + import sys + sys.exit(1) -from zope.interface import implements -from OpenSSL.crypto import load_certificate -from OpenSSL.crypto import FILETYPE_PEM +from leap.common.certs import get_compatible_ssl_context_factory +from leap.common.check import leap_assert + +from zope.interface import implements from twisted.internet import reactor -from twisted.internet.ssl import ClientContextFactory -from twisted.internet.ssl import CertificateOptions -from twisted.internet.defer import succeed +from twisted.internet import defer +from twisted.python import failure from twisted.web.client import Agent from twisted.web.client import HTTPConnectionPool +from twisted.web.client import _HTTP11ClientFactory as HTTP11ClientFactory from twisted.web.client import readBody -from twisted.web.client import BrowserLikePolicyForHTTPS from twisted.web.http_headers import Headers from twisted.web.iweb import IBodyProducer +from twisted.web._newclient import HTTP11ClientProtocol + + +__all__ = ["HTTPClient"] + + +# A default HTTP timeout is used for 2 distinct purposes: +# 1. as HTTP connection timeout, prior to connection estabilshment. +# 2. as data reception timeout, after the connection has been established. +DEFAULT_HTTP_TIMEOUT = 30 # seconds + + +class _HTTP11ClientFactory(HTTP11ClientFactory): + """ + A timeout-able HTTP 1.1 client protocol factory. + """ + + def __init__(self, quiescentCallback, timeout): + """ + :param quiescentCallback: The quiescent callback to be passed to + protocol instances, used to return them to + the connection pool. + :type quiescentCallback: callable(Protocol) + :param timeout: The timeout, in seconds, for requests made by + protocols created by this factory. + :type timeout: float + """ + HTTP11ClientFactory.__init__(self, quiescentCallback) + self._timeout = timeout + + def buildProtocol(self, _): + """ + Build the HTTP 1.1 client protocol. + """ + return _HTTP11ClientProtocol(self._quiescentCallback, self._timeout) + + +class _HTTPConnectionPool(HTTPConnectionPool): + """ + A timeout-able HTTP connection pool. + """ + + _factory = _HTTP11ClientFactory + + def __init__(self, reactor, persistent, timeout, maxPersistentPerHost=10): + HTTPConnectionPool.__init__(self, reactor, persistent=persistent) + self.maxPersistentPerHost = maxPersistentPerHost + self._timeout = timeout + + def _newConnection(self, key, endpoint): + def quiescentCallback(protocol): + self._putConnection(key, protocol) + factory = self._factory(quiescentCallback, timeout=self._timeout) + return endpoint.connect(factory) class HTTPClient(object): """ HTTP client done the twisted way, with a main focus on pinning the SSL certificate. + + By default, it uses a shared connection pool. If you want a dedicated + one, create and pass on __init__ pool parameter. + Please note that this client will limit the maximum amount of connections + by using a DeferredSemaphore. + This limit is equal to the maxPersistentPerHost used on pool and is needed + in order to avoid resource abuse on huge requests batches. """ - def __init__(self, cert_file=None): + _pool = _HTTPConnectionPool( + reactor, + persistent=True, + timeout=DEFAULT_HTTP_TIMEOUT, + maxPersistentPerHost=10 + ) + + def __init__(self, cert_file=None, + timeout=DEFAULT_HTTP_TIMEOUT, pool=None): """ Init the HTTP client :param cert_file: The path to the certificate file, if None given the system's CAs will be used. :type cert_file: str + :param timeout: The amount of time that this Agent will wait for the + peer to accept a connection and for each request to be + finished. If a pool is passed, then this argument is + ignored. + :type timeout: float """ - self._pool = HTTPConnectionPool(reactor, persistent=True) - self._pool.maxPersistentPerHost = 10 - if cert_file: - cert = self._load_cert(cert_file) - self._agent = Agent( - reactor, - HTTPClient.ClientContextFactory(cert), - pool=self._pool) - else: - # trust the system's CAs - self._agent = Agent( - reactor, - BrowserLikePolicyForHTTPS(), - pool=self._pool) + self._timeout = timeout + self._pool = pool if pool is not None else self._pool + self._agent = Agent( + reactor, + get_compatible_ssl_context_factory(cert_file), + pool=self._pool, + connectTimeout=self._timeout) + self._semaphore = defer.DeferredSemaphore( + self._pool.maxPersistentPerHost) - def _load_cert(self, cert_file): - """ - Load a X509 certificate from a file. + def _createPool(self, maxPersistentPerHost=10, persistent=True): + pool = _HTTPConnectionPool(reactor, persistent, self._timeout) + pool.maxPersistentPerHost = maxPersistentPerHost + return pool - :param cert_file: The path to the certificate file. - :type cert_file: str - - :return: The X509 certificate. - :rtype: OpenSSL.crypto.X509 - """ - if os.path.exists(cert_file): - with open(cert_file) as f: - data = f.read() - return load_certificate(FILETYPE_PEM, data) - - def request(self, url, method='GET', body=None, headers={}): + def _request(self, url, method, body, headers, callback): """ Perform an HTTP request. @@ -95,68 +164,185 @@ class HTTPClient(object): :type body: str :param headers: The headers of the request. :type headers: dict + :param callback: A callback to be added to the request's deferred + callback chain. + :type callback: callable :return: A deferred that fires with the body of the request. :rtype: twisted.internet.defer.Deferred """ if body: - body = HTTPClient.StringBodyProducer(body) + body = _StringBodyProducer(body) d = self._agent.request( method, url, headers=Headers(headers), bodyProducer=body) - d.addCallback(readBody) + d.addCallback(callback) return d - class ClientContextFactory(ClientContextFactory): + def request(self, url, method='GET', body=None, headers={}, + callback=readBody): + """ + Perform an HTTP request, but limit the maximum amount of concurrent + connections. + + May be passed a callback to be added to the request's deferred + callback chain. The callback is expected to receive the response of + the request and may do whatever it wants with the response. By + default, if no callback is passed, we will use a simple body reader + which returns a deferred that is fired with the body of the response. + + :param url: The URL for the request. + :type url: str + :param method: The HTTP method of the request. + :type method: str + :param body: The body of the request, if any. + :type body: str + :param headers: The headers of the request. + :type headers: dict + :param callback: A callback to be added to the request's deferred + callback chain. + :type callback: callable + + :return: A deferred that fires with the body of the request. + :rtype: twisted.internet.defer.Deferred + """ + leap_assert( + callable(callback), + message="The callback parameter should be a callable!") + return self._semaphore.run(self._request, url, method, body, headers, + callback) + + def close(self): """ - A context factory that will verify the server's certificate against a - given CA certificate. + Close any cached connections. """ + self._pool.closeCachedConnections() + +# +# An IBodyProducer to write the body of an HTTP request as a string. +# + + +class _StringBodyProducer(object): + """ + A producer that writes the body of a request to a consumer. + """ - def __init__(self, cacert): - """ - Initialize the context factory. + implements(IBodyProducer) - :param cacert: The CA certificate. - :type cacert: OpenSSL.crypto.X509 - """ - self._cacert = cacert + def __init__(self, body): + """ + Initialize the string produer. - def getContext(self, hostname, port): - opts = CertificateOptions(verify=True, caCerts=[self._cacert]) - return opts.getContext() + :param body: The body of the request. + :type body: str + """ + self.body = body + self.length = len(body) - class StringBodyProducer(object): + def startProducing(self, consumer): """ - A producer that writes the body of a request to a consumer. + Write the body to the consumer. + + :param consumer: Any IConsumer provider. + :type consumer: twisted.internet.interfaces.IConsumer + + :return: A successful deferred. + :rtype: twisted.internet.defer.Deferred """ + consumer.write(self.body) + return defer.succeed(None) - implements(IBodyProducer) + def pauseProducing(self): + pass - def __init__(self, body): - """ - Initialize the string produer. + def stopProducing(self): + pass - :param body: The body of the request. - :type body: str - """ - self.body = body - self.length = len(body) - def startProducing(self, consumer): - """ - Write the body to the consumer. +# +# Patched twisted.web classes +# - :param consumer: Any IConsumer provider. - :type consumer: twisted.internet.interfaces.IConsumer +class _HTTP11ClientProtocol(HTTP11ClientProtocol): + """ + A timeout-able HTTP 1.1 client protocol, that is instantiated by the + _HTTP11ClientFactory below. + """ - :return: A successful deferred. - :rtype: twisted.internet.defer.Deferred - """ - consumer.write(self.body) - return succeed(None) + def __init__(self, quiescentCallback, timeout): + """ + Initialize the protocol. - def pauseProducing(self): - pass + :param quiescentCallback: + :type quiescentCallback: callable + :param timeout: A timeout, in seconds, for requests made by this + protocol. + :type timeout: float + """ + HTTP11ClientProtocol.__init__(self, quiescentCallback) + self._timeout = timeout + self._timeoutCall = None + + def request(self, request): + """ + Issue request over self.transport and return a Deferred which + will fire with a Response instance or an error. + + :param request: The object defining the parameters of the request to + issue. + :type request: twisted.web._newclient.Request + + :return: A deferred which fires after the request has finished. + :rtype: Deferred + """ + d = HTTP11ClientProtocol.request(self, request) + if self._timeout: + self._last_buffer_len = 0 + timeoutCall = reactor.callLater( + self._timeout, self._doTimeout, request) + self._timeoutCall = timeoutCall + return d + + def _doTimeout(self, request): + """ + Give up the request because of a timeout. + + :param request: The object defining the parameters of the request to + issue. + :type request: twisted.web._newclient.Request + """ + self._giveUp( + failure.Failure( + defer.TimeoutError( + "Getting %s took longer than %s seconds." + % (request.absoluteURI, self._timeout)))) + + def _cancelTimeout(self): + """ + Cancel the request timeout, when it's finished. + """ + if self._timeoutCall and self._timeoutCall.active(): + self._timeoutCall.cancel() + self._timeoutCall = None + + def _finishResponse(self, rest): + """ + Cancel the timeout when finished receiving the response. + """ + self._cancelTimeout() + HTTP11ClientProtocol._finishResponse(self, rest) + + def dataReceived(self, bytes): + """ + Receive some data and extend the timeout period of this request. + + :param bytes: A string of indeterminate length. + :type bytes: str + """ + HTTP11ClientProtocol.dataReceived(self, bytes) + if self._timeoutCall and self._timeoutCall.active(): + self._timeoutCall.reset(self._timeout) - def stopProducing(self): - pass + def connectionLost(self, reason): + self._cancelTimeout() + return HTTP11ClientProtocol.connectionLost(self, reason) |