summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--bonafide/src/leap/bonafide/_protocol.py45
-rw-r--r--bonafide/src/leap/bonafide/config.py19
-rw-r--r--bonafide/src/leap/bonafide/provider.py9
3 files changed, 49 insertions, 24 deletions
diff --git a/bonafide/src/leap/bonafide/_protocol.py b/bonafide/src/leap/bonafide/_protocol.py
index 9e964af..e50d020 100644
--- a/bonafide/src/leap/bonafide/_protocol.py
+++ b/bonafide/src/leap/bonafide/_protocol.py
@@ -22,7 +22,7 @@ import resource
from collections import defaultdict
from leap.bonafide import config
-from leap.bonafide import provider
+from leap.bonafide.provider import Api
from leap.bonafide.session import Session, OK
from leap.common.config import get_path_prefix
@@ -46,18 +46,17 @@ class BonafideProtocol(object):
_apis = defaultdict(None)
_sessions = defaultdict(None)
- def _get_api(self, provider_id):
+ def _get_api(self, provider):
# TODO should get deferred
- if provider_id in self._apis:
- return self._apis[provider_id]
+ if provider.domain in self._apis:
+ return self._apis[provider.domain]
- # XXX lookup the provider config instead
# TODO defer the autoconfig for the provider if needed...
- api = provider.Api('https://api.%s:4430' % provider_id)
- self._apis[provider_id] = api
+ api = Api(provider.api_uri, provider.version)
+ self._apis[provider.domain] = api
return api
- def _get_session(self, full_id, password=""):
+ def _get_session(self, provider, full_id, password=""):
if full_id in self._sessions:
return self._sessions[full_id]
@@ -65,7 +64,7 @@ class BonafideProtocol(object):
# TODO use twisted.cred instead
username, provider_id = config.get_username_and_provider(full_id)
credentials = UsernamePassword(username, password)
- api = self._get_api(provider_id)
+ api = self._get_api(provider)
provider_pem = _get_provider_ca_path(provider_id)
session = Session(credentials, api, provider_pem)
self._sessions[full_id] = session
@@ -78,10 +77,11 @@ class BonafideProtocol(object):
_, provider_id = config.get_username_and_provider(full_id)
provider = config.Provider(provider_id)
- d = provider.callWhenReady(self._do_signup, full_id, password)
+ d = provider.callWhenReady(
+ self._do_signup, provider, full_id, password)
return d
- def _do_signup(self, full_id, password):
+ def _do_signup(self, provider, full_id, password):
# XXX check it's unauthenticated
def return_user(result, _session):
@@ -91,7 +91,7 @@ class BonafideProtocol(object):
username, _ = config.get_username_and_provider(full_id)
# XXX get deferred?
- session = self._get_session(full_id, password)
+ session = self._get_session(provider, full_id, password)
d = session.signup(username, password)
d.addCallback(return_user, session)
return d
@@ -102,17 +102,17 @@ class BonafideProtocol(object):
provider = config.Provider(provider_id)
def maybe_finish_provider_bootstrap(result, provider):
- session = self._get_session(full_id, password)
+ session = self._get_session(provider, full_id, password)
d = provider.download_services_config_with_auth(session)
d.addCallback(lambda _: result)
return d
- d = provider.callWhenMainConfigReady(
- self._do_authenticate, full_id, password)
+ d = provider.callWhenReady(
+ self._do_authenticate, provider, full_id, password)
d.addCallback(maybe_finish_provider_bootstrap, provider)
return d
- def _do_authenticate(self, full_id, password):
+ def _do_authenticate(self, provider, full_id, password):
def return_token_and_uuid(result, _session):
if result == OK:
@@ -122,7 +122,7 @@ class BonafideProtocol(object):
log.msg('AUTH for %s' % full_id)
# XXX get deferred?
- session = self._get_session(full_id, password)
+ session = self._get_session(provider, full_id, password)
d = session.authenticate()
d.addCallback(return_token_and_uuid, session)
return d
@@ -130,9 +130,10 @@ class BonafideProtocol(object):
def do_logout(self, full_id):
# XXX use the AVATAR here
log.msg('LOGOUT for %s' % full_id)
- session = self._get_session(full_id)
- if not session.is_authenticated:
+ if (full_id not in self._sessions or
+ not self._sessions[full_id].is_authenticated):
return fail(RuntimeError("There is no session for such user"))
+ session = self._sessions[full_id]
d = session.logout()
d.addCallback(lambda _: self._sessions.pop(full_id))
@@ -140,8 +141,10 @@ class BonafideProtocol(object):
return d
def do_get_smtp_cert(self, full_id):
- session = self._get_session(full_id)
- d = session.get_smtp_cert()
+ if (full_id not in self._sessions or
+ not self._sessions[full_id].is_authenticated):
+ return fail(RuntimeError("There is no session for such user"))
+ d = self._sessions[full_id].get_smtp_cert()
return d
def do_get_vpn_cert(self):
diff --git a/bonafide/src/leap/bonafide/config.py b/bonafide/src/leap/bonafide/config.py
index 496c9a8..8c34bc1 100644
--- a/bonafide/src/leap/bonafide/config.py
+++ b/bonafide/src/leap/bonafide/config.py
@@ -174,6 +174,22 @@ class Provider(object):
self.ongoing_bootstrap[self._domain] = defer.succeed(
'already_initialized')
+ @property
+ def domain(self):
+ return self._domain
+
+ @property
+ def api_uri(self):
+ if not self._provider_config:
+ return 'https://api.%s:4430' % self._domain
+ return self._provider_config.api_uri
+
+ @property
+ def version(self):
+ if not self._provider_config:
+ return 1
+ return int(self._provider_config.api_version)
+
def is_configured(self):
provider_json = self._get_provider_json_path()
# XXX check if all the services are there
@@ -473,9 +489,6 @@ class Provider(object):
# XXX pass if-modified-since header
return httpRequest(self._agent, *args, **kw)
- def _get_api_uri(self):
- pass
-
class Record(object):
def __init__(self, **kw):
diff --git a/bonafide/src/leap/bonafide/provider.py b/bonafide/src/leap/bonafide/provider.py
index 7e78196..82824e9 100644
--- a/bonafide/src/leap/bonafide/provider.py
+++ b/bonafide/src/leap/bonafide/provider.py
@@ -24,6 +24,12 @@ import re
from urlparse import urlparse
+"""
+Maximum API version number supported by bonafide
+"""
+MAX_API_VERSION = 1
+
+
class _MetaActionDispatcher(type):
"""
@@ -84,7 +90,10 @@ class BaseProvider(object):
raise ValueError(
'ProviderApi needs to be passed a url with https scheme')
self.netloc = parsed.netloc
+
self.version = version
+ if version > MAX_API_VERSION:
+ self.version = MAX_API_VERSION
def get_hostname(self):
return urlparse(self._get_base_url()).hostname