diff options
| -rw-r--r-- | changes/feature_twisted_threads | 1 | ||||
| -rw-r--r-- | src/leap/config/providerconfig.py | 9 | ||||
| -rw-r--r-- | src/leap/gui/mainwindow.py | 78 | ||||
| -rw-r--r-- | src/leap/gui/wizard.py | 28 | ||||
| -rw-r--r-- | src/leap/services/abstractbootstrapper.py | 155 | ||||
| -rw-r--r-- | src/leap/services/eip/eipbootstrapper.py | 264 | ||||
| -rw-r--r-- | src/leap/services/eip/providerbootstrapper.py | 384 | ||||
| -rw-r--r-- | src/leap/services/eip/vpnlaunchers.py | 2 | ||||
| -rw-r--r-- | src/leap/services/mail/smtpbootstrapper.py | 142 | ||||
| -rw-r--r-- | src/leap/services/soledad/soledadbootstrapper.py | 233 | ||||
| -rw-r--r-- | src/leap/util/checkerthread.py | 109 | 
11 files changed, 542 insertions, 863 deletions
| diff --git a/changes/feature_twisted_threads b/changes/feature_twisted_threads new file mode 100644 index 00000000..364d1132 --- /dev/null +++ b/changes/feature_twisted_threads @@ -0,0 +1 @@ +  o Use twisted's deferToThread and Deferreds to handle parallel tasks
\ No newline at end of file diff --git a/src/leap/config/providerconfig.py b/src/leap/config/providerconfig.py index 8f75d4fe..68099ad4 100644 --- a/src/leap/config/providerconfig.py +++ b/src/leap/config/providerconfig.py @@ -130,6 +130,15 @@ class ProviderConfig(BaseConfig):          """          return "openvpn" in self.get_services() +    def provides_mx(self): +        """ +        Returns True if this particular provider has the MX service, +        False otherwise. + +        :rtype: bool +        """ +        return "mx" in self.get_services() +  if __name__ == "__main__":      logger = logging.getLogger(name='leap') diff --git a/src/leap/gui/mainwindow.py b/src/leap/gui/mainwindow.py index 12187f51..25478aa1 100644 --- a/src/leap/gui/mainwindow.py +++ b/src/leap/gui/mainwindow.py @@ -27,7 +27,7 @@ from functools import partial  import keyring  from PySide import QtCore, QtGui -from mock import Mock +from twisted.internet import threads  from leap.common.check import leap_assert  from leap.common.events import register @@ -50,7 +50,6 @@ from leap.services.eip.vpnlaunchers import (VPNLauncherException,                                              EIPNoPkexecAvailable,                                              EIPNoPolkitAuthAgentAvailable)  from leap.util import __version__ as VERSION -from leap.util.checkerthread import CheckerThread  from leap.services.mail.smtpconfig import SMTPConfig @@ -78,6 +77,9 @@ class MainWindow(QtGui.QMainWindow):      PORT_KEY = "port"      IP_KEY = "ip_address" +    OPENVPN_SERVICE = "openvpn" +    MX_SERVICE = "mx" +      # Signals      new_updates = QtCore.Signal(object)      raise_window = QtCore.Signal([]) @@ -155,9 +157,6 @@ class MainWindow(QtGui.QMainWindow):          # This is created once we have a valid provider config          self._srp_auth = None -        self._checker_thread = CheckerThread() -        self._checker_thread.start() -          # This thread is always running, although it's quite          # lightweight when it's done setting up provider          # configuration and certificate. @@ -187,6 +186,8 @@ class MainWindow(QtGui.QMainWindow):              self._finish_eip_bootstrap)          self._soledad_bootstrapper = SoledadBootstrapper() +        self._soledad_bootstrapper.download_config.connect( +            self._soledad_intermediate_stage)          self._soledad_bootstrapper.gen_key.connect(              self._soledad_bootstrapped_stage) @@ -262,8 +263,7 @@ class MainWindow(QtGui.QMainWindow):          if self._first_run():              self._wizard_firstrun = True -            self._wizard = Wizard(self._checker_thread, -                                  standalone=standalone, +            self._wizard = Wizard(standalone=standalone,                                    bypass_checks=bypass_checks)              # Give this window time to finish init and then show the wizard              QtCore.QTimer.singleShot(1, self._launch_wizard) @@ -281,8 +281,8 @@ class MainWindow(QtGui.QMainWindow):      def _launch_wizard(self):          if self._wizard is None: -            self._wizard = Wizard(self._checker_thread, -                                  bypass_checks=self._bypass_checks) +            self._wizard = Wizard(bypass_checks=self._bypass_checks) +        self._wizard.accepted.connect(self._finish_init)          self._wizard.exec_()          self._wizard = None @@ -369,6 +369,7 @@ class MainWindow(QtGui.QMainWindow):                                        msg)      def _finish_init(self): +        self.ui.cmbProviders.clear()          self.ui.cmbProviders.addItems(self._configured_providers())          self._show_systray()          self.show() @@ -425,6 +426,9 @@ class MainWindow(QtGui.QMainWindow):          """          Sets up the systray icon          """ +        if self._systray is not None: +            self._systray.setVisible(True) +            return          systrayMenu = QtGui.QMenu(self)          systrayMenu.addAction(self._action_visible)          systrayMenu.addAction(self.ui.action_sign_out) @@ -618,7 +622,6 @@ class MainWindow(QtGui.QMainWindow):          provider = self.ui.cmbProviders.currentText()          self._provider_bootstrapper.run_provider_select_checks( -            self._checker_thread,              provider,              download_if_needed=True) @@ -643,7 +646,6 @@ class MainWindow(QtGui.QMainWindow):                                                              provider,                                                              "provider.json")):                  self._provider_bootstrapper.run_provider_setup_checks( -                    self._checker_thread,                      self._provider_config,                      download_if_needed=True)              else: @@ -728,7 +730,7 @@ class MainWindow(QtGui.QMainWindow):              auth_partial = partial(self._srp_auth.authenticate,                                     username,                                     password) -            self._checker_thread.add_checks([auth_partial]) +            threads.deferToThread(auth_partial)          else:              self._set_status(data[self._provider_bootstrapper.ERROR_KEY])              self._login_set_enabled(True) @@ -760,7 +762,6 @@ class MainWindow(QtGui.QMainWindow):          self._systray.setIcon(self.LOGGED_IN_ICON)          self._soledad_bootstrapper.run_soledad_setup_checks( -            self._checker_thread,              self._provider_config,              self.ui.lnUser.text(),              self.ui.lnPassword.text(), @@ -768,6 +769,22 @@ class MainWindow(QtGui.QMainWindow):          self._download_eip_config() +    def _soledad_intermediate_stage(self, data): +        """ +        SLOT +        TRIGGERS: +          self._soledad_bootstrapper.download_config + +        If there was a problem, displays it, otherwise it does nothing. +        This is used for intermediate bootstrapping stages, in case +        they fail. +        """ +        passed = data[self._soledad_bootstrapper.PASSED_KEY] +        if not passed: +            # TODO: display in the GUI +            logger.error("Soledad failed to start: %s" % +                         (data[self._soledad_bootstrapper.ERROR_KEY],)) +      def _soledad_bootstrapped_stage(self, data):          """          SLOT @@ -787,14 +804,24 @@ class MainWindow(QtGui.QMainWindow):          else:              logger.debug("Done bootstrapping Soledad") -            self._soledad = data[self._soledad_bootstrapper.SOLEDAD_KEY] -            self._keymanager = data[self._soledad_bootstrapper.KEYMANAGER_KEY] +            self._soledad = self._soledad_bootstrapper.soledad +            self._keymanager = self._soledad_bootstrapper.keymanager -            self._smtp_bootstrapper.run_smtp_setup_checks( -                self._checker_thread, -                self._provider_config, -                self._smtp_config, -                True) +            if self._provider_config.provides_mx() and \ +                    self._enabled_services.count(self.MX_SERVICE) > 0: +                self._smtp_bootstrapper.run_smtp_setup_checks( +                    self._provider_config, +                    self._smtp_config, +                    True) +            else: +                if self._enabled_services.count(self.MX_SERVICE) > 0: +                    pass # TODO: show MX status +                    #self._set_eip_status(self.tr("%s does not support MX") % +                    #                     (self._provider_config.get_domain(),), +                    #                     error=True) +                else: +                    pass # TODO: show MX status +                    #self._set_eip_status(self.tr("MX is disabled"))      def _smtp_bootstrapped_stage(self, data):          """ @@ -914,14 +941,13 @@ class MainWindow(QtGui.QMainWindow):          self._set_eip_status(self.tr("Checking configuration, please wait..."))          if self._provider_config.provides_eip() and \ -                self._enabled_services.count("openvpn") > 0: +                self._enabled_services.count(self.OPENVPN_SERVICE) > 0:              self._vpn_systray.setVisible(True)              self._eip_bootstrapper.run_eip_setup_checks( -                self._checker_thread,                  self._provider_config,                  download_if_needed=True)          else: -            if self._enabled_services.count("openvpn") > 0: +            if self._enabled_services.count(self.OPENVPN_SERVICE) > 0:                  self._set_eip_status(self.tr("%s does not support EIP") %                                       (self._provider_config.get_domain(),),                                       error=True) @@ -1035,7 +1061,9 @@ class MainWindow(QtGui.QMainWindow):          """          self._set_eip_status_icon("error")          self._set_eip_status(self.tr("Signing out...")) -        self._checker_thread.add_checks([self._srp_auth.logout]) +        # XXX: If other defers are doing authenticated stuff, this +        # might conflict with those. CHECK! +        threads.deferToThread(self._srp_auth.logout)      def _done_logging_out(self, ok, message):          """ @@ -1121,8 +1149,6 @@ class MainWindow(QtGui.QMainWindow):          logger.debug('About to quit, doing cleanup...')          self._vpn.set_should_quit()          self._vpn.wait() -        self._checker_thread.set_should_quit() -        self._checker_thread.wait()          self._cleanup_pidfiles()      def quit(self): diff --git a/src/leap/gui/wizard.py b/src/leap/gui/wizard.py index 713383a6..552ace50 100644 --- a/src/leap/gui/wizard.py +++ b/src/leap/gui/wizard.py @@ -24,6 +24,7 @@ import json  from PySide import QtCore, QtGui  from functools import partial +from twisted.internet import threads  from ui_wizard import Ui_Wizard  from leap.config.providerconfig import ProviderConfig @@ -53,12 +54,10 @@ class Wizard(QtGui.QWizard):      BARE_USERNAME_REGEX = r"^[A-Za-z\d_]+$" -    def __init__(self, checker, standalone=False, bypass_checks=False): +    def __init__(self, standalone=False, bypass_checks=False):          """          Constructor for the main Wizard. -        :param checker: Checker thread that the wizard should use. -        :type checker: CheckerThread          :param standalone: If True, the application is running as standalone              and the wizard should display some messages according to this.          :type standalone: bool @@ -82,16 +81,19 @@ class Wizard(QtGui.QWizard):          # Correspondence for services and their name to display          EIP_LABEL = self.tr("Encrypted Internet") +        MX_LABEL = self.tr("Encrypted Mail")          if self._is_need_eip_password_warning():              EIP_LABEL += " " + self.tr(                  "(will need admin password to start)")          self.SERVICE_DISPLAY = [ -            EIP_LABEL +            EIP_LABEL, +            MX_LABEL          ]          self.SERVICE_CONFIG = [ -            "openvpn" +            "openvpn", +            "mx"          ]          self._selected_services = set() @@ -147,8 +149,6 @@ class Wizard(QtGui.QWizard):          self._username = None          self._password = None -        self._checker_thread = checker -          self.page(self.REGISTER_USER_PAGE).setButtonText(              QtGui.QWizard.CommitButton, self.tr("&Next >"))          self.page(self.FINISH_PAGE).setButtonText( @@ -231,10 +231,12 @@ class Wizard(QtGui.QWizard):              register = SRPRegister(provider_config=self._provider_config)              register.registration_finished.connect(                  self._registration_finished) -            self._checker_thread.add_checks( -                [partial(register.register_user, -                         username.encode("utf8"), -                         password.encode("utf8"))]) + +            threads.deferToThread( +                partial(register.register_user, +                        username.encode("utf8"), +                        password.encode("utf8"))) +              self._username = username              self._password = password              self._set_register_status(self.tr("Starting registration...")) @@ -318,7 +320,6 @@ class Wizard(QtGui.QWizard):          self.ui.lblNameResolution.setPixmap(self.QUESTION_ICON)          self._provider_bootstrapper.run_provider_select_checks( -            self._checker_thread,              self._domain)      def _complete_task(self, data, label, complete=False, complete_page=-1): @@ -510,8 +511,7 @@ class Wizard(QtGui.QWizard):                                             .get_name(),))              self.ui.lblDownloadCaCert.setPixmap(self.QUESTION_ICON)              self._provider_bootstrapper.\ -                run_provider_setup_checks(self._checker_thread, -                                          self._provider_config) +                run_provider_setup_checks(self._provider_config)          if pageId == self.PRESENT_PROVIDER_PAGE:              self.page(pageId).setSubTitle(self.tr("Description of services " diff --git a/src/leap/services/abstractbootstrapper.py b/src/leap/services/abstractbootstrapper.py new file mode 100644 index 00000000..bce03e6b --- /dev/null +++ b/src/leap/services/abstractbootstrapper.py @@ -0,0 +1,155 @@ +# -*- coding: utf-8 -*- +# abstractbootstrapper.py +# Copyright (C) 2013 LEAP +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program.  If not, see <http://www.gnu.org/licenses/>. + +""" +Abstract bootstrapper implementation +""" +import logging + +import requests + +from PySide import QtCore +from twisted.internet import threads +from leap.common.check import leap_assert, leap_assert_type + +logger = logging.getLogger(__name__) + + +class AbstractBootstrapper(QtCore.QObject): +    """ +    Abstract Bootstrapper that implements the needed deferred callbacks +    """ + +    PASSED_KEY = "passed" +    ERROR_KEY = "error" + +    def __init__(self, bypass_checks=False): +        """ +        Constructor for the abstract bootstrapper + +        :param bypass_checks: Set to true if the app should bypass +                              first round of checks for CA +                              certificates at bootstrap +        :type bypass_checks: bool +        """ +        QtCore.QObject.__init__(self) + +        leap_assert(self._gui_errback.im_func == \ +                        AbstractBootstrapper._gui_errback.im_func, +                    "Cannot redefine _gui_errback") +        leap_assert(self._errback.im_func == \ +                        AbstractBootstrapper._errback.im_func, +                    "Cannot redefine _errback") +        leap_assert(self._gui_notify.im_func == \ +                        AbstractBootstrapper._gui_notify.im_func, +                    "Cannot redefine _gui_notify") + +        # **************************************************** # +        # Dependency injection helpers, override this for more +        # granular testing +        self._fetcher = requests +        # **************************************************** # + +        self._session = self._fetcher.session() +        self._bypass_checks = bypass_checks +        self._signal_to_emit = None +        self._err_msg = None + +    def _gui_errback(self, failure): +        """ +        Errback used to notify the GUI of a problem, it should be used +        as the last errback of the whole chain. + +        Traps all exceptions if a signal is defined, otherwise it just +        lets it continue. + +        NOTE: This method is final, it should not be redefined. + +        :param failure: failure object that Twisted generates +        :type failure: twisted.python.failure.Failure +        """ +        if self._signal_to_emit: +            err_msg = self._err_msg \ +                if self._err_msg is not None \ +                else str(failure.value) +            self._signal_to_emit.emit({ +                    self.PASSED_KEY: False, +                    self.ERROR_KEY: err_msg +                    }) +            failure.trap(Exception) + +    def _errback(self, failure, signal=None): +        """ +        Regular errback used for the middle of the chain. If it's +        executed, the first one will set the signal to emit as +        failure. + +        NOTE: This method is final, it should not be redefined. + +        :param failure: failure object that Twisted generates +        :type failure: twisted.python.failure.Failure +        :param signal: Signal to emit if it fails here first +        :type signal: QtCore.SignalInstance + +        :returns: failure object that Twisted generates +        :rtype: twisted.python.failure.Failure +        """ +        if self._signal_to_emit is None: +            self._signal_to_emit = signal +        return failure + +    def _gui_notify(self, _, signal=None): +        """ +        Callback used to notify the GUI of a success. Will emit signal +        if specified + +        NOTE: This method is final, it should not be redefined. + +        :param _: IGNORED. Returned from the previous callback +        :type _: IGNORED +        :param signal: Signal to emit if it fails here first +        :type signal: QtCore.SignalInstance +        """ +        if signal: +            logger.debug("Emitting %s" % (signal,)) +            signal.emit({self.PASSED_KEY: True, self.ERROR_KEY: ""}) + +    def addCallbackChain(self, callbacks): +        """ +        Creates a callback/errback chain on another thread using +        deferToThread and adds the _gui_errback to the end to notify +        the GUI on an error. + +        :param callbacks: List of tuples of callbacks and the signal +                          associated to that callback +        :type callbacks: list(tuple(func, func)) +        """ +        leap_assert_type(callbacks, list) + +        self._signal_to_emit = None +        self._err_msg = None + +        d = None +        for cb, sig in callbacks: +            if d is None: +                d = threads.deferToThread(cb) +            else: +                d.addCallback(cb) +            d.addErrback(self._errback, signal=sig) +            d.addCallback(self._gui_notify, signal=sig) +        d.addErrback(self._gui_errback) + diff --git a/src/leap/services/eip/eipbootstrapper.py b/src/leap/services/eip/eipbootstrapper.py index a881f235..7216bb80 100644 --- a/src/leap/services/eip/eipbootstrapper.py +++ b/src/leap/services/eip/eipbootstrapper.py @@ -22,9 +22,7 @@ EIP bootstrapping  import logging  import os -import requests - -from PySide import QtGui, QtCore +from PySide import QtCore  from leap.common.check import leap_assert, leap_assert_type  from leap.common.certs import is_valid_pemfile, should_redownload @@ -32,49 +30,34 @@ from leap.common.files import check_and_fix_urw_only, get_mtime, mkdir_p  from leap.config.providerconfig import ProviderConfig  from leap.crypto.srpauth import SRPAuth  from leap.services.eip.eipconfig import EIPConfig -from leap.util.checkerthread import CheckerThread  from leap.util.request_helpers import get_content +from leap.services.abstractbootstrapper import AbstractBootstrapper  logger = logging.getLogger(__name__) -class EIPBootstrapper(QtCore.QObject): +class EIPBootstrapper(AbstractBootstrapper):      """      Sets up EIP for a provider a series of checks and emits signals      after they are passed.      If a check fails, the subsequent checks are not executed      """ -    PASSED_KEY = "passed" -    ERROR_KEY = "error" - -    IDLE_SLEEP_INTERVAL = 100 -      # All dicts returned are of the form      # {"passed": bool, "error": str}      download_config = QtCore.Signal(dict)      download_client_certificate = QtCore.Signal(dict)      def __init__(self): -        QtCore.QObject.__init__(self) +        AbstractBootstrapper.__init__(self) -        # **************************************************** # -        # Dependency injection helpers, override this for more -        # granular testing -        self._fetcher = requests -        # **************************************************** # - -        self._session = self._fetcher.session()          self._provider_config = None          self._eip_config = None          self._download_if_needed = False -    def _download_config(self): +    def _download_config(self, *args):          """          Downloads the EIP config for the given provider - -        :return: True if the checks passed, False otherwise -        :rtype: bool          """          leap_assert(self._provider_config, @@ -83,65 +66,47 @@ class EIPBootstrapper(QtCore.QObject):          logger.debug("Downloading EIP config for %s" %                       (self._provider_config.get_domain(),)) -        download_config_data = { -            self.PASSED_KEY: False, -            self.ERROR_KEY: "" -        } -          self._eip_config = EIPConfig() -        try: -            headers = {} -            mtime = get_mtime(os.path.join(self._eip_config -                                           .get_path_prefix(), -                                           "leap", -                                           "providers", -                                           self._provider_config.get_domain(), -                                           "eip-service.json")) - -            if self._download_if_needed and mtime: -                headers['if-modified-since'] = mtime - -            # there is some confusion with this uri, -            # it's in 1/config/eip, config/eip and config/1/eip... -            config_uri = "%s/%s/config/eip-service.json" % ( -                self._provider_config.get_api_uri(), -                self._provider_config.get_api_version()) -            logger.debug('Downloading eip config from: %s' % config_uri) - -            res = self._session.get(config_uri, -                                    verify=self._provider_config -                                    .get_ca_cert_path(), -                                    headers=headers) -            res.raise_for_status() - -            # Not modified -            if res.status_code == 304: -                logger.debug("EIP definition has not been modified") -            else: -                eip_definition, mtime = get_content(res) - -                self._eip_config.load(data=eip_definition, mtime=mtime) -                self._eip_config.save(["leap", +        headers = {} +        mtime = get_mtime(os.path.join(self._eip_config +                                       .get_path_prefix(), +                                       "leap",                                         "providers",                                         self._provider_config.get_domain(), -                                       "eip-service.json"]) - -            download_config_data[self.PASSED_KEY] = True -        except Exception as e: -            download_config_data[self.ERROR_KEY] = "%s" % (e,) - -        logger.debug("Emitting download_config %s" % (download_config_data,)) -        self.download_config.emit(download_config_data) - -        return download_config_data[self.PASSED_KEY] - -    def _download_client_certificates(self): +                                       "eip-service.json")) + +        if self._download_if_needed and mtime: +            headers['if-modified-since'] = mtime + +        # there is some confusion with this uri, +        # it's in 1/config/eip, config/eip and config/1/eip... +        config_uri = "%s/%s/config/eip-service.json" % ( +            self._provider_config.get_api_uri(), +            self._provider_config.get_api_version()) +        logger.debug('Downloading eip config from: %s' % config_uri) + +        res = self._session.get(config_uri, +                                verify=self._provider_config +                                .get_ca_cert_path(), +                                headers=headers) +        res.raise_for_status() + +        # Not modified +        if res.status_code == 304: +            logger.debug("EIP definition has not been modified") +        else: +            eip_definition, mtime = get_content(res) + +            self._eip_config.load(data=eip_definition, mtime=mtime) +            self._eip_config.save(["leap", +                                   "providers", +                                   self._provider_config.get_domain(), +                                   "eip-service.json"]) + +    def _download_client_certificates(self, *args):          """          Downloads the EIP client certificate for the given provider - -        :return: True if the checks passed, False otherwise -        :rtype: bool          """          leap_assert(self._provider_config, "We need a provider configuration!")          leap_assert(self._eip_config, "We need an eip configuration!") @@ -149,11 +114,6 @@ class EIPBootstrapper(QtCore.QObject):          logger.debug("Downloading EIP client certificate for %s" %                       (self._provider_config.get_domain(),)) -        download_cert = { -            self.PASSED_KEY: False, -            self.ERROR_KEY: "" -        } -          client_cert_path = self._eip_config.\              get_client_cert_path(self._provider_config,                                   about_to_download=True) @@ -164,56 +124,39 @@ class EIPBootstrapper(QtCore.QObject):          if self._download_if_needed and \                  os.path.exists(client_cert_path): -            try: -                check_and_fix_urw_only(client_cert_path) -                download_cert[self.PASSED_KEY] = True -            except Exception as e: -                download_cert[self.PASSED_KEY] = False -                download_cert[self.ERROR_KEY] = "%s" % (e,) -            self.download_client_certificate.emit(download_cert) -            return download_cert[self.PASSED_KEY] - -        try: -            srp_auth = SRPAuth(self._provider_config) -            session_id = srp_auth.get_session_id() -            cookies = None -            if session_id: -                cookies = {"_session_id": session_id} -            cert_uri = "%s/%s/cert" % ( -                self._provider_config.get_api_uri(), -                self._provider_config.get_api_version()) -            logger.debug('getting cert from uri: %s' % cert_uri) -            res = self._session.get(cert_uri, -                                    verify=self._provider_config -                                    .get_ca_cert_path(), -                                    cookies=cookies) -            res.raise_for_status() -            client_cert = res.content - -            # TODO: check certificate validity - -            if not is_valid_pemfile(client_cert): -                raise Exception(self.tr("The downloaded certificate is not a " -                                        "valid PEM file")) - -            mkdir_p(os.path.dirname(client_cert_path)) - -            with open(client_cert_path, "w") as f: -                f.write(client_cert) -              check_and_fix_urw_only(client_cert_path) - -            download_cert[self.PASSED_KEY] = True -        except Exception as e: -            download_cert[self.ERROR_KEY] = "%s" % (e,) - -        logger.debug("Emitting download_client_certificates %s" % -                     (download_cert,)) -        self.download_client_certificate.emit(download_cert) - -        return download_cert[self.PASSED_KEY] - -    def run_eip_setup_checks(self, checker, +            return + +        srp_auth = SRPAuth(self._provider_config) +        session_id = srp_auth.get_session_id() +        cookies = None +        if session_id: +            cookies = {"_session_id": session_id} +        cert_uri = "%s/%s/cert" % ( +            self._provider_config.get_api_uri(), +            self._provider_config.get_api_version()) +        logger.debug('getting cert from uri: %s' % cert_uri) +        res = self._session.get(cert_uri, +                                verify=self._provider_config +                                .get_ca_cert_path(), +                                cookies=cookies) +        res.raise_for_status() +        client_cert = res.content + +        # TODO: check certificate validity + +        if not is_valid_pemfile(client_cert): +            raise Exception(self.tr("The downloaded certificate is not a " +                                    "valid PEM file")) + +        mkdir_p(os.path.dirname(client_cert_path)) + +        with open(client_cert_path, "w") as f: +            f.write(client_cert) + +        check_and_fix_urw_only(client_cert_path) + +    def run_eip_setup_checks(self,                               provider_config,                               download_if_needed=False):          """ @@ -228,60 +171,9 @@ class EIPBootstrapper(QtCore.QObject):          self._provider_config = provider_config          self._download_if_needed = download_if_needed -        checker.add_checks([ -            self._download_config, -            self._download_client_certificates -        ]) - - -if __name__ == "__main__": -    import sys -    from functools import partial -    app = QtGui.QApplication(sys.argv) - -    import signal - -    def sigint_handler(*args, **kwargs): -        logger.debug('SIGINT catched. shutting down...') -        checker = args[0] -        checker.set_should_quit() -        QtGui.QApplication.quit() - -    def signal_tester(d): -        print d - -    logger = logging.getLogger(name='leap') -    logger.setLevel(logging.DEBUG) -    console = logging.StreamHandler() -    console.setLevel(logging.DEBUG) -    formatter = logging.Formatter( -        '%(asctime)s ' -        '- %(name)s - %(levelname)s - %(message)s') -    console.setFormatter(formatter) -    logger.addHandler(console) - -    eip_checks = EIPBootstrapper() -    checker = CheckerThread() - -    sigint = partial(sigint_handler, checker) -    signal.signal(signal.SIGINT, sigint) - -    timer = QtCore.QTimer() -    timer.start(500) -    timer.timeout.connect(lambda: None) -    app.connect(app, QtCore.SIGNAL("aboutToQuit()"), -                checker.set_should_quit) -    w = QtGui.QWidget() -    w.resize(100, 100) -    w.show() - -    checker.start() - -    provider_config = ProviderConfig() -    if provider_config.load(os.path.join("leap", -                                         "providers", -                                         "bitmask.net", -                                         "provider.json")): -        eip_checks.run_eip_setup_checks(checker, provider_config) - -    sys.exit(app.exec_()) +        cb_chain = [ +            (self._download_config, self.download_config), +            (self._download_client_certificates, self.download_client_certificate) +        ] + +        self.addCallbackChain(cb_chain) diff --git a/src/leap/services/eip/providerbootstrapper.py b/src/leap/services/eip/providerbootstrapper.py index 289d212b..1339e086 100644 --- a/src/leap/services/eip/providerbootstrapper.py +++ b/src/leap/services/eip/providerbootstrapper.py @@ -24,30 +24,25 @@ import os  import requests -from PySide import QtGui, QtCore +from PySide import QtCore  from leap.common.certs import get_digest  from leap.common.files import check_and_fix_urw_only, get_mtime, mkdir_p  from leap.common.check import leap_assert, leap_assert_type  from leap.config.providerconfig import ProviderConfig -from leap.util.checkerthread import CheckerThread  from leap.util.request_helpers import get_content +from leap.services.abstractbootstrapper import AbstractBootstrapper  logger = logging.getLogger(__name__) -class ProviderBootstrapper(QtCore.QObject): +class ProviderBootstrapper(AbstractBootstrapper):      """      Given a provider URL performs a series of checks and emits signals      after they are passed.      If a check fails, the subsequent checks are not executed      """ -    PASSED_KEY = "passed" -    ERROR_KEY = "error" - -    IDLE_SLEEP_INTERVAL = 100 -      # All dicts returned are of the form      # {"passed": bool, "error": str}      name_resolution = QtCore.Signal(dict) @@ -66,68 +61,34 @@ class ProviderBootstrapper(QtCore.QObject):          first round of checks for CA certificates at bootstrap          :type bypass_checks: bool          """ -        QtCore.QObject.__init__(self) +        AbstractBootstrapper.__init__(self, bypass_checks) -        # **************************************************** # -        # Dependency injection helpers, override this for more -        # granular testing -        self._fetcher = requests -        # **************************************************** # - -        self._session = self._fetcher.session()          self._domain = None          self._provider_config = None          self._download_if_needed = False -        self._bypass_checks = bypass_checks      def _check_name_resolution(self):          """          Checks that the name resolution for the provider name works - -        :return: True if the checks passed, False otherwise -        :rtype: bool          """ -          leap_assert(self._domain, "Cannot check DNS without a domain")          logger.debug("Checking name resolution for %s" % (self._domain)) -        name_resolution_data = { -            self.PASSED_KEY: False, -            self.ERROR_KEY: "" -        } -          # We don't skip this check, since it's basic for the whole          # system to work -        try: -            socket.gethostbyname(self._domain) -            name_resolution_data[self.PASSED_KEY] = True -        except socket.gaierror as e: -            name_resolution_data[self.ERROR_KEY] = "%s" % (e,) +        socket.gethostbyname(self._domain) -        logger.debug("Emitting name_resolution %s" % (name_resolution_data,)) -        self.name_resolution.emit(name_resolution_data) - -        return name_resolution_data[self.PASSED_KEY] - -    def _check_https(self): +    def _check_https(self, *args):          """          Checks that https is working and that the provided certificate          checks out - -        :return: True if the checks passed, False otherwise -        :rtype: bool          """          leap_assert(self._domain, "Cannot check HTTPS without a domain")          logger.debug("Checking https for %s" % (self._domain)) -        https_data = { -            self.PASSED_KEY: False, -            self.ERROR_KEY: "" -        } -          # We don't skip this check, since it's basic for the whole          # system to work @@ -135,105 +96,75 @@ class ProviderBootstrapper(QtCore.QObject):              res = self._session.get("https://%s" % (self._domain,),                                      verify=not self._bypass_checks)              res.raise_for_status() -            https_data[self.PASSED_KEY] = True -        except requests.exceptions.SSLError as e: -            logger.error("%s" % (e,)) -            https_data[self.ERROR_KEY] = self.tr("Provider certificate could " -                                                 "not verify") -        except Exception as e: -            logger.error("%s" % (e,)) -            https_data[self.ERROR_KEY] = self.tr("Provider does not support " -                                                 "HTTPS") - -        logger.debug("Emitting https_connection %s" % (https_data,)) -        self.https_connection.emit(https_data) - -        return https_data[self.PASSED_KEY] - -    def _download_provider_info(self): +        except requests.exceptions.SSLError: +            self._err_msg = self.tr("Provider certificate could " +                                    "not be verified") +            raise +        except Exception: +            self._err_msg = self.tr("Provider does not support HTTPS") +            raise + +    def _download_provider_info(self, *args):          """          Downloads the provider.json defition - -        :return: True if the checks passed, False otherwise -        :rtype: bool          """          leap_assert(self._domain,                      "Cannot download provider info without a domain")          logger.debug("Downloading provider info for %s" % (self._domain)) -        download_data = { -            self.PASSED_KEY: False, -            self.ERROR_KEY: "" -        } - -        try: -            headers = {} -            mtime = get_mtime(os.path.join(ProviderConfig() -                                           .get_path_prefix(), -                                           "leap", -                                           "providers", -                                           self._domain, -                                           "provider.json")) -            if self._download_if_needed and mtime: -                headers['if-modified-since'] = mtime - -            res = self._session.get("https://%s/%s" % (self._domain, -                                                       "provider.json"), -                                    headers=headers, -                                    verify=not self._bypass_checks) -            res.raise_for_status() - -            # Not modified -            if res.status_code == 304: -                logger.debug("Provider definition has not been modified") -            else: -                provider_definition, mtime = get_content(res) - -                provider_config = ProviderConfig() -                provider_config.load(data=provider_definition, mtime=mtime) -                provider_config.save(["leap", -                                      "providers", -                                      self._domain, -                                      "provider.json"]) - -            download_data[self.PASSED_KEY] = True -        except Exception as e: -            download_data[self.ERROR_KEY] = "%s" % (e,) - -        logger.debug("Emitting download_provider_info %s" % (download_data,)) -        self.download_provider_info.emit(download_data) - -        return download_data[self.PASSED_KEY] - -    def run_provider_select_checks(self, checker, -                                   domain, download_if_needed=False): +        headers = {} +        mtime = get_mtime(os.path.join(ProviderConfig() +                                       .get_path_prefix(), +                                       "leap", +                                       "providers", +                                       self._domain, +                                       "provider.json")) +        if self._download_if_needed and mtime: +            headers['if-modified-since'] = mtime + +        res = self._session.get("https://%s/%s" % (self._domain, +                                                   "provider.json"), +                                headers=headers, +                                verify=not self._bypass_checks) +        res.raise_for_status() + +        # Not modified +        if res.status_code == 304: +            logger.debug("Provider definition has not been modified") +        else: +            provider_definition, mtime = get_content(res) + +            provider_config = ProviderConfig() +            provider_config.load(data=provider_definition, mtime=mtime) +            provider_config.save(["leap", +                                  "providers", +                                  self._domain, +                                  "provider.json"]) + +    def run_provider_select_checks(self, domain, download_if_needed=False):          """          Populates the check queue. -        :param checker: checker thread to be used to run this check -        :type checker: CheckerThread -          :param domain: domain to check          :type domain: str          :param download_if_needed: if True, makes the checks do not                                     overwrite already downloaded data          :type download_if_needed: bool - -        :return: True if the checks passed, False otherwise -        :rtype: bool          """          leap_assert(domain and len(domain) > 0, "We need a domain!")          self._domain = domain          self._download_if_needed = download_if_needed -        checker.add_checks([ -            self._check_name_resolution, -            self._check_https, -            self._download_provider_info -        ]) +        cb_chain = [ +            (self._check_name_resolution, self.name_resolution), +            (self._check_https, self.https_connection), +            (self._download_provider_info, self.download_provider_info) +        ] + +        self.addCallbackChain(cb_chain)      def _should_proceed_cert(self):          """ @@ -250,12 +181,9 @@ class ProviderBootstrapper(QtCore.QObject):          return not os.path.exists(self._provider_config                                    .get_ca_cert_path(about_to_download=True)) -    def _download_ca_cert(self): +    def _download_ca_cert(self, *args):          """          Downloads the CA cert that is going to be used for the api URL - -        :return: True if the checks passed, False otherwise -        :rtype: bool          """          leap_assert(self._provider_config, "Cannot download the ca cert " @@ -264,56 +192,28 @@ class ProviderBootstrapper(QtCore.QObject):          logger.debug("Downloading ca cert for %s at %s" %                       (self._domain, self._provider_config.get_ca_cert_uri())) -        download_ca_cert_data = { -            self.PASSED_KEY: False, -            self.ERROR_KEY: "" -        } -          if not self._should_proceed_cert(): -            try: -                check_and_fix_urw_only( -                    self._provider_config -                    .get_ca_cert_path(about_to_download=True)) -                download_ca_cert_data[self.PASSED_KEY] = True -            except Exception as e: -                download_ca_cert_data[self.PASSED_KEY] = False -                download_ca_cert_data[self.ERROR_KEY] = "%s" % (e,) -            self.download_ca_cert.emit(download_ca_cert_data) -            return download_ca_cert_data[self.PASSED_KEY] - -        try: -            res = self._session.get(self._provider_config.get_ca_cert_uri(), -                                    verify=not self._bypass_checks) -            res.raise_for_status() - -            cert_path = self._provider_config.get_ca_cert_path( -                about_to_download=True) - -            cert_dir = os.path.dirname(cert_path) +            check_and_fix_urw_only( +                self._provider_config +                .get_ca_cert_path(about_to_download=True)) -            mkdir_p(cert_dir) +        res = self._session.get(self._provider_config.get_ca_cert_uri(), +                                verify=not self._bypass_checks) +        res.raise_for_status() -            with open(cert_path, "w") as f: -                f.write(res.content) +        cert_path = self._provider_config.get_ca_cert_path( +            about_to_download=True) +        cert_dir = os.path.dirname(cert_path) +        mkdir_p(cert_dir) +        with open(cert_path, "w") as f: +            f.write(res.content) -            check_and_fix_urw_only(cert_path) +        check_and_fix_urw_only(cert_path) -            download_ca_cert_data[self.PASSED_KEY] = True -        except Exception as e: -            download_ca_cert_data[self.ERROR_KEY] = "%s" % (e,) - -        logger.debug("Emitting download_ca_cert %s" % (download_ca_cert_data,)) -        self.download_ca_cert.emit(download_ca_cert_data) - -        return download_ca_cert_data[self.PASSED_KEY] - -    def _check_ca_fingerprint(self): +    def _check_ca_fingerprint(self, *args):          """          Checks the CA cert fingerprint against the one provided in the          json definition - -        :return: True if the checks passed, False otherwise -        :rtype: bool          """          leap_assert(self._provider_config, "Cannot check the ca cert "                      "without a provider config!") @@ -322,50 +222,27 @@ class ProviderBootstrapper(QtCore.QObject):                       (self._domain,                        self._provider_config.get_ca_cert_path())) -        check_ca_fingerprint_data = { -            self.PASSED_KEY: False, -            self.ERROR_KEY: "" -        } -          if not self._should_proceed_cert(): -            check_ca_fingerprint_data[self.PASSED_KEY] = True -            self.check_ca_fingerprint.emit(check_ca_fingerprint_data) -            return True +            return -        try: -            parts = self._provider_config.get_ca_cert_fingerprint().split(":") -            leap_assert(len(parts) == 2, "Wrong fingerprint format") - -            method = parts[0].strip() -            fingerprint = parts[1].strip() -            cert_data = None -            with open(self._provider_config.get_ca_cert_path()) as f: -                cert_data = f.read() - -            leap_assert(len(cert_data) > 0, "Could not read certificate data") - -            digest = get_digest(cert_data, method) +        parts = self._provider_config.get_ca_cert_fingerprint().split(":") +        leap_assert(len(parts) == 2, "Wrong fingerprint format") -            leap_assert(digest == fingerprint, -                        "Downloaded certificate has a different fingerprint!") +        method = parts[0].strip() +        fingerprint = parts[1].strip() +        cert_data = None +        with open(self._provider_config.get_ca_cert_path()) as f: +            cert_data = f.read() -            check_ca_fingerprint_data[self.PASSED_KEY] = True -        except Exception as e: -            check_ca_fingerprint_data[self.ERROR_KEY] = "%s" % (e,) +        leap_assert(len(cert_data) > 0, "Could not read certificate data") +        digest = get_digest(cert_data, method) +        leap_assert(digest == fingerprint, +                    "Downloaded certificate has a different fingerprint!") -        logger.debug("Emitting check_ca_fingerprint %s" % -                     (check_ca_fingerprint_data,)) -        self.check_ca_fingerprint.emit(check_ca_fingerprint_data) - -        return check_ca_fingerprint_data[self.PASSED_KEY] - -    def _check_api_certificate(self): +    def _check_api_certificate(self, *args):          """          Tries to make an API call with the downloaded cert and checks          if it validates against it - -        :return: True if the checks passed, False otherwise -        :rtype: bool          """          leap_assert(self._provider_config, "Cannot check the ca cert "                      "without a provider config!") @@ -374,34 +251,17 @@ class ProviderBootstrapper(QtCore.QObject):                       (self._provider_config.get_api_uri(),                        self._provider_config.get_ca_cert_path())) -        check_api_certificate_data = { -            self.PASSED_KEY: False, -            self.ERROR_KEY: "" -        } -          if not self._should_proceed_cert(): -            check_api_certificate_data[self.PASSED_KEY] = True -            self.check_api_certificate.emit(check_api_certificate_data) -            return True +            return -        try: -            test_uri = "%s/%s/cert" % (self._provider_config.get_api_uri(), -                                       self._provider_config.get_api_version()) -            res = self._session.get(test_uri, -                                    verify=self._provider_config -                                    .get_ca_cert_path()) -            res.raise_for_status() -            check_api_certificate_data[self.PASSED_KEY] = True -        except Exception as e: -            check_api_certificate_data[self.ERROR_KEY] = "%s" % (e,) +        test_uri = "%s/%s/cert" % (self._provider_config.get_api_uri(), +                                   self._provider_config.get_api_version()) +        res = self._session.get(test_uri, +                                verify=self._provider_config +                                .get_ca_cert_path()) +        res.raise_for_status() -        logger.debug("Emitting check_api_certificate %s" % -                     (check_api_certificate_data,)) -        self.check_api_certificate.emit(check_api_certificate_data) - -        return check_api_certificate_data[self.PASSED_KEY] - -    def run_provider_setup_checks(self, checker, +    def run_provider_setup_checks(self,                                    provider_config,                                    download_if_needed=False):          """ @@ -420,64 +280,10 @@ class ProviderBootstrapper(QtCore.QObject):          self._provider_config = provider_config          self._download_if_needed = download_if_needed -        checker.add_checks([ -            self._download_ca_cert, -            self._check_ca_fingerprint, -            self._check_api_certificate -        ]) - -if __name__ == "__main__": -    import sys -    from functools import partial -    app = QtGui.QApplication(sys.argv) - -    import signal - -    def sigint_handler(*args, **kwargs): -        logger.debug('SIGINT catched. shutting down...') -        bootstrapper_checks = args[0] -        bootstrapper_checks.set_should_quit() -        QtGui.QApplication.quit() - -    def signal_tester(d): -        print d - -    logger = logging.getLogger(name='leap') -    logger.setLevel(logging.DEBUG) -    console = logging.StreamHandler() -    console.setLevel(logging.DEBUG) -    formatter = logging.Formatter( -        '%(asctime)s ' -        '- %(name)s - %(levelname)s - %(message)s') -    console.setFormatter(formatter) -    logger.addHandler(console) - -    bootstrapper_checks = ProviderBootstrapper() - -    checker = CheckerThread() -    checker.start() - -    sigint = partial(sigint_handler, checker) -    signal.signal(signal.SIGINT, sigint) - -    timer = QtCore.QTimer() -    timer.start(500) -    timer.timeout.connect(lambda: None) -    app.connect(app, QtCore.SIGNAL("aboutToQuit()"), -                checker.set_should_quit) -    w = QtGui.QWidget() -    w.resize(100, 100) -    w.show() - -    bootstrapper_checks.run_provider_select_checks(checker, -                                                   "bitmask.net") - -    provider_config = ProviderConfig() -    if provider_config.load(os.path.join("leap", -                                         "providers", -                                         "bitmask.net", -                                         "provider.json")): -        bootstrapper_checks.run_provider_setup_checks(checker, -                                                      provider_config) - -    sys.exit(app.exec_()) +        cb_chain = [ +            (self._download_ca_cert, self.download_ca_cert), +            (self._check_ca_fingerprint, self.check_ca_fingerprint), +            (self._check_api_certificate, self.check_api_certificate) +        ] + +        self.addCallbackChain(cb_chain) diff --git a/src/leap/services/eip/vpnlaunchers.py b/src/leap/services/eip/vpnlaunchers.py index addad959..0691e121 100644 --- a/src/leap/services/eip/vpnlaunchers.py +++ b/src/leap/services/eip/vpnlaunchers.py @@ -132,7 +132,7 @@ def _is_auth_agent_running():      """      polkit_gnome = 'ps aux | grep polkit-[g]nome-authentication-agent-1'      polkit_kde = 'ps aux | grep polkit-[k]de-authentication-agent-1' -    return (len(commands.getoutput(polkit_gnome) > 0) or +    return (len(commands.getoutput(polkit_gnome)) > 0 or              len(commands.getoutput(polkit_kde)) > 0) diff --git a/src/leap/services/mail/smtpbootstrapper.py b/src/leap/services/mail/smtpbootstrapper.py index 6e0a0a47..64bf3153 100644 --- a/src/leap/services/mail/smtpbootstrapper.py +++ b/src/leap/services/mail/smtpbootstrapper.py @@ -22,8 +22,6 @@ SMTP bootstrapping  import logging  import os -import requests -  from PySide import QtCore  from leap.common.check import leap_assert, leap_assert_type @@ -31,44 +29,30 @@ from leap.common.files import get_mtime  from leap.config.providerconfig import ProviderConfig  from leap.crypto.srpauth import SRPAuth  from leap.util.request_helpers import get_content +from leap.services.abstractbootstrapper import AbstractBootstrapper  logger = logging.getLogger(__name__) -class SMTPBootstrapper(QtCore.QObject): +class SMTPBootstrapper(AbstractBootstrapper):      """      SMTP init procedure      """ -    PASSED_KEY = "passed" -    ERROR_KEY = "error" - -    IDLE_SLEEP_INTERVAL = 100 -      # All dicts returned are of the form      # {"passed": bool, "error": str}      download_config = QtCore.Signal(dict)      def __init__(self): -        QtCore.QObject.__init__(self) +        AbstractBootstrapper.__init__(self) -        # **************************************************** # -        # Dependency injection helpers, override this for more -        # granular testing -        self._fetcher = requests -        # **************************************************** # - -        self._session = self._fetcher.session()          self._provider_config = None          self._smtp_config = None          self._download_if_needed = False -    def _download_config(self): +    def _download_config(self, *args):          """          Downloads the SMTP config for the given provider - -        :return: True if everything went as expected, False otherwise -        :rtype: bool          """          leap_assert(self._provider_config, @@ -77,79 +61,59 @@ class SMTPBootstrapper(QtCore.QObject):          logger.debug("Downloading SMTP config for %s" %                       (self._provider_config.get_domain(),)) -        download_config_data = { -            self.PASSED_KEY: False, -            self.ERROR_KEY: "" -        } - -        try: -            headers = {} -            mtime = get_mtime(os.path.join(self._smtp_config -                                           .get_path_prefix(), -                                           "leap", -                                           "providers", -                                           self._provider_config.get_domain(), -                                           "smtp-service.json")) - -            if self._download_if_needed and mtime: -                headers['if-modified-since'] = mtime - -            # there is some confusion with this uri, -            config_uri = "%s/%s/config/smtp-service.json" % ( -                self._provider_config.get_api_uri(), -                self._provider_config.get_api_version()) -            logger.debug('Downloading SMTP config from: %s' % config_uri) - -            srp_auth = SRPAuth(self._provider_config) -            session_id = srp_auth.get_session_id() -            cookies = None -            if session_id: -                cookies = {"_session_id": session_id} - -            res = self._session.get(config_uri, -                                    verify=self._provider_config -                                    .get_ca_cert_path(), -                                    headers=headers, -                                    cookies=cookies) -            res.raise_for_status() - -            # Not modified -            if res.status_code == 304: -                logger.debug("SMTP definition has not been modified") -                self._smtp_config.load(os.path.join( -                    "leap", "providers", -                    self._provider_config.get_domain(), -                    "smtp-service.json")) -            else: -                smtp_definition, mtime = get_content(res) - -                self._smtp_config.load(data=smtp_definition, mtime=mtime) -                self._smtp_config.save(["leap", -                                        "providers", -                                        self._provider_config.get_domain(), -                                        "smtp-service.json"]) - -            download_config_data[self.PASSED_KEY] = True -        except Exception as e: -            download_config_data[self.PASSED_KEY] = False -            download_config_data[self.ERROR_KEY] = "%s" % (e,) - -        logger.debug("Emitting download_config %s" % (download_config_data,)) -        self.download_config.emit(download_config_data) - -        return download_config_data[self.PASSED_KEY] +        headers = {} +        mtime = get_mtime(os.path.join(self._smtp_config +                                       .get_path_prefix(), +                                       "leap", +                                       "providers", +                                       self._provider_config.get_domain(), +                                       "smtp-service.json")) + +        if self._download_if_needed and mtime: +            headers['if-modified-since'] = mtime + +        # there is some confusion with this uri, +        config_uri = "%s/%s/config/smtp-service.json" % ( +            self._provider_config.get_api_uri(), +            self._provider_config.get_api_version()) +        logger.debug('Downloading SMTP config from: %s' % config_uri) + +        srp_auth = SRPAuth(self._provider_config) +        session_id = srp_auth.get_session_id() +        cookies = None +        if session_id: +            cookies = {"_session_id": session_id} + +        res = self._session.get(config_uri, +                                verify=self._provider_config +                                .get_ca_cert_path(), +                                headers=headers, +                                cookies=cookies) +        res.raise_for_status() + +        # Not modified +        if res.status_code == 304: +            logger.debug("SMTP definition has not been modified") +            self._smtp_config.load(os.path.join("leap", +                                                "providers", +                                                self._provider_config.get_domain(), +                                                "smtp-service.json")) +        else: +            smtp_definition, mtime = get_content(res) + +            self._smtp_config.load(data=smtp_definition, mtime=mtime) +            self._smtp_config.save(["leap", +                                    "providers", +                                    self._provider_config.get_domain(), +                                    "smtp-service.json"])      def run_smtp_setup_checks(self, -                              checker,                                provider_config,                                smtp_config,                                download_if_needed=False):          """          Starts the checks needed for a new smtp setup -        :param checker: Object that executes actions in a different -                        thread -        :type checker: leap.util.checkerthread.CheckerThread          :param provider_config: Provider configuration          :type provider_config: ProviderConfig          :param smtp_config: SMTP configuration to populate @@ -164,6 +128,8 @@ class SMTPBootstrapper(QtCore.QObject):          self._smtp_config = smtp_config          self._download_if_needed = download_if_needed -        checker.add_checks([ -            self._download_config -        ]) +        cb_chain = [ +            (self._download_config, self.download_config), +        ] + +        self.addCallbackChain(cb_chain) diff --git a/src/leap/services/soledad/soledadbootstrapper.py b/src/leap/services/soledad/soledadbootstrapper.py index eea9b0d5..2635a7e6 100644 --- a/src/leap/services/soledad/soledadbootstrapper.py +++ b/src/leap/services/soledad/soledadbootstrapper.py @@ -22,10 +22,7 @@ Soledad bootstrapping  import logging  import os -import requests -  from PySide import QtCore -from mock import Mock  from leap.common.check import leap_assert, leap_assert_type  from leap.common.files import get_mtime @@ -36,39 +33,29 @@ from leap.crypto.srpauth import SRPAuth  from leap.services.soledad.soledadconfig import SoledadConfig  from leap.util.request_helpers import get_content  from leap.soledad import Soledad +from leap.services.abstractbootstrapper import AbstractBootstrapper  logger = logging.getLogger(__name__) -class SoledadBootstrapper(QtCore.QObject): +class SoledadBootstrapper(AbstractBootstrapper):      """      Soledad init procedure      """ -    PASSED_KEY = "passed" -    ERROR_KEY = "error"      SOLEDAD_KEY = "soledad"      KEYMANAGER_KEY = "keymanager"      PUBKEY_KEY = "user[public_key]" -    IDLE_SLEEP_INTERVAL = 100 -      # All dicts returned are of the form      # {"passed": bool, "error": str}      download_config = QtCore.Signal(dict)      gen_key = QtCore.Signal(dict)      def __init__(self): -        QtCore.QObject.__init__(self) - -        # **************************************************** # -        # Dependency injection helpers, override this for more -        # granular testing -        self._fetcher = requests -        # **************************************************** # +        AbstractBootstrapper.__init__(self) -        self._session = self._fetcher.session()          self._provider_config = None          self._soledad_config = None          self._keymanager = None @@ -76,6 +63,14 @@ class SoledadBootstrapper(QtCore.QObject):          self._user = ""          self._password = "" +    @property +    def keymanager(self): +        return self._keymanager + +    @property +    def soledad(self): +        return self._soledad +      def _load_and_sync_soledad(self, srp_auth):          """          Once everthing is in the right place, we instantiate and sync @@ -92,7 +87,8 @@ class SoledadBootstrapper(QtCore.QObject):          local_db_path = "%s/%s.db" % (prefix, uuid)          # TODO: use the proper URL -        server_url = 'https://mole.dev.bitmask.net:2424/user-%s' % (uuid,) +        #server_url = 'https://mole.dev.bitmask.net:2424/user-%s' % (uuid,) +        server_url = 'https://gadwall.dev.bitmask.net:1111/user-%s' % (uuid,)          # server_url = self._soledad_config.get_hosts(...)          cert_file = self._provider_config.get_ca_cert_path() @@ -109,9 +105,6 @@ class SoledadBootstrapper(QtCore.QObject):      def _download_config(self):          """          Downloads the Soledad config for the given provider - -        :return: True if everything went as expected, False otherwise -        :rtype: bool          """          leap_assert(self._provider_config, @@ -120,150 +113,84 @@ class SoledadBootstrapper(QtCore.QObject):          logger.debug("Downloading Soledad config for %s" %                       (self._provider_config.get_domain(),)) -        download_config_data = { -            self.PASSED_KEY: False, -            self.ERROR_KEY: "" -        } -          self._soledad_config = SoledadConfig() -        try: -            headers = {} -            mtime = get_mtime(os.path.join(self._soledad_config -                                           .get_path_prefix(), -                                           "leap", -                                           "providers", -                                           self._provider_config.get_domain(), -                                           "soledad-service.json")) - -            if self._download_if_needed and mtime: -                headers['if-modified-since'] = mtime - -            # there is some confusion with this uri, -            config_uri = "%s/%s/config/soledad-service.json" % ( -                self._provider_config.get_api_uri(), -                self._provider_config.get_api_version()) -            logger.debug('Downloading soledad config from: %s' % config_uri) - -            srp_auth = SRPAuth(self._provider_config) -            session_id = srp_auth.get_session_id() -            cookies = None -            if session_id: -                cookies = {"_session_id": session_id} - -            res = self._session.get(config_uri, -                                    verify=self._provider_config -                                    .get_ca_cert_path(), -                                    headers=headers, -                                    cookies=cookies) -            res.raise_for_status() - -            # Not modified -            if res.status_code == 304: -                logger.debug("Soledad definition has not been modified") -            else: -                soledad_definition, mtime = get_content(res) - -                self._soledad_config.load(data=soledad_definition, mtime=mtime) -                self._soledad_config.save(["leap", -                                           "providers", -                                           self._provider_config.get_domain(), -                                           "soledad-service.json"]) - -            self._load_and_sync_soledad(srp_auth) - -            download_config_data[self.PASSED_KEY] = True -        except Exception as e: -            download_config_data[self.PASSED_KEY] = False -            download_config_data[self.ERROR_KEY] = "%s" % (e,) - -        logger.debug("Emitting download_config %s" % (download_config_data,)) -        self.download_config.emit(download_config_data) - -        return download_config_data[self.PASSED_KEY] +        headers = {} +        mtime = get_mtime(os.path.join(self._soledad_config +                                       .get_path_prefix(), +                                       "leap", +                                       "providers", +                                       self._provider_config.get_domain(), +                                       "soledad-service.json")) + +        if self._download_if_needed and mtime: +            headers['if-modified-since'] = mtime + +        # there is some confusion with this uri, +        config_uri = "%s/%s/config/soledad-service.json" % ( +            self._provider_config.get_api_uri(), +            self._provider_config.get_api_version()) +        logger.debug('Downloading soledad config from: %s' % config_uri) + +        srp_auth = SRPAuth(self._provider_config) +        session_id = srp_auth.get_session_id() +        cookies = None +        if session_id: +            cookies = {"_session_id": session_id} + +        res = self._session.get(config_uri, +                                verify=self._provider_config +                                .get_ca_cert_path(), +                                headers=headers, +                                cookies=cookies) +        res.raise_for_status() + +        # Not modified +        if res.status_code == 304: +            logger.debug("Soledad definition has not been modified") +        else: +            soledad_definition, mtime = get_content(res) + +            self._soledad_config.load(data=soledad_definition, mtime=mtime) +            self._soledad_config.save(["leap", +                                       "providers", +                                       self._provider_config.get_domain(), +                                       "soledad-service.json"]) + +        self._load_and_sync_soledad(srp_auth)      def _gen_key(self):          """          Generates the key pair if needed, uploads it to the webapp and          nickserver - -        :return: True if everything is done successfully, False -        otherwise -        :rtype: bool          """          leap_assert(self._provider_config,                      "We need a provider configuration!") -        # XXX Sanitize this          address = "%s@%s" % (self._user, self._provider_config.get_domain())          logger.debug("Retrieving key for %s" % (address,)) -        genkey_data = { -            self.PASSED_KEY: False, -            self.ERROR_KEY: "", -            self.SOLEDAD_KEY: None, -            self.KEYMANAGER_KEY: None -        } - +        srp_auth = SRPAuth(self._provider_config) +        self._keymanager = KeyManager( +            address, +            "https://%s:6425" % (self._provider_config.get_domain()), +            self._soledad, +            #token=srp_auth.get_token(), # TODO: enable token usage +            session_id=srp_auth.get_session_id(), +            ca_cert_path=self._provider_config.get_ca_cert_path(), +            api_uri=self._provider_config.get_api_uri(), +            api_version=self._provider_config.get_api_version(), +            uid=srp_auth.get_uid())          try: -            srp_auth = SRPAuth(self._provider_config) -            self._keymanager = KeyManager( -                address, -                "https://nickserver",  # TODO: nickserver url, none for now -                self._soledad, -                token=srp_auth.get_token()) -            self._keymanager._fetcher.put = Mock() -            try: -                self._keymanager.get_key(address, openpgp.OpenPGPKey, -                                         private=True, fetch_remote=False) -            except KeyNotFound: -                logger.debug( -                    "Key not found. Generating key for %s" % (address,)) -                self._keymanager.gen_key(openpgp.OpenPGPKey) - -                logger.debug("Key generated successfully.") - -            cookies = None -            session_id = srp_auth.get_session_id() -            if session_id: -                cookies = {"_session_id": session_id} - -            key_uri = "%s/%s/users/%s.json" % ( -                self._provider_config.get_api_uri(), -                self._provider_config.get_api_version(), -                srp_auth.get_uid()) - -            logger.debug("Uploading public key to %s" % (key_uri,)) - -            pubkey = self._keymanager.get_key( -                address, openpgp.OpenPGPKey, -                private=False, fetch_remote=False) -            key_data = { -                self.PUBKEY_KEY: pubkey.key_data, -            } - -            # TODO: check if uploaded before uploading it -            key_result = self._session.put(key_uri, -                                           data=key_data, -                                           verify=self._provider_config -                                           .get_ca_cert_path(), -                                           cookies=cookies) -            key_result.raise_for_status() -            genkey_data[self.PASSED_KEY] = True -            genkey_data[self.SOLEDAD_KEY] = self._soledad -            genkey_data[self.KEYMANAGER_KEY] = self._keymanager -        except Exception as e: -            genkey_data[self.PASSED_KEY] = False -            genkey_data[self.ERROR_KEY] = "%s" % (e,) - -        logger.debug("Emitting gen_key %s" % (genkey_data,)) -        self.gen_key.emit(genkey_data) - -        return genkey_data[self.PASSED_KEY] +            self._keymanager.get_key(address, openpgp.OpenPGPKey, +                                     private=True, fetch_remote=False) +        except KeyNotFound: +            logger.debug("Key not found. Generating key for %s" % (address,)) +            self._keymanager.gen_key(openpgp.OpenPGPKey) +            logger.debug("Key generated successfully.")      def run_soledad_setup_checks(self, -                                 checker,                                   provider_config,                                   user,                                   password, @@ -273,6 +200,10 @@ class SoledadBootstrapper(QtCore.QObject):          :param provider_config: Provider configuration          :type provider_config: ProviderConfig +        :param user: User's login +        :type user: str +        :param password: User's password +        :type password: str          """          leap_assert_type(provider_config, ProviderConfig) @@ -281,7 +212,9 @@ class SoledadBootstrapper(QtCore.QObject):          self._user = user          self._password = password -        checker.add_checks([ -            self._download_config, -            self._gen_key -        ]) +        cb_chain = [ +            (self._download_config, self.download_config), +            (self._gen_key, self.gen_key) +        ] + +        self.addCallbackChain(cb_chain) diff --git a/src/leap/util/checkerthread.py b/src/leap/util/checkerthread.py deleted file mode 100644 index 02aa333f..00000000 --- a/src/leap/util/checkerthread.py +++ /dev/null @@ -1,109 +0,0 @@ -# -*- coding: utf-8 -*- -# checkerthread.py -# Copyright (C) 2013 LEAP -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU General Public License -# along with this program.  If not, see <http://www.gnu.org/licenses/>. - -""" -Checker thread -""" - -import logging - -from PySide import QtCore - -from leap.common.check import leap_assert_type - -logger = logging.getLogger(__name__) - - -class CheckerThread(QtCore.QThread): -    """ -    Generic checker thread that can perform any type of operation as -    long as it returns a boolean value that identifies how the -    execution went. -    """ - -    IDLE_SLEEP_INTERVAL = 1 - -    def __init__(self): -        QtCore.QThread.__init__(self) - -        self._checks = [] -        self._checks_lock = QtCore.QMutex() - -        self._should_quit = False -        self._should_quit_lock = QtCore.QMutex() - -    def get_should_quit(self): -        """ -        Returns whether this thread should quit - -        :return: True if the thread should terminate itself, Flase otherwise -        :rtype: bool -        """ - -        QtCore.QMutexLocker(self._should_quit_lock) -        return self._should_quit - -    def set_should_quit(self): -        """ -        Sets the should_quit flag to True so that this thread -        terminates the first chance it gets -        """ -        QtCore.QMutexLocker(self._should_quit_lock) -        self._should_quit = True - -    def start(self): -        """ -        Starts the thread and resets the should_quit flag -        """ -        with QtCore.QMutexLocker(self._should_quit_lock): -            self._should_quit = False - -        QtCore.QThread.start(self) - -    def add_checks(self, checks): -        """ -        Adds a list of checks to the ones being executed - -        :param checks: check functions to perform -        :type checkes: list -        """ -        with QtCore.QMutexLocker(self._checks_lock): -            self._checks += checks - -    def run(self): -        """ -        Main run loop for this thread. Executes the checks. -        """ -        shouldContinue = False -        while True: -            if self.get_should_quit(): -                logger.debug("Quitting checker thread") -                return -            checkSomething = False -            with QtCore.QMutexLocker(self._checks_lock): -                if len(self._checks) > 0: -                    check = self._checks.pop(0) -                    shouldContinue = check() -                    leap_assert_type(shouldContinue, bool) -                    checkSomething = True -                    if not shouldContinue: -                        logger.debug("Something went wrong with the checks, " -                                     "clearing...") -                        self._checks = [] -                        checkSomething = False -            if not checkSomething: -                self.sleep(self.IDLE_SLEEP_INTERVAL) | 
