summaryrefslogtreecommitdiff
path: root/src/leap/soledad/client/_http.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/leap/soledad/client/_http.py')
-rw-r--r--src/leap/soledad/client/_http.py74
1 files changed, 63 insertions, 11 deletions
diff --git a/src/leap/soledad/client/_http.py b/src/leap/soledad/client/_http.py
index 1a1260b0..db681dd5 100644
--- a/src/leap/soledad/client/_http.py
+++ b/src/leap/soledad/client/_http.py
@@ -15,14 +15,21 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
"""
-A twisted-based, TLS-pinned, token-authenticated HTTP client.
+A twisted-based HTTP client that:
+
+ - is pinned to a specific TLS certificate,
+ - does token authentication using the Authorization header,
+ - can do bandwidth throttling.
"""
import base64
import os
+import sys
from twisted.internet import reactor
+from twisted.protocols.policies import ThrottlingFactory
+from twisted.protocols.policies import ThrottlingProtocol
from twisted.web.iweb import IAgent
-from twisted.web.client import Agent
+from twisted.web.client import Agent as _Agent
from twisted.web.client import CookieAgent
from twisted.web.client import HTTPConnectionPool
from twisted.web.http_headers import Headers
@@ -34,13 +41,13 @@ from zope.interface import implementer
from leap.common.http import getPolicyForHTTPS
-__all__ = ['HTTPClient', 'PinnedTokenAgent']
+__all__ = ['HTTPClient']
class HTTPClient(_HTTPClient):
def __init__(self, uuid, token, cert_file):
- agent = PinnedTokenAgent(uuid, token, cert_file)
+ agent = Agent(uuid, token, cert_file)
jar = CookieJar()
self._agent = CookieAgent(agent, jar)
super(self.__class__, self).__init__(self._agent)
@@ -49,19 +56,64 @@ class HTTPClient(_HTTPClient):
self._agent.set_token(token)
+class HTTPThrottlingProtocol(ThrottlingProtocol):
+
+ def request(self, *args, **kwargs):
+ return self.wrappedProtocol.request(*args, **kwargs)
+
+ def throttleWrites(self):
+ if hasattr(self, 'producer') and self.producer:
+ self.producer.pauseProducing()
+
+ def unthrottleWrites(self):
+ if hasattr(self, 'producer') and self.producer:
+ self.producer.resumeProducing()
+
+
+class HTTPThrottlingFactory(ThrottlingFactory):
+
+ protocol = HTTPThrottlingProtocol
+
+
+class ThrottlingHTTPConnectionPool(HTTPConnectionPool):
+
+ maxPersistentPerHost = 1 # throttling happens "host-wise"
+ maxConnectionCount = sys.maxsize # max number of concurrent connections
+ readLimit = 1 * 10 ** 6 # max bytes we should read per second
+ writeLimit = 1 * 10 ** 6 # max bytes we should write per second
+
+ def _newConnection(self, key, endpoint):
+ def quiescentCallback(protocol):
+ self._putConnection(key, protocol)
+ factory = self._factory(quiescentCallback, repr(endpoint))
+ throttlingFactory = HTTPThrottlingFactory(
+ factory,
+ maxConnectionCount=self.maxConnectionCount,
+ readLimit=self.readLimit,
+ writeLimit=self.writeLimit)
+ return endpoint.connect(throttlingFactory)
+
+
@implementer(IAgent)
-class PinnedTokenAgent(Agent):
+class Agent(_Agent):
- def __init__(self, uuid, token, cert_file):
+ def __init__(self, uuid, token, cert_file, throttling=False):
self._uuid = uuid
self._token = None
self._creds = None
self.set_token(token)
- # pin this agent with the platform TLS certificate
factory = getPolicyForHTTPS(cert_file)
- persistent = os.environ.get('SOLEDAD_HTTP_PERSIST', None)
- pool = HTTPConnectionPool(reactor, persistent=bool(persistent))
- Agent.__init__(self, reactor, contextFactory=factory, pool=pool)
+ pool = self._get_pool()
+ _Agent.__init__(self, reactor, contextFactory=factory, pool=pool)
+
+ def _get_pool(self):
+ throttling = bool(os.environ.get('SOLEDAD_THROTTLING'))
+ persistent = bool(os.environ.get('SOLEDAD_HTTP_PERSIST'))
+ if throttling:
+ klass = ThrottlingHTTPConnectionPool
+ else:
+ klass = HTTPConnectionPool
+ return klass(reactor, persistent=persistent)
def set_token(self, token):
self._token = token
@@ -77,5 +129,5 @@ class PinnedTokenAgent(Agent):
headers = headers or Headers()
headers.addRawHeader('Authorization', self._creds)
# perform the authenticated request
- return Agent.request(
+ return _Agent.request(
self, method, uri, headers=headers, bodyProducer=bodyProducer)