diff options
author | Ruben Pollan <meskio@sindominio.net> | 2017-10-06 11:50:36 +0200 |
---|---|---|
committer | Kali Kaneko <kali@leap.se> | 2017-10-06 18:38:42 +0200 |
commit | a5cb9c9940b34252da66d43498d705980532f60c (patch) | |
tree | 8a4c700c34e09c2faf6f8fe11504cd7c9a8a0350 | |
parent | b66ec16f764be769e4a15dae783292ac4cd32f3b (diff) |
[feat] use bonafide Provider object as a singleton
There was common situations where two provider instances where running
in parallel. And was creating weird errors (like getting wrong api_uri)
because the bootstrap deferreds were global but the Provider objects
not.
I don't like much singletons, but I think now is simpler than before.
- Resolves: #9073
-rw-r--r-- | src/leap/bitmask/bonafide/_protocol.py | 22 | ||||
-rw-r--r-- | src/leap/bitmask/bonafide/config.py | 53 | ||||
-rw-r--r-- | src/leap/bitmask/core/mail_services.py | 4 | ||||
-rw-r--r-- | tests/integration/bonafide/test_config.py | 46 |
4 files changed, 68 insertions, 57 deletions
diff --git a/src/leap/bitmask/bonafide/_protocol.py b/src/leap/bitmask/bonafide/_protocol.py index 004359e2..04c5d451 100644 --- a/src/leap/bitmask/bonafide/_protocol.py +++ b/src/leap/bitmask/bonafide/_protocol.py @@ -20,11 +20,6 @@ Bonafide protocol. import os from collections import defaultdict -try: - import resource -except ImportError: - pass - from leap.bitmask.bonafide import config from leap.bitmask.bonafide.provider import Api from leap.bitmask.bonafide.session import Session, OK @@ -50,11 +45,9 @@ class BonafideProtocol(object): log = Logger() def _get_api(self, provider): - # TODO should get deferred if provider.domain in self._apis: return self._apis[provider.domain] - # TODO defer the autoconfig for the provider if needed... api = Api(provider.api_uri, provider.version) self._apis[provider.domain] = api return api @@ -64,7 +57,6 @@ class BonafideProtocol(object): return self._sessions[full_id] # TODO if password/username null, then pass AnonymousCreds - # TODO use twisted.cred instead username, provider_id = config.get_username_and_provider(full_id) credentials = UsernamePassword(username, password) api = self._get_api(provider) @@ -84,7 +76,7 @@ class BonafideProtocol(object): self.log.debug('SIGNUP for %s' % full_id) _, provider_id = config.get_username_and_provider(full_id) - provider = config.Provider(provider_id, autoconf=autoconf) + provider = config.Provider.get(provider_id, autoconf=autoconf) d = provider.callWhenReady( self._do_signup, provider, full_id, password, invite) return d @@ -92,23 +84,22 @@ class BonafideProtocol(object): def _do_signup(self, provider, full_id, password, invite): # XXX check it's unauthenticated - def return_user(result, _session): + def return_user(result): return_code, user = result if return_code == OK: return user username, _ = config.get_username_and_provider(full_id) - # XXX get deferred? session = self._get_session(provider, full_id, password) d = session.signup(username, password, invite) - d.addCallback(return_user, session) + d.addCallback(return_user) d.addErrback(self._del_session_errback, full_id) return d def do_authenticate(self, full_id, password, autoconf=False): _, provider_id = config.get_username_and_provider(full_id) - provider = config.Provider(provider_id, autoconf=autoconf) + provider = config.Provider.get(provider_id, autoconf=autoconf) def maybe_finish_provider_bootstrap(result): session = self._get_session(provider, full_id, password) @@ -130,7 +121,6 @@ class BonafideProtocol(object): self.log.debug('AUTH for %s' % full_id) - # XXX get deferred? session = self._get_session(provider, full_id, password) d = session.authenticate() d.addCallback(return_token_and_uuid, session) @@ -170,11 +160,11 @@ class BonafideProtocol(object): return session.change_password(new_password) def do_get_provider(self, provider_id, autoconf=False): - provider = config.Provider(provider_id, autoconf=autoconf) + provider = config.Provider.get(provider_id, autoconf=autoconf) return provider.callWhenMainConfigReady(provider.config) def do_get_service(self, provider_id, service, autoconf=False): - provider = config.Provider(provider_id, autoconf=autoconf) + provider = config.Provider.get(provider_id, autoconf=autoconf) return provider.callWhenMainConfigReady(provider.config, service) def do_provider_delete(self, provider_id): diff --git a/src/leap/bitmask/bonafide/config.py b/src/leap/bitmask/bonafide/config.py index d0468a49..3417e498 100644 --- a/src/leap/bitmask/bonafide/config.py +++ b/src/leap/bitmask/bonafide/config.py @@ -18,7 +18,6 @@ Configuration for a LEAP provider. """ import binascii -import datetime import json import os import platform @@ -31,10 +30,9 @@ from cryptography.hazmat.primitives import hashes from cryptography.x509 import load_pem_x509_certificate from urlparse import urlparse -from twisted.internet import defer, reactor +from twisted.internet import defer from twisted.logger import Logger from twisted.web.client import downloadPage -from twisted.web.client import readBody from leap.bitmask.bonafide._http import httpRequest from leap.bitmask.bonafide.provider import Discovery @@ -140,12 +138,7 @@ def delete_provider(domain): raise NotConfiguredError("Provider %s is not configured, can't be " "deleted" % (domain,)) shutil.rmtree(path) - - # FIXME: this feels hacky, can we find a better way?? - if domain in Provider.first_bootstrap: - del Provider.first_bootstrap[domain] - if domain in Provider.ongoing_bootstrap: - del Provider.ongoing_bootstrap[domain] + Provider.providers[domain] = None class Provider(object): @@ -155,10 +148,15 @@ class Provider(object): 'mx': ['soledad', 'smtp']} log = Logger() + providers = defaultdict(None) - first_bootstrap = defaultdict(None) - ongoing_bootstrap = defaultdict(None) - stuck_bootstrap = defaultdict(None) + @classmethod + def get(self, domain, autoconf=False, basedir=None, + cert_path=None): + if domain not in self.providers: + self.providers[domain] = Provider(domain, autoconf, basedir, + cert_path) + return self.providers[domain] def __init__(self, domain, autoconf=False, basedir=None, cert_path=None): @@ -169,6 +167,9 @@ class Provider(object): self._disco = Discovery('https://%s' % domain) self._provider_config = None + self.first_bootstrap = defer.Deferred() + self.stuck_bootstrap = None + is_configured = self.is_configured() if not cert_path and is_configured: cert_path = self._get_ca_cert_path() @@ -198,7 +199,7 @@ class Provider(object): @property def api_uri(self): if not self._provider_config: - return 'https://api.%s:4430' % self._domain + return None return self._provider_config.api_uri @property @@ -233,23 +234,15 @@ class Provider(object): def bootstrap(self, replace_if_newer=False): domain = self._domain self.log.debug('Bootstrapping provider %s' % domain) - ongoing = self.ongoing_bootstrap.get(domain) - if ongoing: - self.log.debug('Already bootstrapping this provider...') - self.ongoing_bootstrap[domain].addCallback( - self._reload_http_client) - return - - self.first_bootstrap[self._domain] = defer.Deferred() def first_bootstrap_done(ignored): try: - self.first_bootstrap[domain].callback('got config') + self.first_bootstrap.callback('got config') except defer.AlreadyCalledError: pass def first_bootstrap_error(failure): - self.first_bootstrap[domain].errback(failure) + self.first_bootstrap.errback(failure) return failure d = self.maybe_download_provider_info(replace=replace_if_newer) @@ -257,15 +250,15 @@ class Provider(object): d.addCallback(self.validate_ca_cert) d.addCallbacks(first_bootstrap_done, first_bootstrap_error) d.addCallback(self.maybe_download_services_config) - self.ongoing_bootstrap[domain] = d + self.ongoing_bootstrap = d def callWhenMainConfigReady(self, cb, *args, **kw): - d = self.first_bootstrap[self._domain] + d = self.first_bootstrap d.addCallback(lambda _: cb(*args, **kw)) return d def callWhenReady(self, cb, *args, **kw): - d = self.ongoing_bootstrap[self._domain] + d = self.ongoing_bootstrap d.addCallback(lambda _: cb(*args, **kw)) return d @@ -372,7 +365,7 @@ class Provider(object): def further_bootstrap_needs_auth(ignored): self.log.warn('Cannot download services config yet, need auth') pending_deferred = defer.Deferred() - self.stuck_bootstrap[self._domain] = pending_deferred + self.stuck_bootstrap = pending_deferred return defer.succeed('ok for now') uri, met, path = self._get_configs_download_params() @@ -391,10 +384,10 @@ class Provider(object): return True def complete_bootstrapping(ignored): - stuck = self.stuck_bootstrap.get(self._domain, None) - if stuck: + if self.stuck_bootstrap: d = self._get_config_for_all_services(session) - d.addCallback(lambda _: stuck.callback('continue!')) + d.addCallback(lambda _: + self.stuck_bootstrap.callback('continue!')) return d self._load_provider_json() diff --git a/src/leap/bitmask/core/mail_services.py b/src/leap/bitmask/core/mail_services.py index 4a5f9798..5337b313 100644 --- a/src/leap/bitmask/core/mail_services.py +++ b/src/leap/bitmask/core/mail_services.py @@ -150,7 +150,7 @@ def _get_provider_from_full_userid(userid): _, provider_id = config.get_username_and_provider(userid) # TODO -- this autoconf should be passed from the # command flag. workaround to get cli workinf for now. - return config.Provider(provider_id, autoconf=True) + return config.Provider.get(provider_id, autoconf=True) def is_service_ready(service, provider): @@ -337,7 +337,7 @@ class KeymanagerContainer(Container): return keymanager def _get_api_uri(self, provider): - api_uri = config.Provider(provider).api_uri + api_uri = config.Provider.get(provider).api_uri return api_uri def _get_nicknym_uri(self, provider): diff --git a/tests/integration/bonafide/test_config.py b/tests/integration/bonafide/test_config.py index 5d45189b..ee6cdc51 100644 --- a/tests/integration/bonafide/test_config.py +++ b/tests/integration/bonafide/test_config.py @@ -37,12 +37,14 @@ class ConfigTest(BaseHTTPSServerTestCase, unittest.TestCase, BaseLeapTest): self.cacert = os.path.join(os.path.dirname(__file__), "cacert.pem") + @defer.inlineCallbacks def test_bootstrap_self_sign_cert_fails(self): home = os.path.join(self.home, 'self_sign') os.mkdir(home) - provider = Provider(self.addr.domain, autoconf=True, basedir=home) + provider = Provider.get(self.addr.domain, autoconf=True, basedir=home) d = provider.callWhenMainConfigReady(lambda: "Cert was accepted") - return self.assertFailure(d, NetworkError) + yield self.assertFailure(d, NetworkError) + Provider.providers[self.addr.domain] = None @defer.inlineCallbacks def test_bootstrap_invalid_ca_cert(self): @@ -58,15 +60,17 @@ class ConfigTest(BaseHTTPSServerTestCase, unittest.TestCase, BaseLeapTest): provider._http.close() try: yield defer.gatherResults([ - d, provider.ongoing_bootstrap[provider._domain]]) + d, provider.ongoing_bootstrap]) except: pass + Provider.providers[self.addr.domain] = None + @defer.inlineCallbacks def test_bootstrap_pinned_cert(self): home = os.path.join(self.home, 'pinned') os.mkdir(home) - provider = Provider(self.addr.domain, autoconf=True, basedir=home, - cert_path=self.cacert) + provider = Provider.get(self.addr.domain, autoconf=True, basedir=home, + cert_path=self.cacert) def check_provider(): config = provider.config() @@ -74,9 +78,31 @@ class ConfigTest(BaseHTTPSServerTestCase, unittest.TestCase, BaseLeapTest): self.assertEqual(config["ca_cert_fingerprint"], "SHA256: %s" % fingerprint) - d = provider.callWhenMainConfigReady(check_provider) - return defer.gatherResults([ - d, provider.ongoing_bootstrap[provider._domain]]) + yield provider.callWhenMainConfigReady(check_provider) + provider._http.close() + yield provider.ongoing_bootstrap + Provider.providers[self.addr.domain] = None + + @defer.inlineCallbacks + def test_api_uri(self): + api_uri = "api.example.com" + self.addr.api_uri = api_uri + home = os.path.join(self.home, 'api_uri') + os.mkdir(home) + provider = Provider.get(self.addr.domain, autoconf=True, + basedir=home, cert_path=self.cacert) + + def check_api_uri(): + parsed_uri = provider.api_uri + self.assertEqual(api_uri, parsed_uri) + + yield provider.callWhenMainConfigReady(check_api_uri) + provider._http.close() + try: + yield provider.ongoing_bootstrap + except: + pass + Provider.providers[self.addr.domain] = None class Addr(object): @@ -84,6 +110,7 @@ class Addr(object): self.host = host self.port = port self.fingerprint = fingerprint + self.api_uri = "https://%s:%s" % (host, port) @property def domain(self): @@ -95,6 +122,7 @@ def request_handler(addr): def do_GET(self): if self.path == '/provider.json': body = provider_json % { + 'api_uri': addr.api_uri, 'host': addr.host, 'port': addr.port, 'fingerprint': addr.fingerprint @@ -126,7 +154,7 @@ fingerprint = \ "cd0131b3352b7a29c307156b24f09fe862b1f5a2e55be7cd888048b91770f220" provider_json = """ { - "api_uri": "https://%(host)s:%(port)s", + "api_uri": "%(api_uri)s", "api_version": "1", "ca_cert_fingerprint": "SHA256: %(fingerprint)s", "ca_cert_uri": "https://%(host)s:%(port)s/ca.crt", |