[bug] Makes request method respect a hard limit
[leap_pycommon.git] / src / leap / common / http.py
index 39f01ba..d4a214c 100644 (file)
 Twisted HTTP/HTTPS client.
 """
 
-import os
+try:
+    import twisted
+except ImportError:
+    print "*******"
+    print "Twisted is needed to use leap.common.http module"
+    print ""
+    print "Install the extra requirement of the package:"
+    print "$ pip install leap.common[Twisted]"
+    import sys
+    sys.exit(1)
 
-from zope.interface import implements
 
-from OpenSSL.crypto import load_certificate
-from OpenSSL.crypto import FILETYPE_PEM
+from leap.common.certs import get_compatible_ssl_context_factory
+
+from zope.interface import implements
 
 from twisted.internet import reactor
-from twisted.internet.ssl import ClientContextFactory
-from twisted.internet.ssl import CertificateOptions
+from twisted.internet import defer
 from twisted.internet.defer import succeed
 
 from twisted.web.client import Agent
 from twisted.web.client import HTTPConnectionPool
 from twisted.web.client import readBody
-from twisted.web.client import BrowserLikePolicyForHTTPS
 from twisted.web.http_headers import Headers
 from twisted.web.iweb import IBodyProducer
 
 
+def createPool(maxPersistentPerHost=10, persistent=True):
+    pool = HTTPConnectionPool(reactor, persistent)
+    pool.maxPersistentPerHost = maxPersistentPerHost
+    return pool
+
+_pool = createPool()
+
+
 class HTTPClient(object):
     """
     HTTP client done the twisted way, with a main focus on pinning the SSL
     certificate.
+
+    By default, it uses a shared connection pool. If you want a dedicated
+    one, create and pass on __init__ pool parameter.
+    Please note that this client will limit the maximum amount of connections
+    by using a DeferredSemaphore.
+    This limit is equal to the maxPersistentPerHost used on pool and is needed
+    in order to avoid resource abuse on huge requests batches.
     """
 
-    def __init__(self, cert_file=None):
+    def __init__(self, cert_file=None, pool=_pool):
         """
         Init the HTTP client
 
         :param cert_file: The path to the certificate file, if None given the
                           system's CAs will be used.
         :type cert_file: str
+        :param pool: An optional dedicated connection pool to override the
+                     default shared one.
+        :type pool: HTTPConnectionPool
         """
-        self._pool = HTTPConnectionPool(reactor, persistent=True)
-        self._pool.maxPersistentPerHost = 10
-
-        if cert_file:
-            cert = self._load_cert(cert_file)
-            self._agent = Agent(
-                reactor,
-                HTTPClient.ClientContextFactory(cert),
-                pool=self._pool)
-        else:
-            # trust the system's CAs
-            self._agent = Agent(
-                reactor,
-                BrowserLikePolicyForHTTPS(),
-                pool=self._pool)
-
-    def _load_cert(self, cert_file):
-        """
-        Load a X509 certificate from a file.
 
-        :param cert_file: The path to the certificate file.
-        :type cert_file: str
+        policy = get_compatible_ssl_context_factory(cert_file)
 
-        :return: The X509 certificate.
-        :rtype: OpenSSL.crypto.X509
-        """
-        if os.path.exists(cert_file):
-            with open(cert_file) as f:
-                data = f.read()
-                return load_certificate(FILETYPE_PEM, data)
+        self._agent = Agent(
+            reactor,
+            policy,
+            pool=pool)
+        self._semaphore = defer.DeferredSemaphore(pool.maxPersistentPerHost)
 
     def request(self, url, method='GET', body=None, headers={}):
         """
@@ -101,30 +104,12 @@ class HTTPClient(object):
         """
         if body:
             body = HTTPClient.StringBodyProducer(body)
-        d = self._agent.request(
-            method, url, headers=Headers(headers), bodyProducer=body)
+        d = self._semaphore.run(self._agent.request,
+                                method, url, headers=Headers(headers),
+                                bodyProducer=body)
         d.addCallback(readBody)
         return d
 
-    class ClientContextFactory(ClientContextFactory):
-        """
-        A context factory that will verify the server's certificate against a
-        given CA certificate.
-        """
-
-        def __init__(self, cacert):
-            """
-            Initialize the context factory.
-
-            :param cacert: The CA certificate.
-            :type cacert: OpenSSL.crypto.X509
-            """
-            self._cacert = cacert
-
-        def getContext(self, hostname, port):
-            opts = CertificateOptions(verify=True, caCerts=[self._cacert])
-            return opts.getContext()
-
     class StringBodyProducer(object):
         """
         A producer that writes the body of a request to a consumer.