diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/leap/soledad/client/_http.py | 74 | 
1 files changed, 63 insertions, 11 deletions
| diff --git a/src/leap/soledad/client/_http.py b/src/leap/soledad/client/_http.py index 1a1260b0..db681dd5 100644 --- a/src/leap/soledad/client/_http.py +++ b/src/leap/soledad/client/_http.py @@ -15,14 +15,21 @@  # You should have received a copy of the GNU General Public License  # along with this program. If not, see <http://www.gnu.org/licenses/>.  """ -A twisted-based, TLS-pinned, token-authenticated HTTP client. +A twisted-based HTTP client that: + +    - is pinned to a specific TLS certificate, +    - does token authentication using the Authorization header, +    - can do bandwidth throttling.  """  import base64  import os +import sys  from twisted.internet import reactor +from twisted.protocols.policies import ThrottlingFactory +from twisted.protocols.policies import ThrottlingProtocol  from twisted.web.iweb import IAgent -from twisted.web.client import Agent +from twisted.web.client import Agent as _Agent  from twisted.web.client import CookieAgent  from twisted.web.client import HTTPConnectionPool  from twisted.web.http_headers import Headers @@ -34,13 +41,13 @@ from zope.interface import implementer  from leap.common.http import getPolicyForHTTPS -__all__ = ['HTTPClient', 'PinnedTokenAgent'] +__all__ = ['HTTPClient']  class HTTPClient(_HTTPClient):      def __init__(self, uuid, token, cert_file): -        agent = PinnedTokenAgent(uuid, token, cert_file) +        agent = Agent(uuid, token, cert_file)          jar = CookieJar()          self._agent = CookieAgent(agent, jar)          super(self.__class__, self).__init__(self._agent) @@ -49,19 +56,64 @@ class HTTPClient(_HTTPClient):          self._agent.set_token(token) +class HTTPThrottlingProtocol(ThrottlingProtocol): + +    def request(self, *args, **kwargs): +        return self.wrappedProtocol.request(*args, **kwargs) + +    def throttleWrites(self): +        if hasattr(self, 'producer') and self.producer: +            self.producer.pauseProducing() + +    def unthrottleWrites(self): +        if hasattr(self, 'producer') and self.producer: +            self.producer.resumeProducing() + + +class HTTPThrottlingFactory(ThrottlingFactory): + +    protocol = HTTPThrottlingProtocol + + +class ThrottlingHTTPConnectionPool(HTTPConnectionPool): + +    maxPersistentPerHost = 1          # throttling happens "host-wise" +    maxConnectionCount = sys.maxsize  # max number of concurrent connections +    readLimit = 1 * 10 ** 6           # max bytes we should read per second +    writeLimit = 1 * 10 ** 6          # max bytes we should write per second + +    def _newConnection(self, key, endpoint): +        def quiescentCallback(protocol): +            self._putConnection(key, protocol) +        factory = self._factory(quiescentCallback, repr(endpoint)) +        throttlingFactory = HTTPThrottlingFactory( +            factory, +            maxConnectionCount=self.maxConnectionCount, +            readLimit=self.readLimit, +            writeLimit=self.writeLimit) +        return endpoint.connect(throttlingFactory) + +  @implementer(IAgent) -class PinnedTokenAgent(Agent): +class Agent(_Agent): -    def __init__(self, uuid, token, cert_file): +    def __init__(self, uuid, token, cert_file, throttling=False):          self._uuid = uuid          self._token = None          self._creds = None          self.set_token(token) -        # pin this agent with the platform TLS certificate          factory = getPolicyForHTTPS(cert_file) -        persistent = os.environ.get('SOLEDAD_HTTP_PERSIST', None) -        pool = HTTPConnectionPool(reactor, persistent=bool(persistent)) -        Agent.__init__(self, reactor, contextFactory=factory, pool=pool) +        pool = self._get_pool() +        _Agent.__init__(self, reactor, contextFactory=factory, pool=pool) + +    def _get_pool(self): +        throttling = bool(os.environ.get('SOLEDAD_THROTTLING')) +        persistent = bool(os.environ.get('SOLEDAD_HTTP_PERSIST')) +        if throttling: +            klass = ThrottlingHTTPConnectionPool +        else: +            klass = HTTPConnectionPool +        return klass(reactor, persistent=persistent)      def set_token(self, token):          self._token = token @@ -77,5 +129,5 @@ class PinnedTokenAgent(Agent):          headers = headers or Headers()          headers.addRawHeader('Authorization', self._creds)          # perform the authenticated request -        return Agent.request( +        return _Agent.request(              self, method, uri, headers=headers, bodyProducer=bodyProducer) | 
