[feature] allow passing callback to http client
[leap_pycommon.git] / src / leap / common / http.py
1 # -*- coding: utf-8 -*-
2 # http.py
3 # Copyright (C) 2015 LEAP
4 #
5 # This program is free software: you can redistribute it and/or modify
6 # it under the terms of the GNU General Public License as published by
7 # the Free Software Foundation, either version 3 of the License, or
8 # (at your option) any later version.
9 #
10 # This program is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 # GNU General Public License for more details.
14 #
15 # You should have received a copy of the GNU General Public License
16 # along with this program. If not, see <http://www.gnu.org/licenses/>.
17 """
18 Twisted HTTP/HTTPS client.
19 """
20
21 try:
22     import twisted
23 except ImportError:
24     print "*******"
25     print "Twisted is needed to use leap.common.http module"
26     print ""
27     print "Install the extra requirement of the package:"
28     print "$ pip install leap.common[Twisted]"
29     import sys
30     sys.exit(1)
31
32
33 from leap.common.certs import get_compatible_ssl_context_factory
34 from leap.common.check import leap_assert
35
36 from zope.interface import implements
37
38 from twisted.internet import reactor
39 from twisted.internet import defer
40 from twisted.python import failure
41
42 from twisted.web.client import Agent
43 from twisted.web.client import HTTPConnectionPool
44 from twisted.web.client import _HTTP11ClientFactory as HTTP11ClientFactory
45 from twisted.web.client import readBody
46 from twisted.web.http_headers import Headers
47 from twisted.web.iweb import IBodyProducer
48 from twisted.web._newclient import HTTP11ClientProtocol
49
50
51 __all__ = ["HTTPClient"]
52
53
54 # A default HTTP timeout is used for 2 distinct purposes:
55 #   1. as HTTP connection timeout, prior to connection estabilshment.
56 #   2. as data reception timeout, after the connection has been established.
57 DEFAULT_HTTP_TIMEOUT = 30  # seconds
58
59
60 class _HTTP11ClientFactory(HTTP11ClientFactory):
61     """
62     A timeout-able HTTP 1.1 client protocol factory.
63     """
64
65     def __init__(self, quiescentCallback, timeout):
66         """
67         :param quiescentCallback: The quiescent callback to be passed to
68                                   protocol instances, used to return them to
69                                   the connection pool.
70         :type quiescentCallback: callable(Protocol)
71         :param timeout: The timeout, in seconds, for requests made by
72                         protocols created by this factory.
73         :type timeout: float
74         """
75         HTTP11ClientFactory.__init__(self, quiescentCallback)
76         self._timeout = timeout
77
78     def buildProtocol(self, _):
79         """
80         Build the HTTP 1.1 client protocol.
81         """
82         return _HTTP11ClientProtocol(self._quiescentCallback, self._timeout)
83
84
85 class _HTTPConnectionPool(HTTPConnectionPool):
86     """
87     A timeout-able HTTP connection pool.
88     """
89
90     _factory = _HTTP11ClientFactory
91
92     def __init__(self, reactor, persistent, timeout, maxPersistentPerHost=10):
93         HTTPConnectionPool.__init__(self, reactor, persistent=persistent)
94         self.maxPersistentPerHost = maxPersistentPerHost
95         self._timeout = timeout
96
97     def _newConnection(self, key, endpoint):
98         def quiescentCallback(protocol):
99             self._putConnection(key, protocol)
100         factory = self._factory(quiescentCallback, timeout=self._timeout)
101         return endpoint.connect(factory)
102
103
104 class HTTPClient(object):
105     """
106     HTTP client done the twisted way, with a main focus on pinning the SSL
107     certificate.
108
109     By default, it uses a shared connection pool. If you want a dedicated
110     one, create and pass on __init__ pool parameter.
111     Please note that this client will limit the maximum amount of connections
112     by using a DeferredSemaphore.
113     This limit is equal to the maxPersistentPerHost used on pool and is needed
114     in order to avoid resource abuse on huge requests batches.
115     """
116
117     _pool = _HTTPConnectionPool(
118         reactor,
119         persistent=True,
120         timeout=DEFAULT_HTTP_TIMEOUT,
121         maxPersistentPerHost=10
122     )
123
124     def __init__(self, cert_file=None,
125                  timeout=DEFAULT_HTTP_TIMEOUT, pool=None):
126         """
127         Init the HTTP client
128
129         :param cert_file: The path to the certificate file, if None given the
130                           system's CAs will be used.
131         :type cert_file: str
132         :param timeout: The amount of time that this Agent will wait for the
133                         peer to accept a connection and for each request to be
134                         finished. If a pool is passed, then this argument is
135                         ignored.
136         :type timeout: float
137         """
138
139         self._timeout = timeout
140         self._pool = pool if pool is not None else self._pool
141         self._agent = Agent(
142             reactor,
143             get_compatible_ssl_context_factory(cert_file),
144             pool=self._pool,
145             connectTimeout=self._timeout)
146         self._semaphore = defer.DeferredSemaphore(
147             self._pool.maxPersistentPerHost)
148
149     def _createPool(self, maxPersistentPerHost=10, persistent=True):
150         pool = _HTTPConnectionPool(reactor, persistent, self._timeout)
151         pool.maxPersistentPerHost = maxPersistentPerHost
152         return pool
153
154     def _request(self, url, method, body, headers, callback):
155         """
156         Perform an HTTP request.
157
158         :param url: The URL for the request.
159         :type url: str
160         :param method: The HTTP method of the request.
161         :type method: str
162         :param body: The body of the request, if any.
163         :type body: str
164         :param headers: The headers of the request.
165         :type headers: dict
166         :param callback: A callback to be added to the request's deferred
167                          callback chain.
168         :type callback: callable
169
170         :return: A deferred that fires with the body of the request.
171         :rtype: twisted.internet.defer.Deferred
172         """
173         if body:
174             body = _StringBodyProducer(body)
175         d = self._agent.request(
176             method, url, headers=Headers(headers), bodyProducer=body)
177         d.addCallback(callback)
178         return d
179
180     def request(self, url, method='GET', body=None, headers={},
181             callback=readBody):
182         """
183         Perform an HTTP request, but limit the maximum amount of concurrent
184         connections.
185
186         May be passed a callback to be added to the request's deferred
187         callback chain. The callback is expected to receive the response of
188         the request and may do whatever it wants with the response. By
189         default, if no callback is passed, we will use a simple body reader
190         which returns a deferred that is fired with the body of the response.
191
192         :param url: The URL for the request.
193         :type url: str
194         :param method: The HTTP method of the request.
195         :type method: str
196         :param body: The body of the request, if any.
197         :type body: str
198         :param headers: The headers of the request.
199         :type headers: dict
200         :param callback: A callback to be added to the request's deferred
201                          callback chain.
202         :type callback: callable
203
204         :return: A deferred that fires with the body of the request.
205         :rtype: twisted.internet.defer.Deferred
206         """
207         leap_assert(
208             callable(callback),
209             message="The callback parameter should be a callable!")
210         return self._semaphore.run(self._request, url, method, body, headers,
211             callback)
212
213     def close(self):
214         """
215         Close any cached connections.
216         """
217         self._pool.closeCachedConnections()
218
219 #
220 # An IBodyProducer to write the body of an HTTP request as a string.
221 #
222
223
224 class _StringBodyProducer(object):
225     """
226     A producer that writes the body of a request to a consumer.
227     """
228
229     implements(IBodyProducer)
230
231     def __init__(self, body):
232         """
233         Initialize the string produer.
234
235         :param body: The body of the request.
236         :type body: str
237         """
238         self.body = body
239         self.length = len(body)
240
241     def startProducing(self, consumer):
242         """
243         Write the body to the consumer.
244
245         :param consumer: Any IConsumer provider.
246         :type consumer: twisted.internet.interfaces.IConsumer
247
248         :return: A successful deferred.
249         :rtype: twisted.internet.defer.Deferred
250         """
251         consumer.write(self.body)
252         return defer.succeed(None)
253
254     def pauseProducing(self):
255         pass
256
257     def stopProducing(self):
258         pass
259
260
261 #
262 # Patched twisted.web classes
263 #
264
265 class _HTTP11ClientProtocol(HTTP11ClientProtocol):
266     """
267     A timeout-able HTTP 1.1 client protocol, that is instantiated by the
268     _HTTP11ClientFactory below.
269     """
270
271     def __init__(self, quiescentCallback, timeout):
272         """
273         Initialize the protocol.
274
275         :param quiescentCallback:
276         :type quiescentCallback: callable
277         :param timeout: A timeout, in seconds, for requests made by this
278                         protocol.
279         :type timeout: float
280         """
281         HTTP11ClientProtocol.__init__(self, quiescentCallback)
282         self._timeout = timeout
283         self._timeoutCall = None
284
285     def request(self, request):
286         """
287         Issue request over self.transport and return a Deferred which
288         will fire with a Response instance or an error.
289
290         :param request: The object defining the parameters of the request to
291                         issue.
292         :type request: twisted.web._newclient.Request
293
294         :return: A deferred which fires after the request has finished.
295         :rtype: Deferred
296         """
297         d = HTTP11ClientProtocol.request(self, request)
298         if self._timeout:
299             self._last_buffer_len = 0
300             timeoutCall = reactor.callLater(
301                 self._timeout, self._doTimeout, request)
302             self._timeoutCall = timeoutCall
303         return d
304
305     def _doTimeout(self, request):
306         """
307         Give up the request because of a timeout.
308
309         :param request: The object defining the parameters of the request to
310                         issue.
311         :type request: twisted.web._newclient.Request
312         """
313         self._giveUp(
314             failure.Failure(
315                 defer.TimeoutError(
316                     "Getting %s took longer than %s seconds."
317                     % (request.absoluteURI, self._timeout))))
318
319     def _cancelTimeout(self):
320         """
321         Cancel the request timeout, when it's finished.
322         """
323         if self._timeoutCall and self._timeoutCall.active():
324             self._timeoutCall.cancel()
325             self._timeoutCall = None
326
327     def _finishResponse(self, rest):
328         """
329         Cancel the timeout when finished receiving the response.
330         """
331         self._cancelTimeout()
332         HTTP11ClientProtocol._finishResponse(self, rest)
333
334     def dataReceived(self, bytes):
335         """
336         Receive some data and extend the timeout period of this request.
337
338         :param bytes: A string of indeterminate length.
339         :type bytes: str
340         """
341         HTTP11ClientProtocol.dataReceived(self, bytes)
342         if self._timeoutCall and self._timeoutCall.active():
343             self._timeoutCall.reset(self._timeout)
344
345     def connectionLost(self, reason):
346         self._cancelTimeout()
347         return HTTP11ClientProtocol.connectionLost(self, reason)