diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/leap/base/config.py | 9 | ||||
| -rw-r--r-- | src/leap/eip/checks.py | 67 | ||||
| -rw-r--r-- | src/leap/eip/specs.py | 21 | ||||
| -rwxr-xr-x | src/leap/gui/firstrunwizard.py | 110 | 
4 files changed, 186 insertions, 21 deletions
| diff --git a/src/leap/base/config.py b/src/leap/base/config.py index cf01d1aa..9ce2e9f0 100644 --- a/src/leap/base/config.py +++ b/src/leap/base/config.py @@ -252,6 +252,15 @@ def get_default_provider_path():      return default_provider_path +def get_provider_path(domain): +    # XXX if not domain, return get_default_provider_path +    default_subpath = os.path.join("providers", domain) +    provider_path = get_config_file( +        '', +        folder=default_subpath) +    return provider_path + +  def validate_ip(ip_str):      """      raises exception if the ip_str is diff --git a/src/leap/eip/checks.py b/src/leap/eip/checks.py index 74afd677..635308bb 100644 --- a/src/leap/eip/checks.py +++ b/src/leap/eip/checks.py @@ -11,6 +11,7 @@ import requests  from leap import __branding as BRANDING  from leap import certs as leapcerts +from leap.base.auth import srpauth_protected  from leap.base import config as baseconfig  from leap.base import constants as baseconstants  from leap.base import providers @@ -98,6 +99,17 @@ class ProviderCertChecker(object):      def check_ca_cert_fingerprint(              self, hash_type="SHA256",              fingerprint=None): +        """ +        compares the fingerprint in +        the ca cert with a string +        we are passed +        returns True if they are equal, False if not. +        @param hash_type: digest function +        @type hash_type: str +        @param fingerprint: the fingerprint to compare with. +        @type fingerprint: str (with : separator) +        @rtype bool +        """          ca_cert_path = self.ca_cert_path          ca_cert_fpr = certs.get_cert_fingerprint(              filepath=ca_cert_path) @@ -185,7 +197,8 @@ class ProviderCertChecker(object):          return False      def download_new_client_cert(self, uri=None, verify=True, -                                 skip_download=False): +                                 skip_download=False, +                                 credentials=None):          logger.debug('download new client cert')          if skip_download:              return True @@ -193,18 +206,34 @@ class ProviderCertChecker(object):              uri = self._get_client_cert_uri()          # XXX raise InsecureURI or something better          assert uri.startswith('https') +          if verify is True and self.cacert is not None:              verify = self.cacert + +        fgetfn = self.fetcher.get + +        if credentials: +            user, passwd = credentials + +            @srpauth_protected(user, passwd) +            def getfn(*args, **kwargs): +                return fgetfn(*args, **kwargs) + +        else: +            # XXX use magic_srpauth decorator instead, +            # merge with the branch above +            def getfn(*args, **kwargs): +                return fgetfn(*args, **kwargs)          try: +              # XXX FIXME!!!!              # verify=verify              # Workaround for #638. return to verification              # when That's done!!! - -            # XXX HOOK SRP here... -            # will have to be more generic in the future. -            req = self.fetcher.get(uri, verify=False) +            #req = self.fetcher.get(uri, verify=False) +            req = getfn(uri, verify=False)              req.raise_for_status() +          except requests.exceptions.SSLError:              logger.warning('SSLError while fetching cert. '                             'Look below for stack trace.') @@ -283,23 +312,26 @@ class ProviderCertChecker(object):          return self._get_ca_cert_path(self.domain)      def _get_root_uri(self): -        return u"https://%s/" % baseconstants.DEFAULT_PROVIDER +        return u"https://%s/" % self.domain      def _get_client_cert_uri(self):          # XXX get the whole thing from constants -        return "https://%s/1/cert" % (baseconstants.DEFAULT_PROVIDER) +        return "https://%s/1/cert" % self.domain      def _get_client_cert_path(self):          # MVS+ : get provider path -        return eipspecs.client_cert_path() +        return eipspecs.client_cert_path(domain=self.domain)      def _get_ca_cert_path(self, domain):          # XXX this folder path will be broken for win          # and this should be moved to eipspecs.ca_path +        # XXX use baseconfig.get_provider_path(folder=Foo) +        # !!! +          capath = baseconfig.get_config_file(              'cacert.pem', -            folder='providers/%s/certs/ca' % domain) +            folder='providers/%s/keys/ca' % domain)          folder, fname = os.path.split(capath)          if not os.path.isdir(folder):              mkdir_p(folder) @@ -321,16 +353,20 @@ class EIPConfigChecker(object):      use run_all to run all checks.      """ -    def __init__(self, fetcher=requests): +    def __init__(self, fetcher=requests, domain=None):          # we do not want to accept too many          # argument on init.          # we want tests          # to be explicitely run. +          self.fetcher = fetcher -        self.eipconfig = eipconfig.EIPConfig() -        self.defaultprovider = providers.LeapProviderDefinition() -        self.eipserviceconfig = eipconfig.EIPServiceConfig() +        # if not domain, get from config +        self.domain = domain + +        self.eipconfig = eipconfig.EIPConfig(domain=domain) +        self.defaultprovider = providers.LeapProviderDefinition(domain=domain) +        self.eipserviceconfig = eipconfig.EIPServiceConfig(domain=domain)      def run_all(self, checker=None, skip_download=False):          """ @@ -421,13 +457,14 @@ class EIPConfigChecker(object):          self.defaultprovider.save()      def fetch_eip_service_config(self, skip_download=False, -                                 config=None, uri=None): +                                 config=None, uri=None, domain=None):          if skip_download:              return True          if config is None:              config = self.eipserviceconfig.config          if uri is None: -            domain = config.get('provider', None) +            if not domain: +                domain = config.get('provider', None)              uri = self._get_eip_service_uri(domain=domain)          self.eipserviceconfig.load(from_uri=uri, fetcher=self.fetcher) diff --git a/src/leap/eip/specs.py b/src/leap/eip/specs.py index 1a670b0e..4014b7c9 100644 --- a/src/leap/eip/specs.py +++ b/src/leap/eip/specs.py @@ -8,7 +8,14 @@ PROVIDER_CA_CERT = __branding.get(      'provider_ca_file',      'testprovider-ca-cert.pem') -provider_ca_path = lambda: str(os.path.join( +provider_ca_path = lambda domain: str(os.path.join( +    #baseconfig.get_default_provider_path(), +    baseconfig.get_provider_path(domain), +    'keys', 'ca', +    'cacert.pem' +)) + +default_provider_ca_path = lambda: str(os.path.join(      baseconfig.get_default_provider_path(),      'keys', 'ca',      PROVIDER_CA_CERT @@ -17,7 +24,13 @@ provider_ca_path = lambda: str(os.path.join(  PROVIDER_DOMAIN = __branding.get('provider_domain', 'testprovider.example.org') -client_cert_path = lambda: unicode(os.path.join( +client_cert_path = lambda domain: unicode(os.path.join( +    baseconfig.get_provider_path(domain), +    'keys', 'client', +    'openvpn.pem' +)) + +default_client_cert_path = lambda: unicode(os.path.join(      baseconfig.get_default_provider_path(),      'keys', 'client',      'openvpn.pem' @@ -46,11 +59,11 @@ eipconfig_spec = {          },          'openvpn_ca_certificate': {              'type': unicode,  # path -            'default': provider_ca_path +            'default': default_provider_ca_path          },          'openvpn_client_certificate': {              'type': unicode,  # path -            'default': client_cert_path +            'default': default_client_cert_path          },          'connect_on_login': {              'type': bool, diff --git a/src/leap/gui/firstrunwizard.py b/src/leap/gui/firstrunwizard.py index 8bb40cdc..68cd4253 100755 --- a/src/leap/gui/firstrunwizard.py +++ b/src/leap/gui/firstrunwizard.py @@ -584,7 +584,7 @@ class ProviderSetupPage(QtGui.QWizardPage):          #ca_cert_path = checker.ca_cert_path          self.progress.setValue(100) -        time.sleep(0.2) +        time.sleep(1)      # pagewizard methods @@ -634,7 +634,6 @@ class UserFormMixIn(object):          # I guess it is because there is no delay...          logger.debug('registering........')          self.validationMsg.setText('registering...') -        # need to call update somehow???      # XXX refactor set_status_foo @@ -774,6 +773,10 @@ class RegisterUserPage(QtGui.QWizardPage, UserFormMixIn):          self.registerField('userName*', self.userNameLineEdit)          self.registerField('userPassword*', self.userPasswordLineEdit) + +        # XXX missing password confirmation +        # XXX validator! +          self.registerField('rememberPassword', rememberPasswordCheckBox)          layout = QtGui.QGridLayout() @@ -898,6 +901,109 @@ class ConnectingPage(QtGui.QWizardPage):              QtGui.QWizard.LogoPixmap,              QtGui.QPixmap(APP_LOGO)) +        self.status = QtGui.QLabel("") +        self.status.setWordWrap(True) +        self.progress = QtGui.QProgressBar() +        self.progress.setMaximum(100) +        self.progress.hide() + +        self.status_line_1 = QtGui.QLabel() +        self.status_line_2 = QtGui.QLabel() +        self.status_line_3 = QtGui.QLabel() +        self.status_line_4 = QtGui.QLabel() + +        layout = QtGui.QGridLayout() +        layout.addWidget(self.status, 0, 1) +        layout.addWidget(self.progress, 5, 1) +        layout.addWidget(self.status_line_1, 8, 1) +        layout.addWidget(self.status_line_2, 9, 1) +        layout.addWidget(self.status_line_3, 10, 1) +        layout.addWidget(self.status_line_4, 11, 1) + +        self.setLayout(layout) + +    def set_status(self, status): +        self.status.setText(status) +        self.status.setWordWrap(True) + +    def get_donemsg(self, msg): +        return "%s ... done" % msg + +    def fetch_and_validate(self): +        # Fake... till you make it... +        import time +        domain = self.field('provider_domain') +        wizard = self.wizard() +        #pconfig = wizard.providerconfig +        eipconfigchecker = wizard.eipconfigchecker() +        pCertChecker = wizard.providercertchecker( +            domain=domain) + +        # XXX get from log_in page if we came that way +        # instead + +        username = self.field('userName') +        password = self.field('userPassword') + +        credentials = username, password + +        self.progress.show() + +        fetching_eip_conf_msg = 'Fetching eip service configuration' +        self.set_status(fetching_eip_conf_msg) +        self.progress.setValue(30) + +        # Fetching eip service +        eipconfigchecker.fetch_eip_service_config( +            domain=domain) + +        self.status_line_1.setText( +            self.get_donemsg(fetching_eip_conf_msg)) + +        getting_client_cert_msg = 'Getting client certificate' +        self.set_status(getting_client_cert_msg) +        self.progress.setValue(66) + +        # Download cert +        pCertChecker.download_new_client_cert( +            credentials=credentials) + +        time.sleep(2) +        self.status_line_2.setText( +            self.get_donemsg(getting_client_cert_msg)) + +        validating_clientcert_msg = 'Validating client certificate' +        self.set_status(validating_clientcert_msg) +        self.progress.setValue(90) +        time.sleep(2) +        self.status_line_3.setText( +            self.get_donemsg(validating_clientcert_msg)) + +        self.progress.setValue(100) +        time.sleep(3) + +        return True + +    # pagewizard methods + +    def initializePage(self): +        # XXX if we're coming from signup page +        # we could say something like +        # 'registration successful!' +        self.status.setText( +            "We have " +            "all we need to connect with the provider.<br><br> " +            "Click <i>next</i> to continue. ") +        self.progress.setValue(0) +        self.progress.hide() +        self.status_line_1.setText('') +        self.status_line_2.setText('') +        self.status_line_3.setText('') + +    def validatePage(self): +        validated = self.fetch_and_validate() +        return validated +  class LastPage(QtGui.QWizardPage):      def __init__(self, parent=None): | 
