diff options
| -rw-r--r-- | src/leap/base/auth.py | 15 | ||||
| -rw-r--r-- | src/leap/base/config.py | 5 | ||||
| -rw-r--r-- | src/leap/eip/checks.py | 85 | ||||
| -rw-r--r-- | src/leap/gui/firstrun/connect.py | 24 | ||||
| -rw-r--r-- | src/leap/gui/firstrun/providersetup.py | 33 | ||||
| -rw-r--r-- | src/leap/gui/firstrun/register.py | 18 | ||||
| -rw-r--r-- | src/leap/gui/progress.py | 2 | 
7 files changed, 103 insertions, 79 deletions
| diff --git a/src/leap/base/auth.py b/src/leap/base/auth.py index 563a0b2a..f629972f 100644 --- a/src/leap/base/auth.py +++ b/src/leap/base/auth.py @@ -43,7 +43,6 @@ class LeapSRPRegister(object):      def __init__(self,                   schema="https",                   provider=None, -                 #port=None,                   verify=True,                   register_path="1/users.json",                   method="POST", @@ -56,11 +55,6 @@ class LeapSRPRegister(object):          self.schema = schema -        # XXX FIXME -        #self.provider = provider -        #self.port = port -        # XXX splitting server,port -        # deprecate port call.          domain, port = get_https_domain_and_port(provider)          self.provider = domain          self.port = port @@ -137,6 +131,9 @@ class SRPAuth(requests.auth.AuthBase):          self.server = server          self.verify = verify +        logger.debug('SRPAuth. verify=%s' % verify) +        logger.debug('server: %s. username=%s' % (server, username)) +          self.init_data = None          self.session = requests.session() @@ -168,6 +165,9 @@ class SRPAuth(requests.auth.AuthBase):          except requests.exceptions.ConnectionError:              raise SRPAuthenticationError(                  "No connection made (salt).") +        except: +            raise SRPAuthenticationError( +                "Unknown error (salt).")          if init_session.status_code not in (200, ):              raise SRPAuthenticationError(                  "No valid response (salt).") @@ -245,7 +245,6 @@ class SRPAuth(requests.auth.AuthBase):          try:              assert self.srp_usr.authenticated()              logger.debug('user is authenticated!') -            print 'user is authenticated!'          except (AssertionError):              raise SRPAuthenticationError(                  "Auth verification failed.") @@ -268,6 +267,8 @@ def srpauth_protected(user=None, passwd=None, server=None, verify=True):                  auth = SRPAuth(user, passwd, server, verify)                  kwargs['auth'] = auth                  kwargs['verify'] = verify +            if not args: +                logger.warning('attempting to get from empty uri!')              return fn(*args, **kwargs)          return wrapper      return srpauth diff --git a/src/leap/base/config.py b/src/leap/base/config.py index 438d1993..e235e5c3 100644 --- a/src/leap/base/config.py +++ b/src/leap/base/config.py @@ -155,7 +155,7 @@ class JSONLeapConfig(BaseLeapConfig):              return False      def load(self, fromfile=None, from_uri=None, fetcher=None, -             force_download=False, verify=False): +             force_download=False, verify=True):          if from_uri is not None:              fetched = self.fetch( @@ -177,8 +177,7 @@ class JSONLeapConfig(BaseLeapConfig):          if not fetcher:              fetcher = self.fetcher -        logger.debug('verify: %s', verify) -        logger.debug('uri: %s', uri) +        logger.debug('uri: %s (verify: %s)' % (uri, verify))          rargs = (uri, )          rkwargs = {'verify': verify} diff --git a/src/leap/eip/checks.py b/src/leap/eip/checks.py index b14e5dd3..bd158e1e 100644 --- a/src/leap/eip/checks.py +++ b/src/leap/eip/checks.py @@ -1,5 +1,5 @@  import logging -import ssl +#import ssl  #import platform  import time  import os @@ -21,6 +21,8 @@ from leap.eip import constants as eipconstants  from leap.eip import exceptions as eipexceptions  from leap.eip import specs as eipspecs  from leap.util.fileutil import mkdir_p +from leap.util.web import get_https_domain_and_port +from leap.util.misc import null_check  logger = logging.getLogger(name=__name__) @@ -46,7 +48,7 @@ reachable and testable as a whole.  def get_branding_ca_cert(domain): -    # XXX deprecated +    # deprecated      ca_file = BRANDING.get('provider_ca_file')      if ca_file:          return leapcerts.where(ca_file) @@ -63,6 +65,10 @@ class ProviderCertChecker(object):          self.fetcher = fetcher          self.domain = domain +        #XXX needs some kind of autoinit +        #right now we set by hand +        #by loading and reading provider config +        self.apidomain = None          self.cacert = eipspecs.provider_ca_path(domain)      def run_all( @@ -159,7 +165,7 @@ class ProviderCertChecker(object):          if autocacert and verify is True and self.cacert is not None:              logger.debug('verify cert: %s', self.cacert)              verify = self.cacert -        logger.debug('is https working?') +        logger.debug('checking https connection')          logger.debug('uri: %s (verify:%s)', uri, verify)          try:              self.fetcher.get(uri, verify=verify) @@ -167,27 +173,24 @@ class ProviderCertChecker(object):          except requests.exceptions.SSLError:  # as exc:              logger.error("SSLError")              raise eipexceptions.HttpsBadCertError -            #logger.warning('BUG #638 CERT VERIFICATION FAILED! ' -                           #'(this should be CRITICAL)') -            #logger.warning('SSLError: %s', exc.message)          except requests.exceptions.ConnectionError:              logger.error('ConnectionError')              raise eipexceptions.HttpsNotSupported          else: -            logger.debug('True')              return True      def check_new_cert_needed(self, skip_download=False, verify=True): +        # XXX add autocacert          logger.debug('is new cert needed?')          if not self.is_cert_valid(do_raise=False): -            logger.debug('True') +            logger.debug('cert needed: true')              self.download_new_client_cert(                  skip_download=skip_download,                  verify=verify)              return True -        logger.debug('False') +        logger.debug('cert needed: false')          return False      def download_new_client_cert(self, uri=None, verify=True, @@ -199,20 +202,20 @@ class ProviderCertChecker(object):          if uri is None:              uri = self._get_client_cert_uri()          # XXX raise InsecureURI or something better -        assert uri.startswith('https') +        #assert uri.startswith('https')          if verify is True and self.cacert is not None:              verify = self.cacert +            logger.debug('verify = %s', verify)          fgetfn = self.fetcher.get          if credentials:              user, passwd = credentials - -            logger.debug('domain = %s', self.domain) +            logger.debug('apidomain = %s', self.apidomain)              @srpauth_protected(user, passwd, -                               server="https://%s" % self.domain, +                               server="https://%s" % self.apidomain,                                 verify=verify)              def getfn(*args, **kwargs):                  return fgetfn(*args, **kwargs) @@ -231,11 +234,16 @@ class ProviderCertChecker(object):              logger.warning('SSLError while fetching cert. '                             'Look below for stack trace.')              # XXX raise better exception -            raise +            return self.fail("SSLError") +        except Exception as exc: +            return self.fail(exc.message) +          try: +            logger.debug('validating cert...')              pemfile_content = req.content              valid = self.is_valid_pemfile(pemfile_content)              if not valid: +                logger.warning('invalid cert')                  return False              cert_path = self._get_client_cert_path()              self.write_cert(pemfile_content, to=cert_path) @@ -299,8 +307,7 @@ class ProviderCertChecker(object):          return u"https://%s/" % self.domain      def _get_client_cert_uri(self): -        # XXX get the whole thing from constants -        return "https://%s/1/cert" % self.domain +        return "https://%s/1/cert" % self.apidomain      def _get_client_cert_path(self):          return eipspecs.client_cert_path(domain=self.domain) @@ -327,6 +334,9 @@ class ProviderCertChecker(object):          with open(to, 'w') as cert_f:              cert_f.write(pemfile_content) +    def set_api_domain(self, domain): +        self.apidomain = domain +  class EIPConfigChecker(object):      """ @@ -346,10 +356,15 @@ class EIPConfigChecker(object):          # if not domain, get from config          self.domain = domain +        self.apidomain = None +        self.cacert = eipspecs.provider_ca_path(domain) -        self.eipconfig = eipconfig.EIPConfig(domain=domain)          self.defaultprovider = providers.LeapProviderDefinition(domain=domain) +        self.defaultprovider.load() +        self.eipconfig = eipconfig.EIPConfig(domain=domain) +        self.set_api_domain()          self.eipserviceconfig = eipconfig.EIPServiceConfig(domain=domain) +        self.eipserviceconfig.load()      def run_all(self, checker=None, skip_download=False):          """ @@ -433,31 +448,35 @@ class EIPConfigChecker(object):                  domain = config.get('provider', None)              uri = self._get_provider_definition_uri(domain=domain) -        # FIXME! Pass ca path verify!!! -        # BUG #638 -        # FIXME FIXME FIXME          self.defaultprovider.load(              from_uri=uri,              fetcher=self.fetcher) -            #verify=False)          self.defaultprovider.save()      def fetch_eip_service_config(self, skip_download=False,                                   force_download=False, -                                 config=None, uri=None, domain=None): +                                 config=None, uri=None,  # domain=None, +                                 autocacert=True):          if skip_download:              return True          if config is None: +            self.eipserviceconfig.load()              config = self.eipserviceconfig.config          if uri is None: -            if not domain: -                domain = self.domain or config.get('provider', None) -            uri = self._get_eip_service_uri(domain=domain) +            #XXX +            #if not domain: +                #domain = self.domain or config.get('provider', None) +            uri = self._get_eip_service_uri( +                domain=self.apidomain) + +        if autocacert and self.cacert is not None: +            verify = self.cacert          self.eipserviceconfig.load(              from_uri=uri,              fetcher=self.fetcher, -            force_download=force_download) +            force_download=force_download, +            verify=verify)          self.eipserviceconfig.save()      def check_complete_eip_config(self, config=None): @@ -465,7 +484,6 @@ class EIPConfigChecker(object):          if config is None:              config = self.eipconfig.config          try: -            'trying assertions'              assert 'provider' in config              assert config['provider'] is not None              # XXX assert there is gateway !! @@ -504,3 +522,16 @@ class EIPConfigChecker(object):          uri = "https://%s/%s" % (domain, path)          logger.debug('getting eip service file from %s', uri)          return uri + +    def set_api_domain(self): +        """sets api domain from defaultprovider config object""" +        api = self.defaultprovider.config.get('api_uri', None) +        # the caller is responsible for having loaded the config +        # object at this point +        if api: +            api_dom = get_https_domain_and_port(api) +            self.apidomain = "%s:%s" % api_dom + +    def get_api_domain(self): +        """gets api domain""" +        return self.apidomain diff --git a/src/leap/gui/firstrun/connect.py b/src/leap/gui/firstrun/connect.py index 920ada50..b7688380 100644 --- a/src/leap/gui/firstrun/connect.py +++ b/src/leap/gui/firstrun/connect.py @@ -44,9 +44,15 @@ class ConnectionPage(ValidationPage):          wizard = self.wizard()          full_domain = self.field('provider_domain')          domain, port = get_https_domain_and_port(full_domain) -        _domain = u"%s:%s" % (domain, port) if port != 443 else unicode(domain) -        verify = True +        pconfig = wizard.eipconfigchecker(domain=domain) +        # this should be persisted... +        pconfig.defaultprovider.load() +        pconfig.set_api_domain() + +        pCertChecker = wizard.providercertchecker( +            domain=domain) +        pCertChecker.set_api_domain(pconfig.apidomain)          ###########################################          # Set Credentials. @@ -63,11 +69,6 @@ class ConnectionPage(ValidationPage):          password = self.field(passwk)          credentials = username, password -        eipconfigchecker = wizard.eipconfigchecker(domain=_domain) -        #XXX change for _domain (sanitized) -        pCertChecker = wizard.providercertchecker( -            domain=full_domain) -          yield(("head_sentinel", 0), lambda: None)          ################################################## @@ -75,8 +76,7 @@ class ConnectionPage(ValidationPage):          ##################################################          def fetcheipconf():              try: -                eipconfigchecker.fetch_eip_service_config( -                    domain=full_domain) +                pconfig.fetch_eip_service_config()              # XXX get specific exception              except Exception as exc: @@ -92,8 +92,7 @@ class ConnectionPage(ValidationPage):          def fetcheipcert():              try:                  downloaded = pCertChecker.download_new_client_cert( -                    credentials=credentials, -                    verify=verify) +                    credentials=credentials)                  if not downloaded:                      logger.error('Could not download client cert.')                      return False @@ -101,6 +100,9 @@ class ConnectionPage(ValidationPage):              except auth.SRPAuthenticationError as exc:                  return self.fail(self.tr(                      "Authentication error: %s" % exc.message)) + +            except Exception as exc: +                return self.fail(exc.message)              else:                  return True diff --git a/src/leap/gui/firstrun/providersetup.py b/src/leap/gui/firstrun/providersetup.py index 48a89091..981e3214 100644 --- a/src/leap/gui/firstrun/providersetup.py +++ b/src/leap/gui/firstrun/providersetup.py @@ -4,6 +4,8 @@ used if First Run Wizard  """  import logging +import requests +  from PyQt4 import QtGui  from leap.base import exceptions as baseexceptions @@ -110,26 +112,15 @@ class ProviderSetupValidationPage(ValidationPage):          #########################          def validatecacert(): -            pass -            #api_uri = pconfig.get('api_uri', None) -            #try: -                #api_cert_verified = pCertChecker.verify_api_https(api_uri) -            #except requests.exceptions.SSLError as exc: -                #logger.error('BUG #638. %s' % exc.message) -                # XXX RAISE! See #638 -                # bypassing until the hostname is fixed. -                # We probably should raise yet-another-warning -                # here saying user that the hostname "XX.XX.XX.XX' does not -                # match 'foo.bar.baz' -                #api_cert_verified = True - -            #if not api_cert_verified: -                # XXX update validationMsg -                # should catch exception -                #return False - -            #??? -            #ca_cert_path = checker.ca_cert_path +            api_uri = pconfig.get('api_uri', None) +            try: +                pCertChecker.verify_api_https(api_uri) +            except requests.exceptions.SSLError as exc: +                return self.fail("Validation Error") +            except Exception as exc: +                return self.fail(exc.msg) +            else: +                return True          yield((self.tr('Validating api certificate'), 90), validatecacert) @@ -141,8 +132,8 @@ class ProviderSetupValidationPage(ValidationPage):          called after _do_checks has finished          (connected to checker thread finished signal)          """ -        prevpage = "providerselection" if self.is_signup else "login"          wizard = self.wizard() +        prevpage = "login" if wizard.from_login else "providerselection"          if self.errors:              logger.debug('going back with errors') diff --git a/src/leap/gui/firstrun/register.py b/src/leap/gui/firstrun/register.py index b04638e0..741b9267 100644 --- a/src/leap/gui/firstrun/register.py +++ b/src/leap/gui/firstrun/register.py @@ -224,11 +224,17 @@ class RegisterUserPage(InlineValidationPage, UserFormMixIn):          generator that yields actual checks          that are executed in a separate thread          """ +        wizard = self.wizard() +          provider = self.field('provider_domain')          username = self.userNameLineEdit.text()          password = self.userPasswordLineEdit.text()          password2 = self.userPassword2LineEdit.text() +        pconfig = wizard.eipconfigchecker(domain=provider) +        pconfig.defaultprovider.load() +        pconfig.set_api_domain() +          def checkpass():              # we better have here              # some call to a password checker... @@ -263,14 +269,11 @@ class RegisterUserPage(InlineValidationPage, UserFormMixIn):              self, "showStepsFrame")          def register(): -            # XXX FIXME! -            verify = False              signup = auth.LeapSRPRegister(                  schema="https", -                provider=provider, -                verify=verify) -            #import ipdb;ipdb.set_trace() +                provider=pconfig.apidomain, +                verify=pconfig.cacert)              try:                  ok, req = signup.register_user(                      username, password) @@ -381,7 +384,4 @@ class RegisterUserPage(InlineValidationPage, UserFormMixIn):      def nextId(self):          wizard = self.wizard() -        #if not wizard: -            #return -        # XXX this should be called connect -        return wizard.get_page_index('signupvalidation') +        return wizard.get_page_index('connect') diff --git a/src/leap/gui/progress.py b/src/leap/gui/progress.py index fceeb2f6..ca4f6cc3 100644 --- a/src/leap/gui/progress.py +++ b/src/leap/gui/progress.py @@ -287,7 +287,7 @@ class WithStepsMixIn(object):              pagename = getattr(self, 'prev_page', None)          if pagename is None:  # pragma: no cover              return -        logger.debug('cleaning wizard errors for %s' % pagename) +        #logger.debug('cleaning wizard errors for %s' % pagename)          self.wizard().set_validation_error(pagename, None)      def populateStepsTable(self): | 
