From bf18c2bc6e3f533187281a3b31febd37ef22f8c0 Mon Sep 17 00:00:00 2001 From: Victor Shyba Date: Wed, 27 May 2015 17:19:13 -0300 Subject: [feat] Make it optional to have a dedicated pool As @meskio pointed out, some cases could need a dedicated pool with different parameters. This is a suggested implementation where the pool is reused by default, creating a dedicated one just if needed/asked. This way we ensure that resources are under control and special cases are still handled. --- src/leap/common/http.py | 14 ++++++--- src/leap/common/tests/test_http.py | 62 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 72 insertions(+), 4 deletions(-) create mode 100644 src/leap/common/tests/test_http.py (limited to 'src/leap/common') diff --git a/src/leap/common/http.py b/src/leap/common/http.py index 1dc5642..1e384e5 100644 --- a/src/leap/common/http.py +++ b/src/leap/common/http.py @@ -44,13 +44,21 @@ from twisted.web.http_headers import Headers from twisted.web.iweb import IBodyProducer +def createPool(maxPersistentPerHost=10, persistent=True): + pool = HTTPConnectionPool(reactor, persistent) + pool.maxPersistentPerHost = maxPersistentPerHost + return pool + +_pool = createPool() + + class HTTPClient(object): """ HTTP client done the twisted way, with a main focus on pinning the SSL certificate. """ - def __init__(self, cert_file=None): + def __init__(self, cert_file=None, pool=_pool): """ Init the HTTP client @@ -58,15 +66,13 @@ class HTTPClient(object): system's CAs will be used. :type cert_file: str """ - self._pool = HTTPConnectionPool(reactor, persistent=True) - self._pool.maxPersistentPerHost = 10 policy = get_compatible_ssl_context_factory(cert_file) self._agent = Agent( reactor, policy, - pool=self._pool) + pool=pool) def request(self, url, method='GET', body=None, headers={}): """ diff --git a/src/leap/common/tests/test_http.py b/src/leap/common/tests/test_http.py new file mode 100644 index 0000000..e240ca3 --- /dev/null +++ b/src/leap/common/tests/test_http.py @@ -0,0 +1,62 @@ +# -*- coding: utf-8 -*- +# test_http.py +# Copyright (C) 2013 LEAP +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . +""" +Tests for: + * leap/common/http.py +""" +import os +try: + import unittest2 as unittest +except ImportError: + import unittest + +from leap.common import http +from leap.common.testing.basetest import BaseLeapTest + +TEST_CERT_PEM = os.path.join( + os.path.split(__file__)[0], + '..', 'testing', "leaptest_combined_keycert.pem") + + +class HTTPClientTest(BaseLeapTest): + + def setUp(self): + pass + + def tearDown(self): + pass + + def test_agents_sharing_pool_by_default(self): + client = http.HTTPClient() + client2 = http.HTTPClient(TEST_CERT_PEM) + self.assertNotEquals(client._agent, client2._agent, "Expected dedicated agents") + self.assertEquals(client._agent._pool, client2._agent._pool, "Pool was not reused by default") + + def test_agent_can_have_dedicated_custom_pool(self): + custom_pool = http.createPool(maxPersistentPerHost=42, persistent=False) + self.assertEquals(custom_pool.maxPersistentPerHost, 42, + "Custom persistent connections limit is not being respected") + self.assertFalse(custom_pool.persistent, + "Custom persistence is not being respected") + default_client = http.HTTPClient() + custom_client = http.HTTPClient(pool=custom_pool) + + self.assertNotEquals(default_client._agent, custom_client._agent, "No agent reuse is expected") + self.assertEquals(custom_pool, custom_client._agent._pool, "Custom pool usage was not respected") + +if __name__ == "__main__": + unittest.main() -- cgit v1.2.3