diff options
-rw-r--r-- | bonafide/src/leap/bonafide/_protocol.py | 45 | ||||
-rw-r--r-- | bonafide/src/leap/bonafide/config.py | 19 | ||||
-rw-r--r-- | bonafide/src/leap/bonafide/provider.py | 9 |
3 files changed, 49 insertions, 24 deletions
diff --git a/bonafide/src/leap/bonafide/_protocol.py b/bonafide/src/leap/bonafide/_protocol.py index 9e964af..e50d020 100644 --- a/bonafide/src/leap/bonafide/_protocol.py +++ b/bonafide/src/leap/bonafide/_protocol.py @@ -22,7 +22,7 @@ import resource from collections import defaultdict from leap.bonafide import config -from leap.bonafide import provider +from leap.bonafide.provider import Api from leap.bonafide.session import Session, OK from leap.common.config import get_path_prefix @@ -46,18 +46,17 @@ class BonafideProtocol(object): _apis = defaultdict(None) _sessions = defaultdict(None) - def _get_api(self, provider_id): + def _get_api(self, provider): # TODO should get deferred - if provider_id in self._apis: - return self._apis[provider_id] + if provider.domain in self._apis: + return self._apis[provider.domain] - # XXX lookup the provider config instead # TODO defer the autoconfig for the provider if needed... - api = provider.Api('https://api.%s:4430' % provider_id) - self._apis[provider_id] = api + api = Api(provider.api_uri, provider.version) + self._apis[provider.domain] = api return api - def _get_session(self, full_id, password=""): + def _get_session(self, provider, full_id, password=""): if full_id in self._sessions: return self._sessions[full_id] @@ -65,7 +64,7 @@ class BonafideProtocol(object): # TODO use twisted.cred instead username, provider_id = config.get_username_and_provider(full_id) credentials = UsernamePassword(username, password) - api = self._get_api(provider_id) + api = self._get_api(provider) provider_pem = _get_provider_ca_path(provider_id) session = Session(credentials, api, provider_pem) self._sessions[full_id] = session @@ -78,10 +77,11 @@ class BonafideProtocol(object): _, provider_id = config.get_username_and_provider(full_id) provider = config.Provider(provider_id) - d = provider.callWhenReady(self._do_signup, full_id, password) + d = provider.callWhenReady( + self._do_signup, provider, full_id, password) return d - def _do_signup(self, full_id, password): + def _do_signup(self, provider, full_id, password): # XXX check it's unauthenticated def return_user(result, _session): @@ -91,7 +91,7 @@ class BonafideProtocol(object): username, _ = config.get_username_and_provider(full_id) # XXX get deferred? - session = self._get_session(full_id, password) + session = self._get_session(provider, full_id, password) d = session.signup(username, password) d.addCallback(return_user, session) return d @@ -102,17 +102,17 @@ class BonafideProtocol(object): provider = config.Provider(provider_id) def maybe_finish_provider_bootstrap(result, provider): - session = self._get_session(full_id, password) + session = self._get_session(provider, full_id, password) d = provider.download_services_config_with_auth(session) d.addCallback(lambda _: result) return d - d = provider.callWhenMainConfigReady( - self._do_authenticate, full_id, password) + d = provider.callWhenReady( + self._do_authenticate, provider, full_id, password) d.addCallback(maybe_finish_provider_bootstrap, provider) return d - def _do_authenticate(self, full_id, password): + def _do_authenticate(self, provider, full_id, password): def return_token_and_uuid(result, _session): if result == OK: @@ -122,7 +122,7 @@ class BonafideProtocol(object): log.msg('AUTH for %s' % full_id) # XXX get deferred? - session = self._get_session(full_id, password) + session = self._get_session(provider, full_id, password) d = session.authenticate() d.addCallback(return_token_and_uuid, session) return d @@ -130,9 +130,10 @@ class BonafideProtocol(object): def do_logout(self, full_id): # XXX use the AVATAR here log.msg('LOGOUT for %s' % full_id) - session = self._get_session(full_id) - if not session.is_authenticated: + if (full_id not in self._sessions or + not self._sessions[full_id].is_authenticated): return fail(RuntimeError("There is no session for such user")) + session = self._sessions[full_id] d = session.logout() d.addCallback(lambda _: self._sessions.pop(full_id)) @@ -140,8 +141,10 @@ class BonafideProtocol(object): return d def do_get_smtp_cert(self, full_id): - session = self._get_session(full_id) - d = session.get_smtp_cert() + if (full_id not in self._sessions or + not self._sessions[full_id].is_authenticated): + return fail(RuntimeError("There is no session for such user")) + d = self._sessions[full_id].get_smtp_cert() return d def do_get_vpn_cert(self): diff --git a/bonafide/src/leap/bonafide/config.py b/bonafide/src/leap/bonafide/config.py index 496c9a8..8c34bc1 100644 --- a/bonafide/src/leap/bonafide/config.py +++ b/bonafide/src/leap/bonafide/config.py @@ -174,6 +174,22 @@ class Provider(object): self.ongoing_bootstrap[self._domain] = defer.succeed( 'already_initialized') + @property + def domain(self): + return self._domain + + @property + def api_uri(self): + if not self._provider_config: + return 'https://api.%s:4430' % self._domain + return self._provider_config.api_uri + + @property + def version(self): + if not self._provider_config: + return 1 + return int(self._provider_config.api_version) + def is_configured(self): provider_json = self._get_provider_json_path() # XXX check if all the services are there @@ -473,9 +489,6 @@ class Provider(object): # XXX pass if-modified-since header return httpRequest(self._agent, *args, **kw) - def _get_api_uri(self): - pass - class Record(object): def __init__(self, **kw): diff --git a/bonafide/src/leap/bonafide/provider.py b/bonafide/src/leap/bonafide/provider.py index 7e78196..82824e9 100644 --- a/bonafide/src/leap/bonafide/provider.py +++ b/bonafide/src/leap/bonafide/provider.py @@ -24,6 +24,12 @@ import re from urlparse import urlparse +""" +Maximum API version number supported by bonafide +""" +MAX_API_VERSION = 1 + + class _MetaActionDispatcher(type): """ @@ -84,7 +90,10 @@ class BaseProvider(object): raise ValueError( 'ProviderApi needs to be passed a url with https scheme') self.netloc = parsed.netloc + self.version = version + if version > MAX_API_VERSION: + self.version = MAX_API_VERSION def get_hostname(self): return urlparse(self._get_base_url()).hostname |