[bug] Makes request method respect a hard limit
[leap_pycommon.git] / src / leap / common / http.py
index 1dc5642..d4a214c 100644 (file)
@@ -35,6 +35,7 @@ from leap.common.certs import get_compatible_ssl_context_factory
 from zope.interface import implements
 
 from twisted.internet import reactor
+from twisted.internet import defer
 from twisted.internet.defer import succeed
 
 from twisted.web.client import Agent
@@ -44,29 +45,46 @@ from twisted.web.http_headers import Headers
 from twisted.web.iweb import IBodyProducer
 
 
+def createPool(maxPersistentPerHost=10, persistent=True):
+    pool = HTTPConnectionPool(reactor, persistent)
+    pool.maxPersistentPerHost = maxPersistentPerHost
+    return pool
+
+_pool = createPool()
+
+
 class HTTPClient(object):
     """
     HTTP client done the twisted way, with a main focus on pinning the SSL
     certificate.
+
+    By default, it uses a shared connection pool. If you want a dedicated
+    one, create and pass on __init__ pool parameter.
+    Please note that this client will limit the maximum amount of connections
+    by using a DeferredSemaphore.
+    This limit is equal to the maxPersistentPerHost used on pool and is needed
+    in order to avoid resource abuse on huge requests batches.
     """
 
-    def __init__(self, cert_file=None):
+    def __init__(self, cert_file=None, pool=_pool):
         """
         Init the HTTP client
 
         :param cert_file: The path to the certificate file, if None given the
                           system's CAs will be used.
         :type cert_file: str
+        :param pool: An optional dedicated connection pool to override the
+                     default shared one.
+        :type pool: HTTPConnectionPool
         """
-        self._pool = HTTPConnectionPool(reactor, persistent=True)
-        self._pool.maxPersistentPerHost = 10
 
         policy = get_compatible_ssl_context_factory(cert_file)
 
         self._agent = Agent(
             reactor,
             policy,
-            pool=self._pool)
+            pool=pool)
+        self._semaphore = defer.DeferredSemaphore(pool.maxPersistentPerHost)
 
     def request(self, url, method='GET', body=None, headers={}):
         """
@@ -86,8 +104,9 @@ class HTTPClient(object):
         """
         if body:
             body = HTTPClient.StringBodyProducer(body)
-        d = self._agent.request(
-            method, url, headers=Headers(headers), bodyProducer=body)
+        d = self._semaphore.run(self._agent.request,
+                                method, url, headers=Headers(headers),
+                                bodyProducer=body)
         d.addCallback(readBody)
         return d