summaryrefslogtreecommitdiff
path: root/src/leap/common/http.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/leap/common/http.py')
-rw-r--r--src/leap/common/http.py342
1 files changed, 264 insertions, 78 deletions
diff --git a/src/leap/common/http.py b/src/leap/common/http.py
index 39f01ba..0dee3a2 100644
--- a/src/leap/common/http.py
+++ b/src/leap/common/http.py
@@ -18,72 +18,141 @@
Twisted HTTP/HTTPS client.
"""
-import os
+try:
+ import twisted
+ assert twisted
+except ImportError:
+ print "*******"
+ print "Twisted is needed to use leap.common.http module"
+ print ""
+ print "Install the extra requirement of the package:"
+ print "$ pip install leap.common[Twisted]"
+ import sys
+ sys.exit(1)
-from zope.interface import implements
-from OpenSSL.crypto import load_certificate
-from OpenSSL.crypto import FILETYPE_PEM
+from leap.common.certs import get_compatible_ssl_context_factory
+from leap.common.check import leap_assert
+
+from zope.interface import implements
from twisted.internet import reactor
-from twisted.internet.ssl import ClientContextFactory
-from twisted.internet.ssl import CertificateOptions
-from twisted.internet.defer import succeed
+from twisted.internet import defer
+from twisted.python import failure
from twisted.web.client import Agent
from twisted.web.client import HTTPConnectionPool
+from twisted.web.client import _HTTP11ClientFactory as HTTP11ClientFactory
from twisted.web.client import readBody
-from twisted.web.client import BrowserLikePolicyForHTTPS
from twisted.web.http_headers import Headers
from twisted.web.iweb import IBodyProducer
+from twisted.web._newclient import HTTP11ClientProtocol
+
+
+__all__ = ["HTTPClient"]
+
+
+# A default HTTP timeout is used for 2 distinct purposes:
+# 1. as HTTP connection timeout, prior to connection estabilshment.
+# 2. as data reception timeout, after the connection has been established.
+DEFAULT_HTTP_TIMEOUT = 30 # seconds
+
+
+class _HTTP11ClientFactory(HTTP11ClientFactory):
+ """
+ A timeout-able HTTP 1.1 client protocol factory.
+ """
+
+ def __init__(self, quiescentCallback, timeout):
+ """
+ :param quiescentCallback: The quiescent callback to be passed to
+ protocol instances, used to return them to
+ the connection pool.
+ :type quiescentCallback: callable(Protocol)
+ :param timeout: The timeout, in seconds, for requests made by
+ protocols created by this factory.
+ :type timeout: float
+ """
+ HTTP11ClientFactory.__init__(self, quiescentCallback)
+ self._timeout = timeout
+
+ def buildProtocol(self, _):
+ """
+ Build the HTTP 1.1 client protocol.
+ """
+ return _HTTP11ClientProtocol(self._quiescentCallback, self._timeout)
+
+
+class _HTTPConnectionPool(HTTPConnectionPool):
+ """
+ A timeout-able HTTP connection pool.
+ """
+
+ _factory = _HTTP11ClientFactory
+
+ def __init__(self, reactor, persistent, timeout, maxPersistentPerHost=10):
+ HTTPConnectionPool.__init__(self, reactor, persistent=persistent)
+ self.maxPersistentPerHost = maxPersistentPerHost
+ self._timeout = timeout
+
+ def _newConnection(self, key, endpoint):
+ def quiescentCallback(protocol):
+ self._putConnection(key, protocol)
+ factory = self._factory(quiescentCallback, timeout=self._timeout)
+ return endpoint.connect(factory)
class HTTPClient(object):
"""
HTTP client done the twisted way, with a main focus on pinning the SSL
certificate.
+
+ By default, it uses a shared connection pool. If you want a dedicated
+ one, create and pass on __init__ pool parameter.
+ Please note that this client will limit the maximum amount of connections
+ by using a DeferredSemaphore.
+ This limit is equal to the maxPersistentPerHost used on pool and is needed
+ in order to avoid resource abuse on huge requests batches.
"""
- def __init__(self, cert_file=None):
+ _pool = _HTTPConnectionPool(
+ reactor,
+ persistent=True,
+ timeout=DEFAULT_HTTP_TIMEOUT,
+ maxPersistentPerHost=10
+ )
+
+ def __init__(self, cert_file=None,
+ timeout=DEFAULT_HTTP_TIMEOUT, pool=None):
"""
Init the HTTP client
:param cert_file: The path to the certificate file, if None given the
system's CAs will be used.
:type cert_file: str
+ :param timeout: The amount of time that this Agent will wait for the
+ peer to accept a connection and for each request to be
+ finished. If a pool is passed, then this argument is
+ ignored.
+ :type timeout: float
"""
- self._pool = HTTPConnectionPool(reactor, persistent=True)
- self._pool.maxPersistentPerHost = 10
- if cert_file:
- cert = self._load_cert(cert_file)
- self._agent = Agent(
- reactor,
- HTTPClient.ClientContextFactory(cert),
- pool=self._pool)
- else:
- # trust the system's CAs
- self._agent = Agent(
- reactor,
- BrowserLikePolicyForHTTPS(),
- pool=self._pool)
+ self._timeout = timeout
+ self._pool = pool if pool is not None else self._pool
+ self._agent = Agent(
+ reactor,
+ get_compatible_ssl_context_factory(cert_file),
+ pool=self._pool,
+ connectTimeout=self._timeout)
+ self._semaphore = defer.DeferredSemaphore(
+ self._pool.maxPersistentPerHost)
- def _load_cert(self, cert_file):
- """
- Load a X509 certificate from a file.
+ def _createPool(self, maxPersistentPerHost=10, persistent=True):
+ pool = _HTTPConnectionPool(reactor, persistent, self._timeout)
+ pool.maxPersistentPerHost = maxPersistentPerHost
+ return pool
- :param cert_file: The path to the certificate file.
- :type cert_file: str
-
- :return: The X509 certificate.
- :rtype: OpenSSL.crypto.X509
- """
- if os.path.exists(cert_file):
- with open(cert_file) as f:
- data = f.read()
- return load_certificate(FILETYPE_PEM, data)
-
- def request(self, url, method='GET', body=None, headers={}):
+ def _request(self, url, method, body, headers, callback):
"""
Perform an HTTP request.
@@ -95,68 +164,185 @@ class HTTPClient(object):
:type body: str
:param headers: The headers of the request.
:type headers: dict
+ :param callback: A callback to be added to the request's deferred
+ callback chain.
+ :type callback: callable
:return: A deferred that fires with the body of the request.
:rtype: twisted.internet.defer.Deferred
"""
if body:
- body = HTTPClient.StringBodyProducer(body)
+ body = _StringBodyProducer(body)
d = self._agent.request(
method, url, headers=Headers(headers), bodyProducer=body)
- d.addCallback(readBody)
+ d.addCallback(callback)
return d
- class ClientContextFactory(ClientContextFactory):
+ def request(self, url, method='GET', body=None, headers={},
+ callback=readBody):
+ """
+ Perform an HTTP request, but limit the maximum amount of concurrent
+ connections.
+
+ May be passed a callback to be added to the request's deferred
+ callback chain. The callback is expected to receive the response of
+ the request and may do whatever it wants with the response. By
+ default, if no callback is passed, we will use a simple body reader
+ which returns a deferred that is fired with the body of the response.
+
+ :param url: The URL for the request.
+ :type url: str
+ :param method: The HTTP method of the request.
+ :type method: str
+ :param body: The body of the request, if any.
+ :type body: str
+ :param headers: The headers of the request.
+ :type headers: dict
+ :param callback: A callback to be added to the request's deferred
+ callback chain.
+ :type callback: callable
+
+ :return: A deferred that fires with the body of the request.
+ :rtype: twisted.internet.defer.Deferred
+ """
+ leap_assert(
+ callable(callback),
+ message="The callback parameter should be a callable!")
+ return self._semaphore.run(self._request, url, method, body, headers,
+ callback)
+
+ def close(self):
"""
- A context factory that will verify the server's certificate against a
- given CA certificate.
+ Close any cached connections.
"""
+ self._pool.closeCachedConnections()
+
+#
+# An IBodyProducer to write the body of an HTTP request as a string.
+#
+
+
+class _StringBodyProducer(object):
+ """
+ A producer that writes the body of a request to a consumer.
+ """
- def __init__(self, cacert):
- """
- Initialize the context factory.
+ implements(IBodyProducer)
- :param cacert: The CA certificate.
- :type cacert: OpenSSL.crypto.X509
- """
- self._cacert = cacert
+ def __init__(self, body):
+ """
+ Initialize the string produer.
- def getContext(self, hostname, port):
- opts = CertificateOptions(verify=True, caCerts=[self._cacert])
- return opts.getContext()
+ :param body: The body of the request.
+ :type body: str
+ """
+ self.body = body
+ self.length = len(body)
- class StringBodyProducer(object):
+ def startProducing(self, consumer):
"""
- A producer that writes the body of a request to a consumer.
+ Write the body to the consumer.
+
+ :param consumer: Any IConsumer provider.
+ :type consumer: twisted.internet.interfaces.IConsumer
+
+ :return: A successful deferred.
+ :rtype: twisted.internet.defer.Deferred
"""
+ consumer.write(self.body)
+ return defer.succeed(None)
- implements(IBodyProducer)
+ def pauseProducing(self):
+ pass
- def __init__(self, body):
- """
- Initialize the string produer.
+ def stopProducing(self):
+ pass
- :param body: The body of the request.
- :type body: str
- """
- self.body = body
- self.length = len(body)
- def startProducing(self, consumer):
- """
- Write the body to the consumer.
+#
+# Patched twisted.web classes
+#
- :param consumer: Any IConsumer provider.
- :type consumer: twisted.internet.interfaces.IConsumer
+class _HTTP11ClientProtocol(HTTP11ClientProtocol):
+ """
+ A timeout-able HTTP 1.1 client protocol, that is instantiated by the
+ _HTTP11ClientFactory below.
+ """
- :return: A successful deferred.
- :rtype: twisted.internet.defer.Deferred
- """
- consumer.write(self.body)
- return succeed(None)
+ def __init__(self, quiescentCallback, timeout):
+ """
+ Initialize the protocol.
- def pauseProducing(self):
- pass
+ :param quiescentCallback:
+ :type quiescentCallback: callable
+ :param timeout: A timeout, in seconds, for requests made by this
+ protocol.
+ :type timeout: float
+ """
+ HTTP11ClientProtocol.__init__(self, quiescentCallback)
+ self._timeout = timeout
+ self._timeoutCall = None
+
+ def request(self, request):
+ """
+ Issue request over self.transport and return a Deferred which
+ will fire with a Response instance or an error.
+
+ :param request: The object defining the parameters of the request to
+ issue.
+ :type request: twisted.web._newclient.Request
+
+ :return: A deferred which fires after the request has finished.
+ :rtype: Deferred
+ """
+ d = HTTP11ClientProtocol.request(self, request)
+ if self._timeout:
+ self._last_buffer_len = 0
+ timeoutCall = reactor.callLater(
+ self._timeout, self._doTimeout, request)
+ self._timeoutCall = timeoutCall
+ return d
+
+ def _doTimeout(self, request):
+ """
+ Give up the request because of a timeout.
+
+ :param request: The object defining the parameters of the request to
+ issue.
+ :type request: twisted.web._newclient.Request
+ """
+ self._giveUp(
+ failure.Failure(
+ defer.TimeoutError(
+ "Getting %s took longer than %s seconds."
+ % (request.absoluteURI, self._timeout))))
+
+ def _cancelTimeout(self):
+ """
+ Cancel the request timeout, when it's finished.
+ """
+ if self._timeoutCall and self._timeoutCall.active():
+ self._timeoutCall.cancel()
+ self._timeoutCall = None
+
+ def _finishResponse(self, rest):
+ """
+ Cancel the timeout when finished receiving the response.
+ """
+ self._cancelTimeout()
+ HTTP11ClientProtocol._finishResponse(self, rest)
+
+ def dataReceived(self, bytes):
+ """
+ Receive some data and extend the timeout period of this request.
+
+ :param bytes: A string of indeterminate length.
+ :type bytes: str
+ """
+ HTTP11ClientProtocol.dataReceived(self, bytes)
+ if self._timeoutCall and self._timeoutCall.active():
+ self._timeoutCall.reset(self._timeout)
- def stopProducing(self):
- pass
+ def connectionLost(self, reason):
+ self._cancelTimeout()
+ return HTTP11ClientProtocol.connectionLost(self, reason)