diff options
Diffstat (limited to 'src/leap')
-rw-r--r-- | src/leap/soledad/client/_http.py | 74 |
1 files changed, 63 insertions, 11 deletions
diff --git a/src/leap/soledad/client/_http.py b/src/leap/soledad/client/_http.py index 1a1260b0..db681dd5 100644 --- a/src/leap/soledad/client/_http.py +++ b/src/leap/soledad/client/_http.py @@ -15,14 +15,21 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see <http://www.gnu.org/licenses/>. """ -A twisted-based, TLS-pinned, token-authenticated HTTP client. +A twisted-based HTTP client that: + + - is pinned to a specific TLS certificate, + - does token authentication using the Authorization header, + - can do bandwidth throttling. """ import base64 import os +import sys from twisted.internet import reactor +from twisted.protocols.policies import ThrottlingFactory +from twisted.protocols.policies import ThrottlingProtocol from twisted.web.iweb import IAgent -from twisted.web.client import Agent +from twisted.web.client import Agent as _Agent from twisted.web.client import CookieAgent from twisted.web.client import HTTPConnectionPool from twisted.web.http_headers import Headers @@ -34,13 +41,13 @@ from zope.interface import implementer from leap.common.http import getPolicyForHTTPS -__all__ = ['HTTPClient', 'PinnedTokenAgent'] +__all__ = ['HTTPClient'] class HTTPClient(_HTTPClient): def __init__(self, uuid, token, cert_file): - agent = PinnedTokenAgent(uuid, token, cert_file) + agent = Agent(uuid, token, cert_file) jar = CookieJar() self._agent = CookieAgent(agent, jar) super(self.__class__, self).__init__(self._agent) @@ -49,19 +56,64 @@ class HTTPClient(_HTTPClient): self._agent.set_token(token) +class HTTPThrottlingProtocol(ThrottlingProtocol): + + def request(self, *args, **kwargs): + return self.wrappedProtocol.request(*args, **kwargs) + + def throttleWrites(self): + if hasattr(self, 'producer') and self.producer: + self.producer.pauseProducing() + + def unthrottleWrites(self): + if hasattr(self, 'producer') and self.producer: + self.producer.resumeProducing() + + +class HTTPThrottlingFactory(ThrottlingFactory): + + protocol = HTTPThrottlingProtocol + + +class ThrottlingHTTPConnectionPool(HTTPConnectionPool): + + maxPersistentPerHost = 1 # throttling happens "host-wise" + maxConnectionCount = sys.maxsize # max number of concurrent connections + readLimit = 1 * 10 ** 6 # max bytes we should read per second + writeLimit = 1 * 10 ** 6 # max bytes we should write per second + + def _newConnection(self, key, endpoint): + def quiescentCallback(protocol): + self._putConnection(key, protocol) + factory = self._factory(quiescentCallback, repr(endpoint)) + throttlingFactory = HTTPThrottlingFactory( + factory, + maxConnectionCount=self.maxConnectionCount, + readLimit=self.readLimit, + writeLimit=self.writeLimit) + return endpoint.connect(throttlingFactory) + + @implementer(IAgent) -class PinnedTokenAgent(Agent): +class Agent(_Agent): - def __init__(self, uuid, token, cert_file): + def __init__(self, uuid, token, cert_file, throttling=False): self._uuid = uuid self._token = None self._creds = None self.set_token(token) - # pin this agent with the platform TLS certificate factory = getPolicyForHTTPS(cert_file) - persistent = os.environ.get('SOLEDAD_HTTP_PERSIST', None) - pool = HTTPConnectionPool(reactor, persistent=bool(persistent)) - Agent.__init__(self, reactor, contextFactory=factory, pool=pool) + pool = self._get_pool() + _Agent.__init__(self, reactor, contextFactory=factory, pool=pool) + + def _get_pool(self): + throttling = bool(os.environ.get('SOLEDAD_THROTTLING')) + persistent = bool(os.environ.get('SOLEDAD_HTTP_PERSIST')) + if throttling: + klass = ThrottlingHTTPConnectionPool + else: + klass = HTTPConnectionPool + return klass(reactor, persistent=persistent) def set_token(self, token): self._token = token @@ -77,5 +129,5 @@ class PinnedTokenAgent(Agent): headers = headers or Headers() headers.addRawHeader('Authorization', self._creds) # perform the authenticated request - return Agent.request( + return _Agent.request( self, method, uri, headers=headers, bodyProducer=bodyProducer) |