[tests] adapt events tests to recent changes
[leap_pycommon.git] / src / leap / common / http.py
1 # -*- coding: utf-8 -*-
2 # http.py
3 # Copyright (C) 2015 LEAP
4 #
5 # This program is free software: you can redistribute it and/or modify
6 # it under the terms of the GNU General Public License as published by
7 # the Free Software Foundation, either version 3 of the License, or
8 # (at your option) any later version.
9 #
10 # This program is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 # GNU General Public License for more details.
14 #
15 # You should have received a copy of the GNU General Public License
16 # along with this program. If not, see <http://www.gnu.org/licenses/>.
17 """
18 Twisted HTTP/HTTPS client.
19 """
20
21 try:
22     import twisted
23     assert twisted
24 except ImportError:
25     print "*******"
26     print "Twisted is needed to use leap.common.http module"
27     print ""
28     print "Install the extra requirement of the package:"
29     print "$ pip install leap.common[Twisted]"
30     import sys
31     sys.exit(1)
32
33
34 from leap.common.certs import get_compatible_ssl_context_factory
35 from leap.common.check import leap_assert
36
37 from zope.interface import implements
38
39 from twisted.internet import reactor
40 from twisted.internet import defer
41 from twisted.python import failure
42
43 from twisted.web.client import Agent
44 from twisted.web.client import HTTPConnectionPool
45 from twisted.web.client import _HTTP11ClientFactory as HTTP11ClientFactory
46 from twisted.web.client import readBody
47 from twisted.web.http_headers import Headers
48 from twisted.web.iweb import IBodyProducer
49 from twisted.web._newclient import HTTP11ClientProtocol
50
51
52 __all__ = ["HTTPClient"]
53
54
55 # A default HTTP timeout is used for 2 distinct purposes:
56 #   1. as HTTP connection timeout, prior to connection estabilshment.
57 #   2. as data reception timeout, after the connection has been established.
58 DEFAULT_HTTP_TIMEOUT = 30  # seconds
59
60
61 class _HTTP11ClientFactory(HTTP11ClientFactory):
62     """
63     A timeout-able HTTP 1.1 client protocol factory.
64     """
65
66     def __init__(self, quiescentCallback, timeout):
67         """
68         :param quiescentCallback: The quiescent callback to be passed to
69                                   protocol instances, used to return them to
70                                   the connection pool.
71         :type quiescentCallback: callable(Protocol)
72         :param timeout: The timeout, in seconds, for requests made by
73                         protocols created by this factory.
74         :type timeout: float
75         """
76         HTTP11ClientFactory.__init__(self, quiescentCallback)
77         self._timeout = timeout
78
79     def buildProtocol(self, _):
80         """
81         Build the HTTP 1.1 client protocol.
82         """
83         return _HTTP11ClientProtocol(self._quiescentCallback, self._timeout)
84
85
86 class _HTTPConnectionPool(HTTPConnectionPool):
87     """
88     A timeout-able HTTP connection pool.
89     """
90
91     _factory = _HTTP11ClientFactory
92
93     def __init__(self, reactor, persistent, timeout, maxPersistentPerHost=10):
94         HTTPConnectionPool.__init__(self, reactor, persistent=persistent)
95         self.maxPersistentPerHost = maxPersistentPerHost
96         self._timeout = timeout
97
98     def _newConnection(self, key, endpoint):
99         def quiescentCallback(protocol):
100             self._putConnection(key, protocol)
101         factory = self._factory(quiescentCallback, timeout=self._timeout)
102         return endpoint.connect(factory)
103
104
105 class HTTPClient(object):
106     """
107     HTTP client done the twisted way, with a main focus on pinning the SSL
108     certificate.
109
110     By default, it uses a shared connection pool. If you want a dedicated
111     one, create and pass on __init__ pool parameter.
112     Please note that this client will limit the maximum amount of connections
113     by using a DeferredSemaphore.
114     This limit is equal to the maxPersistentPerHost used on pool and is needed
115     in order to avoid resource abuse on huge requests batches.
116     """
117
118     _pool = _HTTPConnectionPool(
119         reactor,
120         persistent=True,
121         timeout=DEFAULT_HTTP_TIMEOUT,
122         maxPersistentPerHost=10
123     )
124
125     def __init__(self, cert_file=None,
126                  timeout=DEFAULT_HTTP_TIMEOUT, pool=None):
127         """
128         Init the HTTP client
129
130         :param cert_file: The path to the certificate file, if None given the
131                           system's CAs will be used.
132         :type cert_file: str
133         :param timeout: The amount of time that this Agent will wait for the
134                         peer to accept a connection and for each request to be
135                         finished. If a pool is passed, then this argument is
136                         ignored.
137         :type timeout: float
138         """
139
140         self._timeout = timeout
141         self._pool = pool if pool is not None else self._pool
142         self._agent = Agent(
143             reactor,
144             get_compatible_ssl_context_factory(cert_file),
145             pool=self._pool,
146             connectTimeout=self._timeout)
147         self._semaphore = defer.DeferredSemaphore(
148             self._pool.maxPersistentPerHost)
149
150     def _createPool(self, maxPersistentPerHost=10, persistent=True):
151         pool = _HTTPConnectionPool(reactor, persistent, self._timeout)
152         pool.maxPersistentPerHost = maxPersistentPerHost
153         return pool
154
155     def _request(self, url, method, body, headers, callback):
156         """
157         Perform an HTTP request.
158
159         :param url: The URL for the request.
160         :type url: str
161         :param method: The HTTP method of the request.
162         :type method: str
163         :param body: The body of the request, if any.
164         :type body: str
165         :param headers: The headers of the request.
166         :type headers: dict
167         :param callback: A callback to be added to the request's deferred
168                          callback chain.
169         :type callback: callable
170
171         :return: A deferred that fires with the body of the request.
172         :rtype: twisted.internet.defer.Deferred
173         """
174         if body:
175             body = _StringBodyProducer(body)
176         d = self._agent.request(
177             method, url, headers=Headers(headers), bodyProducer=body)
178         d.addCallback(callback)
179         return d
180
181     def request(self, url, method='GET', body=None, headers={},
182                 callback=readBody):
183         """
184         Perform an HTTP request, but limit the maximum amount of concurrent
185         connections.
186
187         May be passed a callback to be added to the request's deferred
188         callback chain. The callback is expected to receive the response of
189         the request and may do whatever it wants with the response. By
190         default, if no callback is passed, we will use a simple body reader
191         which returns a deferred that is fired with the body of the response.
192
193         :param url: The URL for the request.
194         :type url: str
195         :param method: The HTTP method of the request.
196         :type method: str
197         :param body: The body of the request, if any.
198         :type body: str
199         :param headers: The headers of the request.
200         :type headers: dict
201         :param callback: A callback to be added to the request's deferred
202                          callback chain.
203         :type callback: callable
204
205         :return: A deferred that fires with the body of the request.
206         :rtype: twisted.internet.defer.Deferred
207         """
208         leap_assert(
209             callable(callback),
210             message="The callback parameter should be a callable!")
211         return self._semaphore.run(self._request, url, method, body, headers,
212                                    callback)
213
214     def close(self):
215         """
216         Close any cached connections.
217         """
218         self._pool.closeCachedConnections()
219
220 #
221 # An IBodyProducer to write the body of an HTTP request as a string.
222 #
223
224
225 class _StringBodyProducer(object):
226     """
227     A producer that writes the body of a request to a consumer.
228     """
229
230     implements(IBodyProducer)
231
232     def __init__(self, body):
233         """
234         Initialize the string produer.
235
236         :param body: The body of the request.
237         :type body: str
238         """
239         self.body = body
240         self.length = len(body)
241
242     def startProducing(self, consumer):
243         """
244         Write the body to the consumer.
245
246         :param consumer: Any IConsumer provider.
247         :type consumer: twisted.internet.interfaces.IConsumer
248
249         :return: A successful deferred.
250         :rtype: twisted.internet.defer.Deferred
251         """
252         consumer.write(self.body)
253         return defer.succeed(None)
254
255     def pauseProducing(self):
256         pass
257
258     def stopProducing(self):
259         pass
260
261
262 #
263 # Patched twisted.web classes
264 #
265
266 class _HTTP11ClientProtocol(HTTP11ClientProtocol):
267     """
268     A timeout-able HTTP 1.1 client protocol, that is instantiated by the
269     _HTTP11ClientFactory below.
270     """
271
272     def __init__(self, quiescentCallback, timeout):
273         """
274         Initialize the protocol.
275
276         :param quiescentCallback:
277         :type quiescentCallback: callable
278         :param timeout: A timeout, in seconds, for requests made by this
279                         protocol.
280         :type timeout: float
281         """
282         HTTP11ClientProtocol.__init__(self, quiescentCallback)
283         self._timeout = timeout
284         self._timeoutCall = None
285
286     def request(self, request):
287         """
288         Issue request over self.transport and return a Deferred which
289         will fire with a Response instance or an error.
290
291         :param request: The object defining the parameters of the request to
292                         issue.
293         :type request: twisted.web._newclient.Request
294
295         :return: A deferred which fires after the request has finished.
296         :rtype: Deferred
297         """
298         d = HTTP11ClientProtocol.request(self, request)
299         if self._timeout:
300             self._last_buffer_len = 0
301             timeoutCall = reactor.callLater(
302                 self._timeout, self._doTimeout, request)
303             self._timeoutCall = timeoutCall
304         return d
305
306     def _doTimeout(self, request):
307         """
308         Give up the request because of a timeout.
309
310         :param request: The object defining the parameters of the request to
311                         issue.
312         :type request: twisted.web._newclient.Request
313         """
314         self._giveUp(
315             failure.Failure(
316                 defer.TimeoutError(
317                     "Getting %s took longer than %s seconds."
318                     % (request.absoluteURI, self._timeout))))
319
320     def _cancelTimeout(self):
321         """
322         Cancel the request timeout, when it's finished.
323         """
324         if self._timeoutCall and self._timeoutCall.active():
325             self._timeoutCall.cancel()
326             self._timeoutCall = None
327
328     def _finishResponse(self, rest):
329         """
330         Cancel the timeout when finished receiving the response.
331         """
332         self._cancelTimeout()
333         HTTP11ClientProtocol._finishResponse(self, rest)
334
335     def dataReceived(self, bytes):
336         """
337         Receive some data and extend the timeout period of this request.
338
339         :param bytes: A string of indeterminate length.
340         :type bytes: str
341         """
342         HTTP11ClientProtocol.dataReceived(self, bytes)
343         if self._timeoutCall and self._timeoutCall.active():
344             self._timeoutCall.reset(self._timeout)
345
346     def connectionLost(self, reason):
347         self._cancelTimeout()
348         return HTTP11ClientProtocol.connectionLost(self, reason)