diff options
-rw-r--r-- | src/leap/common/http.py | 98 | ||||
-rw-r--r-- | src/leap/common/tests/test_http.py | 7 |
2 files changed, 59 insertions, 46 deletions
diff --git a/src/leap/common/http.py b/src/leap/common/http.py index 56938b4..f67507d 100644 --- a/src/leap/common/http.py +++ b/src/leap/common/http.py @@ -56,6 +56,50 @@ __all__ = ["HTTPClient"] DEFAULT_HTTP_TIMEOUT = 30 # seconds +class _HTTP11ClientFactory(HTTP11ClientFactory): + """ + A timeout-able HTTP 1.1 client protocol factory. + """ + + def __init__(self, quiescentCallback, timeout): + """ + :param quiescentCallback: The quiescent callback to be passed to + protocol instances, used to return them to + the connection pool. + :type quiescentCallback: callable(Protocol) + :param timeout: The timeout, in seconds, for requests made by + protocols created by this factory. + :type timeout: float + """ + HTTP11ClientFactory.__init__(self, quiescentCallback) + self._timeout = timeout + + def buildProtocol(self, _): + """ + Build the HTTP 1.1 client protocol. + """ + return _HTTP11ClientProtocol(self._quiescentCallback, self._timeout) + + +class _HTTPConnectionPool(HTTPConnectionPool): + """ + A timeout-able HTTP connection pool. + """ + + _factory = _HTTP11ClientFactory + + def __init__(self, reactor, persistent, timeout, maxPersistentPerHost=10): + HTTPConnectionPool.__init__(self, reactor, persistent=persistent) + self.maxPersistentPerHost = maxPersistentPerHost + self._timeout = timeout + + def _newConnection(self, key, endpoint): + def quiescentCallback(protocol): + self._putConnection(key, protocol) + factory = self._factory(quiescentCallback, timeout=self._timeout) + return endpoint.connect(factory) + + class HTTPClient(object): """ HTTP client done the twisted way, with a main focus on pinning the SSL @@ -69,7 +113,14 @@ class HTTPClient(object): in order to avoid resource abuse on huge requests batches. """ - def __init__(self, cert_file=None, timeout=DEFAULT_HTTP_TIMEOUT): + _pool = _HTTPConnectionPool( + reactor, + persistent=True, + timeout = DEFAULT_HTTP_TIMEOUT, + maxPersistentPerHost=10 + ) + + def __init__(self, cert_file=None, timeout=DEFAULT_HTTP_TIMEOUT, pool=None): """ Init the HTTP client @@ -84,7 +135,7 @@ class HTTPClient(object): """ self._timeout = timeout - self._pool = self._createPool() + self._pool = pool if pool is not None else self._pool self._agent = Agent( reactor, get_compatible_ssl_context_factory(cert_file), @@ -278,46 +329,3 @@ class _HTTP11ClientProtocol(HTTP11ClientProtocol): HTTP11ClientProtocol.dataReceived(self, bytes) if self._timeoutCall and self._timeoutCall.active(): self._timeoutCall.reset(self._timeout) - - -class _HTTP11ClientFactory(HTTP11ClientFactory): - """ - A timeout-able HTTP 1.1 client protocol factory. - """ - - def __init__(self, quiescentCallback, timeout): - """ - :param quiescentCallback: The quiescent callback to be passed to - protocol instances, used to return them to - the connection pool. - :type quiescentCallback: callable(Protocol) - :param timeout: The timeout, in seconds, for requests made by - protocols created by this factory. - :type timeout: float - """ - HTTP11ClientFactory.__init__(self, quiescentCallback) - self._timeout = timeout - - def buildProtocol(self, _): - """ - Build the HTTP 1.1 client protocol. - """ - return _HTTP11ClientProtocol(self._quiescentCallback, self._timeout) - - -class _HTTPConnectionPool(HTTPConnectionPool): - """ - A timeout-able HTTP connection pool. - """ - - _factory = _HTTP11ClientFactory - - def __init__(self, reactor, persistent, timeout): - HTTPConnectionPool.__init__(self, reactor, persistent=persistent) - self._timeout = timeout - - def _newConnection(self, key, endpoint): - def quiescentCallback(protocol): - self._putConnection(key, protocol) - factory = self._factory(quiescentCallback, timeout=self._timeout) - return endpoint.connect(factory) diff --git a/src/leap/common/tests/test_http.py b/src/leap/common/tests/test_http.py index e240ca3..a586fd1 100644 --- a/src/leap/common/tests/test_http.py +++ b/src/leap/common/tests/test_http.py @@ -47,7 +47,12 @@ class HTTPClientTest(BaseLeapTest): self.assertEquals(client._agent._pool, client2._agent._pool, "Pool was not reused by default") def test_agent_can_have_dedicated_custom_pool(self): - custom_pool = http.createPool(maxPersistentPerHost=42, persistent=False) + custom_pool = http._HTTPConnectionPool( + None, + timeout=10, + maxPersistentPerHost=42, + persistent=False + ) self.assertEquals(custom_pool.maxPersistentPerHost, 42, "Custom persistent connections limit is not being respected") self.assertFalse(custom_pool.persistent, |