[feat] add close method for http agent
[leap_pycommon.git] / src / leap / common / http.py
1 # -*- coding: utf-8 -*-
2 # http.py
3 # Copyright (C) 2015 LEAP
4 #
5 # This program is free software: you can redistribute it and/or modify
6 # it under the terms of the GNU General Public License as published by
7 # the Free Software Foundation, either version 3 of the License, or
8 # (at your option) any later version.
9 #
10 # This program is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 # GNU General Public License for more details.
14 #
15 # You should have received a copy of the GNU General Public License
16 # along with this program. If not, see <http://www.gnu.org/licenses/>.
17 """
18 Twisted HTTP/HTTPS client.
19 """
20
21 try:
22     import twisted
23 except ImportError:
24     print "*******"
25     print "Twisted is needed to use leap.common.http module"
26     print ""
27     print "Install the extra requirement of the package:"
28     print "$ pip install leap.common[Twisted]"
29     import sys
30     sys.exit(1)
31
32
33 from leap.common.certs import get_compatible_ssl_context_factory
34
35 from zope.interface import implements
36
37 from twisted.internet import reactor
38 from twisted.internet import defer
39 from twisted.internet.defer import succeed
40
41 from twisted.web.client import Agent
42 from twisted.web.client import HTTPConnectionPool
43 from twisted.web.client import readBody
44 from twisted.web.http_headers import Headers
45 from twisted.web.iweb import IBodyProducer
46
47
48 def createPool(maxPersistentPerHost=10, persistent=True):
49     pool = HTTPConnectionPool(reactor, persistent)
50     pool.maxPersistentPerHost = maxPersistentPerHost
51     return pool
52
53 _pool = createPool()
54
55
56 class HTTPClient(object):
57     """
58     HTTP client done the twisted way, with a main focus on pinning the SSL
59     certificate.
60
61     By default, it uses a shared connection pool. If you want a dedicated
62     one, create and pass on __init__ pool parameter.
63     Please note that this client will limit the maximum amount of connections
64     by using a DeferredSemaphore.
65     This limit is equal to the maxPersistentPerHost used on pool and is needed
66     in order to avoid resource abuse on huge requests batches.
67     """
68
69     def __init__(self, cert_file=None, pool=_pool):
70         """
71         Init the HTTP client
72
73         :param cert_file: The path to the certificate file, if None given the
74                           system's CAs will be used.
75         :type cert_file: str
76         :param pool: An optional dedicated connection pool to override the
77                      default shared one.
78         :type pool: HTTPConnectionPool
79         """
80
81         policy = get_compatible_ssl_context_factory(cert_file)
82
83         self._pool = pool
84         self._agent = Agent(
85             reactor,
86             policy,
87             pool=pool)
88         self._semaphore = defer.DeferredSemaphore(pool.maxPersistentPerHost)
89
90     def request(self, url, method='GET', body=None, headers={}):
91         """
92         Perform an HTTP request.
93
94         :param url: The URL for the request.
95         :type url: str
96         :param method: The HTTP method of the request.
97         :type method: str
98         :param body: The body of the request, if any.
99         :type body: str
100         :param headers: The headers of the request.
101         :type headers: dict
102
103         :return: A deferred that fires with the body of the request.
104         :rtype: twisted.internet.defer.Deferred
105         """
106         if body:
107             body = HTTPClient.StringBodyProducer(body)
108         d = self._semaphore.run(self._agent.request,
109                                 method, url, headers=Headers(headers),
110                                 bodyProducer=body)
111         d.addCallback(readBody)
112         return d
113
114     def close(self):
115         """
116         Close any cached connections.
117         """
118         self._pool.closeCachedConnections()
119
120     class StringBodyProducer(object):
121         """
122         A producer that writes the body of a request to a consumer.
123         """
124
125         implements(IBodyProducer)
126
127         def __init__(self, body):
128             """
129             Initialize the string produer.
130
131             :param body: The body of the request.
132             :type body: str
133             """
134             self.body = body
135             self.length = len(body)
136
137         def startProducing(self, consumer):
138             """
139             Write the body to the consumer.
140
141             :param consumer: Any IConsumer provider.
142             :type consumer: twisted.internet.interfaces.IConsumer
143
144             :return: A successful deferred.
145             :rtype: twisted.internet.defer.Deferred
146             """
147             consumer.write(self.body)
148             return succeed(None)
149
150         def pauseProducing(self):
151             pass
152
153         def stopProducing(self):
154             pass