summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRuben Pollan <meskio@sindominio.net>2017-10-06 11:50:36 +0200
committerKali Kaneko <kali@leap.se>2017-10-06 18:38:42 +0200
commita5cb9c9940b34252da66d43498d705980532f60c (patch)
tree8a4c700c34e09c2faf6f8fe11504cd7c9a8a0350
parentb66ec16f764be769e4a15dae783292ac4cd32f3b (diff)
[feat] use bonafide Provider object as a singleton
There was common situations where two provider instances where running in parallel. And was creating weird errors (like getting wrong api_uri) because the bootstrap deferreds were global but the Provider objects not. I don't like much singletons, but I think now is simpler than before. - Resolves: #9073
-rw-r--r--src/leap/bitmask/bonafide/_protocol.py22
-rw-r--r--src/leap/bitmask/bonafide/config.py53
-rw-r--r--src/leap/bitmask/core/mail_services.py4
-rw-r--r--tests/integration/bonafide/test_config.py46
4 files changed, 68 insertions, 57 deletions
diff --git a/src/leap/bitmask/bonafide/_protocol.py b/src/leap/bitmask/bonafide/_protocol.py
index 004359e..04c5d45 100644
--- a/src/leap/bitmask/bonafide/_protocol.py
+++ b/src/leap/bitmask/bonafide/_protocol.py
@@ -20,11 +20,6 @@ Bonafide protocol.
import os
from collections import defaultdict
-try:
- import resource
-except ImportError:
- pass
-
from leap.bitmask.bonafide import config
from leap.bitmask.bonafide.provider import Api
from leap.bitmask.bonafide.session import Session, OK
@@ -50,11 +45,9 @@ class BonafideProtocol(object):
log = Logger()
def _get_api(self, provider):
- # TODO should get deferred
if provider.domain in self._apis:
return self._apis[provider.domain]
- # TODO defer the autoconfig for the provider if needed...
api = Api(provider.api_uri, provider.version)
self._apis[provider.domain] = api
return api
@@ -64,7 +57,6 @@ class BonafideProtocol(object):
return self._sessions[full_id]
# TODO if password/username null, then pass AnonymousCreds
- # TODO use twisted.cred instead
username, provider_id = config.get_username_and_provider(full_id)
credentials = UsernamePassword(username, password)
api = self._get_api(provider)
@@ -84,7 +76,7 @@ class BonafideProtocol(object):
self.log.debug('SIGNUP for %s' % full_id)
_, provider_id = config.get_username_and_provider(full_id)
- provider = config.Provider(provider_id, autoconf=autoconf)
+ provider = config.Provider.get(provider_id, autoconf=autoconf)
d = provider.callWhenReady(
self._do_signup, provider, full_id, password, invite)
return d
@@ -92,23 +84,22 @@ class BonafideProtocol(object):
def _do_signup(self, provider, full_id, password, invite):
# XXX check it's unauthenticated
- def return_user(result, _session):
+ def return_user(result):
return_code, user = result
if return_code == OK:
return user
username, _ = config.get_username_and_provider(full_id)
- # XXX get deferred?
session = self._get_session(provider, full_id, password)
d = session.signup(username, password, invite)
- d.addCallback(return_user, session)
+ d.addCallback(return_user)
d.addErrback(self._del_session_errback, full_id)
return d
def do_authenticate(self, full_id, password, autoconf=False):
_, provider_id = config.get_username_and_provider(full_id)
- provider = config.Provider(provider_id, autoconf=autoconf)
+ provider = config.Provider.get(provider_id, autoconf=autoconf)
def maybe_finish_provider_bootstrap(result):
session = self._get_session(provider, full_id, password)
@@ -130,7 +121,6 @@ class BonafideProtocol(object):
self.log.debug('AUTH for %s' % full_id)
- # XXX get deferred?
session = self._get_session(provider, full_id, password)
d = session.authenticate()
d.addCallback(return_token_and_uuid, session)
@@ -170,11 +160,11 @@ class BonafideProtocol(object):
return session.change_password(new_password)
def do_get_provider(self, provider_id, autoconf=False):
- provider = config.Provider(provider_id, autoconf=autoconf)
+ provider = config.Provider.get(provider_id, autoconf=autoconf)
return provider.callWhenMainConfigReady(provider.config)
def do_get_service(self, provider_id, service, autoconf=False):
- provider = config.Provider(provider_id, autoconf=autoconf)
+ provider = config.Provider.get(provider_id, autoconf=autoconf)
return provider.callWhenMainConfigReady(provider.config, service)
def do_provider_delete(self, provider_id):
diff --git a/src/leap/bitmask/bonafide/config.py b/src/leap/bitmask/bonafide/config.py
index d0468a4..3417e49 100644
--- a/src/leap/bitmask/bonafide/config.py
+++ b/src/leap/bitmask/bonafide/config.py
@@ -18,7 +18,6 @@
Configuration for a LEAP provider.
"""
import binascii
-import datetime
import json
import os
import platform
@@ -31,10 +30,9 @@ from cryptography.hazmat.primitives import hashes
from cryptography.x509 import load_pem_x509_certificate
from urlparse import urlparse
-from twisted.internet import defer, reactor
+from twisted.internet import defer
from twisted.logger import Logger
from twisted.web.client import downloadPage
-from twisted.web.client import readBody
from leap.bitmask.bonafide._http import httpRequest
from leap.bitmask.bonafide.provider import Discovery
@@ -140,12 +138,7 @@ def delete_provider(domain):
raise NotConfiguredError("Provider %s is not configured, can't be "
"deleted" % (domain,))
shutil.rmtree(path)
-
- # FIXME: this feels hacky, can we find a better way??
- if domain in Provider.first_bootstrap:
- del Provider.first_bootstrap[domain]
- if domain in Provider.ongoing_bootstrap:
- del Provider.ongoing_bootstrap[domain]
+ Provider.providers[domain] = None
class Provider(object):
@@ -155,10 +148,15 @@ class Provider(object):
'mx': ['soledad', 'smtp']}
log = Logger()
+ providers = defaultdict(None)
- first_bootstrap = defaultdict(None)
- ongoing_bootstrap = defaultdict(None)
- stuck_bootstrap = defaultdict(None)
+ @classmethod
+ def get(self, domain, autoconf=False, basedir=None,
+ cert_path=None):
+ if domain not in self.providers:
+ self.providers[domain] = Provider(domain, autoconf, basedir,
+ cert_path)
+ return self.providers[domain]
def __init__(self, domain, autoconf=False, basedir=None,
cert_path=None):
@@ -169,6 +167,9 @@ class Provider(object):
self._disco = Discovery('https://%s' % domain)
self._provider_config = None
+ self.first_bootstrap = defer.Deferred()
+ self.stuck_bootstrap = None
+
is_configured = self.is_configured()
if not cert_path and is_configured:
cert_path = self._get_ca_cert_path()
@@ -198,7 +199,7 @@ class Provider(object):
@property
def api_uri(self):
if not self._provider_config:
- return 'https://api.%s:4430' % self._domain
+ return None
return self._provider_config.api_uri
@property
@@ -233,23 +234,15 @@ class Provider(object):
def bootstrap(self, replace_if_newer=False):
domain = self._domain
self.log.debug('Bootstrapping provider %s' % domain)
- ongoing = self.ongoing_bootstrap.get(domain)
- if ongoing:
- self.log.debug('Already bootstrapping this provider...')
- self.ongoing_bootstrap[domain].addCallback(
- self._reload_http_client)
- return
-
- self.first_bootstrap[self._domain] = defer.Deferred()
def first_bootstrap_done(ignored):
try:
- self.first_bootstrap[domain].callback('got config')
+ self.first_bootstrap.callback('got config')
except defer.AlreadyCalledError:
pass
def first_bootstrap_error(failure):
- self.first_bootstrap[domain].errback(failure)
+ self.first_bootstrap.errback(failure)
return failure
d = self.maybe_download_provider_info(replace=replace_if_newer)
@@ -257,15 +250,15 @@ class Provider(object):
d.addCallback(self.validate_ca_cert)
d.addCallbacks(first_bootstrap_done, first_bootstrap_error)
d.addCallback(self.maybe_download_services_config)
- self.ongoing_bootstrap[domain] = d
+ self.ongoing_bootstrap = d
def callWhenMainConfigReady(self, cb, *args, **kw):
- d = self.first_bootstrap[self._domain]
+ d = self.first_bootstrap
d.addCallback(lambda _: cb(*args, **kw))
return d
def callWhenReady(self, cb, *args, **kw):
- d = self.ongoing_bootstrap[self._domain]
+ d = self.ongoing_bootstrap
d.addCallback(lambda _: cb(*args, **kw))
return d
@@ -372,7 +365,7 @@ class Provider(object):
def further_bootstrap_needs_auth(ignored):
self.log.warn('Cannot download services config yet, need auth')
pending_deferred = defer.Deferred()
- self.stuck_bootstrap[self._domain] = pending_deferred
+ self.stuck_bootstrap = pending_deferred
return defer.succeed('ok for now')
uri, met, path = self._get_configs_download_params()
@@ -391,10 +384,10 @@ class Provider(object):
return True
def complete_bootstrapping(ignored):
- stuck = self.stuck_bootstrap.get(self._domain, None)
- if stuck:
+ if self.stuck_bootstrap:
d = self._get_config_for_all_services(session)
- d.addCallback(lambda _: stuck.callback('continue!'))
+ d.addCallback(lambda _:
+ self.stuck_bootstrap.callback('continue!'))
return d
self._load_provider_json()
diff --git a/src/leap/bitmask/core/mail_services.py b/src/leap/bitmask/core/mail_services.py
index 4a5f979..5337b31 100644
--- a/src/leap/bitmask/core/mail_services.py
+++ b/src/leap/bitmask/core/mail_services.py
@@ -150,7 +150,7 @@ def _get_provider_from_full_userid(userid):
_, provider_id = config.get_username_and_provider(userid)
# TODO -- this autoconf should be passed from the
# command flag. workaround to get cli workinf for now.
- return config.Provider(provider_id, autoconf=True)
+ return config.Provider.get(provider_id, autoconf=True)
def is_service_ready(service, provider):
@@ -337,7 +337,7 @@ class KeymanagerContainer(Container):
return keymanager
def _get_api_uri(self, provider):
- api_uri = config.Provider(provider).api_uri
+ api_uri = config.Provider.get(provider).api_uri
return api_uri
def _get_nicknym_uri(self, provider):
diff --git a/tests/integration/bonafide/test_config.py b/tests/integration/bonafide/test_config.py
index 5d45189..ee6cdc5 100644
--- a/tests/integration/bonafide/test_config.py
+++ b/tests/integration/bonafide/test_config.py
@@ -37,12 +37,14 @@ class ConfigTest(BaseHTTPSServerTestCase, unittest.TestCase, BaseLeapTest):
self.cacert = os.path.join(os.path.dirname(__file__),
"cacert.pem")
+ @defer.inlineCallbacks
def test_bootstrap_self_sign_cert_fails(self):
home = os.path.join(self.home, 'self_sign')
os.mkdir(home)
- provider = Provider(self.addr.domain, autoconf=True, basedir=home)
+ provider = Provider.get(self.addr.domain, autoconf=True, basedir=home)
d = provider.callWhenMainConfigReady(lambda: "Cert was accepted")
- return self.assertFailure(d, NetworkError)
+ yield self.assertFailure(d, NetworkError)
+ Provider.providers[self.addr.domain] = None
@defer.inlineCallbacks
def test_bootstrap_invalid_ca_cert(self):
@@ -58,15 +60,17 @@ class ConfigTest(BaseHTTPSServerTestCase, unittest.TestCase, BaseLeapTest):
provider._http.close()
try:
yield defer.gatherResults([
- d, provider.ongoing_bootstrap[provider._domain]])
+ d, provider.ongoing_bootstrap])
except:
pass
+ Provider.providers[self.addr.domain] = None
+ @defer.inlineCallbacks
def test_bootstrap_pinned_cert(self):
home = os.path.join(self.home, 'pinned')
os.mkdir(home)
- provider = Provider(self.addr.domain, autoconf=True, basedir=home,
- cert_path=self.cacert)
+ provider = Provider.get(self.addr.domain, autoconf=True, basedir=home,
+ cert_path=self.cacert)
def check_provider():
config = provider.config()
@@ -74,9 +78,31 @@ class ConfigTest(BaseHTTPSServerTestCase, unittest.TestCase, BaseLeapTest):
self.assertEqual(config["ca_cert_fingerprint"],
"SHA256: %s" % fingerprint)
- d = provider.callWhenMainConfigReady(check_provider)
- return defer.gatherResults([
- d, provider.ongoing_bootstrap[provider._domain]])
+ yield provider.callWhenMainConfigReady(check_provider)
+ provider._http.close()
+ yield provider.ongoing_bootstrap
+ Provider.providers[self.addr.domain] = None
+
+ @defer.inlineCallbacks
+ def test_api_uri(self):
+ api_uri = "api.example.com"
+ self.addr.api_uri = api_uri
+ home = os.path.join(self.home, 'api_uri')
+ os.mkdir(home)
+ provider = Provider.get(self.addr.domain, autoconf=True,
+ basedir=home, cert_path=self.cacert)
+
+ def check_api_uri():
+ parsed_uri = provider.api_uri
+ self.assertEqual(api_uri, parsed_uri)
+
+ yield provider.callWhenMainConfigReady(check_api_uri)
+ provider._http.close()
+ try:
+ yield provider.ongoing_bootstrap
+ except:
+ pass
+ Provider.providers[self.addr.domain] = None
class Addr(object):
@@ -84,6 +110,7 @@ class Addr(object):
self.host = host
self.port = port
self.fingerprint = fingerprint
+ self.api_uri = "https://%s:%s" % (host, port)
@property
def domain(self):
@@ -95,6 +122,7 @@ def request_handler(addr):
def do_GET(self):
if self.path == '/provider.json':
body = provider_json % {
+ 'api_uri': addr.api_uri,
'host': addr.host,
'port': addr.port,
'fingerprint': addr.fingerprint
@@ -126,7 +154,7 @@ fingerprint = \
"cd0131b3352b7a29c307156b24f09fe862b1f5a2e55be7cd888048b91770f220"
provider_json = """
{
- "api_uri": "https://%(host)s:%(port)s",
+ "api_uri": "%(api_uri)s",
"api_version": "1",
"ca_cert_fingerprint": "SHA256: %(fingerprint)s",
"ca_cert_uri": "https://%(host)s:%(port)s/ca.crt",