summaryrefslogtreecommitdiff
path: root/src/leap
diff options
context:
space:
mode:
Diffstat (limited to 'src/leap')
-rw-r--r--src/leap/base/config.py9
-rw-r--r--src/leap/eip/checks.py67
-rw-r--r--src/leap/eip/specs.py21
-rwxr-xr-xsrc/leap/gui/firstrunwizard.py110
4 files changed, 186 insertions, 21 deletions
diff --git a/src/leap/base/config.py b/src/leap/base/config.py
index cf01d1aa..9ce2e9f0 100644
--- a/src/leap/base/config.py
+++ b/src/leap/base/config.py
@@ -252,6 +252,15 @@ def get_default_provider_path():
return default_provider_path
+def get_provider_path(domain):
+ # XXX if not domain, return get_default_provider_path
+ default_subpath = os.path.join("providers", domain)
+ provider_path = get_config_file(
+ '',
+ folder=default_subpath)
+ return provider_path
+
+
def validate_ip(ip_str):
"""
raises exception if the ip_str is
diff --git a/src/leap/eip/checks.py b/src/leap/eip/checks.py
index 74afd677..635308bb 100644
--- a/src/leap/eip/checks.py
+++ b/src/leap/eip/checks.py
@@ -11,6 +11,7 @@ import requests
from leap import __branding as BRANDING
from leap import certs as leapcerts
+from leap.base.auth import srpauth_protected
from leap.base import config as baseconfig
from leap.base import constants as baseconstants
from leap.base import providers
@@ -98,6 +99,17 @@ class ProviderCertChecker(object):
def check_ca_cert_fingerprint(
self, hash_type="SHA256",
fingerprint=None):
+ """
+ compares the fingerprint in
+ the ca cert with a string
+ we are passed
+ returns True if they are equal, False if not.
+ @param hash_type: digest function
+ @type hash_type: str
+ @param fingerprint: the fingerprint to compare with.
+ @type fingerprint: str (with : separator)
+ @rtype bool
+ """
ca_cert_path = self.ca_cert_path
ca_cert_fpr = certs.get_cert_fingerprint(
filepath=ca_cert_path)
@@ -185,7 +197,8 @@ class ProviderCertChecker(object):
return False
def download_new_client_cert(self, uri=None, verify=True,
- skip_download=False):
+ skip_download=False,
+ credentials=None):
logger.debug('download new client cert')
if skip_download:
return True
@@ -193,18 +206,34 @@ class ProviderCertChecker(object):
uri = self._get_client_cert_uri()
# XXX raise InsecureURI or something better
assert uri.startswith('https')
+
if verify is True and self.cacert is not None:
verify = self.cacert
+
+ fgetfn = self.fetcher.get
+
+ if credentials:
+ user, passwd = credentials
+
+ @srpauth_protected(user, passwd)
+ def getfn(*args, **kwargs):
+ return fgetfn(*args, **kwargs)
+
+ else:
+ # XXX use magic_srpauth decorator instead,
+ # merge with the branch above
+ def getfn(*args, **kwargs):
+ return fgetfn(*args, **kwargs)
try:
+
# XXX FIXME!!!!
# verify=verify
# Workaround for #638. return to verification
# when That's done!!!
-
- # XXX HOOK SRP here...
- # will have to be more generic in the future.
- req = self.fetcher.get(uri, verify=False)
+ #req = self.fetcher.get(uri, verify=False)
+ req = getfn(uri, verify=False)
req.raise_for_status()
+
except requests.exceptions.SSLError:
logger.warning('SSLError while fetching cert. '
'Look below for stack trace.')
@@ -283,23 +312,26 @@ class ProviderCertChecker(object):
return self._get_ca_cert_path(self.domain)
def _get_root_uri(self):
- return u"https://%s/" % baseconstants.DEFAULT_PROVIDER
+ return u"https://%s/" % self.domain
def _get_client_cert_uri(self):
# XXX get the whole thing from constants
- return "https://%s/1/cert" % (baseconstants.DEFAULT_PROVIDER)
+ return "https://%s/1/cert" % self.domain
def _get_client_cert_path(self):
# MVS+ : get provider path
- return eipspecs.client_cert_path()
+ return eipspecs.client_cert_path(domain=self.domain)
def _get_ca_cert_path(self, domain):
# XXX this folder path will be broken for win
# and this should be moved to eipspecs.ca_path
+ # XXX use baseconfig.get_provider_path(folder=Foo)
+ # !!!
+
capath = baseconfig.get_config_file(
'cacert.pem',
- folder='providers/%s/certs/ca' % domain)
+ folder='providers/%s/keys/ca' % domain)
folder, fname = os.path.split(capath)
if not os.path.isdir(folder):
mkdir_p(folder)
@@ -321,16 +353,20 @@ class EIPConfigChecker(object):
use run_all to run all checks.
"""
- def __init__(self, fetcher=requests):
+ def __init__(self, fetcher=requests, domain=None):
# we do not want to accept too many
# argument on init.
# we want tests
# to be explicitely run.
+
self.fetcher = fetcher
- self.eipconfig = eipconfig.EIPConfig()
- self.defaultprovider = providers.LeapProviderDefinition()
- self.eipserviceconfig = eipconfig.EIPServiceConfig()
+ # if not domain, get from config
+ self.domain = domain
+
+ self.eipconfig = eipconfig.EIPConfig(domain=domain)
+ self.defaultprovider = providers.LeapProviderDefinition(domain=domain)
+ self.eipserviceconfig = eipconfig.EIPServiceConfig(domain=domain)
def run_all(self, checker=None, skip_download=False):
"""
@@ -421,13 +457,14 @@ class EIPConfigChecker(object):
self.defaultprovider.save()
def fetch_eip_service_config(self, skip_download=False,
- config=None, uri=None):
+ config=None, uri=None, domain=None):
if skip_download:
return True
if config is None:
config = self.eipserviceconfig.config
if uri is None:
- domain = config.get('provider', None)
+ if not domain:
+ domain = config.get('provider', None)
uri = self._get_eip_service_uri(domain=domain)
self.eipserviceconfig.load(from_uri=uri, fetcher=self.fetcher)
diff --git a/src/leap/eip/specs.py b/src/leap/eip/specs.py
index 1a670b0e..4014b7c9 100644
--- a/src/leap/eip/specs.py
+++ b/src/leap/eip/specs.py
@@ -8,7 +8,14 @@ PROVIDER_CA_CERT = __branding.get(
'provider_ca_file',
'testprovider-ca-cert.pem')
-provider_ca_path = lambda: str(os.path.join(
+provider_ca_path = lambda domain: str(os.path.join(
+ #baseconfig.get_default_provider_path(),
+ baseconfig.get_provider_path(domain),
+ 'keys', 'ca',
+ 'cacert.pem'
+))
+
+default_provider_ca_path = lambda: str(os.path.join(
baseconfig.get_default_provider_path(),
'keys', 'ca',
PROVIDER_CA_CERT
@@ -17,7 +24,13 @@ provider_ca_path = lambda: str(os.path.join(
PROVIDER_DOMAIN = __branding.get('provider_domain', 'testprovider.example.org')
-client_cert_path = lambda: unicode(os.path.join(
+client_cert_path = lambda domain: unicode(os.path.join(
+ baseconfig.get_provider_path(domain),
+ 'keys', 'client',
+ 'openvpn.pem'
+))
+
+default_client_cert_path = lambda: unicode(os.path.join(
baseconfig.get_default_provider_path(),
'keys', 'client',
'openvpn.pem'
@@ -46,11 +59,11 @@ eipconfig_spec = {
},
'openvpn_ca_certificate': {
'type': unicode, # path
- 'default': provider_ca_path
+ 'default': default_provider_ca_path
},
'openvpn_client_certificate': {
'type': unicode, # path
- 'default': client_cert_path
+ 'default': default_client_cert_path
},
'connect_on_login': {
'type': bool,
diff --git a/src/leap/gui/firstrunwizard.py b/src/leap/gui/firstrunwizard.py
index 8bb40cdc..68cd4253 100755
--- a/src/leap/gui/firstrunwizard.py
+++ b/src/leap/gui/firstrunwizard.py
@@ -584,7 +584,7 @@ class ProviderSetupPage(QtGui.QWizardPage):
#ca_cert_path = checker.ca_cert_path
self.progress.setValue(100)
- time.sleep(0.2)
+ time.sleep(1)
# pagewizard methods
@@ -634,7 +634,6 @@ class UserFormMixIn(object):
# I guess it is because there is no delay...
logger.debug('registering........')
self.validationMsg.setText('registering...')
- # need to call update somehow???
# XXX refactor set_status_foo
@@ -774,6 +773,10 @@ class RegisterUserPage(QtGui.QWizardPage, UserFormMixIn):
self.registerField('userName*', self.userNameLineEdit)
self.registerField('userPassword*', self.userPasswordLineEdit)
+
+ # XXX missing password confirmation
+ # XXX validator!
+
self.registerField('rememberPassword', rememberPasswordCheckBox)
layout = QtGui.QGridLayout()
@@ -898,6 +901,109 @@ class ConnectingPage(QtGui.QWizardPage):
QtGui.QWizard.LogoPixmap,
QtGui.QPixmap(APP_LOGO))
+ self.status = QtGui.QLabel("")
+ self.status.setWordWrap(True)
+ self.progress = QtGui.QProgressBar()
+ self.progress.setMaximum(100)
+ self.progress.hide()
+
+ self.status_line_1 = QtGui.QLabel()
+ self.status_line_2 = QtGui.QLabel()
+ self.status_line_3 = QtGui.QLabel()
+ self.status_line_4 = QtGui.QLabel()
+
+ layout = QtGui.QGridLayout()
+ layout.addWidget(self.status, 0, 1)
+ layout.addWidget(self.progress, 5, 1)
+ layout.addWidget(self.status_line_1, 8, 1)
+ layout.addWidget(self.status_line_2, 9, 1)
+ layout.addWidget(self.status_line_3, 10, 1)
+ layout.addWidget(self.status_line_4, 11, 1)
+
+ self.setLayout(layout)
+
+ def set_status(self, status):
+ self.status.setText(status)
+ self.status.setWordWrap(True)
+
+ def get_donemsg(self, msg):
+ return "%s ... done" % msg
+
+ def fetch_and_validate(self):
+ # Fake... till you make it...
+ import time
+ domain = self.field('provider_domain')
+ wizard = self.wizard()
+ #pconfig = wizard.providerconfig
+ eipconfigchecker = wizard.eipconfigchecker()
+ pCertChecker = wizard.providercertchecker(
+ domain=domain)
+
+ # XXX get from log_in page if we came that way
+ # instead
+
+ username = self.field('userName')
+ password = self.field('userPassword')
+
+ credentials = username, password
+
+ self.progress.show()
+
+ fetching_eip_conf_msg = 'Fetching eip service configuration'
+ self.set_status(fetching_eip_conf_msg)
+ self.progress.setValue(30)
+
+ # Fetching eip service
+ eipconfigchecker.fetch_eip_service_config(
+ domain=domain)
+
+ self.status_line_1.setText(
+ self.get_donemsg(fetching_eip_conf_msg))
+
+ getting_client_cert_msg = 'Getting client certificate'
+ self.set_status(getting_client_cert_msg)
+ self.progress.setValue(66)
+
+ # Download cert
+ pCertChecker.download_new_client_cert(
+ credentials=credentials)
+
+ time.sleep(2)
+ self.status_line_2.setText(
+ self.get_donemsg(getting_client_cert_msg))
+
+ validating_clientcert_msg = 'Validating client certificate'
+ self.set_status(validating_clientcert_msg)
+ self.progress.setValue(90)
+ time.sleep(2)
+ self.status_line_3.setText(
+ self.get_donemsg(validating_clientcert_msg))
+
+ self.progress.setValue(100)
+ time.sleep(3)
+
+ return True
+
+ # pagewizard methods
+
+ def initializePage(self):
+ # XXX if we're coming from signup page
+ # we could say something like
+ # 'registration successful!'
+ self.status.setText(
+ "We have "
+ "all we need to connect with the provider.<br><br> "
+ "Click <i>next</i> to continue. ")
+ self.progress.setValue(0)
+ self.progress.hide()
+ self.status_line_1.setText('')
+ self.status_line_2.setText('')
+ self.status_line_3.setText('')
+
+ def validatePage(self):
+ validated = self.fetch_and_validate()
+ return validated
+
class LastPage(QtGui.QWizardPage):
def __init__(self, parent=None):