diff options
author | drebs <drebs@leap.se> | 2012-12-24 10:14:58 -0200 |
---|---|---|
committer | drebs <drebs@leap.se> | 2012-12-24 10:14:58 -0200 |
commit | 319e279b59ac080779d0a3375ae4d6582f5ee6a3 (patch) | |
tree | 118dd0f495c0d54f2b2c66ea235e4e4e6b8cefd5 /src/leap | |
parent | ca5fb41a55e1292005ed186baf3710831d9ad678 (diff) | |
parent | a7b091a0553e6120f3e0eb6d4e73a89732c589b2 (diff) |
Merge branch 'develop' of ssh://code.leap.se/leap_client into develop
Diffstat (limited to 'src/leap')
56 files changed, 3177 insertions, 1047 deletions
diff --git a/src/leap/app.py b/src/leap/app.py index d594c7cd..334b58c8 100644 --- a/src/leap/app.py +++ b/src/leap/app.py @@ -8,10 +8,11 @@ import sip sip.setapi('QVariant', 2) sip.setapi('QString', 2) from PyQt4.QtGui import (QApplication, QSystemTrayIcon, QMessageBox) -from PyQt4.QtCore import QTimer +from PyQt4 import QtCore from leap import __version__ as VERSION from leap.baseapp.mainwindow import LeapWindow +from leap.gui import locale_rc def sigint_handler(*args, **kwargs): @@ -62,6 +63,17 @@ def main(): logger.info('Starting app') app = QApplication(sys.argv) + # To test: + # $ LANG=es ./app.py + locale = QtCore.QLocale.system().name() + print locale + qtTranslator = QtCore.QTranslator() + if qtTranslator.load("qt_%s" % locale, ":/translations"): + app.installTranslator(qtTranslator) + appTranslator = QtCore.QTranslator() + if appTranslator.load("leap_client_%s" % locale, ":/translations"): + app.installTranslator(appTranslator) + # needed for initializing qsettings # it will write .config/leap/leap.conf # top level app settings @@ -83,7 +95,7 @@ def main(): # this dummy timer ensures that # control is given to the outside loop, so we # can hook our sigint handler. - timer = QTimer() + timer = QtCore.QTimer() timer.start(500) timer.timeout.connect(lambda: None) diff --git a/src/leap/base/auth.py b/src/leap/base/auth.py index 50533278..ecc24179 100644 --- a/src/leap/base/auth.py +++ b/src/leap/base/auth.py @@ -10,6 +10,7 @@ from PyQt4 import QtCore from leap.base import constants as baseconstants from leap.crypto import leapkeyring +from leap.util.misc import null_check from leap.util.web import get_https_domain_and_port logger = logging.getLogger(__name__) @@ -26,11 +27,6 @@ one if not. """ -class ImproperlyConfigured(Exception): - """ - """ - - class SRPAuthenticationError(Exception): """ exception raised @@ -38,14 +34,6 @@ class SRPAuthenticationError(Exception): """ -def null_check(value, value_name): - try: - assert value is not None - except AssertionError: - raise ImproperlyConfigured( - "%s parameter cannot be None" % value_name) - - safe_unhexlify = lambda x: binascii.unhexlify(x) \ if (len(x) % 2 == 0) else binascii.unhexlify('0' + x) @@ -55,7 +43,7 @@ class LeapSRPRegister(object): def __init__(self, schema="https", provider=None, - port=None, + #port=None, verify=True, register_path="1/users.json", method="POST", @@ -64,13 +52,13 @@ class LeapSRPRegister(object): hashfun=srp.SHA256, ng_constant=srp.NG_1024): - null_check(provider, provider) + null_check(provider, "provider") self.schema = schema # XXX FIXME - self.provider = provider - self.port = port + #self.provider = provider + #self.port = port # XXX splitting server,port # deprecate port call. domain, port = get_https_domain_and_port(provider) @@ -154,9 +142,6 @@ class SRPAuth(requests.auth.AuthBase): self.init_srp() - def get_json_data(self, response): - return json.loads(response.content) - def init_srp(self): usr = srp.User( self.username, @@ -187,8 +172,7 @@ class SRPAuth(requests.auth.AuthBase): raise SRPAuthenticationError( "No valid response (salt).") - # XXX should get auth_result.json instead - self.init_data = self.get_json_data(init_session) + self.init_data = init_session.json return self.init_data def get_server_proof_data(self): @@ -206,13 +190,7 @@ class SRPAuth(requests.auth.AuthBase): raise SRPAuthenticationError( "No valid response (HAMK).") - # XXX should get auth_result.json instead - try: - self.auth_data = self.get_json_data(auth_result) - except ValueError: - raise SRPAuthenticationError( - "No valid data sent (HAMK)") - + self.auth_data = auth_result.json return self.auth_data def authenticate(self): @@ -267,13 +245,14 @@ class SRPAuth(requests.auth.AuthBase): try: assert self.srp_usr.authenticated() logger.debug('user is authenticated!') + print 'user is authenticated!' except (AssertionError): raise SRPAuthenticationError( "Auth verification failed.") def __call__(self, req): self.authenticate() - req.session = self.session + req.cookies = self.session.cookies return req @@ -367,8 +346,10 @@ if __name__ == "__main__": req.raise_for_status return req - req = test_srp_protected_get('https://localhost:8443/1/cert') - print 'cert :', req.content[:200] + "..." + #req = test_srp_protected_get('https://localhost:8443/1/cert') + req = test_srp_protected_get('%s/1/cert' % SERVER) + #print 'cert :', req.content[:200] + "..." + print req.content sys.exit(0) if action == "add": diff --git a/src/leap/base/checks.py b/src/leap/base/checks.py index 23446f4a..dc2602c2 100644 --- a/src/leap/base/checks.py +++ b/src/leap/base/checks.py @@ -39,9 +39,6 @@ class LeapNetworkChecker(object): # XXX remove this hardcoded random ip # ping leap.se or eip provider instead...? requests.get('http://216.172.161.165') - - except (requests.HTTPError, requests.RequestException) as e: - raise exceptions.NoInternetConnection(e.message) except requests.ConnectionError as e: error = "Unidentified Connection Error" if e.message == "[Errno 113] No route to host": @@ -51,11 +48,17 @@ class LeapNetworkChecker(object): error = "Provider server appears to be down." logger.error(error) raise exceptions.NoInternetConnection(error) + except (requests.HTTPError, requests.RequestException) as e: + raise exceptions.NoInternetConnection(e.message) logger.debug('Network appears to be up.') def is_internet_up(self): iface, gateway = self.get_default_interface_gateway() - self.ping_gateway(self.provider_gateway) + try: + self.ping_gateway(self.provider_gateway) + except exceptions.NoConnectionToGateway: + return False + return True def check_tunnel_default_interface(self): """ diff --git a/src/leap/base/config.py b/src/leap/base/config.py index 0255fbab..438d1993 100644 --- a/src/leap/base/config.py +++ b/src/leap/base/config.py @@ -5,11 +5,12 @@ import grp import json import logging import socket -import tempfile +import time import os logger = logging.getLogger(name=__name__) +from dateutil import parser as dateparser import requests from leap.base import exceptions @@ -125,17 +126,43 @@ class JSONLeapConfig(BaseLeapConfig): # mandatory baseconfig interface - def save(self, to=None): - if to is None: - to = self.filename - folder, filename = os.path.split(to) - if folder and not os.path.isdir(folder): - mkdir_p(folder) - self._config.serialize(to) + def save(self, to=None, force=False): + """ + force param will skip the dirty check. + :type force: bool + """ + # XXX this force=True does not feel to right + # but still have to look for a better way + # of dealing with dirtiness and the + # trick of loading remote config only + # when newer. + + if force: + do_save = True + else: + do_save = self._config.is_dirty() + + if do_save: + if to is None: + to = self.filename + folder, filename = os.path.split(to) + if folder and not os.path.isdir(folder): + mkdir_p(folder) + self._config.serialize(to) + return True + + else: + return False + + def load(self, fromfile=None, from_uri=None, fetcher=None, + force_download=False, verify=False): - def load(self, fromfile=None, from_uri=None, fetcher=None, verify=False): if from_uri is not None: - fetched = self.fetch(from_uri, fetcher=fetcher, verify=verify) + fetched = self.fetch( + from_uri, + fetcher=fetcher, + verify=verify, + force_dl=force_download) if fetched: return if fromfile is None: @@ -146,33 +173,69 @@ class JSONLeapConfig(BaseLeapConfig): logger.error('tried to load config from non-existent path') logger.error('Not Found: %s', fromfile) - def fetch(self, uri, fetcher=None, verify=True): + def fetch(self, uri, fetcher=None, verify=True, force_dl=False): if not fetcher: fetcher = self.fetcher + logger.debug('verify: %s', verify) logger.debug('uri: %s', uri) - request = fetcher.get(uri, verify=verify) - # XXX should send a if-modified-since header - # XXX get 404, ... - # and raise a UnableToFetch... + rargs = (uri, ) + rkwargs = {'verify': verify} + headers = {} + + curmtime = self.get_mtime() if not force_dl else None + if curmtime: + logger.debug('requesting with if-modified-since %s' % curmtime) + headers['if-modified-since'] = curmtime + rkwargs['headers'] = headers + + #request = fetcher.get(uri, verify=verify) + request = fetcher.get(*rargs, **rkwargs) request.raise_for_status() - fd, fname = tempfile.mkstemp(suffix=".json") - if request.json: - self._config.load(json.dumps(request.json)) + if request.status_code == 304: + logger.debug('...304 Not Changed') + # On this point, we have to assume that + # we HAD the filename. If that filename is corruct, + # we should enforce a force_download in the load + # method above. + self._config.load(fromfile=self.filename) + return True + if request.json: + mtime = None + last_modified = request.headers.get('last-modified', None) + if last_modified: + _mtime = dateparser.parse(last_modified) + mtime = int(_mtime.strftime("%s")) + if callable(request.json): + _json = request.json() + else: + # back-compat + _json = request.json + self._config.load(json.dumps(_json), mtime=mtime) + self._config.set_dirty() else: # not request.json # might be server did not announce content properly, # let's try deserializing all the same. try: self._config.load(request.content) + self._config.set_dirty() except ValueError: raise eipexceptions.LeapBadConfigFetchedError return True + def get_mtime(self): + try: + _mtime = os.stat(self.filename)[8] + mtime = time.strftime("%c GMT", time.gmtime(_mtime)) + return mtime + except OSError: + return None + def get_config(self): return self._config.config diff --git a/src/leap/base/constants.py b/src/leap/base/constants.py index f7be8d98..b38723be 100644 --- a/src/leap/base/constants.py +++ b/src/leap/base/constants.py @@ -14,18 +14,27 @@ DEFAULT_PROVIDER = __branding.get( DEFINITION_EXPECTED_PATH = "provider.json" DEFAULT_PROVIDER_DEFINITION = { - u'api_uri': u'https://api.%s/' % DEFAULT_PROVIDER, - u'api_version': u'0.1.0', - u'ca_cert_fingerprint': u'8aab80ae4326fd30721689db813733783fe0bd7e', - u'ca_cert_uri': u'https://%s/cacert.pem' % DEFAULT_PROVIDER, - u'description': {u'en': u'This is a test provider'}, - u'display_name': {u'en': u'Test Provider'}, - u'domain': u'%s' % DEFAULT_PROVIDER, - u'enrollment_policy': u'open', - u'public_key': u'cb7dbd679f911e85bc2e51bd44afd7308ee19c21', - u'serial': 1, - u'services': [u'eip'], - u'version': u'0.1.0'} + u"api_uri": "https://api.%s/" % DEFAULT_PROVIDER, + u"api_version": u"1", + u"ca_cert_fingerprint": "SHA256: fff", + u"ca_cert_uri": u"https://%s/ca.crt" % DEFAULT_PROVIDER, + u"default_language": u"en", + u"description": { + u"en": u"A demonstration service provider using the LEAP platform" + }, + u"domain": "%s" % DEFAULT_PROVIDER, + u"enrollment_policy": u"open", + u"languages": [ + u"en" + ], + u"name": { + u"en": u"Test Provider" + }, + u"services": [ + "openvpn" + ] +} + MAX_ICMP_PACKET_LOSS = 10 diff --git a/src/leap/base/network.py b/src/leap/base/network.py index 3aba3f61..765d8ea0 100644 --- a/src/leap/base/network.py +++ b/src/leap/base/network.py @@ -3,10 +3,11 @@ from __future__ import (print_function) import logging import threading -from leap.eip.config import get_eip_gateway +from leap.eip import config as eipconfig from leap.base.checks import LeapNetworkChecker from leap.base.constants import ROUTE_CHECK_INTERVAL from leap.base.exceptions import TunnelNotDefaultRouteError +from leap.util.misc import null_check from leap.util.coroutines import (launch_thread, process_events) from time import sleep @@ -27,11 +28,20 @@ class NetworkCheckerThread(object): lambda exc: logger.error("%s", exc.message)) self.shutdown = threading.Event() - # XXX get provider_gateway and pass it to checker - # see in eip.config for function - # #718 + # XXX get provider passed here + provider = kwargs.pop('provider', None) + null_check(provider, 'provider') + + eipconf = eipconfig.EIPConfig(domain=provider) + eipconf.load() + eipserviceconf = eipconfig.EIPServiceConfig(domain=provider) + eipserviceconf.load() + + gw = eipconfig.get_eip_gateway( + eipconfig=eipconf, + eipserviceconfig=eipserviceconf) self.checker = LeapNetworkChecker( - provider_gw=get_eip_gateway()) + provider_gw=gw) def start(self): self.process_handle = self._launch_recurrent_network_checks( diff --git a/src/leap/base/pluggableconfig.py b/src/leap/base/pluggableconfig.py index b8615ad8..0ca985ea 100644 --- a/src/leap/base/pluggableconfig.py +++ b/src/leap/base/pluggableconfig.py @@ -180,6 +180,8 @@ class PluggableConfig(object): self.adaptors = adaptors self.types = types self._format = format + self.mtime = None + self.dirty = False @property def option_dict(self): @@ -319,6 +321,13 @@ class PluggableConfig(object): serializable = self.prep_value(config) adaptor.write(serializable, filename) + if self.mtime: + self.touch_mtime(filename) + + def touch_mtime(self, filename): + mtime = self.mtime + os.utime(filename, (mtime, mtime)) + def deserialize(self, string=None, fromfile=None, format=None): """ load configuration from a file or string @@ -364,6 +373,12 @@ class PluggableConfig(object): content = _try_deserialize() return content + def set_dirty(self): + self.dirty = True + + def is_dirty(self): + return self.dirty + def load(self, *args, **kwargs): """ load from string or file @@ -373,6 +388,8 @@ class PluggableConfig(object): """ string = args[0] if args else None fromfile = kwargs.get("fromfile", None) + mtime = kwargs.pop("mtime", None) + self.mtime = mtime content = None # start with defaults, so we can @@ -402,7 +419,8 @@ class PluggableConfig(object): return True -def testmain(): +def testmain(): # pragma: no cover + from tests import test_validation as t import pprint diff --git a/src/leap/base/specs.py b/src/leap/base/specs.py index b4bb8dcf..962aa07d 100644 --- a/src/leap/base/specs.py +++ b/src/leap/base/specs.py @@ -2,22 +2,26 @@ leap_provider_spec = { 'description': 'provider definition', 'type': 'object', 'properties': { - 'serial': { - 'type': int, - 'default': 1, - 'required': True, - }, + #'serial': { + #'type': int, + #'default': 1, + #'required': True, + #}, 'version': { 'type': unicode, 'default': '0.1.0' #'required': True }, + "default_language": { + 'type': unicode, + 'default': 'en' + }, 'domain': { 'type': unicode, # XXX define uri type 'default': 'testprovider.example.org' #'required': True, }, - 'display_name': { + 'name': { 'type': dict, # XXX multilingual object? 'default': {u'en': u'Test Provider'} #'required': True diff --git a/src/leap/base/tests/test_auth.py b/src/leap/base/tests/test_auth.py new file mode 100644 index 00000000..17b84b52 --- /dev/null +++ b/src/leap/base/tests/test_auth.py @@ -0,0 +1,58 @@ +from BaseHTTPServer import BaseHTTPRequestHandler +import urlparse +try: + import unittest2 as unittest +except ImportError: + import unittest + +import requests +#from mock import Mock + +from leap.base import auth +#from leap.base import exceptions +from leap.eip.tests.test_checks import NoLogRequestHandler +from leap.testing.basetest import BaseLeapTest +from leap.testing.https_server import BaseHTTPSServerTestCase + + +class LeapSRPRegisterTests(BaseHTTPSServerTestCase, BaseLeapTest): + __name__ = "leap_srp_register_test" + provider = "testprovider.example.org" + + class request_handler(NoLogRequestHandler, BaseHTTPRequestHandler): + responses = { + '/': ['OK', '']} + + def do_GET(self): + path = urlparse.urlparse(self.path) + message = '\n'.join(self.responses.get( + path.path, None)) + self.send_response(200) + self.end_headers() + self.wfile.write(message) + + def setUp(self): + pass + + def tearDown(self): + pass + + def test_srp_auth_should_implement_check_methods(self): + SERVER = "https://localhost:8443" + srp_auth = auth.LeapSRPRegister(provider=SERVER, verify=False) + + self.assertTrue(hasattr(srp_auth, "init_session"), + "missing meth") + self.assertTrue(hasattr(srp_auth, "get_registration_uri"), + "missing meth") + self.assertTrue(hasattr(srp_auth, "register_user"), + "missing meth") + + def test_srp_auth_basic_functionality(self): + SERVER = "https://localhost:8443" + srp_auth = auth.LeapSRPRegister(provider=SERVER, verify=False) + + self.assertIsInstance(srp_auth.session, requests.sessions.Session) + self.assertEqual( + srp_auth.get_registration_uri(), + "https://localhost:8443/1/users.json") diff --git a/src/leap/base/tests/test_checks.py b/src/leap/base/tests/test_checks.py index 8d573b1e..7a694f89 100644 --- a/src/leap/base/tests/test_checks.py +++ b/src/leap/base/tests/test_checks.py @@ -118,6 +118,22 @@ class LeapNetworkCheckTest(BaseLeapTest): with self.assertRaises(exceptions.NoInternetConnection): checker.check_internet_connection() + with patch.object(requests, "get") as mocked_get: + mocked_get.side_effect = requests.ConnectionError( + "[Errno 113] No route to host") + with self.assertRaises(exceptions.NoInternetConnection): + with patch.object(checker, "ping_gateway") as mock_ping: + mock_ping.return_value = True + checker.check_internet_connection() + + with patch.object(requests, "get") as mocked_get: + mocked_get.side_effect = requests.ConnectionError( + "[Errno 113] No route to host") + with self.assertRaises(exceptions.NoInternetConnection): + with patch.object(checker, "ping_gateway") as mock_ping: + mock_ping.side_effect = exceptions.NoConnectionToGateway + checker.check_internet_connection() + @unittest.skipUnless(_uid == 0, "root only") def test_ping_gateway(self): checker = checks.LeapNetworkChecker() diff --git a/src/leap/base/tests/test_providers.py b/src/leap/base/tests/test_providers.py index 15c4ed58..9c11f270 100644 --- a/src/leap/base/tests/test_providers.py +++ b/src/leap/base/tests/test_providers.py @@ -8,7 +8,7 @@ import os import jsonschema -from leap import __branding as BRANDING +#from leap import __branding as BRANDING from leap.testing.basetest import BaseLeapTest from leap.base import providers @@ -16,10 +16,12 @@ from leap.base import providers EXPECTED_DEFAULT_CONFIG = { u"api_version": u"0.1.0", u"description": {u'en': u"Test provider"}, - u"display_name": {u'en': u"Test Provider"}, + u"default_language": u"en", + #u"display_name": {u'en': u"Test Provider"}, u"domain": u"testprovider.example.org", + u'name': {u'en': u'Test Provider'}, u"enrollment_policy": u"open", - u"serial": 1, + #u"serial": 1, u"services": [ u"eip" ], @@ -33,8 +35,8 @@ class TestLeapProviderDefinition(BaseLeapTest): self.domain = "testprovider.example.org" self.definition = providers.LeapProviderDefinition( domain=self.domain) - self.definition.save() - self.definition.load() + self.definition.save(force=True) + self.definition.load() # why have to load after save?? self.config = self.definition.config def tearDown(self): @@ -61,7 +63,7 @@ class TestLeapProviderDefinition(BaseLeapTest): def test_provider_dump(self): # check a good provider definition is dumped to disk self.testfile = self.get_tempfile('test.json') - self.definition.save(to=self.testfile) + self.definition.save(to=self.testfile, force=True) deserialized = json.load(open(self.testfile, 'rb')) self.maxDiff = None self.assertEqual(deserialized, EXPECTED_DEFAULT_CONFIG) @@ -88,7 +90,8 @@ class TestLeapProviderDefinition(BaseLeapTest): def test_provider_validation(self): self.definition.validate(self.config) _config = copy.deepcopy(self.config) - _config['serial'] = 'aaa' + # bad type, raise validation error + _config['domain'] = 111 with self.assertRaises(jsonschema.ValidationError): self.definition.validate(_config) diff --git a/src/leap/baseapp/eip.py b/src/leap/baseapp/eip.py index 54acbc0e..55ecfa79 100644 --- a/src/leap/baseapp/eip.py +++ b/src/leap/baseapp/eip.py @@ -203,7 +203,6 @@ class EIPConductorAppMixin(object): # we could bring Timer Init to this Mixin # or to its own Mixin. self.timer.start(constants.TIMER_MILLISECONDS) - self.network_checker.start() return if self.eip_service_started is True: diff --git a/src/leap/baseapp/leap_app.py b/src/leap/baseapp/leap_app.py index 4b63dd2f..4d3aebd6 100644 --- a/src/leap/baseapp/leap_app.py +++ b/src/leap/baseapp/leap_app.py @@ -148,6 +148,6 @@ class MainWindowMixin(object): # in conductor # XXX send signal instead? logger.info('Shutting down') - self.conductor.cleanup() + self.conductor.disconnect(shutdown=True) logger.info('Exiting. Bye.') QtGui.qApp.quit() diff --git a/src/leap/baseapp/mainwindow.py b/src/leap/baseapp/mainwindow.py index 8d61bf5c..02adab65 100644 --- a/src/leap/baseapp/mainwindow.py +++ b/src/leap/baseapp/mainwindow.py @@ -41,6 +41,7 @@ class LeapWindow(QtGui.QMainWindow, triggerEIPError = QtCore.pyqtSignal([object]) start_eipconnection = QtCore.pyqtSignal([]) shutdownSignal = QtCore.pyqtSignal([]) + initNetworkChecker = QtCore.pyqtSignal([]) # this is status change got from openvpn management openvpnStatusChange = QtCore.pyqtSignal([object]) @@ -61,10 +62,15 @@ class LeapWindow(QtGui.QMainWindow, logger.debug('provider: %s', self.provider_domain) logger.debug('eip_username: %s', self.eip_username) + provider = self.provider_domain EIPConductorAppMixin.__init__( - self, opts=opts, provider=self.provider_domain) + self, opts=opts, provider=provider) StatusAwareTrayIconMixin.__init__(self) - NetworkCheckerAppMixin.__init__(self) + + # XXX network checker should probably not + # trigger run_checks on init... but wait + # for ready signal instead... + NetworkCheckerAppMixin.__init__(self, provider=provider) MainWindowMixin.__init__(self) geom_key = "DebugGeometry" if self.debugmode else "Geometry" @@ -97,6 +103,8 @@ class LeapWindow(QtGui.QMainWindow, lambda: self.start_or_stopVPN()) self.shutdownSignal.connect( self.cleanupAndQuit) + self.initNetworkChecker.connect( + lambda: self.init_network_checker(self.provider_domain)) # status change. # TODO unify diff --git a/src/leap/baseapp/network.py b/src/leap/baseapp/network.py index 077d5164..a33265e5 100644 --- a/src/leap/baseapp/network.py +++ b/src/leap/baseapp/network.py @@ -9,19 +9,30 @@ from PyQt4 import QtCore from leap.baseapp.dialogs import ErrorDialog from leap.base.network import NetworkCheckerThread +from leap.util.misc import null_check + class NetworkCheckerAppMixin(object): """ initialize an instance of the Network Checker, which gathers error and passes them on. """ - def __init__(self, *args, **kwargs): - self.network_checker = NetworkCheckerThread( - error_cb=self.networkError.emit, - debug=self.debugmode) + provider = kwargs.pop('provider', None) + if provider: + self.init_network_checker(provider) + + def init_network_checker(self, provider): + null_check(provider, "provider_domain") + if not hasattr(self, 'network_checker'): + self.network_checker = NetworkCheckerThread( + error_cb=self.networkError.emit, + debug=self.debugmode, + provider=provider) + self.network_checker.start() - # XXX move run_checks to slot + @QtCore.pyqtSlot(object) + def runNetworkChecks(self): self.network_checker.run_checks() @QtCore.pyqtSlot(object) diff --git a/src/leap/baseapp/systray.py b/src/leap/baseapp/systray.py index 49f044aa..0dd0f195 100644 --- a/src/leap/baseapp/systray.py +++ b/src/leap/baseapp/systray.py @@ -217,6 +217,8 @@ class StatusAwareTrayIconMixin(object): updates icon, according to the openvpn status change. """ icon_name = self.conductor.get_icon_name() + if not icon_name: + return # XXX refactor. Use QStateMachine @@ -228,6 +230,11 @@ class StatusAwareTrayIconMixin(object): leap_status_name = self.conductor.get_leap_status() self.eipStatusChange.emit(leap_status_name) + if icon_name == "connected": + # When we change to "connected', we launch + # the network checker. + self.initNetworkChecker.emit() + self.setIcon(icon_name) # change connection pixmap widget self.setConnWidget(icon_name) diff --git a/src/leap/crypto/certs.py b/src/leap/crypto/certs.py index 8908865d..78f49fb0 100644 --- a/src/leap/crypto/certs.py +++ b/src/leap/crypto/certs.py @@ -1,10 +1,17 @@ import ctypes +from StringIO import StringIO import socket import gnutls.connection import gnutls.crypto import gnutls.library +from leap.util.misc import null_check + + +class BadCertError(Exception): + """raised for malformed certs""" + def get_https_cert_from_domain(domain): """ @@ -20,12 +27,43 @@ def get_https_cert_from_domain(domain): return cert -def get_cert_from_file(filepath): - with open(filepath) as f: - cert = gnutls.crypto.X509Certificate(f.read()) +def get_cert_from_file(_file): + getcert = lambda f: gnutls.crypto.X509Certificate(f.read()) + if isinstance(_file, str): + with open(_file) as f: + cert = getcert(f) + else: + cert = getcert(_file) return cert +def get_pkey_from_file(_file): + getkey = lambda f: gnutls.crypto.X509PrivateKey(f.read()) + if isinstance(_file, str): + with open(_file) as f: + key = getkey(f) + else: + key = getkey(_file) + return key + + +def can_load_cert_and_pkey(string): + try: + f = StringIO(string) + cert = get_cert_from_file(f) + + f = StringIO(string) + key = get_pkey_from_file(f) + + null_check(cert, 'certificate') + null_check(key, 'private key') + except: + # XXX catch GNUTLSError? + raise BadCertError + else: + return True + + def get_cert_fingerprint(domain=None, filepath=None, hash_type="SHA256", sep=":"): """ diff --git a/src/leap/crypto/leapkeyring.py b/src/leap/crypto/leapkeyring.py index d4be7bf9..c241d0bc 100644 --- a/src/leap/crypto/leapkeyring.py +++ b/src/leap/crypto/leapkeyring.py @@ -53,6 +53,7 @@ class LeapCryptedFileKeyring(keyring.backend.CryptedFileKeyring): def leap_set_password(key, value, seed="xxx"): + key, value = map(unicode, (key, value)) keyring.set_keyring(LeapCryptedFileKeyring(seed=seed)) keyring.set_password('leap', key, value) diff --git a/src/leap/crypto/tests/__init__.py b/src/leap/crypto/tests/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/src/leap/crypto/tests/__init__.py diff --git a/src/leap/crypto/tests/test_certs.py b/src/leap/crypto/tests/test_certs.py new file mode 100644 index 00000000..e476b630 --- /dev/null +++ b/src/leap/crypto/tests/test_certs.py @@ -0,0 +1,22 @@ +import unittest + +from leap.testing.https_server import where +from leap.crypto import certs + + +class CertTestCase(unittest.TestCase): + + def test_can_load_client_and_pkey(self): + with open(where('leaptestscert.pem')) as cf: + cs = cf.read() + with open(where('leaptestskey.pem')) as kf: + ks = kf.read() + certs.can_load_cert_and_pkey(cs + ks) + + with self.assertRaises(certs.BadCertError): + # screw header + certs.can_load_cert_and_pkey(cs.replace("BEGIN", "BEGINN") + ks) + + +if __name__ == "__main__": + unittest.main() diff --git a/src/leap/eip/checks.py b/src/leap/eip/checks.py index 116c535e..9ae6e5f5 100644 --- a/src/leap/eip/checks.py +++ b/src/leap/eip/checks.py @@ -84,8 +84,7 @@ class ProviderCertChecker(object): # For MVS checker.is_there_provider_ca() - # XXX FAKE IT!!! - checker.is_https_working(verify=do_verify, autocacert=True) + checker.is_https_working(verify=do_verify, autocacert=False) checker.check_new_cert_needed(verify=do_verify) def download_ca_cert(self, uri=None, verify=True): @@ -160,7 +159,6 @@ class ProviderCertChecker(object): if autocacert and verify is True and self.cacert is not None: logger.debug('verify cert: %s', self.cacert) verify = self.cacert - #import pdb4qt; pdb4qt.set_trace() logger.debug('is https working?') logger.debug('uri: %s (verify:%s)', uri, verify) try: @@ -242,7 +240,9 @@ class ProviderCertChecker(object): raise try: pemfile_content = req.content - self.is_valid_pemfile(pemfile_content) + valid = self.is_valid_pemfile(pemfile_content) + if not valid: + return False cert_path = self._get_client_cert_path() self.write_cert(pemfile_content, to=cert_path) except: @@ -276,7 +276,10 @@ class ProviderCertChecker(object): cert = gnutls.crypto.X509Certificate(cert_s) from_ = time.gmtime(cert.activation_time) to_ = time.gmtime(cert.expiration_time) + # FIXME BUG ON LEAP_CLI, certs are not valid on gmtime + # See #1153 return from_ < now() < to_ + #return now() < to_ def is_valid_pemfile(self, cert_s=None): """ @@ -291,22 +294,11 @@ class ProviderCertChecker(object): with open(certfile) as cf: cert_s = cf.read() try: - # XXX get a real cert validation - # so far this is only checking begin/end - # delimiters :) - # XXX use gnutls for get proper - # validation. - # crypto.X509Certificate(cert_s) - sep = "-" * 5 + "BEGIN CERTIFICATE" + "-" * 5 - # we might have private key and cert in the same file - certparts = cert_s.split(sep) - if len(certparts) > 1: - cert_s = sep + certparts[1] - ssl.PEM_cert_to_DER_cert(cert_s) - except: - # XXX raise proper exception - raise - return True + valid = certs.can_load_cert_and_pkey(cert_s) + except certs.BadCertError: + logger.warning("Not valid pemfile") + valid = False + return valid @property def ca_cert_path(self): @@ -427,6 +419,7 @@ class EIPConfigChecker(object): return True def fetch_definition(self, skip_download=False, + force_download=False, config=None, uri=None, domain=None): """ @@ -459,6 +452,7 @@ class EIPConfigChecker(object): self.defaultprovider.save() def fetch_eip_service_config(self, skip_download=False, + force_download=False, config=None, uri=None, domain=None): if skip_download: return True @@ -469,7 +463,10 @@ class EIPConfigChecker(object): domain = self.domain or config.get('provider', None) uri = self._get_eip_service_uri(domain=domain) - self.eipserviceconfig.load(from_uri=uri, fetcher=self.fetcher) + self.eipserviceconfig.load( + from_uri=uri, + fetcher=self.fetcher, + force_download=force_download) self.eipserviceconfig.save() def check_complete_eip_config(self, config=None): @@ -497,7 +494,7 @@ class EIPConfigChecker(object): return self.eipconfig.exists() def _dump_default_eipconfig(self): - self.eipconfig.save() + self.eipconfig.save(force=True) def _get_provider_definition_uri(self, domain=None, path=None): if domain is None: diff --git a/src/leap/eip/config.py b/src/leap/eip/config.py index 42c00380..48e6e9a7 100644 --- a/src/leap/eip/config.py +++ b/src/leap/eip/config.py @@ -1,10 +1,12 @@ import logging import os import platform +import re import tempfile from leap import __branding as BRANDING from leap import certs +from leap.util.misc import null_check from leap.util.fileutil import (which, mkdir_p, check_and_fix_urw_only) from leap.base import config as baseconfig @@ -53,53 +55,80 @@ def get_socket_path(): socket_path = os.path.join( tempfile.mkdtemp(prefix="leap-tmp"), 'openvpn.socket') - logger.debug('socket path: %s', socket_path) + #logger.debug('socket path: %s', socket_path) return socket_path -def get_eip_gateway(provider=None): +def get_eip_gateway(eipconfig=None, eipserviceconfig=None): """ return the first host in eip service config that matches the name defined in the eip.json config file. """ - placeholder = "testprovider.example.org" - # XXX check for null on provider?? + # XXX eventually we should move to a more clever + # gateway selection. maybe we could return + # all gateways that match our cluster. + + null_check(eipconfig, "eipconfig") + null_check(eipserviceconfig, "eipserviceconfig") + PLACEHOLDER = "testprovider.example.org" - eipconfig = EIPConfig(domain=provider) - eipconfig.load() conf = eipconfig.config + eipsconf = eipserviceconfig.config primary_gateway = conf.get('primary_gateway', None) if not primary_gateway: - return placeholder + return PLACEHOLDER - eipserviceconfig = EIPServiceConfig(domain=provider) - eipserviceconfig.load() - eipsconf = eipserviceconfig.get_config() gateways = eipsconf.get('gateways', None) if not gateways: logger.error('missing gateways in eip service config') - return placeholder + return PLACEHOLDER + if len(gateways) > 0: for gw in gateways: - name = gw.get('name', None) - if not name: + clustername = gw.get('cluster', None) + if not clustername: + logger.error('no cluster name') return - if name == primary_gateway: - hosts = gw.get('hosts', None) - if not hosts: - logger.error('no hosts') + if clustername == primary_gateway: + # XXX at some moment, we must + # make this a more generic function, + # and return ports, protocols... + ipaddress = gw.get('ip_address', None) + if not ipaddress: + logger.error('no ip_address') return - if len(hosts) > 0: - return hosts[0] - else: - logger.error('no hosts') + return ipaddress logger.error('could not find primary gateway in provider' 'gateway list') +def get_cipher_options(eipserviceconfig=None): + """ + gathers optional cipher options from eip-service config. + :param eipserviceconfig: EIPServiceConfig instance + """ + null_check(eipserviceconfig, 'eipserviceconfig') + eipsconf = eipserviceconfig.get_config() + + ALLOWED_KEYS = ("auth", "cipher", "tls-cipher") + CIPHERS_REGEX = re.compile("[A-Z0-9\-]+") + opts = [] + if 'openvpn_configuration' in eipsconf: + config = eipserviceconfig.config.get( + "openvpn_configuration", {}) + for key, value in config.items(): + if key in ALLOWED_KEYS and value is not None: + sanitized_val = CIPHERS_REGEX.findall(value) + if len(sanitized_val) != 0: + _val = sanitized_val[0] + opts.append('--%s' % key) + opts.append('%s' % _val) + return opts + + def build_ovpn_options(daemon=False, socket_path=None, **kwargs): """ build a list of options @@ -116,6 +145,10 @@ def build_ovpn_options(daemon=False, socket_path=None, **kwargs): # things from there if present. provider = kwargs.pop('provider', None) + eipconfig = EIPConfig(domain=provider) + eipconfig.load() + eipserviceconfig = EIPServiceConfig(domain=provider) + eipserviceconfig.load() # get user/group name # also from config. @@ -137,11 +170,17 @@ def build_ovpn_options(daemon=False, socket_path=None, **kwargs): opts.append('--verb') opts.append("%s" % verbosity) - # remote + # remote ############################## + # (server, port, protocol) + opts.append('--remote') - gw = get_eip_gateway(provider=provider) + + gw = get_eip_gateway(eipconfig=eipconfig, + eipserviceconfig=eipserviceconfig) logger.debug('setting eip gateway to %s', gw) opts.append(str(gw)) + + # get port/protocol from eipservice too opts.append('1194') #opts.append('80') opts.append('udp') @@ -150,6 +189,13 @@ def build_ovpn_options(daemon=False, socket_path=None, **kwargs): opts.append('--remote-cert-tls') opts.append('server') + # get ciphers ####################### + + ciphers = get_cipher_options( + eipserviceconfig=eipserviceconfig) + for cipheropt in ciphers: + opts.append(str(cipheropt)) + # set user and group opts.append('--user') opts.append('%s' % user) diff --git a/src/leap/eip/eipconnection.py b/src/leap/eip/eipconnection.py index 7828c864..27734f80 100644 --- a/src/leap/eip/eipconnection.py +++ b/src/leap/eip/eipconnection.py @@ -5,6 +5,9 @@ from __future__ import (absolute_import,) import logging import Queue import sys +import time + +from dateutil.parser import parse as dateparse from leap.eip.checks import ProviderCertChecker from leap.eip.checks import EIPConfigChecker @@ -15,20 +18,142 @@ from leap.eip.openvpnconnection import OpenVPNConnection logger = logging.getLogger(name=__name__) -class EIPConnection(OpenVPNConnection): +class StatusMixIn(object): + + # a bunch of methods related with querying the connection + # state/status and displaying useful info. + # Needs to get clear on what is what, and + # separate functions. + # Should separate EIPConnectionStatus (self.status) + # from the OpenVPN state/status command and parsing. + + def connection_state(self): + """ + returns the current connection state + """ + return self.status.current + + def get_icon_name(self): + """ + get icon name from status object + """ + return self.status.get_state_icon() + + def get_leap_status(self): + return self.status.get_leap_status() + + def poll_connection_state(self): + """ + """ + try: + state = self.get_connection_state() + except eip_exceptions.ConnectionRefusedError: + # connection refused. might be not ready yet. + logger.warning('connection refused') + return + if not state: + logger.debug('no state') + return + (ts, status_step, + ok, ip, remote) = state + self.status.set_vpn_state(status_step) + status_step = self.status.get_readable_status() + return (ts, status_step, ok, ip, remote) + + def make_error(self): + """ + capture error and wrap it in an + understandable format + """ + # mostly a hack to display errors in the debug UI + # w/o breaking the polling. + #XXX get helpful error codes + self.with_errors = True + now = int(time.time()) + return '%s,LAUNCHER ERROR,ERROR,-,-' % now + + def state(self): + """ + Sends OpenVPN command: state + """ + state = self._send_command("state") + if not state: + return None + if isinstance(state, str): + return state + if isinstance(state, list): + if len(state) == 1: + return state[0] + else: + return state[-1] + + def vpn_status(self): + """ + OpenVPN command: status + """ + status = self._send_command("status") + return status + + def vpn_status2(self): + """ + OpenVPN command: last 2 statuses + """ + return self._send_command("status 2") + + # + # parse info as the UI expects + # + + def get_status_io(self): + status = self.vpn_status() + if isinstance(status, str): + lines = status.split('\n') + if isinstance(status, list): + lines = status + try: + (header, when, tun_read, tun_write, + tcp_read, tcp_write, auth_read) = tuple(lines) + except ValueError: + return None + + when_ts = dateparse(when.split(',')[1]).timetuple() + sep = ',' + # XXX clean up this! + tun_read = tun_read.split(sep)[1] + tun_write = tun_write.split(sep)[1] + tcp_read = tcp_read.split(sep)[1] + tcp_write = tcp_write.split(sep)[1] + auth_read = auth_read.split(sep)[1] + + # XXX this could be a named tuple. prettier. + return when_ts, (tun_read, tun_write, tcp_read, tcp_write, auth_read) + + def get_connection_state(self): + state = self.state() + if state is not None: + ts, status_step, ok, ip, remote = state.split(',') + ts = time.gmtime(float(ts)) + # XXX this could be a named tuple. prettier. + return ts, status_step, ok, ip, remote + + +class EIPConnection(OpenVPNConnection, StatusMixIn): """ + Aka conductor. Manages the execution of the OpenVPN process, auto starts, monitors the network connection, handles configuration, fixes leaky hosts, handles errors, etc. Status updates (connected, bandwidth, etc) are signaled to the GUI. """ + # XXX change name to EIPConductor ?? + def __init__(self, provider_cert_checker=ProviderCertChecker, config_checker=EIPConfigChecker, *args, **kwargs): - self.settingsfile = kwargs.get('settingsfile', None) - self.logfile = kwargs.get('logfile', None) + #self.settingsfile = kwargs.get('settingsfile', None) + #self.logfile = kwargs.get('logfile', None) self.provider = kwargs.pop('provider', None) self._providercertchecker = provider_cert_checker self._configchecker = config_checker @@ -48,11 +173,27 @@ class EIPConnection(OpenVPNConnection): super(EIPConnection, self).__init__(*args, **kwargs) + def connect(self): + """ + entry point for connection process + """ + # in OpenVPNConnection + self.try_openvpn_connection() + + def disconnect(self, shutdown=False): + """ + disconnects client + """ + self.terminate_openvpn_connection(shutdown=shutdown) + self.status.change_to(self.status.DISCONNECTED) + def has_errors(self): return True if self.error_queue.qsize() != 0 else False def init_checkers(self): - # initialize checkers + """ + initialize checkers + """ self.provider_cert_checker = self._providercertchecker( domain=self.provider) self.config_checker = self._configchecker(domain=self.provider) @@ -101,96 +242,6 @@ class EIPConnection(OpenVPNConnection): except Exception as exc: push_err(exc) - def connect(self): - """ - entry point for connection process - """ - #self.forget_errors() - self._try_connection() - - def disconnect(self): - """ - disconnects client - """ - self.cleanup() - logger.debug("disconnect: clicked.") - self.status.change_to(self.status.DISCONNECTED) - - #def shutdown(self): - #""" - #shutdown and quit - #""" - #self.desired_con_state = self.status.DISCONNECTED - - def connection_state(self): - """ - returns the current connection state - """ - return self.status.current - - def poll_connection_state(self): - """ - """ - try: - state = self.get_connection_state() - except eip_exceptions.ConnectionRefusedError: - # connection refused. might be not ready yet. - logger.warning('connection refused') - return - if not state: - logger.debug('no state') - return - (ts, status_step, - ok, ip, remote) = state - self.status.set_vpn_state(status_step) - status_step = self.status.get_readable_status() - return (ts, status_step, ok, ip, remote) - - def get_icon_name(self): - """ - get icon name from status object - """ - return self.status.get_state_icon() - - def get_leap_status(self): - return self.status.get_leap_status() - - # - # private methods - # - - #def _disconnect(self): - # """ - # private method for disconnecting - # """ - # if self.subp is not None: - # logger.debug('disconnecting...') - # self.subp.terminate() - # self.subp = None - - #def _is_alive(self): - #""" - #don't know yet - #""" - #pass - - def _connect(self): - """ - entry point for connection cascade methods. - """ - try: - conn_result = self._try_connection() - except eip_exceptions.UnrecoverableError as except_msg: - logger.error("FATAL: %s" % unicode(except_msg)) - conn_result = self.status.UNRECOVERABLE - - # XXX enqueue exceptions themselves instead? - except Exception as except_msg: - self.error_queue.append(except_msg) - logger.error("Failed Connection: %s" % - unicode(except_msg)) - return conn_result - class EIPConnectionStatus(object): """ diff --git a/src/leap/eip/openvpnconnection.py b/src/leap/eip/openvpnconnection.py index 859378c0..c2dc71a6 100644 --- a/src/leap/eip/openvpnconnection.py +++ b/src/leap/eip/openvpnconnection.py @@ -7,7 +7,6 @@ import os import psutil import shutil import socket -import time from functools import partial logger = logging.getLogger(name=__name__) @@ -20,12 +19,123 @@ from leap.eip import config as eip_config from leap.eip import exceptions as eip_exceptions -class OpenVPNConnection(Connection): +class OpenVPNManagement(object): + + # TODO explain a little bit how management interface works + # and our telnet interface with support for unix sockets. + + """ + for more information, read openvpn management notes. + zcat `dpkg -L openvpn | grep management` + """ + + def _connect_to_management(self): + """ + Connect to openvpn management interface + """ + if hasattr(self, 'tn'): + self._close_management_socket() + self.tn = UDSTelnet(self.host, self.port) + + # XXX make password optional + # specially for win. we should generate + # the pass on the fly when invoking manager + # from conductor + + #self.tn.read_until('ENTER PASSWORD:', 2) + #self.tn.write(self.password + '\n') + #self.tn.read_until('SUCCESS:', 2) + if self.tn: + self._seek_to_eof() + return True + + def _close_management_socket(self, announce=True): + """ + Close connection to openvpn management interface + """ + logger.debug('closing socket') + if announce: + self.tn.write("quit\n") + self.tn.read_all() + self.tn.get_socket().close() + del self.tn + + def _seek_to_eof(self): + """ + Read as much as available. Position seek pointer to end of stream + """ + try: + b = self.tn.read_eager() + except EOFError: + logger.debug("Could not read from socket. Assuming it died.") + return + while b: + try: + b = self.tn.read_eager() + except EOFError: + logger.debug("Could not read from socket. Assuming it died.") + + def _send_command(self, cmd): + """ + Send a command to openvpn and return response as list + """ + if not self.connected(): + try: + self._connect_to_management() + except eip_exceptions.MissingSocketError: + logger.warning('missing management socket') + return [] + try: + if hasattr(self, 'tn'): + self.tn.write(cmd + "\n") + except socket.error: + logger.error('socket error') + self._close_management_socket(announce=False) + return [] + buf = self.tn.read_until(b"END", 2) + self._seek_to_eof() + blist = buf.split('\r\n') + if blist[-1].startswith('END'): + del blist[-1] + return blist + else: + return [] + + def _send_short_command(self, cmd): + """ + parse output from commands that are + delimited by "success" instead + """ + if not self.connected(): + self.connect() + self.tn.write(cmd + "\n") + # XXX not working? + buf = self.tn.read_until(b"SUCCESS", 2) + self._seek_to_eof() + blist = buf.split('\r\n') + return blist + + # + # random maybe useful vpn commands + # + + def pid(self): + #XXX broken + return self._send_short_command("pid") + + +class OpenVPNConnection(Connection, OpenVPNManagement): """ All related to invocation - of the openvpn binary + of the openvpn binary. + It's extended by EIPConnection. """ + # XXX Inheriting from Connection was an early design idea + # but currently that's an empty class. + # We can get rid of that if we don't use it for sharing + # state with other leap modules. + def __init__(self, watcher_cb=None, debug=False, @@ -34,24 +144,21 @@ class OpenVPNConnection(Connection): password=None, *args, **kwargs): """ - :param config_file: configuration file to read from :param watcher_cb: callback to be \ called for each line in watched stdout :param signal_map: dictionary of signal names and callables \ to be triggered for each one of them. - :type config_file: str :type watcher_cb: function :type signal_map: dict """ #XXX FIXME #change watcher_cb to line_observer + # XXX if not host: raise ImproperlyConfigured logger.debug('init openvpn connection') self.debug = debug - # XXX if not host: raise ImproperlyConfigured self.ovpn_verbosity = kwargs.get('ovpn_verbosity', None) - #self.config_file = config_file self.watcher_cb = watcher_cb #self.signal_maps = signal_maps @@ -62,21 +169,13 @@ to be triggered for each one of them. self.port = None self.proto = None - #XXX workaround for signaling - #the ui that we don't know how to - #manage a connection error - #self.with_errors = False - self.command = None self.args = None # XXX get autostart from config self.autostart = True - # - # management init methods - # - + # management interface init self.host = host if isinstance(port, str) and port.isdigit(): port = int(port) @@ -88,101 +187,47 @@ to be triggered for each one of them. self.password = password def run_openvpn_checks(self): + """ + runs check needed before launching + openvpn subprocess. will raise if errors found. + """ logger.debug('running openvpn checks') + # XXX I think that "check_if_running" should be called + # from try openvpn connection instead. -- kali. + # let's prepare tests for that before changing it... self._check_if_running_instance() self._set_ovpn_command() self._check_vpn_keys() - def _set_ovpn_command(self): - # XXX check also for command-line --command flag - try: - command, args = eip_config.build_ovpn_command( - provider=self.provider, - debug=self.debug, - socket_path=self.host, - ovpn_verbosity=self.ovpn_verbosity) - except eip_exceptions.EIPNoPolkitAuthAgentAvailable: - command = args = None - raise - except eip_exceptions.EIPNoPkexecAvailable: - command = args = None - raise - - # XXX if not command, signal error. - self.command = command - self.args = args - - def _check_vpn_keys(self): - """ - checks for correct permissions on vpn keys - """ - try: - eip_config.check_vpn_keys(provider=self.provider) - except eip_exceptions.EIPInitBadKeyFilePermError: - logger.error('Bad VPN Keys permission!') - # do nothing now - # and raise the rest ... - - def _launch_openvpn(self): - """ - invocation of openvpn binaries in a subprocess. - """ - #XXX TODO: - #deprecate watcher_cb, - #use _only_ signal_maps instead - - logger.debug('_launch_openvpn called') - if self.watcher_cb is not None: - linewrite_callback = self.watcher_cb - else: - #XXX get logger instead - linewrite_callback = lambda line: print('watcher: %s' % line) - - # the partial is not - # being applied now because we're not observing the process - # stdout like we did in the early stages. but I leave it - # here since it will be handy for observing patterns in the - # thru-the-manager updates (with regex) - observers = (linewrite_callback, - partial(lambda con_status, line: None, self.status)) - subp, watcher = spawn_and_watch_process( - self.command, - self.args, - observers=observers) - self.subp = subp - self.watcher = watcher - - def _try_connection(self): + def try_openvpn_connection(self): """ attempts to connect """ + # XXX should make public method if self.command is None: raise eip_exceptions.EIPNoCommandError if self.subp is not None: logger.debug('cowardly refusing to launch subprocess again') + # XXX this is not returning ???!! + # FIXME -- so it's calling it all the same!! self._launch_openvpn() - def _check_if_running_instance(self): + def connected(self): """ - check if openvpn is already running + Returns True if connected + rtype: bool """ - for process in psutil.get_process_list(): - if process.name == "openvpn": - logger.debug('an openvpn instance is already running.') - logger.debug('attempting to stop openvpn instance.') - if not self._stop(): - raise eip_exceptions.OpenVPNAlreadyRunning - - logger.debug('no openvpn instance found.') + # XXX make a property + return hasattr(self, 'tn') - def cleanup(self): + def terminate_openvpn_connection(self, shutdown=False): """ terminates openvpn child subprocess """ if self.subp: try: - self._stop() + self._stop_openvpn() except eip_exceptions.ConnectionRefusedError: logger.warning( 'unable to send sigterm signal to openvpn: ' @@ -201,9 +246,10 @@ to be triggered for each one of them. 'cannot terminate subprocess! Retcode %s' '(We might have left openvpn running)' % RETCODE) - self.cleanup_tempfiles() + if shutdown: + self._cleanup_tempfiles() - def cleanup_tempfiles(self): + def _cleanup_tempfiles(self): """ remove all temporal files we might have left behind @@ -223,172 +269,93 @@ to be triggered for each one of them. except OSError: logger.error('could not delete tmpfolder %s' % tempfolder) - def _get_openvpn_process(self): - # plist = [p for p in psutil.get_process_list() if p.name == "openvpn"] - # return plist[0] if plist else None - for process in psutil.get_process_list(): - if process.name == "openvpn": - return process - return None + # checks - # management methods - # - # XXX REVIEW-ME - # REFACTOR INFO: (former "manager". - # Can we move to another - # base class to test independently?) - # - - #def forget_errors(self): - #logger.debug('forgetting errors') - #self.with_errors = False - - def connect_to_management(self): - """Connect to openvpn management interface""" - #logger.debug('connecting socket') - if hasattr(self, 'tn'): - self.close() - self.tn = UDSTelnet(self.host, self.port) - - # XXX make password optional - # specially for win. we should generate - # the pass on the fly when invoking manager - # from conductor - - #self.tn.read_until('ENTER PASSWORD:', 2) - #self.tn.write(self.password + '\n') - #self.tn.read_until('SUCCESS:', 2) - if self.tn: - self._seek_to_eof() - return True - - def _seek_to_eof(self): + def _check_if_running_instance(self): """ - Read as much as available. Position seek pointer to end of stream + check if openvpn is already running """ try: - b = self.tn.read_eager() - except EOFError: - logger.debug("Could not read from socket. Assuming it died.") - return - while b: - try: - b = self.tn.read_eager() - except EOFError: - logger.debug("Could not read from socket. Assuming it died.") + for process in psutil.get_process_list(): + if process.name == "openvpn": + logger.debug('an openvpn instance is already running.') + logger.debug('attempting to stop openvpn instance.') + if not self._stop_openvpn(): + raise eip_exceptions.OpenVPNAlreadyRunning - def connected(self): - """ - Returns True if connected - rtype: bool - """ - return hasattr(self, 'tn') + except psutil.error.NoSuchProcess: + logger.debug('detected a process which died. passing.') - def close(self, announce=True): - """ - Close connection to openvpn management interface - """ - logger.debug('closing socket') - if announce: - self.tn.write("quit\n") - self.tn.read_all() - self.tn.get_socket().close() - del self.tn + logger.debug('no openvpn instance found.') - def _send_command(self, cmd): - """ - Send a command to openvpn and return response as list - """ - if not self.connected(): - try: - self.connect_to_management() - except eip_exceptions.MissingSocketError: - logger.warning('missing management socket') - return [] + def _set_ovpn_command(self): try: - if hasattr(self, 'tn'): - self.tn.write(cmd + "\n") - except socket.error: - logger.error('socket error') - self.close(announce=False) - return [] - buf = self.tn.read_until(b"END", 2) - self._seek_to_eof() - blist = buf.split('\r\n') - if blist[-1].startswith('END'): - del blist[-1] - return blist - else: - return [] - - def _send_short_command(self, cmd): - """ - parse output from commands that are - delimited by "success" instead - """ - if not self.connected(): - self.connect() - self.tn.write(cmd + "\n") - # XXX not working? - buf = self.tn.read_until(b"SUCCESS", 2) - self._seek_to_eof() - blist = buf.split('\r\n') - return blist - - # - # useful vpn commands - # + command, args = eip_config.build_ovpn_command( + provider=self.provider, + debug=self.debug, + socket_path=self.host, + ovpn_verbosity=self.ovpn_verbosity) + except eip_exceptions.EIPNoPolkitAuthAgentAvailable: + command = args = None + raise + except eip_exceptions.EIPNoPkexecAvailable: + command = args = None + raise - def pid(self): - #XXX broken - return self._send_short_command("pid") + # XXX if not command, signal error. + self.command = command + self.args = args - def make_error(self): + def _check_vpn_keys(self): """ - capture error and wrap it in an - understandable format + checks for correct permissions on vpn keys """ - #XXX get helpful error codes - self.with_errors = True - now = int(time.time()) - return '%s,LAUNCHER ERROR,ERROR,-,-' % now + try: + eip_config.check_vpn_keys(provider=self.provider) + except eip_exceptions.EIPInitBadKeyFilePermError: + logger.error('Bad VPN Keys permission!') + # do nothing now + # and raise the rest ... - def state(self): - """ - OpenVPN command: state - """ - state = self._send_command("state") - if not state: - return None - if isinstance(state, str): - return state - if isinstance(state, list): - if len(state) == 1: - return state[0] - else: - return state[-1] + # starting and stopping openvpn subprocess - def vpn_status(self): + def _launch_openvpn(self): """ - OpenVPN command: status + invocation of openvpn binaries in a subprocess. """ - #logger.debug('status called') - status = self._send_command("status") - return status + #XXX TODO: + #deprecate watcher_cb, + #use _only_ signal_maps instead - def vpn_status2(self): - """ - OpenVPN command: last 2 statuses - """ - return self._send_command("status 2") + logger.debug('_launch_openvpn called') + if self.watcher_cb is not None: + linewrite_callback = self.watcher_cb + else: + #XXX get logger instead + linewrite_callback = lambda line: print('watcher: %s' % line) - def _stop(self): + # the partial is not + # being applied now because we're not observing the process + # stdout like we did in the early stages. but I leave it + # here since it will be handy for observing patterns in the + # thru-the-manager updates (with regex) + observers = (linewrite_callback, + partial(lambda con_status, line: None, self.status)) + subp, watcher = spawn_and_watch_process( + self.command, + self.args, + observers=observers) + self.subp = subp + self.watcher = watcher + + def _stop_openvpn(self): """ stop openvpn process by sending SIGTERM to the management interface """ - logger.debug("disconnecting...") + # XXX method a bit too long, split + logger.debug("terminating openvpn process...") if self.connected(): try: self._send_command("signal SIGTERM\n") @@ -407,8 +374,9 @@ to be triggered for each one of them. logger.debug('process :%s' % process) cmdline = process.cmdline - if isinstance(cmdline, list): - _index = cmdline.index("--management") + manag_flag = "--management" + if isinstance(cmdline, list) and manag_flag in cmdline: + _index = cmdline.index(manag_flag) self.host = cmdline[_index + 1] self._send_command("signal SIGTERM\n") @@ -423,38 +391,10 @@ to be triggered for each one of them. return True - # - # parse info - # - - def get_status_io(self): - status = self.vpn_status() - if isinstance(status, str): - lines = status.split('\n') - if isinstance(status, list): - lines = status - try: - (header, when, tun_read, tun_write, - tcp_read, tcp_write, auth_read) = tuple(lines) - except ValueError: - return None - - when_ts = time.strptime(when.split(',')[1], "%a %b %d %H:%M:%S %Y") - sep = ',' - # XXX cleanup! - tun_read = tun_read.split(sep)[1] - tun_write = tun_write.split(sep)[1] - tcp_read = tcp_read.split(sep)[1] - tcp_write = tcp_write.split(sep)[1] - auth_read = auth_read.split(sep)[1] - - # XXX this could be a named tuple. prettier. - return when_ts, (tun_read, tun_write, tcp_read, tcp_write, auth_read) - - def get_connection_state(self): - state = self.state() - if state is not None: - ts, status_step, ok, ip, remote = state.split(',') - ts = time.gmtime(float(ts)) - # XXX this could be a named tuple. prettier. - return ts, status_step, ok, ip, remote + def _get_openvpn_process(self): + # plist = [p for p in psutil.get_process_list() if p.name == "openvpn"] + # return plist[0] if plist else None + for process in psutil.get_process_list(): + if process.name == "openvpn": + return process + return None diff --git a/src/leap/eip/specs.py b/src/leap/eip/specs.py index 57e7537b..c41fd29b 100644 --- a/src/leap/eip/specs.py +++ b/src/leap/eip/specs.py @@ -77,12 +77,12 @@ eipconfig_spec = { }, 'primary_gateway': { 'type': unicode, - 'default': u"turkey", + 'default': u"location_unknown", #'required': True }, 'secondary_gateway': { 'type': unicode, - 'default': u"france" + 'default': u"location_unknown2" }, 'management_password': { 'type': unicode @@ -100,25 +100,37 @@ eipservice_config_spec = { 'default': 1 }, 'version': { - 'type': unicode, + 'type': int, 'required': True, - 'default': "0.1.0" + 'default': 1 }, - 'capabilities': { - 'type': dict, - 'default': { - "transport": ["openvpn"], - "ports": ["80", "53"], - "protocols": ["udp", "tcp"], - "static_ips": True, - "adblock": True} + 'clusters': { + 'type': list, + 'default': [ + {"label": { + "en": "Location Unknown"}, + "name": "location_unknown"}] }, 'gateways': { 'type': list, - 'default': [{"country_code": "us", - "label": {"en":"west"}, - "capabilities": {}, - "hosts": ["1.2.3.4", "1.2.3.5"]}] + 'default': [ + {"capabilities": { + "adblock": True, + "filter_dns": True, + "ports": ["80", "53", "443", "1194"], + "protocols": ["udp", "tcp"], + "transport": ["openvpn"], + "user_ips": False}, + "cluster": "location_unknown", + "host": "location.example.org", + "ip_address": "127.0.0.1"}] + }, + 'openvpn_configuration': { + 'type': dict, + 'default': { + "auth": None, + "cipher": None, + "tls-cipher": None} } } } diff --git a/src/leap/eip/tests/data.py b/src/leap/eip/tests/data.py index cadf720e..a7fe1853 100644 --- a/src/leap/eip/tests/data.py +++ b/src/leap/eip/tests/data.py @@ -23,26 +23,29 @@ EIP_SAMPLE_CONFIG = { "keys/client/openvpn.pem" % PROVIDER), "connect_on_login": True, "block_cleartext_traffic": True, - "primary_gateway": "turkey", - "secondary_gateway": "france", + "primary_gateway": "location_unknown", + "secondary_gateway": "location_unknown2", #"management_password": "oph7Que1othahwiech6J" } EIP_SAMPLE_SERVICE = { "serial": 1, - "version": "0.1.0", - "capabilities": { - "transport": ["openvpn"], - "ports": ["80", "53"], - "protocols": ["udp", "tcp"], - "static_ips": True, - "adblock": True - }, + "version": 1, + "clusters": [ + {"label": { + "en": "Location Unknown"}, + "name": "location_unknown"} + ], "gateways": [ - {"country_code": "tr", - "name": "turkey", - "label": {"en":"Ankara, Turkey"}, - "capabilities": {}, - "hosts": ["192.0.43.10"]} + {"capabilities": { + "adblock": True, + "filter_dns": True, + "ports": ["80", "53", "443", "1194"], + "protocols": ["udp", "tcp"], + "transport": ["openvpn"], + "user_ips": False}, + "cluster": "location_unknown", + "host": "location.example.org", + "ip_address": "192.0.43.10"} ] } diff --git a/src/leap/eip/tests/test_checks.py b/src/leap/eip/tests/test_checks.py index 1d7bfc17..ab11037a 100644 --- a/src/leap/eip/tests/test_checks.py +++ b/src/leap/eip/tests/test_checks.py @@ -25,6 +25,7 @@ from leap.eip.tests import data as testdata from leap.testing.basetest import BaseLeapTest from leap.testing.https_server import BaseHTTPSServerTestCase from leap.testing.https_server import where as where_cert +from leap.util.fileutil import mkdir_f class NoLogRequestHandler: @@ -118,6 +119,7 @@ class EIPCheckTest(BaseLeapTest): sampleconfig = copy.copy(testdata.EIP_SAMPLE_CONFIG) sampleconfig['provider'] = None eipcfg_path = checker.eipconfig.filename + mkdir_f(eipcfg_path) with open(eipcfg_path, 'w') as fp: json.dump(sampleconfig, fp) #with self.assertRaises(eipexceptions.EIPMissingDefaultProvider): @@ -138,6 +140,8 @@ class EIPCheckTest(BaseLeapTest): def test_fetch_definition(self): with patch.object(requests, "get") as mocked_get: mocked_get.return_value.status_code = 200 + mocked_get.return_value.headers = { + 'last-modified': "Wed Dec 12 12:12:12 GMT 2012"} mocked_get.return_value.json = DEFAULT_PROVIDER_DEFINITION checker = eipchecks.EIPConfigChecker(fetcher=requests) sampleconfig = testdata.EIP_SAMPLE_CONFIG @@ -156,6 +160,8 @@ class EIPCheckTest(BaseLeapTest): def test_fetch_eip_service_config(self): with patch.object(requests, "get") as mocked_get: mocked_get.return_value.status_code = 200 + mocked_get.return_value.headers = { + 'last-modified': "Wed Dec 12 12:12:12 GMT 2012"} mocked_get.return_value.json = testdata.EIP_SAMPLE_SERVICE checker = eipchecks.EIPConfigChecker(fetcher=requests) sampleconfig = testdata.EIP_SAMPLE_CONFIG diff --git a/src/leap/eip/tests/test_config.py b/src/leap/eip/tests/test_config.py index 50538240..5977ef3c 100644 --- a/src/leap/eip/tests/test_config.py +++ b/src/leap/eip/tests/test_config.py @@ -1,3 +1,4 @@ +from collections import OrderedDict import json import os import platform @@ -10,11 +11,11 @@ except ImportError: #from leap.base import constants #from leap.eip import config as eip_config -from leap import __branding as BRANDING +#from leap import __branding as BRANDING from leap.eip import config as eipconfig from leap.eip.tests.data import EIP_SAMPLE_CONFIG, EIP_SAMPLE_SERVICE from leap.testing.basetest import BaseLeapTest -from leap.util.fileutil import mkdir_p +from leap.util.fileutil import mkdir_p, mkdir_f _system = platform.system() @@ -47,11 +48,22 @@ class EIPConfigTest(BaseLeapTest): open(tfile, 'wb').close() os.chmod(tfile, stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR) - def write_sample_eipservice(self): + def write_sample_eipservice(self, vpnciphers=False, extra_vpnopts=None, + gateways=None): conf = eipconfig.EIPServiceConfig() - folder, f = os.path.split(conf.filename) - if not os.path.isdir(folder): - mkdir_p(folder) + mkdir_f(conf.filename) + if gateways: + EIP_SAMPLE_SERVICE['gateways'] = gateways + if vpnciphers: + openvpnconfig = OrderedDict({ + "auth": "SHA1", + "cipher": "AES-128-CBC", + "tls-cipher": "DHE-RSA-AES128-SHA"}) + if extra_vpnopts: + for k, v in extra_vpnopts.items(): + openvpnconfig[k] = v + EIP_SAMPLE_SERVICE['openvpn_configuration'] = openvpnconfig + with open(conf.filename, 'w') as fd: fd.write(json.dumps(EIP_SAMPLE_SERVICE)) @@ -63,8 +75,17 @@ class EIPConfigTest(BaseLeapTest): with open(conf.filename, 'w') as fd: fd.write(json.dumps(EIP_SAMPLE_CONFIG)) - def get_expected_openvpn_args(self): + def get_expected_openvpn_args(self, with_openvpn_ciphers=False): + """ + yeah, this is almost as duplicating the + code for building the command + """ args = [] + eipconf = eipconfig.EIPConfig(domain=self.provider) + eipconf.load() + eipsconf = eipconfig.EIPServiceConfig(domain=self.provider) + eipsconf.load() + username = self.get_username() groupname = self.get_groupname() @@ -75,8 +96,10 @@ class EIPConfigTest(BaseLeapTest): args.append('--persist-tun') args.append('--persist-key') args.append('--remote') + args.append('%s' % eipconfig.get_eip_gateway( - provider=self.provider)) + eipconfig=eipconf, + eipserviceconfig=eipsconf)) # XXX get port!? args.append('1194') # XXX get proto @@ -85,6 +108,14 @@ class EIPConfigTest(BaseLeapTest): args.append('--remote-cert-tls') args.append('server') + if with_openvpn_ciphers: + CIPHERS = [ + "--tls-cipher", "DHE-RSA-AES128-SHA", + "--cipher", "AES-128-CBC", + "--auth", "SHA1"] + for opt in CIPHERS: + args.append(opt) + args.append('--user') args.append(username) args.append('--group') @@ -130,6 +161,55 @@ class EIPConfigTest(BaseLeapTest): # params in the function call, to disable # some checks. + def test_get_eip_gateway(self): + self.write_sample_eipconfig() + eipconf = eipconfig.EIPConfig(domain=self.provider) + + # default eipservice + self.write_sample_eipservice() + eipsconf = eipconfig.EIPServiceConfig(domain=self.provider) + + gateway = eipconfig.get_eip_gateway( + eipconfig=eipconf, + eipserviceconfig=eipsconf) + + # in spec is local gateway by default + self.assertEqual(gateway, '127.0.0.1') + + # change eipservice + # right now we only check that cluster == selected primary gw in + # eip.json, and pick first matching ip + eipconf._config.config['primary_gateway'] = "foo_provider" + newgateways = [{"cluster": "foo_provider", + "ip_address": "127.0.0.99"}] + self.write_sample_eipservice(gateways=newgateways) + eipsconf = eipconfig.EIPServiceConfig(domain=self.provider) + # load from disk file + eipsconf.load() + + gateway = eipconfig.get_eip_gateway( + eipconfig=eipconf, + eipserviceconfig=eipsconf) + self.assertEqual(gateway, '127.0.0.99') + + # change eipservice, several gateways + # right now we only check that cluster == selected primary gw in + # eip.json, and pick first matching ip + eipconf._config.config['primary_gateway'] = "bar_provider" + newgateways = [{"cluster": "foo_provider", + "ip_address": "127.0.0.99"}, + {'cluster': "bar_provider", + "ip_address": "127.0.0.88"}] + self.write_sample_eipservice(gateways=newgateways) + eipsconf = eipconfig.EIPServiceConfig(domain=self.provider) + # load from disk file + eipsconf.load() + + gateway = eipconfig.get_eip_gateway( + eipconfig=eipconf, + eipserviceconfig=eipsconf) + self.assertEqual(gateway, '127.0.0.88') + def test_build_ovpn_command_empty_config(self): self.touch_exec() self.write_sample_eipservice() @@ -139,14 +219,63 @@ class EIPConfigTest(BaseLeapTest): from leap.util.fileutil import which path = os.environ['PATH'] vpnbin = which('openvpn', path=path) - print 'path =', path - print 'vpnbin = ', vpnbin - command, args = eipconfig.build_ovpn_command( + #print 'path =', path + #print 'vpnbin = ', vpnbin + vpncommand, vpnargs = eipconfig.build_ovpn_command( + do_pkexec_check=False, vpnbin=vpnbin, + socket_path="/tmp/test.socket", + provider=self.provider) + self.assertEqual(vpncommand, self.home + '/bin/openvpn') + self.assertEqual(vpnargs, self.get_expected_openvpn_args()) + + def test_build_ovpn_command_openvpnoptions(self): + self.touch_exec() + + from leap.eip import config as eipconfig + from leap.util.fileutil import which + path = os.environ['PATH'] + vpnbin = which('openvpn', path=path) + + self.write_sample_eipconfig() + + # regular run, everything normal + self.write_sample_eipservice(vpnciphers=True) + vpncommand, vpnargs = eipconfig.build_ovpn_command( + do_pkexec_check=False, vpnbin=vpnbin, + socket_path="/tmp/test.socket", + provider=self.provider) + self.assertEqual(vpncommand, self.home + '/bin/openvpn') + expected = self.get_expected_openvpn_args( + with_openvpn_ciphers=True) + self.assertEqual(vpnargs, expected) + + # bad options -- illegal options + self.write_sample_eipservice( + vpnciphers=True, + # WE ONLY ALLOW vpn options in auth, cipher, tls-cipher + extra_vpnopts={"notallowedconfig": "badvalue"}) + vpncommand, vpnargs = eipconfig.build_ovpn_command( + do_pkexec_check=False, vpnbin=vpnbin, + socket_path="/tmp/test.socket", + provider=self.provider) + self.assertEqual(vpncommand, self.home + '/bin/openvpn') + expected = self.get_expected_openvpn_args( + with_openvpn_ciphers=True) + self.assertEqual(vpnargs, expected) + + # bad options -- illegal chars + self.write_sample_eipservice( + vpnciphers=True, + # WE ONLY ALLOW A-Z09\- + extra_vpnopts={"cipher": "AES-128-CBC;FOOTHING"}) + vpncommand, vpnargs = eipconfig.build_ovpn_command( do_pkexec_check=False, vpnbin=vpnbin, socket_path="/tmp/test.socket", provider=self.provider) - self.assertEqual(command, self.home + '/bin/openvpn') - self.assertEqual(args, self.get_expected_openvpn_args()) + self.assertEqual(vpncommand, self.home + '/bin/openvpn') + expected = self.get_expected_openvpn_args( + with_openvpn_ciphers=True) + self.assertEqual(vpnargs, expected) if __name__ == "__main__": diff --git a/src/leap/eip/tests/test_eipconnection.py b/src/leap/eip/tests/test_eipconnection.py index aefca36f..163f8d45 100644 --- a/src/leap/eip/tests/test_eipconnection.py +++ b/src/leap/eip/tests/test_eipconnection.py @@ -1,6 +1,8 @@ +import glob import logging import platform -import os +#import os +import shutil logging.basicConfig() logger = logging.getLogger(name=__name__) @@ -66,11 +68,26 @@ class EIPConductorTest(BaseLeapTest): self.manager = Mock(name="openvpnmanager_mock") self.con = MockedEIPConnection() self.con.provider = self.provider + + # XXX watch out. This sometimes is throwing the following error: + # NoSuchProcess: process no longer exists (pid=6571) + # because of a bad implementation of _check_if_running_instance + self.con.run_openvpn_checks() def tearDown(self): + pass + + def doCleanups(self): + super(BaseLeapTest, self).doCleanups() + self.cleanupSocketDir() del self.con + def cleanupSocketDir(self): + ptt = ('/tmp/leap-tmp*') + for tmpdir in glob.glob(ptt): + shutil.rmtree(tmpdir) + # # tests # @@ -81,6 +98,7 @@ class EIPConductorTest(BaseLeapTest): """ con = self.con self.assertEqual(con.autostart, True) + # XXX moar! def test_ovpn_command(self): """ @@ -98,6 +116,7 @@ class EIPConductorTest(BaseLeapTest): # needed to run tests. (roughly 3 secs for this only) # We should modularize and inject Mocks on more places. + oldcon = self.con del(self.con) config_checker = Mock() self.con = MockedEIPConnection(config_checker=config_checker) @@ -107,6 +126,7 @@ class EIPConductorTest(BaseLeapTest): skip_download=False) # XXX test for cert_checker also + self.con = oldcon # connect/disconnect calls @@ -123,9 +143,14 @@ class EIPConductorTest(BaseLeapTest): self.con.status.CONNECTED) # disconnect - self.con.cleanup = Mock() + self.con.terminate_openvpn_connection = Mock() self.con.disconnect() - self.con.cleanup.assert_called_once_with() + self.con.terminate_openvpn_connection.assert_called_once_with( + shutdown=False) + self.con.terminate_openvpn_connection = Mock() + self.con.disconnect(shutdown=True) + self.con.terminate_openvpn_connection.assert_called_once_with( + shutdown=True) # new status should be disconnected # XXX this should evolve and check no errors diff --git a/src/leap/eip/tests/test_openvpnconnection.py b/src/leap/eip/tests/test_openvpnconnection.py index 0f27facf..f7493567 100644 --- a/src/leap/eip/tests/test_openvpnconnection.py +++ b/src/leap/eip/tests/test_openvpnconnection.py @@ -58,16 +58,27 @@ class OpenVPNConnectionTest(BaseLeapTest): def setUp(self): # XXX this will have to change for win, host=localhost host = eipconfig.get_socket_path() + self.host = host self.manager = MockedOpenVPNConnection(host=host) def tearDown(self): + pass + + def doCleanups(self): + super(BaseLeapTest, self).doCleanups() + self.cleanupSocketDir() + + def cleanupSocketDir(self): # remove the socket folder. # XXX only if posix. in win, host is localhost, so nothing # has to be done. - if self.manager.host: - folder, fpath = os.path.split(self.manager.host) - assert folder.startswith('/tmp/leap-tmp') # safety check - shutil.rmtree(folder) + if self.host: + folder, fpath = os.path.split(self.host) + try: + assert folder.startswith('/tmp/leap-tmp') # safety check + shutil.rmtree(folder) + except: + self.fail("could not remove temp file") del self.manager @@ -108,12 +119,14 @@ class OpenVPNConnectionTest(BaseLeapTest): self.assertEqual(self.manager.port, 7777) def test_port_types_init(self): + oldmanager = self.manager self.manager = MockedOpenVPNConnection(port="42") self.assertEqual(self.manager.port, 42) self.manager = MockedOpenVPNConnection() self.assertEqual(self.manager.port, "unix") self.manager = MockedOpenVPNConnection(port="bad") self.assertEqual(self.manager.port, None) + self.manager = oldmanager def test_uds_telnet_called_on_connect(self): self.manager.connect_to_management() diff --git a/src/leap/gui/__init__.py b/src/leap/gui/__init__.py index 9b8f8746..804bfbc1 100644 --- a/src/leap/gui/__init__.py +++ b/src/leap/gui/__init__.py @@ -6,5 +6,6 @@ except ValueError: pass import firstrun +import firstrun.wizard -__all__ = ['firstrun'] +__all__ = ['firstrun', 'firstrun.wizard'] diff --git a/src/leap/gui/firstrun/__init__.py b/src/leap/gui/firstrun/__init__.py index 8a70d90e..d380b75a 100644 --- a/src/leap/gui/firstrun/__init__.py +++ b/src/leap/gui/firstrun/__init__.py @@ -5,7 +5,6 @@ try: except ValueError: pass -import connect import intro import last import login @@ -17,7 +16,6 @@ import register import regvalidation __all__ = [ - 'connect', 'intro', 'last', 'login', @@ -26,4 +24,4 @@ __all__ = [ 'providerselect', 'providersetup', 'register', - 'regvalidation'] + 'regvalidation'] # ,'wizard'] diff --git a/src/leap/gui/firstrun/connect.py b/src/leap/gui/firstrun/connect.py deleted file mode 100644 index a0fe021c..00000000 --- a/src/leap/gui/firstrun/connect.py +++ /dev/null @@ -1,231 +0,0 @@ -""" -Connecting Page, used in First Run Wizard -""" -# XXX FIXME -# DEPRECATED. All functionality moved to regvalidation -# This file should be removed after checking that one is ok. -# XXX - -import logging - -from PyQt4 import QtGui - -logger = logging.getLogger(__name__) - -from leap.base import auth - -from leap.gui.constants import APP_LOGO -from leap.gui.styles import ErrorLabelStyleSheet - - -class ConnectingPage(QtGui.QWizardPage): - - # XXX change to a ValidationPage - - def __init__(self, parent=None): - super(ConnectingPage, self).__init__(parent) - - self.setTitle("Connecting") - self.setSubTitle('Connecting to provider.') - - self.setPixmap( - QtGui.QWizard.LogoPixmap, - QtGui.QPixmap(APP_LOGO)) - - self.status = QtGui.QLabel("") - self.status.setWordWrap(True) - self.progress = QtGui.QProgressBar() - self.progress.setMaximum(100) - self.progress.hide() - - # for pre-checks - self.status_line_1 = QtGui.QLabel() - self.status_line_2 = QtGui.QLabel() - self.status_line_3 = QtGui.QLabel() - self.status_line_4 = QtGui.QLabel() - - # for connecting signals... - self.status_line_5 = QtGui.QLabel() - - layout = QtGui.QGridLayout() - layout.addWidget(self.status, 0, 1) - layout.addWidget(self.progress, 5, 1) - layout.addWidget(self.status_line_1, 8, 1) - layout.addWidget(self.status_line_2, 9, 1) - layout.addWidget(self.status_line_3, 10, 1) - layout.addWidget(self.status_line_4, 11, 1) - - # XXX to be used? - #self.validation_status = QtGui.QLabel("") - #self.validation_status.setStyleSheet( - #ErrorLabelStyleSheet) - #self.validation_msg = QtGui.QLabel("") - - self.setLayout(layout) - - self.goto_login_again = False - - def set_status(self, status): - self.status.setText(status) - self.status.setWordWrap(True) - - def set_status_line(self, line, status): - line = getattr(self, 'status_line_%s' % line) - if line: - line.setText(status) - - def set_validation_status(self, status): - # Do not remember if we're using - # status lines > 3 now... - # if we are, move below - self.status_line_3.setStyleSheet( - ErrorLabelStyleSheet) - self.status_line_3.setText(status) - - def set_validation_message(self, message): - self.status_line_4.setText(message) - self.status_line_4.setWordWrap(True) - - def get_donemsg(self, msg): - return "%s ... done" % msg - - def run_eip_checks_for_provider_and_connect(self, domain): - wizard = self.wizard() - conductor = wizard.conductor - start_eip_signal = getattr( - wizard, - 'start_eipconnection_signal', None) - - if conductor: - conductor.set_provider_domain(domain) - conductor.run_checks() - self.conductor = conductor - errors = self.eip_error_check() - if not errors and start_eip_signal: - start_eip_signal.emit() - - else: - logger.warning( - "No conductor found. This means that " - "probably the wizard has been launched " - "in an stand-alone way") - - def eip_error_check(self): - """ - a version of the main app error checker, - but integrated within the connecting page of the wizard. - consumes the conductor error queue. - pops errors, and add those to the wizard page - """ - logger.debug('eip error check from connecting page') - errq = self.conductor.error_queue - # XXX missing! - - def fetch_and_validate(self): - # XXX MOVE TO validate function in register-validation - import time - domain = self.field('provider_domain') - wizard = self.wizard() - #pconfig = wizard.providerconfig - eipconfigchecker = wizard.eipconfigchecker() - pCertChecker = wizard.providercertchecker( - domain=domain) - - # username and password are in different fields - # if they were stored in log_in or sign_up pages. - from_login = self.wizard().from_login - unamek_base = 'userName' - passwk_base = 'userPassword' - unamek = 'login_%s' % unamek_base if from_login else unamek_base - passwk = 'login_%s' % passwk_base if from_login else passwk_base - - username = self.field(unamek) - password = self.field(passwk) - credentials = username, password - - self.progress.show() - - fetching_eip_conf_msg = 'Fetching eip service configuration' - self.set_status(fetching_eip_conf_msg) - self.progress.setValue(30) - - # Fetching eip service - eipconfigchecker.fetch_eip_service_config( - domain=domain) - - self.status_line_1.setText( - self.get_donemsg(fetching_eip_conf_msg)) - - getting_client_cert_msg = 'Getting client certificate' - self.set_status(getting_client_cert_msg) - self.progress.setValue(66) - - # Download cert - try: - pCertChecker.download_new_client_cert( - credentials=credentials, - # FIXME FIXME FIXME - # XXX FIX THIS!!!!! - # BUG #638. remove verify - # FIXME FIXME FIXME - verify=False) - except auth.SRPAuthenticationError as exc: - self.set_validation_status( - "Authentication error: %s" % exc.message) - return False - - time.sleep(2) - self.status_line_2.setText( - self.get_donemsg(getting_client_cert_msg)) - - validating_clientcert_msg = 'Validating client certificate' - self.set_status(validating_clientcert_msg) - self.progress.setValue(90) - time.sleep(2) - self.status_line_3.setText( - self.get_donemsg(validating_clientcert_msg)) - - self.progress.setValue(100) - time.sleep(3) - - # here we go! :) - self.run_eip_checks_for_provider_and_connect(domain) - - #self.validation_block = self.wait_for_validation_block() - - # XXX signal timeout! - return True - - # - # wizardpage methods - # - - def nextId(self): - wizard = self.wizard() - # XXX this does not work because - # page login has already been met - #if self.goto_login_again: - #next_ = "login" - #else: - #next_ = "lastpage" - next_ = "lastpage" - return wizard.get_page_index(next_) - - def initializePage(self): - # XXX if we're coming from signup page - # we could say something like - # 'registration successful!' - self.status.setText( - "We have " - "all we need to connect with the provider.<br><br> " - "Click <i>next</i> to continue. ") - self.progress.setValue(0) - self.progress.hide() - self.status_line_1.setText('') - self.status_line_2.setText('') - self.status_line_3.setText('') - - def validatePage(self): - # XXX remove - validated = self.fetch_and_validate() - return validated diff --git a/src/leap/gui/firstrun/intro.py b/src/leap/gui/firstrun/intro.py index 4bb008c7..0a7484e2 100644 --- a/src/leap/gui/firstrun/intro.py +++ b/src/leap/gui/firstrun/intro.py @@ -11,7 +11,7 @@ class IntroPage(QtGui.QWizardPage): def __init__(self, parent=None): super(IntroPage, self).__init__(parent) - self.setTitle("First run wizard.") + self.setTitle(self.tr("First run wizard.")) #self.setPixmap( #QtGui.QWizard.WatermarkPixmap, @@ -21,7 +21,7 @@ class IntroPage(QtGui.QWizardPage): QtGui.QWizard.LogoPixmap, QtGui.QPixmap(APP_LOGO)) - label = QtGui.QLabel( + label = QtGui.QLabel(self.tr( "Now we will guide you through " "some configuration that is needed before you " "can connect for the first time.<br><br>" @@ -29,16 +29,16 @@ class IntroPage(QtGui.QWizardPage): "you can find the wizard in the '<i>Settings</i>' menu from the " "main window.<br><br>" "Do you want to <b>sign up</b> for a new account, or <b>log " - "in</b> with an already existing username?<br>") + "in</b> with an already existing username?<br>")) label.setWordWrap(True) radiobuttonGroup = QtGui.QGroupBox() self.sign_up = QtGui.QRadioButton( - "Sign up for a new account.") + self.tr("Sign up for a new account.")) self.sign_up.setChecked(True) self.log_in = QtGui.QRadioButton( - "Log In with my credentials.") + self.tr("Log In with my credentials.")) radiobLayout = QtGui.QVBoxLayout() radiobLayout.addWidget(self.sign_up) diff --git a/src/leap/gui/firstrun/last.py b/src/leap/gui/firstrun/last.py index d33d2e77..1d8caca4 100644 --- a/src/leap/gui/firstrun/last.py +++ b/src/leap/gui/firstrun/last.py @@ -58,6 +58,8 @@ class LastPage(QtGui.QWizardPage): self.label.setText( "Click '<i>%s</i>' to end the wizard and " "save your settings." % finishText) + # XXX init network checker + # trigger signal @coroutine def eip_status_handler(self): diff --git a/src/leap/gui/firstrun/login.py b/src/leap/gui/firstrun/login.py index 02bace86..e7afee9f 100644 --- a/src/leap/gui/firstrun/login.py +++ b/src/leap/gui/firstrun/login.py @@ -82,6 +82,120 @@ class LogInPage(InlineValidationPage, UserFormMixIn): # InlineValidationPage #self.registerField('is_login_wizard') + # actual checks + + def _do_checks(self): + + full_username = self.userNameLineEdit.text() + ########################### + # 0) check user@domain form + ########################### + + def checkusername(): + if full_username.count('@') != 1: + return self.fail( + self.tr( + "Username must be in the username@provider form.")) + else: + return True + + yield(("head_sentinel", 0), checkusername) + + username, domain = full_username.split('@') + password = self.userPasswordLineEdit.text() + + # We try a call to an authenticated + # page here as a mean to catch + # srp authentication errors while + wizard = self.wizard() + eipconfigchecker = wizard.eipconfigchecker() + + ######################## + # 1) try name resolution + ######################## + # show the frame before going on... + QtCore.QMetaObject.invokeMethod( + self, "showStepsFrame") + + # Able to contact domain? + # can get definition? + # two-by-one + def resolvedomain(): + try: + eipconfigchecker.fetch_definition(domain=domain) + + # we're using requests here for all + # the possible error cases that it catches. + except requests.exceptions.ConnectionError as exc: + return self.fail(exc.message[1]) + except requests.exceptions.HTTPError as exc: + return self.fail(exc.message) + except Exception as exc: + # XXX get catchall error msg + return self.fail( + exc.message) + else: + return True + + yield((self.tr("Resolving domain name"), 20), resolvedomain) + + wizard.set_providerconfig( + eipconfigchecker.defaultprovider.config) + + ######################## + # 2) do authentication + ######################## + credentials = username, password + pCertChecker = wizard.providercertchecker( + domain=domain) + + def validate_credentials(): + ################# + # FIXME #BUG #638 + verify = False + + try: + pCertChecker.download_new_client_cert( + credentials=credentials, + verify=verify) + + except auth.SRPAuthenticationError as exc: + return self.fail( + self.tr("Authentication error: %s" % exc.message)) + + except Exception as exc: + return self.fail(exc.message) + + else: + return True + + yield(('Validating credentials', 60), validate_credentials) + + self.set_done() + yield(("end_sentinel", 100), lambda: None) + + def green_validation_status(self): + val = self.validationMsg + val.setText(self.tr('Credentials validated.')) + val.setStyleSheet(styles.GreenLineEdit) + + def on_checks_validation_ready(self): + """ + after checks + """ + if self.is_done(): + self.disableFields() + self.cleanup_errormsg() + self.clean_wizard_errors(self.current_page) + # make the user confirm the transition + # to next page. + self.nextText('&Next') + self.nextFocus() + self.green_validation_status() + self.do_confirm_next = True + + # ui update + def nextText(self, text): self.setButtonText( QtGui.QWizard.NextButton, text) @@ -94,12 +208,18 @@ class LogInPage(InlineValidationPage, UserFormMixIn): # InlineValidationPage self.wizard().button( QtGui.QWizard.NextButton).setDisabled(True) - def onUserNameEdit(self, *args): + def onUserNamePositionChanged(self, *args): if self.initial_username_sample: self.userNameLineEdit.setText('') # XXX set regular color self.initial_username_sample = None + def onUserNameTextChanged(self, *args): + if self.initial_username_sample: + k = args[0][-1] + self.initial_username_sample = None + self.userNameLineEdit.setText(k) + def disableFields(self): for field in (self.userNameLineEdit, self.userPasswordLineEdit): @@ -111,13 +231,8 @@ class LogInPage(InlineValidationPage, UserFormMixIn): # InlineValidationPage errors = self.wizard().get_validation_error( self.current_page) - #prev_er = getattr(self, 'prevalidation_error', None) showerr = self.validationMsg.setText - #if not errors and prev_er: - #showerr(prev_er) - #return -# if errors: bad_str = getattr(self, 'bad_string', None) cur_str = self.userNameLineEdit.text() @@ -128,9 +243,6 @@ class LogInPage(InlineValidationPage, UserFormMixIn): # InlineValidationPage self.bad_string = cur_str showerr(errors) else: - #if prev_er: - #showerr(prev_er) - #return # not the first time if cur_str == bad_str: showerr(errors) @@ -177,7 +289,9 @@ class LogInPage(InlineValidationPage, UserFormMixIn): # InlineValidationPage username = self.userNameLineEdit username.setText('username@provider.example.org') username.cursorPositionChanged.connect( - self.onUserNameEdit) + self.onUserNamePositionChanged) + username.textChanged.connect( + self.onUserNameTextChanged) self.initial_username_sample = True self.validationMsg.setText('') self.valFrame.hide() @@ -215,116 +329,3 @@ class LogInPage(InlineValidationPage, UserFormMixIn): # InlineValidationPage self.do_checks() return self.is_done() - - def _do_checks(self): - # XXX convert this to inline - - full_username = self.userNameLineEdit.text() - ########################### - # 0) check user@domain form - ########################### - - def checkusername(): - if full_username.count('@') != 1: - return self.fail( - self.tr( - "Username must be in the username@provider form.")) - else: - return True - - yield(("head_sentinel", 0), checkusername) - - # XXX I think this is not needed - # since we're also checking for the is_signup field. - #self.wizard().from_login = True - - username, domain = full_username.split('@') - password = self.userPasswordLineEdit.text() - - # We try a call to an authenticated - # page here as a mean to catch - # srp authentication errors while - wizard = self.wizard() - eipconfigchecker = wizard.eipconfigchecker() - - ######################## - # 1) try name resolution - ######################## - # show the frame before going on... - QtCore.QMetaObject.invokeMethod( - self, "showStepsFrame") - - # Able to contact domain? - # can get definition? - # two-by-one - def resolvedomain(): - try: - eipconfigchecker.fetch_definition(domain=domain) - - # we're using requests here for all - # the possible error cases that it catches. - except requests.exceptions.ConnectionError as exc: - return self.fail(exc.message[1]) - except requests.exceptions.HTTPError as exc: - return self.fail(exc.message) - except Exception as exc: - # XXX get catchall error msg - return self.fail( - exc.message) - - yield((self.tr("resolving domain name"), 20), resolvedomain) - - wizard.set_providerconfig( - eipconfigchecker.defaultprovider.config) - - ######################## - # 2) do authentication - ######################## - credentials = username, password - pCertChecker = wizard.providercertchecker( - domain=domain) - - def validate_credentials(): - ################# - # FIXME #BUG #638 - verify = False - - try: - pCertChecker.download_new_client_cert( - credentials=credentials, - verify=verify) - - except auth.SRPAuthenticationError as exc: - return self.fail( - self.tr("Authentication error: %s" % exc.message)) - - except Exception as exc: - return self.fail(exc.message) - - else: - return True - - yield(('Validating credentials', 20), validate_credentials) - - self.set_done() - yield(("end_sentinel", 0), lambda: None) - - def green_validation_status(self): - val = self.validationMsg - val.setText(self.tr('Credentials validated.')) - val.setStyleSheet(styles.GreenLineEdit) - - def on_checks_validation_ready(self): - """ - after checks - """ - if self.is_done(): - self.disableFields() - self.cleanup_errormsg() - self.clean_wizard_errors(self.current_page) - # make the user confirm the transition - # to next page. - self.nextText('&Next') - self.nextFocus() - self.green_validation_status() - self.do_confirm_next = True diff --git a/src/leap/gui/firstrun/providerselect.py b/src/leap/gui/firstrun/providerselect.py index a4be51a9..fd48f7f9 100644 --- a/src/leap/gui/firstrun/providerselect.py +++ b/src/leap/gui/firstrun/providerselect.py @@ -40,7 +40,7 @@ class SelectProviderPage(InlineValidationPage): self.did_cert_check = False - self.is_done = False + self.done = False self.setupSteps() self.setupUI() @@ -131,7 +131,7 @@ class SelectProviderPage(InlineValidationPage): # certinfo - def setupCertInfoGroup(self): + def setupCertInfoGroup(self): # pragma: no cover # XXX not used now. certinfoGroup = QtGui.QGroupBox( self.tr("Certificate validation")) @@ -188,7 +188,6 @@ class SelectProviderPage(InlineValidationPage): _domain = u"%s:%s" % (domain, port) if port != 443 else unicode(domain) netchecker = wizard.netchecker() - providercertchecker = wizard.providercertchecker() eipconfigchecker = wizard.eipconfigchecker(domain=_domain) @@ -205,6 +204,7 @@ class SelectProviderPage(InlineValidationPage): this domain """ try: + #import ipdb;ipdb.set_trace() netchecker.check_name_resolution( domain) @@ -306,7 +306,7 @@ class SelectProviderPage(InlineValidationPage): # done! - self.is_done = True + self.done = True yield(("end_sentinel", 100), lambda: None) def on_checks_validation_ready(self): @@ -316,7 +316,7 @@ class SelectProviderPage(InlineValidationPage): self.domain_checked = True self.completeChanged.emit() # let's set focus... - if self.is_done: + if self.is_done(): self.wizard().clean_validation_error(self.current_page) nextbutton = self.wizard().button(QtGui.QWizard.NextButton) nextbutton.setFocus() @@ -329,7 +329,7 @@ class SelectProviderPage(InlineValidationPage): def is_insecure_cert_trusted(self): return self.trustProviderCertCheckBox.isChecked() - def onTrustCheckChanged(self, state): + def onTrustCheckChanged(self, state): # pragma: no cover XXX checked = False if state == 2: checked = True @@ -342,7 +342,7 @@ class SelectProviderPage(InlineValidationPage): # trigger signal to redraw next button self.completeChanged.emit() - def add_cert_info(self, certinfo): + def add_cert_info(self, certinfo): # pragma: no cover XXX self.certWarning.setText( "Do you want to <b>trust this provider certificate?</b>") self.certInfo.setText( @@ -351,7 +351,7 @@ class SelectProviderPage(InlineValidationPage): self.certinfoGroup.show() def onProviderChanged(self, text): - self.is_done = False + self.done = False provider = self.providerNameEdit.text() if provider: self.providerCheckButton.setDisabled(False) @@ -374,7 +374,7 @@ class SelectProviderPage(InlineValidationPage): def isComplete(self): provider = self.providerNameEdit.text() - if not self.is_done: + if not self.is_done(): return False if not provider: @@ -383,7 +383,7 @@ class SelectProviderPage(InlineValidationPage): if self.is_insecure_cert_trusted(): return True if not self.did_cert_check: - if self.is_done: + if self.is_done(): # XXX sure? return True return False @@ -452,7 +452,7 @@ class SelectProviderPage(InlineValidationPage): if hasattr(self, 'certinfoGroup'): # XXX remove ? self.certinfoGroup.hide() - self.is_done = False + self.done = False self.providerCheckButton.setDisabled(True) self.valFrame.hide() self.steps.removeAllSteps() diff --git a/src/leap/gui/firstrun/register.py b/src/leap/gui/firstrun/register.py index e85723cb..4c811093 100644 --- a/src/leap/gui/firstrun/register.py +++ b/src/leap/gui/firstrun/register.py @@ -131,6 +131,16 @@ class RegisterUserPage(InlineValidationPage, UserFormMixIn): field.setDisabled(True) # error painting + def paintEvent(self, event): + """ + we hook our populate errors + on paintEvent because we need it to catch + when user enters the page coming from next, + and initializePage does not cover that case. + Maybe there's a better event to hook upon. + """ + super(RegisterUserPage, self).paintEvent(event) + self.populateErrors() def markRedAndGetFocus(self, field): field.setStyleSheet(styles.ErrorLineEdit) @@ -193,16 +203,21 @@ class RegisterUserPage(InlineValidationPage, UserFormMixIn): """ self.bad_string = None - def paintEvent(self, event): + def green_validation_status(self): + val = self.validationMsg + val.setText(self.tr('Registration succeeded!')) + val.setStyleSheet(styles.GreenLineEdit) + + def reset_validation_status(self): """ - we hook our populate errors - on paintEvent because we need it to catch - when user enters the page coming from next, - and initializePage does not cover that case. - Maybe there's a better event to hook upon. + empty the validation msg + and clean the inline validation widget. """ - super(RegisterUserPage, self).paintEvent(event) - self.populateErrors() + self.validationMsg.setText('') + self.steps.removeAllSteps() + self.clearTable() + + # actual checks def _do_checks(self): """ @@ -255,6 +270,7 @@ class RegisterUserPage(InlineValidationPage, UserFormMixIn): schema="https", provider=provider, verify=verify) + #import ipdb;ipdb.set_trace() try: ok, req = signup.register_user( username, password) @@ -277,9 +293,15 @@ class RegisterUserPage(InlineValidationPage, UserFormMixIn): self.tr( "Error during registration (%s)") % req.status_code) - validation_msgs = json.loads(req.content) - errors = validation_msgs.get('errors', None) - logger.debug('validation errors: %s' % validation_msgs) + try: + validation_msgs = json.loads(req.content) + errors = validation_msgs.get('errors', None) + logger.debug('validation errors: %s' % validation_msgs) + except ValueError: + # probably bad json returned + return self.fail( + self.tr( + "Could not register (bad response)")) if errors and errors.get('login', None): # XXX this sometimes catch the blank username @@ -287,11 +309,13 @@ class RegisterUserPage(InlineValidationPage, UserFormMixIn): return self.fail( self.tr('Username not available.')) + return True + logger.debug('registering user') yield(("registering with provider", 40), register) self.set_done() - yield(("end_sentinel", 0), lambda: None) + yield(("end_sentinel", 100), lambda: None) def on_checks_validation_ready(self): """ @@ -308,20 +332,6 @@ class RegisterUserPage(InlineValidationPage, UserFormMixIn): self.green_validation_status() self.do_confirm_next = True - def green_validation_status(self): - val = self.validationMsg - val.setText(self.tr('Registration succeeded!')) - val.setStyleSheet(styles.GreenLineEdit) - - def reset_validation_status(self): - """ - empty the validation msg - and clean the inline validation widget. - """ - self.validationMsg.setText('') - self.steps.removeAllSteps() - self.clearTable() - # pagewizard methods def validatePage(self): @@ -352,17 +362,26 @@ class RegisterUserPage(InlineValidationPage, UserFormMixIn): """ inits wizard page """ - provider = self.field('provider_domain') - self.setSubTitle( - self.tr("Register a new user with provider %s.") % - provider) + provider = unicode(self.field('provider_domain')) + if provider: + # here we should have provider + # but in tests we might not. + + # XXX this error causes a segfault on free() + # that we might want to get fixed ... + #self.setSubTitle( + #self.tr("Register a new user with provider %s.") % + #provider) + self.setSubTitle( + self.tr("Register a new user with provider %s." % + provider)) self.validationMsg.setText('') self.userPassword2LineEdit.setText('') self.valFrame.hide() def nextId(self): wizard = self.wizard() - if not wizard: - return + #if not wizard: + #return # XXX this should be called connect return wizard.get_page_index('signupvalidation') diff --git a/src/leap/gui/firstrun/regvalidation.py b/src/leap/gui/firstrun/regvalidation.py index 0e67834b..b86583e0 100644 --- a/src/leap/gui/firstrun/regvalidation.py +++ b/src/leap/gui/firstrun/regvalidation.py @@ -29,6 +29,7 @@ class RegisterUserValidationPage(ValidationPage): def __init__(self, parent=None): super(RegisterUserValidationPage, self).__init__(parent) + self.current_page = "signupvalidation" title = "Connecting..." # XXX uh... really? @@ -100,9 +101,12 @@ class RegisterUserValidationPage(ValidationPage): def fetcheipcert(): try: - pCertChecker.download_new_client_cert( + downloaded = pCertChecker.download_new_client_cert( credentials=credentials, verify=verify) + if not downloaded: + logger.error('Could not download client cert.') + return False except auth.SRPAuthenticationError as exc: return self.fail(self.tr( @@ -126,10 +130,12 @@ class RegisterUserValidationPage(ValidationPage): """ # this should be called CONNECT PAGE AGAIN. # here we go! :) - full_domain = self.field('provider_domain') - domain, port = get_https_domain_and_port(full_domain) - _domain = u"%s:%s" % (domain, port) if port != 443 else unicode(domain) - self.run_eip_checks_for_provider_and_connect(_domain) + if self.is_done(): + full_domain = self.field('provider_domain') + domain, port = get_https_domain_and_port(full_domain) + _domain = u"%s:%s" % ( + domain, port) if port != 443 else unicode(domain) + self.run_eip_checks_for_provider_and_connect(_domain) def run_eip_checks_for_provider_and_connect(self, domain): wizard = self.wizard() diff --git a/src/leap/gui/firstrun/tests/integration/fake_provider.py b/src/leap/gui/firstrun/tests/integration/fake_provider.py index 33ee0ee6..445b4487 100755 --- a/src/leap/gui/firstrun/tests/integration/fake_provider.py +++ b/src/leap/gui/firstrun/tests/integration/fake_provider.py @@ -40,6 +40,8 @@ from twisted.web.static import File from twisted.web.resource import Resource from twisted.internet import reactor +from leap.testing.https_server import where + # See # http://twistedmatrix.com/documents/current/web/howto/web-in-60/index.htmln # for more examples @@ -229,14 +231,13 @@ def get_certs_path(): def get_TLS_credentials(): # XXX this is giving errors # XXX REview! We want to use gnutls! - certs_path = get_certs_path() cert = crypto.X509Certificate( - open(certs_path + '/leaptestscert.pem').read()) + open(where('leaptestscert.pem')).read()) key = crypto.X509PrivateKey( - open(certs_path + '/leaptestskey.pem').read()) + open(where('leaptestskey.pem')).read()) ca = crypto.X509Certificate( - open(certs_path + '/cacert.pem').read()) + open(where('cacert.pem')).read()) #crl = crypto.X509CRL(open(certs_path + '/crl.pem').read()) #cred = crypto.X509Credentials(cert, key, [ca], [crl]) cred = X509Credentials(cert, key, [ca]) @@ -253,19 +254,17 @@ class OpenSSLServerContextFactory: """Create an SSL context. This is a sample implementation that loads a certificate from a file called 'server.pem'.""" - certs_path = get_certs_path() ctx = SSL.Context(SSL.SSLv23_METHOD) - ctx.use_certificate_file(certs_path + '/leaptestscert.pem') - ctx.use_privatekey_file(certs_path + '/leaptestskey.pem') + #certs_path = get_certs_path() + #ctx.use_certificate_file(certs_path + '/leaptestscert.pem') + #ctx.use_privatekey_file(certs_path + '/leaptestskey.pem') + ctx.use_certificate_file(where('leaptestscert.pem')) + ctx.use_privatekey_file(where('leaptestskey.pem')) return ctx -if __name__ == "__main__": - - from twisted.python import log - log.startLogging(sys.stdout) - +def serve_fake_provider(): root = Resource() root.putChild("provider.json", File("./provider.json")) config = Resource() @@ -293,3 +292,11 @@ if __name__ == "__main__": reactor.listenSSL(8443, factory, OpenSSLServerContextFactory()) reactor.run() + + +if __name__ == "__main__": + + from twisted.python import log + log.startLogging(sys.stdout) + + serve_fake_provider() diff --git a/src/leap/gui/firstrun/wizard.py b/src/leap/gui/firstrun/wizard.py index 9b77b877..89209401 100755 --- a/src/leap/gui/firstrun/wizard.py +++ b/src/leap/gui/firstrun/wizard.py @@ -2,8 +2,11 @@ import logging import sip -sip.setapi('QString', 2) -sip.setapi('QVariant', 2) +try: + sip.setapi('QString', 2) + sip.setapi('QVariant', 2) +except ValueError: + pass from PyQt4 import QtCore from PyQt4 import QtGui @@ -46,12 +49,29 @@ TODO-ish: """ +def get_pages_dict(): + return OrderedDict(( + ('intro', firstrun.intro.IntroPage), + ('providerselection', + firstrun.providerselect.SelectProviderPage), + ('login', firstrun.login.LogInPage), + ('providerinfo', firstrun.providerinfo.ProviderInfoPage), + ('providersetupvalidation', + firstrun.providersetup.ProviderSetupValidationPage), + ('signup', firstrun.register.RegisterUserPage), + ('signupvalidation', + firstrun.regvalidation.RegisterUserValidationPage), + ('lastpage', firstrun.last.LastPage) + )) + + class FirstRunWizard(QtGui.QWizard): def __init__( self, conductor_instance, parent=None, + pages_dict=None, eip_username=None, providers=None, success_cb=None, is_provider_setup=False, @@ -112,20 +132,7 @@ class FirstRunWizard(QtGui.QWizard): self.is_previously_registered = bool(self.eip_username) self.from_login = False - pages_dict = OrderedDict(( - ('intro', firstrun.intro.IntroPage), - ('providerselection', - firstrun.providerselect.SelectProviderPage), - ('login', firstrun.login.LogInPage), - ('providerinfo', firstrun.providerinfo.ProviderInfoPage), - ('providersetupvalidation', - firstrun.providersetup.ProviderSetupValidationPage), - ('signup', firstrun.register.RegisterUserPage), - ('signupvalidation', - firstrun.regvalidation.RegisterUserValidationPage), - ('connecting', firstrun.connect.ConnectingPage), - ('lastpage', firstrun.last.LastPage) - )) + pages_dict = pages_dict or get_pages_dict() self.add_pages_from_dict(pages_dict) self.validation_errors = {} @@ -146,6 +153,10 @@ class FirstRunWizard(QtGui.QWizard): # TODO: set style for MAC / windows ... #self.setWizardStyle() + # + # setup pages in wizard + # + def add_pages_from_dict(self, pages_dict): """ @param pages_dict: the dictionary with pages, where @@ -168,6 +179,10 @@ class FirstRunWizard(QtGui.QWizard): """ return self.pages_dict.keys().index(page_name) + # + # validation errors + # + def set_validation_error(self, pagename, error): self.validation_errors[pagename] = error @@ -179,20 +194,6 @@ class FirstRunWizard(QtGui.QWizard): def get_validation_error(self, pagename): return self.validation_errors.get(pagename, None) - def set_providerconfig(self, providerconfig): - self.providerconfig = providerconfig - - def setWindowFlags(self, flags): - logger.debug('setting window flags') - QtGui.QWizard.setWindowFlags(self, flags) - - def focusOutEvent(self, event): - # needed ? - self.setFocus(True) - self.activateWindow() - self.raise_() - self.show() - def accept(self): """ final step in the wizard. @@ -246,11 +247,14 @@ class FirstRunWizard(QtGui.QWizard): if cb and callable(cb): self.success_cb() - def get_provider_by_index(self): - provider = self.field('provider_index') - return self.providers[provider] + # misc helpers def get_random_str(self, n): + """ + returns a random string + :param n: the length of the desired string + :rvalue: str + """ from string import (ascii_uppercase, ascii_lowercase, digits) from random import choice return ''.join(choice( @@ -258,6 +262,24 @@ class FirstRunWizard(QtGui.QWizard): ascii_lowercase + digits) for x in range(n)) + def set_providerconfig(self, providerconfig): + """ + sets a providerconfig attribute + used when we fetch and parse a json configuration + """ + self.providerconfig = providerconfig + + def get_provider_by_index(self): # pragma: no cover + """ + returns the value of a provider given its index. + this was used in the select provider page, + in the case where we were preseeding providers in a combobox + """ + # Leaving it here for the moment when we go back at the + # option of preseeding with known provider values. + provider = self.field('provider_index') + return self.providers[provider] + if __name__ == '__main__': # standalone test diff --git a/src/leap/gui/locale_rc.py b/src/leap/gui/locale_rc.py new file mode 100644 index 00000000..f165ff8e --- /dev/null +++ b/src/leap/gui/locale_rc.py @@ -0,0 +1,132 @@ +# -*- coding: utf-8 -*- + +# Resource object code +# +# Created: vie nov 16 22:33:33 2012 +# by: The Resource Compiler for PyQt (Qt v4.8.2) +# +# WARNING! All changes made in this file will be lost! + +from PyQt4 import QtCore + +qt_resource_data = "\ +\x00\x00\x05\xaa\ +\x3c\ +\xb8\x64\x18\xca\xef\x9c\x95\xcd\x21\x1c\xbf\x60\xa1\xbd\xdd\x42\ +\x00\x00\x00\x20\x09\xfc\x2c\x8e\x00\x00\x04\xfb\x0a\x74\xb8\x1e\ +\x00\x00\x00\xd6\x0a\xfd\x99\xfe\x00\x00\x00\x51\x0c\x44\x41\xbe\ +\x00\x00\x00\x00\x69\x00\x00\x05\x69\x03\x00\x00\x00\x22\x00\x50\ +\x00\x72\x00\x69\x00\x6d\x00\x65\x00\x72\x00\x61\x00\x20\x00\x63\ +\x00\x6f\x00\x6e\x00\x65\x00\x78\x00\x69\x00\x6f\x00\x6e\x00\x2e\ +\x08\x00\x00\x00\x00\x06\x00\x00\x00\x11\x46\x69\x72\x73\x74\x20\ +\x72\x75\x6e\x20\x77\x69\x7a\x61\x72\x64\x2e\x07\x00\x00\x00\x09\ +\x49\x6e\x74\x72\x6f\x50\x61\x67\x65\x01\x03\x00\x00\x00\x4c\x00\ +\x4c\x00\x6f\x00\x67\x00\x75\x00\x65\x00\x61\x00\x72\x00\x6d\x00\ +\x65\x00\x20\x00\x63\x00\x6f\x00\x6e\x00\x20\x00\x6d\x00\x69\x00\ +\x20\x00\x75\x00\x73\x00\x75\x00\x61\x00\x72\x00\x69\x00\x6f\x00\ +\x20\x00\x79\x00\x20\x00\x63\x00\x6f\x00\x6e\x00\x74\x00\x72\x00\ +\x61\x00\x73\x00\x65\x00\x6e\x00\x61\x00\x2e\x08\x00\x00\x00\x00\ +\x06\x00\x00\x00\x1b\x4c\x6f\x67\x20\x49\x6e\x20\x77\x69\x74\x68\ +\x20\x6d\x79\x20\x63\x72\x65\x64\x65\x6e\x74\x69\x61\x6c\x73\x2e\ +\x07\x00\x00\x00\x09\x49\x6e\x74\x72\x6f\x50\x61\x67\x65\x01\x03\ +\x00\x00\x02\xaa\x00\x56\x00\x61\x00\x6d\x00\x6f\x00\x73\x00\x20\ +\x00\x61\x00\x20\x00\x72\x00\x65\x00\x75\x00\x6e\x00\x69\x00\x72\ +\x00\x20\x00\x6c\x00\x61\x00\x20\x00\x69\x00\x6e\x00\x66\x00\x6f\ +\x00\x72\x00\x6d\x00\x61\x00\x63\x00\x69\x00\x6f\x00\x6e\x00\x20\ +\x00\x71\x00\x75\x00\x65\x00\x20\x00\x6e\x00\x65\x00\x63\x00\x65\ +\x00\x73\x00\x69\x00\x74\x00\x61\x00\x73\x00\x20\x00\x61\x00\x6e\ +\x00\x74\x00\x65\x00\x73\x00\x20\x00\x64\x00\x65\x00\x20\x00\x6c\ +\x00\x61\x00\x20\x00\x70\x00\x72\x00\x69\x00\x6d\x00\x65\x00\x72\ +\x00\x61\x00\x20\x00\x63\x00\x6f\x00\x6e\x00\x65\x00\x78\x00\x69\ +\x00\x6f\x00\x6e\x00\x2e\x00\x3c\x00\x62\x00\x72\x00\x3e\x00\x3c\ +\x00\x62\x00\x72\x00\x3e\x00\x53\x00\x69\x00\x20\x00\x61\x00\x6c\ +\x00\x67\x00\x75\x00\x6e\x00\x61\x00\x20\x00\x76\x00\x65\x00\x7a\ +\x00\x20\x00\x6e\x00\x65\x00\x63\x00\x65\x00\x73\x00\x69\x00\x74\ +\x00\x61\x00\x73\x00\x20\x00\x6d\x00\x6f\x00\x64\x00\x69\x00\x66\ +\x00\x69\x00\x63\x00\x61\x00\x72\x00\x20\x00\x65\x00\x73\x00\x74\ +\x00\x61\x00\x73\x00\x20\x00\x6f\x00\x70\x00\x63\x00\x69\x00\x6f\ +\x00\x6e\x00\x65\x00\x73\x00\x20\x00\x64\x00\x65\x00\x20\x00\x6e\ +\x00\x75\x00\x65\x00\x76\x00\x6f\x00\x2c\x00\x20\x00\x70\x00\x75\ +\x00\x65\x00\x64\x00\x65\x00\x73\x00\x20\x00\x65\x00\x6e\x00\x63\ +\x00\x6f\x00\x6e\x00\x74\x00\x72\x00\x61\x00\x72\x00\x20\x00\x65\ +\x00\x73\x00\x74\x00\x65\x00\x20\x00\x61\x00\x73\x00\x69\x00\x73\ +\x00\x74\x00\x65\x00\x6e\x00\x74\x00\x65\x00\x20\x00\x65\x00\x6e\ +\x00\x20\x00\x65\x00\x6c\x00\x20\x00\x6d\x00\x65\x00\x6e\x00\x75\ +\x00\x20\x00\x3c\x00\x69\x00\x3e\x00\x4f\x00\x70\x00\x63\x00\x69\ +\x00\x6f\x00\x6e\x00\x65\x00\x73\x00\x3c\x00\x2f\x00\x69\x00\x3e\ +\x00\x20\x00\x65\x00\x6e\x00\x20\x00\x6c\x00\x61\x00\x20\x00\x76\ +\x00\x65\x00\x6e\x00\x74\x00\x61\x00\x6e\x00\x61\x00\x20\x00\x70\ +\x00\x72\x00\x69\x00\x6e\x00\x63\x00\x69\x00\x70\x00\x61\x00\x6c\ +\x00\x2e\x00\x3c\x00\x62\x00\x72\x00\x3e\x00\x3c\x00\x62\x00\x72\ +\x00\x3e\x00\x51\x00\x75\x00\x65\x00\x20\x00\x64\x00\x65\x00\x73\ +\x00\x65\x00\x61\x00\x73\x00\x20\x00\x68\x00\x61\x00\x63\x00\x65\ +\x00\x72\x00\x20\x00\x61\x00\x68\x00\x6f\x00\x72\x00\x61\x00\x3f\ +\x00\x20\x00\x50\x00\x75\x00\x65\x00\x64\x00\x65\x00\x73\x00\x20\ +\x00\x3c\x00\x62\x00\x3e\x00\x72\x00\x65\x00\x67\x00\x69\x00\x73\ +\x00\x74\x00\x72\x00\x61\x00\x72\x00\x3c\x00\x2f\x00\x62\x00\x3e\ +\x00\x20\x00\x75\x00\x6e\x00\x61\x00\x20\x00\x6e\x00\x75\x00\x65\ +\x00\x76\x00\x61\x00\x20\x00\x63\x00\x75\x00\x65\x00\x6e\x00\x74\ +\x00\x61\x00\x20\x00\x6f\x00\x20\x00\x3c\x00\x62\x00\x3e\x00\x6c\ +\x00\x6f\x00\x67\x00\x75\x00\x65\x00\x61\x00\x72\x00\x74\x00\x65\ +\x00\x3c\x00\x2f\x00\x62\x00\x3e\x00\x20\x00\x63\x00\x6f\x00\x6e\ +\x00\x20\x00\x75\x00\x6e\x00\x61\x00\x20\x00\x71\x00\x75\x00\x65\ +\x00\x20\x00\x79\x00\x61\x00\x20\x00\x74\x00\x69\x00\x65\x00\x6e\ +\x00\x65\x00\x73\x00\x3f\x00\x3c\x00\x62\x00\x72\x00\x3e\x08\x00\ +\x00\x00\x00\x06\x00\x00\x01\x5d\x4e\x6f\x77\x20\x77\x65\x20\x77\ +\x69\x6c\x6c\x20\x67\x75\x69\x64\x65\x20\x79\x6f\x75\x20\x74\x68\ +\x72\x6f\x75\x67\x68\x20\x73\x6f\x6d\x65\x20\x63\x6f\x6e\x66\x69\ +\x67\x75\x72\x61\x74\x69\x6f\x6e\x20\x74\x68\x61\x74\x20\x69\x73\ +\x20\x6e\x65\x65\x64\x65\x64\x20\x62\x65\x66\x6f\x72\x65\x20\x79\ +\x6f\x75\x20\x63\x61\x6e\x20\x63\x6f\x6e\x6e\x65\x63\x74\x20\x66\ +\x6f\x72\x20\x74\x68\x65\x20\x66\x69\x72\x73\x74\x20\x74\x69\x6d\ +\x65\x2e\x3c\x62\x72\x3e\x3c\x62\x72\x3e\x49\x66\x20\x79\x6f\x75\ +\x20\x65\x76\x65\x72\x20\x6e\x65\x65\x64\x20\x74\x6f\x20\x6d\x6f\ +\x64\x69\x66\x79\x20\x74\x68\x65\x73\x65\x20\x6f\x70\x74\x69\x6f\ +\x6e\x73\x20\x61\x67\x61\x69\x6e\x2c\x20\x79\x6f\x75\x20\x63\x61\ +\x6e\x20\x66\x69\x6e\x64\x20\x74\x68\x65\x20\x77\x69\x7a\x61\x72\ +\x64\x20\x69\x6e\x20\x74\x68\x65\x20\x27\x3c\x69\x3e\x53\x65\x74\ +\x74\x69\x6e\x67\x73\x3c\x2f\x69\x3e\x27\x20\x6d\x65\x6e\x75\x20\ +\x66\x72\x6f\x6d\x20\x74\x68\x65\x20\x6d\x61\x69\x6e\x20\x77\x69\ +\x6e\x64\x6f\x77\x2e\x3c\x62\x72\x3e\x3c\x62\x72\x3e\x44\x6f\x20\ +\x79\x6f\x75\x20\x77\x61\x6e\x74\x20\x74\x6f\x20\x3c\x62\x3e\x73\ +\x69\x67\x6e\x20\x75\x70\x3c\x2f\x62\x3e\x20\x66\x6f\x72\x20\x61\ +\x20\x6e\x65\x77\x20\x61\x63\x63\x6f\x75\x6e\x74\x2c\x20\x6f\x72\ +\x20\x3c\x62\x3e\x6c\x6f\x67\x20\x69\x6e\x3c\x2f\x62\x3e\x20\x77\ +\x69\x74\x68\x20\x61\x6e\x20\x61\x6c\x72\x65\x61\x64\x79\x20\x65\ +\x78\x69\x73\x74\x69\x6e\x67\x20\x75\x73\x65\x72\x6e\x61\x6d\x65\ +\x3f\x3c\x62\x72\x3e\x07\x00\x00\x00\x09\x49\x6e\x74\x72\x6f\x50\ +\x61\x67\x65\x01\x03\x00\x00\x00\x36\x00\x52\x00\x65\x00\x67\x00\ +\x69\x00\x73\x00\x74\x00\x72\x00\x61\x00\x72\x00\x20\x00\x75\x00\ +\x6e\x00\x61\x00\x20\x00\x63\x00\x75\x00\x65\x00\x6e\x00\x74\x00\ +\x61\x00\x20\x00\x6e\x00\x75\x00\x65\x00\x76\x00\x61\x00\x2e\x08\ +\x00\x00\x00\x00\x06\x00\x00\x00\x1a\x53\x69\x67\x6e\x20\x75\x70\ +\x20\x66\x6f\x72\x20\x61\x20\x6e\x65\x77\x20\x61\x63\x63\x6f\x75\ +\x6e\x74\x2e\x07\x00\x00\x00\x09\x49\x6e\x74\x72\x6f\x50\x61\x67\ +\x65\x01\x88\x00\x00\x00\x02\x01\x01\ +" + +qt_resource_name = "\ +\x00\x0c\ +\x0d\xfc\x11\x13\ +\x00\x74\ +\x00\x72\x00\x61\x00\x6e\x00\x73\x00\x6c\x00\x61\x00\x74\x00\x69\x00\x6f\x00\x6e\x00\x73\ +\x00\x14\ +\x08\xa9\x0f\x1d\ +\x00\x6c\ +\x00\x65\x00\x61\x00\x70\x00\x5f\x00\x63\x00\x6c\x00\x69\x00\x65\x00\x6e\x00\x74\x00\x5f\x00\x65\x00\x73\x00\x5f\x00\x45\x00\x53\ +\x00\x2e\x00\x71\x00\x6d\ +" + +qt_resource_struct = "\ +\x00\x00\x00\x00\x00\x02\x00\x00\x00\x01\x00\x00\x00\x01\ +\x00\x00\x00\x00\x00\x02\x00\x00\x00\x01\x00\x00\x00\x02\ +\x00\x00\x00\x1e\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\ +" + +def qInitResources(): + QtCore.qRegisterResourceData(0x01, qt_resource_struct, qt_resource_name, qt_resource_data) + +def qCleanupResources(): + QtCore.qUnregisterResourceData(0x01, qt_resource_struct, qt_resource_name, qt_resource_data) + +qInitResources() diff --git a/src/leap/gui/progress.py b/src/leap/gui/progress.py index 64b87b2c..ffea80de 100644 --- a/src/leap/gui/progress.py +++ b/src/leap/gui/progress.py @@ -4,7 +4,7 @@ from first run wizard """ try: from collections import OrderedDict -except ImportError: +except ImportError: # pragma: no cover # We must be in 2.6 from leap.util.dicts import OrderedDict @@ -73,15 +73,16 @@ class ProgressStepContainer(object): self.steps = {} def step(self, identity): - return self.step.get(identity) + return self.steps.get(identity, None) def addStep(self, step): self.steps[step.index] = step def removeStep(self, step): - del self.steps[step.index] - del step - self.dirty = True + if step and self.steps.get(step.index, None): + del self.steps[step.index] + del step + self.dirty = True def removeAllSteps(self): for item in iter(self): @@ -107,7 +108,7 @@ class StepsTableWidget(QtGui.QTableWidget): """ def __init__(self, parent=None): - super(StepsTableWidget, self).__init__(parent) + super(StepsTableWidget, self).__init__(parent=parent) # remove headers and all edit/select behavior self.horizontalHeader().hide() @@ -149,18 +150,39 @@ class StepsTableWidget(QtGui.QTableWidget): class WithStepsMixIn(object): + """ + This Class is a mixin that can be inherited + by InlineValidation pages (which will display + a progress steps widget in the same page as the form) + or by Validation Pages (which will only display + the progress steps in the page, below a progress bar widget) + """ + STEPS_TIMER_MS = 100 - # worker threads for checks + # + # methods related to worker threads + # launched for individual checks + # def setupStepsProcessingQueue(self): + """ + should be called from the init method + of the derived classes + """ self.steps_queue = Queue.Queue() self.stepscheck_timer = QtCore.QTimer() self.stepscheck_timer.timeout.connect(self.processStepsQueue) - self.stepscheck_timer.start(100) + self.stepscheck_timer.start(self.STEPS_TIMER_MS) # we need to keep a reference to child threads self.threads = [] def do_checks(self): + """ + main entry point for checks. + it calls _do_checks in derived classes, + and it expects it to be a generator + yielding a tuple in the form (("message", progress_int), checkfunction) + """ # yo dawg, I heard you like checks # so I put a __do_checks in your do_checks @@ -168,7 +190,7 @@ class WithStepsMixIn(object): def __do_checks(fun=None, queue=None): - for checkcase in fun(): + for checkcase in fun(): # pragma: no cover checkmsg, checkfun = checkcase queue.put(checkmsg) @@ -180,15 +202,34 @@ class WithStepsMixIn(object): __do_checks, fun=self._do_checks, queue=self.steps_queue)) - t.finished.connect(self.on_checks_validation_ready) + if hasattr(self, 'on_checks_validation_ready'): + t.finished.connect(self.on_checks_validation_ready) t.begin() self.threads.append(t) + def processStepsQueue(self): + """ + consume steps queue + and pass messages + to the ui updater functions + """ + while self.steps_queue.qsize(): + try: + status = self.steps_queue.get(0) + if status == "failed": + self.set_failed_icon() + else: + self.onStepStatusChanged(*status) + except Queue.Empty: # pragma: no cover + pass + def fail(self, err=None): """ return failed state and send error notification as - a nice side effect + a nice side effect. this function is called from + the _do_checks check functions returned in the + generator. """ wizard = self.wizard() senderr = lambda err: wizard.set_validation_error( @@ -202,38 +243,29 @@ class WithStepsMixIn(object): def launch_checks(self): self.do_checks() + # (gui) presentation stuff begins ##################### + # slot #@QtCore.pyqtSlot(str, int) def onStepStatusChanged(self, status, progress=None): + status = unicode(status) if status not in ("head_sentinel", "end_sentinel"): self.add_status_line(status) if status in ("end_sentinel"): - self.checks_finished = True + #self.checks_finished = True self.set_checked_icon() if progress and hasattr(self, 'progress'): self.progress.setValue(progress) self.progress.update() - def processStepsQueue(self): - """ - consume steps queue - and pass messages - to the ui updater functions - """ - while self.steps_queue.qsize(): - try: - status = self.steps_queue.get(0) - if status == "failed": - self.set_failed_icon() - else: - self.onStepStatusChanged(*status) - except Queue.Empty: - pass - def setupSteps(self): self.steps = ProgressStepContainer() # steps table widget - self.stepsTableWidget = StepsTableWidget(self) + if isinstance(self, QtCore.QObject): + parent = self + else: + parent = None + self.stepsTableWidget = StepsTableWidget(parent=parent) zeros = (0, 0, 0, 0) self.stepsTableWidget.setContentsMargins(*zeros) self.errors = OrderedDict() @@ -242,15 +274,17 @@ class WithStepsMixIn(object): self.errors[name] = error def pop_first_error(self): - return list(reversed(self.errors.items())).pop() + errkey, errval = list(reversed(self.errors.items())).pop() + del self.errors[errkey] + return errkey, errval def clean_errors(self): self.errors = OrderedDict() def clean_wizard_errors(self, pagename=None): - if pagename is None: + if pagename is None: # pragma: no cover pagename = getattr(self, 'prev_page', None) - if pagename is None: + if pagename is None: # pragma: no cover return logger.debug('cleaning wizard errors for %s' % pagename) self.wizard().set_validation_error(pagename, None) @@ -295,6 +329,8 @@ class WithStepsMixIn(object): # setting cell widget. # see note on StepsTableWidget about plans to # change this for a better solution. + if not hasattr(self, 'steps'): + return index = len(self.steps) table = self.stepsTableWidget _index = index - 1 if current else index - 2 @@ -340,6 +376,9 @@ class WithStepsMixIn(object): def is_done(self): return self.done + # convenience for going back and forth + # in the wizard pages. + def go_back(self): self.wizard().back() diff --git a/src/leap/gui/tests/__init__.py b/src/leap/gui/tests/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/src/leap/gui/tests/__init__.py diff --git a/src/leap/gui/tests/test_firstrun_login.py b/src/leap/gui/tests/test_firstrun_login.py new file mode 100644 index 00000000..fa800c23 --- /dev/null +++ b/src/leap/gui/tests/test_firstrun_login.py @@ -0,0 +1,212 @@ +import sys +import unittest + +import mock + +from leap.testing import qunittest +#from leap.testing import pyqt + +from PyQt4 import QtGui +#from PyQt4 import QtCore +#import PyQt4.QtCore # some weirdness with mock module + +from PyQt4.QtTest import QTest +from PyQt4.QtCore import Qt + +from leap.gui import firstrun + +try: + from collections import OrderedDict +except ImportError: + # We must be in 2.6 + from leap.util.dicts import OrderedDict + + +class TestPage(firstrun.login.LogInPage): + pass + + +class LogInPageLogicTestCase(qunittest.TestCase): + + # XXX can spy on signal connections + __name__ = "register user page logic tests" + + def setUp(self): + self.app = QtGui.QApplication(sys.argv) + QtGui.qApp = self.app + self.page = TestPage(None) + self.page.wizard = mock.MagicMock() + + def tearDown(self): + QtGui.qApp = None + self.app = None + self.page = None + + def test__do_checks(self): + eq = self.assertEqual + + self.page.userNameLineEdit.setText('testuser@domain') + self.page.userPasswordLineEdit.setText('testpassword') + + # fake register process + with mock.patch('leap.base.auth.LeapSRPRegister') as mockAuth: + mockSignup = mock.MagicMock() + + reqMockup = mock.Mock() + # XXX should inject bad json to get error + reqMockup.content = '{"errors": null}' + mockSignup.register_user.return_value = (True, reqMockup) + mockAuth.return_value = mockSignup + checks = [x for x in self.page._do_checks()] + + eq(len(checks), 4) + labels = [str(x) for (x, y), z in checks] + eq(labels, ['head_sentinel', + 'Resolving domain name', + 'Validating credentials', + 'end_sentinel']) + progress = [y for (x, y), z in checks] + eq(progress, [0, 20, 60, 100]) + + # normal run, ie, no exceptions + + checkfuns = [z for (x, y), z in checks] + checkusername, resolvedomain, valcreds = checkfuns[:-1] + + self.assertTrue(checkusername()) + #self.mocknetchecker.check_name_resolution.assert_called_with( + #'test_provider1') + + self.assertTrue(resolvedomain()) + #self.mockpcertchecker.is_https_working.assert_called_with( + #"https://test_provider1", verify=True) + + self.assertTrue(valcreds()) + + # XXX missing: inject failing exceptions + # XXX TODO make it break + + +class RegisterUserPageUITestCase(qunittest.TestCase): + + # XXX can spy on signal connections + __name__ = "Register User Page UI tests" + + def setUp(self): + self.app = QtGui.QApplication(sys.argv) + QtGui.qApp = self.app + + self.pagename = "signup" + pages = OrderedDict(( + (self.pagename, TestPage), + ('providersetupvalidation', + firstrun.regvalidation.RegisterUserValidationPage))) + self.wizard = firstrun.wizard.FirstRunWizard(None, pages_dict=pages) + self.page = self.wizard.page(self.wizard.get_page_index(self.pagename)) + + self.page.do_checks = mock.Mock() + + # wizard would do this for us + self.page.initializePage() + + def tearDown(self): + QtGui.qApp = None + self.app = None + self.wizard = None + + # XXX refactor out + def fill_field(self, field, text): + """ + fills a field (line edit) that is passed along + :param field: the qLineEdit + :param text: the text to be filled + :type field: QLineEdit widget + :type text: str + """ + keyp = QTest.keyPress + field.setFocus(True) + for c in text: + keyp(field, c) + self.assertEqual(field.text(), text) + + def del_field(self, field): + """ + deletes entried text in + field line edit + :param field: the QLineEdit + :type field: QLineEdit widget + """ + keyp = QTest.keyPress + for c in range(len(field.text())): + keyp(field, Qt.Key_Backspace) + self.assertEqual(field.text(), "") + + def test_buttons_disabled_until_textentry(self): + # it's a commit button this time + nextbutton = self.wizard.button(QtGui.QWizard.CommitButton) + + self.assertFalse(nextbutton.isEnabled()) + + f_username = self.page.userNameLineEdit + f_password = self.page.userPasswordLineEdit + + self.fill_field(f_username, "testuser") + self.fill_field(f_password, "testpassword") + + # commit should be enabled + # XXX Need a workaround here + # because the isComplete is not being evaluated... + # (no event loop running??) + #import ipdb;ipdb.set_trace() + #self.assertTrue(nextbutton.isEnabled()) + self.assertTrue(self.page.isComplete()) + + self.del_field(f_username) + self.del_field(f_password) + + # after rm fields commit button + # should be disabled again + #self.assertFalse(nextbutton.isEnabled()) + self.assertFalse(self.page.isComplete()) + + def test_validate_page(self): + self.assertFalse(self.page.validatePage()) + # XXX TODO MOAR CASES... + # add errors, False + # change done, False + # not done, do_checks called + # click confirm, True + # done and do_confirm, True + + def test_next_id(self): + self.assertEqual(self.page.nextId(), 1) + + def test_paint_event(self): + self.page.populateErrors = mock.Mock() + self.page.paintEvent(None) + self.page.populateErrors.assert_called_with() + + def test_validation_ready(self): + f_username = self.page.userNameLineEdit + f_password = self.page.userPasswordLineEdit + + self.fill_field(f_username, "testuser") + self.fill_field(f_password, "testpassword") + + self.page.done = True + self.page.on_checks_validation_ready() + self.assertFalse(f_username.isEnabled()) + self.assertFalse(f_password.isEnabled()) + + self.assertEqual(self.page.validationMsg.text(), + "Credentials validated.") + self.assertEqual(self.page.do_confirm_next, True) + + def test_regex(self): + # XXX enter invalid username with key presses + # check text is not updated + pass + + +if __name__ == "__main__": + unittest.main() diff --git a/src/leap/gui/tests/test_firstrun_providerselect.py b/src/leap/gui/tests/test_firstrun_providerselect.py new file mode 100644 index 00000000..976c68cd --- /dev/null +++ b/src/leap/gui/tests/test_firstrun_providerselect.py @@ -0,0 +1,201 @@ +import sys +import unittest + +import mock + +from leap.testing import qunittest +#from leap.testing import pyqt + +from PyQt4 import QtGui +#from PyQt4 import QtCore +#import PyQt4.QtCore # some weirdness with mock module + +from PyQt4.QtTest import QTest +from PyQt4.QtCore import Qt + +from leap.gui import firstrun + +try: + from collections import OrderedDict +except ImportError: + # We must be in 2.6 + from leap.util.dicts import OrderedDict + + +class TestPage(firstrun.providerselect.SelectProviderPage): + pass + + +class SelectProviderPageLogicTestCase(qunittest.TestCase): + + # XXX can spy on signal connections + + def setUp(self): + self.app = QtGui.QApplication(sys.argv) + QtGui.qApp = self.app + self.page = TestPage(None) + self.page.wizard = mock.MagicMock() + + mocknetchecker = mock.Mock() + self.page.wizard().netchecker.return_value = mocknetchecker + self.mocknetchecker = mocknetchecker + + mockpcertchecker = mock.Mock() + self.page.wizard().providercertchecker.return_value = mockpcertchecker + self.mockpcertchecker = mockpcertchecker + + mockeipconfchecker = mock.Mock() + self.page.wizard().eipconfigchecker.return_value = mockeipconfchecker + self.mockeipconfchecker = mockeipconfchecker + + def tearDown(self): + QtGui.qApp = None + self.app = None + self.page = None + + def test__do_checks(self): + eq = self.assertEqual + + self.page.providerNameEdit.setText('test_provider1') + + checks = [x for x in self.page._do_checks()] + eq(len(checks), 5) + labels = [str(x) for (x, y), z in checks] + eq(labels, ['head_sentinel', 'checking domain name', + 'checking https connection', + 'fetching provider info', 'end_sentinel']) + progress = [y for (x, y), z in checks] + eq(progress, [0, 20, 40, 80, 100]) + + # normal run, ie, no exceptions + + checkfuns = [z for (x, y), z in checks] + namecheck, httpscheck, fetchinfo = checkfuns[1:-1] + + self.assertTrue(namecheck()) + self.mocknetchecker.check_name_resolution.assert_called_with( + 'test_provider1') + + self.assertTrue(httpscheck()) + self.mockpcertchecker.is_https_working.assert_called_with( + "https://test_provider1", verify=True) + + self.assertTrue(fetchinfo()) + self.mockeipconfchecker.fetch_definition.assert_called_with( + domain="test_provider1") + + # XXX missing: inject failing exceptions + # XXX TODO make it break + + +class SelectProviderPageUITestCase(qunittest.TestCase): + + # XXX can spy on signal connections + __name__ = "Select Provider Page UI tests" + + def setUp(self): + self.app = QtGui.QApplication(sys.argv) + QtGui.qApp = self.app + + self.pagename = "providerselection" + pages = OrderedDict(( + (self.pagename, TestPage), + ('providerinfo', + firstrun.providerinfo.ProviderInfoPage))) + self.wizard = firstrun.wizard.FirstRunWizard(None, pages_dict=pages) + self.page = self.wizard.page(self.wizard.get_page_index(self.pagename)) + + self.page.do_checks = mock.Mock() + + # wizard would do this for us + self.page.initializePage() + + def tearDown(self): + QtGui.qApp = None + self.app = None + self.wizard = None + + def fill_provider(self): + """ + fills provider line edit + """ + keyp = QTest.keyPress + pedit = self.page.providerNameEdit + pedit.setFocus(True) + for c in "testprovider": + keyp(pedit, c) + self.assertEqual(pedit.text(), "testprovider") + + def del_provider(self): + """ + deletes entried provider in + line edit + """ + keyp = QTest.keyPress + pedit = self.page.providerNameEdit + for c in range(len("testprovider")): + keyp(pedit, Qt.Key_Backspace) + self.assertEqual(pedit.text(), "") + + def test_buttons_disabled_until_textentry(self): + nextbutton = self.wizard.button(QtGui.QWizard.NextButton) + checkbutton = self.page.providerCheckButton + + self.assertFalse(nextbutton.isEnabled()) + self.assertFalse(checkbutton.isEnabled()) + + self.fill_provider() + # checkbutton should be enabled + self.assertTrue(checkbutton.isEnabled()) + self.assertFalse(nextbutton.isEnabled()) + + self.del_provider() + # after rm provider checkbutton disabled again + self.assertFalse(checkbutton.isEnabled()) + self.assertFalse(nextbutton.isEnabled()) + + def test_check_button_triggers_tests(self): + checkbutton = self.page.providerCheckButton + self.assertFalse(checkbutton.isEnabled()) + self.assertFalse(self.page.do_checks.called) + + self.fill_provider() + + self.assertTrue(checkbutton.isEnabled()) + mclick = QTest.mouseClick + # click! + mclick(checkbutton, Qt.LeftButton) + self.waitFor(seconds=0.1) + self.assertTrue(self.page.do_checks.called) + + # XXX + # can play with different side_effects for do_checks mock... + # so we can see what happens with errors and so on + + def test_page_completed_after_checks(self): + nextbutton = self.wizard.button(QtGui.QWizard.NextButton) + self.assertFalse(nextbutton.isEnabled()) + + self.assertFalse(self.page.isComplete()) + self.fill_provider() + # simulate checks done + self.page.done = True + self.page.on_checks_validation_ready() + self.assertTrue(self.page.isComplete()) + # cannot test for nexbutton enabled + # cause it's the the wizard loop + # that would do that I think + + def test_validate_page(self): + self.assertTrue(self.page.validatePage()) + + def test_next_id(self): + self.assertEqual(self.page.nextId(), 1) + + def test_paint_event(self): + self.page.populateErrors = mock.Mock() + self.page.paintEvent(None) + self.page.populateErrors.assert_called_with() + +if __name__ == "__main__": + unittest.main() diff --git a/src/leap/gui/tests/test_firstrun_register.py b/src/leap/gui/tests/test_firstrun_register.py new file mode 100644 index 00000000..3447fe9d --- /dev/null +++ b/src/leap/gui/tests/test_firstrun_register.py @@ -0,0 +1,244 @@ +import sys +import unittest + +import mock + +from leap.testing import qunittest +#from leap.testing import pyqt + +from PyQt4 import QtGui +#from PyQt4 import QtCore +#import PyQt4.QtCore # some weirdness with mock module + +from PyQt4.QtTest import QTest +from PyQt4.QtCore import Qt + +from leap.gui import firstrun + +try: + from collections import OrderedDict +except ImportError: + # We must be in 2.6 + from leap.util.dicts import OrderedDict + + +class TestPage(firstrun.register.RegisterUserPage): + + def field(self, field): + if field == "provider_domain": + return "testprovider" + + +class RegisterUserPageLogicTestCase(qunittest.TestCase): + + # XXX can spy on signal connections + __name__ = "register user page logic tests" + + def setUp(self): + self.app = QtGui.QApplication(sys.argv) + QtGui.qApp = self.app + self.page = TestPage(None) + self.page.wizard = mock.MagicMock() + + #mocknetchecker = mock.Mock() + #self.page.wizard().netchecker.return_value = mocknetchecker + #self.mocknetchecker = mocknetchecker +# + #mockpcertchecker = mock.Mock() + #self.page.wizard().providercertchecker.return_value = mockpcertchecker + #self.mockpcertchecker = mockpcertchecker +# + #mockeipconfchecker = mock.Mock() + #self.page.wizard().eipconfigchecker.return_value = mockeipconfchecker + #self.mockeipconfchecker = mockeipconfchecker + + def tearDown(self): + QtGui.qApp = None + self.app = None + self.page = None + + def test__do_checks(self): + eq = self.assertEqual + + self.page.userNameLineEdit.setText('testuser') + self.page.userPasswordLineEdit.setText('testpassword') + self.page.userPassword2LineEdit.setText('testpassword') + + # fake register process + with mock.patch('leap.base.auth.LeapSRPRegister') as mockAuth: + mockSignup = mock.MagicMock() + + reqMockup = mock.Mock() + # XXX should inject bad json to get error + reqMockup.content = '{"errors": null}' + mockSignup.register_user.return_value = (True, reqMockup) + mockAuth.return_value = mockSignup + checks = [x for x in self.page._do_checks()] + + eq(len(checks), 3) + labels = [str(x) for (x, y), z in checks] + eq(labels, ['head_sentinel', + 'registering with provider', + 'end_sentinel']) + progress = [y for (x, y), z in checks] + eq(progress, [0, 40, 100]) + + # normal run, ie, no exceptions + + checkfuns = [z for (x, y), z in checks] + passcheck, register = checkfuns[:-1] + + self.assertTrue(passcheck()) + #self.mocknetchecker.check_name_resolution.assert_called_with( + #'test_provider1') + + self.assertTrue(register()) + #self.mockpcertchecker.is_https_working.assert_called_with( + #"https://test_provider1", verify=True) + + # XXX missing: inject failing exceptions + # XXX TODO make it break + + +class RegisterUserPageUITestCase(qunittest.TestCase): + + # XXX can spy on signal connections + __name__ = "Register User Page UI tests" + + def setUp(self): + self.app = QtGui.QApplication(sys.argv) + QtGui.qApp = self.app + + self.pagename = "signup" + pages = OrderedDict(( + (self.pagename, TestPage), + ('signupvalidation', + firstrun.regvalidation.RegisterUserValidationPage))) + self.wizard = firstrun.wizard.FirstRunWizard(None, pages_dict=pages) + self.page = self.wizard.page(self.wizard.get_page_index(self.pagename)) + + self.page.do_checks = mock.Mock() + + # wizard would do this for us + self.page.initializePage() + + def tearDown(self): + QtGui.qApp = None + self.app = None + self.wizard = None + + def fill_field(self, field, text): + """ + fills a field (line edit) that is passed along + :param field: the qLineEdit + :param text: the text to be filled + :type field: QLineEdit widget + :type text: str + """ + keyp = QTest.keyPress + field.setFocus(True) + for c in text: + keyp(field, c) + self.assertEqual(field.text(), text) + + def del_field(self, field): + """ + deletes entried text in + field line edit + :param field: the QLineEdit + :type field: QLineEdit widget + """ + keyp = QTest.keyPress + for c in range(len(field.text())): + keyp(field, Qt.Key_Backspace) + self.assertEqual(field.text(), "") + + def test_buttons_disabled_until_textentry(self): + # it's a commit button this time + nextbutton = self.wizard.button(QtGui.QWizard.CommitButton) + + self.assertFalse(nextbutton.isEnabled()) + + f_username = self.page.userNameLineEdit + f_password = self.page.userPasswordLineEdit + f_passwor2 = self.page.userPassword2LineEdit + + self.fill_field(f_username, "testuser") + self.fill_field(f_password, "testpassword") + self.fill_field(f_passwor2, "testpassword") + + # commit should be enabled + # XXX Need a workaround here + # because the isComplete is not being evaluated... + # (no event loop running??) + #import ipdb;ipdb.set_trace() + #self.assertTrue(nextbutton.isEnabled()) + self.assertTrue(self.page.isComplete()) + + self.del_field(f_username) + self.del_field(f_password) + self.del_field(f_passwor2) + + # after rm fields commit button + # should be disabled again + #self.assertFalse(nextbutton.isEnabled()) + self.assertFalse(self.page.isComplete()) + + @unittest.skip + def test_check_button_triggers_tests(self): + checkbutton = self.page.providerCheckButton + self.assertFalse(checkbutton.isEnabled()) + self.assertFalse(self.page.do_checks.called) + + self.fill_provider() + + self.assertTrue(checkbutton.isEnabled()) + mclick = QTest.mouseClick + # click! + mclick(checkbutton, Qt.LeftButton) + self.waitFor(seconds=0.1) + self.assertTrue(self.page.do_checks.called) + + # XXX + # can play with different side_effects for do_checks mock... + # so we can see what happens with errors and so on + + def test_validate_page(self): + self.assertFalse(self.page.validatePage()) + # XXX TODO MOAR CASES... + # add errors, False + # change done, False + # not done, do_checks called + # click confirm, True + # done and do_confirm, True + + def test_next_id(self): + self.assertEqual(self.page.nextId(), 1) + + def test_paint_event(self): + self.page.populateErrors = mock.Mock() + self.page.paintEvent(None) + self.page.populateErrors.assert_called_with() + + def test_validation_ready(self): + f_username = self.page.userNameLineEdit + f_password = self.page.userPasswordLineEdit + f_passwor2 = self.page.userPassword2LineEdit + + self.fill_field(f_username, "testuser") + self.fill_field(f_password, "testpassword") + self.fill_field(f_passwor2, "testpassword") + + self.page.done = True + self.page.on_checks_validation_ready() + self.assertFalse(f_username.isEnabled()) + self.assertFalse(f_password.isEnabled()) + self.assertFalse(f_passwor2.isEnabled()) + + self.assertEqual(self.page.validationMsg.text(), + "Registration succeeded!") + self.assertEqual(self.page.do_confirm_next, True) + + +if __name__ == "__main__": + unittest.main() diff --git a/src/leap/gui/tests/test_firstrun_wizard.py b/src/leap/gui/tests/test_firstrun_wizard.py new file mode 100644 index 00000000..091cd932 --- /dev/null +++ b/src/leap/gui/tests/test_firstrun_wizard.py @@ -0,0 +1,137 @@ +import sys +import unittest + +import mock + +from leap.testing import qunittest +from leap.testing import pyqt + +from PyQt4 import QtGui +#from PyQt4 import QtCore +import PyQt4.QtCore # some weirdness with mock module + +from PyQt4.QtTest import QTest +#from PyQt4.QtCore import Qt + +from leap.gui import firstrun + + +class TestWizard(firstrun.wizard.FirstRunWizard): + pass + + +PAGES_DICT = dict(( + ('intro', firstrun.intro.IntroPage), + ('providerselection', + firstrun.providerselect.SelectProviderPage), + ('login', firstrun.login.LogInPage), + ('providerinfo', firstrun.providerinfo.ProviderInfoPage), + ('providersetupvalidation', + firstrun.providersetup.ProviderSetupValidationPage), + ('signup', firstrun.register.RegisterUserPage), + ('signupvalidation', + firstrun.regvalidation.RegisterUserValidationPage), + ('lastpage', firstrun.last.LastPage) +)) + + +mockQSettings = mock.MagicMock() +mockQSettings().setValue.return_value = True + +#PyQt4.QtCore.QSettings = mockQSettings + + +class FirstRunWizardTestCase(qunittest.TestCase): + + # XXX can spy on signal connections + + def setUp(self): + self.app = QtGui.QApplication(sys.argv) + QtGui.qApp = self.app + self.wizard = TestWizard(None) + + def tearDown(self): + QtGui.qApp = None + self.app = None + self.wizard = None + + def test_defaults(self): + self.assertEqual(self.wizard.pages_dict, PAGES_DICT) + + @mock.patch('PyQt4.QtCore.QSettings', mockQSettings) + def test_accept(self): + """ + test the main accept method + that gets called when user has gone + thru all the wizard and click on finish button + """ + + self.wizard.success_cb = mock.Mock() + self.wizard.success_cb.return_value = True + + # dummy values; we inject them in the field + # mocks (where wizard gets them) and then + # we check that they are passed to QSettings.setValue + field_returns = ["testuser", "1234", "testprovider", True] + + def field_side_effects(*args): + return field_returns.pop(0) + + self.wizard.field = mock.Mock(side_effect=field_side_effects) + self.wizard.get_random_str = mock.Mock() + RANDOMSTR = "thisisarandomstringTM" + self.wizard.get_random_str.return_value = RANDOMSTR + + # mocked settings (see decorator on this method) + mqs = PyQt4.QtCore.QSettings + + # go! call accept... + self.wizard.accept() + + # did settings().setValue get called with the proper + # arguments? + call = mock.call + calls = [call("FirstRunWizardDone", True), + call("provider_domain", "testprovider"), + call("remember_user_and_pass", True), + call("eip_username", "testuser@testprovider"), + call("testprovider_seed", RANDOMSTR)] + mqs().setValue.assert_has_calls(calls, any_order=True) + + # assert success callback is success oh boy + self.wizard.success_cb.assert_called_with() + + def test_random_str(self): + r = self.wizard.get_random_str(42) + self.assertTrue(len(r) == 42) + + def test_page_index(self): + """ + we test both the get_page_index function + and the correct ordering of names + """ + # remember it's implemented as an ordered dict + + pagenames = ('intro', 'providerselection', 'login', 'providerinfo', + 'providersetupvalidation', 'signup', 'signupvalidation', + 'lastpage') + eq = self.assertEqual + w = self.wizard + for index, name in enumerate(pagenames): + eq(w.get_page_index(name), index) + + def test_validation_errors(self): + """ + tests getters and setters for validation errors + """ + page = "testpage" + eq = self.assertEqual + w = self.wizard + eq(w.get_validation_error(page), None) + w.set_validation_error(page, "error") + eq(w.get_validation_error(page), "error") + w.clean_validation_error(page) + eq(w.get_validation_error(page), None) + +if __name__ == "__main__": + unittest.main() diff --git a/src/leap/gui/test_mainwindow_rc.py b/src/leap/gui/tests/test_mainwindow_rc.py index c5abb4aa..67b9fae0 100644 --- a/src/leap/gui/test_mainwindow_rc.py +++ b/src/leap/gui/tests/test_mainwindow_rc.py @@ -27,3 +27,6 @@ class MainWindowResourcesTest(unittest.TestCase): self.assertEqual( hashlib.md5(mainwindow_rc.qt_resource_data).hexdigest(), '53e196f29061d8f08f112e5a2e64eb53') + +if __name__ == "__main__": + unittest.main() diff --git a/src/leap/gui/tests/test_progress.py b/src/leap/gui/tests/test_progress.py new file mode 100644 index 00000000..1f9f9e38 --- /dev/null +++ b/src/leap/gui/tests/test_progress.py @@ -0,0 +1,449 @@ +from collections import namedtuple +import sys +import unittest +import Queue + +import mock + +from leap.testing import qunittest +from leap.testing import pyqt + +from PyQt4 import QtGui +from PyQt4 import QtCore +from PyQt4.QtTest import QTest +from PyQt4.QtCore import Qt + +from leap.gui import progress + + +class ProgressStepTestCase(unittest.TestCase): + + def test_step_attrs(self): + ps = progress.ProgressStep + step = ps('test', False, 1) + # instance + self.assertEqual(step.index, 1) + self.assertEqual(step.name, "test") + self.assertEqual(step.done, False) + step = ps('test2', True, 2) + self.assertEqual(step.index, 2) + self.assertEqual(step.name, "test2") + self.assertEqual(step.done, True) + + # class methods and attrs + self.assertEqual(ps.columns(), ('name', 'done')) + self.assertEqual(ps.NAME, 0) + self.assertEqual(ps.DONE, 1) + + +class ProgressStepContainerTestCase(unittest.TestCase): + def setUp(self): + self.psc = progress.ProgressStepContainer() + + def addSteps(self, number): + Step = progress.ProgressStep + for n in range(number): + self.psc.addStep(Step("%s" % n, False, n)) + + def test_attrs(self): + self.assertEqual(self.psc.columns, + ('name', 'done')) + + def test_add_steps(self): + Step = progress.ProgressStep + self.assertTrue(len(self.psc) == 0) + self.psc.addStep(Step('one', False, 0)) + self.assertTrue(len(self.psc) == 1) + self.psc.addStep(Step('two', False, 1)) + self.assertTrue(len(self.psc) == 2) + + def test_del_all_steps(self): + self.assertTrue(len(self.psc) == 0) + self.addSteps(5) + self.assertTrue(len(self.psc) == 5) + self.psc.removeAllSteps() + self.assertTrue(len(self.psc) == 0) + + def test_del_step(self): + Step = progress.ProgressStep + self.addSteps(5) + self.assertTrue(len(self.psc) == 5) + self.psc.removeStep(self.psc.step(4)) + self.assertTrue(len(self.psc) == 4) + self.psc.removeStep(self.psc.step(4)) + self.psc.removeStep(Step('none', False, 5)) + self.psc.removeStep(self.psc.step(4)) + + def test_iter(self): + self.addSteps(10) + self.assertEqual( + [x.index for x in self.psc], + [x for x in range(10)]) + + +class StepsTableWidgetTestCase(unittest.TestCase): + + def setUp(self): + self.app = QtGui.QApplication(sys.argv) + QtGui.qApp = self.app + self.stw = progress.StepsTableWidget() + + def tearDown(self): + QtGui.qApp = None + self.app = None + + def test_defaults(self): + self.assertTrue(isinstance(self.stw, QtGui.QTableWidget)) + self.assertEqual(self.stw.focusPolicy(), 0) + + +class TestWithStepsClass(QtGui.QWidget, progress.WithStepsMixIn): + + def __init__(self, parent=None): + super(TestWithStepsClass, self).__init__(parent=parent) + self.setupStepsProcessingQueue() + self.statuses = [] + self.current_page = "testpage" + + def onStepStatusChanged(self, *args): + """ + blank out this gui method + that will add status lines + """ + self.statuses.append(args) + + +class WithStepsMixInTestCase(qunittest.TestCase): + + TIMER_WAIT = 2 * progress.WithStepsMixIn.STEPS_TIMER_MS / 1000.0 + + # XXX can spy on signal connections + + def setUp(self): + self.app = QtGui.QApplication(sys.argv) + QtGui.qApp = self.app + self.stepy = TestWithStepsClass() + #self.connects = [] + #pyqt.enableSignalDebugging( + #connectCall=lambda *args: self.connects.append(args)) + #self.assertEqual(self.connects, []) + #self.stepy.stepscheck_timer.timeout.disconnect( + #self.stepy.processStepsQueue) + + def tearDown(self): + QtGui.qApp = None + self.app = None + + def test_has_queue(self): + s = self.stepy + self.assertTrue(hasattr(s, 'steps_queue')) + self.assertTrue(isinstance(s.steps_queue, Queue.Queue)) + self.assertTrue(isinstance(s.stepscheck_timer, QtCore.QTimer)) + + def test_do_checks_delegation(self): + s = self.stepy + + _do_checks = mock.Mock() + _do_checks.return_value = ( + (("test", 0), lambda: None), + (("test", 0), lambda: None)) + s._do_checks = _do_checks + s.do_checks() + self.waitFor(seconds=self.TIMER_WAIT) + _do_checks.assert_called_with() + self.assertEqual(len(s.statuses), 2) + + # test that a failed test interrupts the run + + s.statuses = [] + _do_checks = mock.Mock() + _do_checks.return_value = ( + (("test", 0), lambda: None), + (("test", 0), lambda: False), + (("test", 0), lambda: None)) + s._do_checks = _do_checks + s.do_checks() + self.waitFor(seconds=self.TIMER_WAIT) + _do_checks.assert_called_with() + self.assertEqual(len(s.statuses), 2) + + def test_process_queue(self): + s = self.stepy + q = s.steps_queue + s.set_failed_icon = mock.MagicMock() + with self.assertRaises(AssertionError): + q.put('foo') + self.waitFor(seconds=self.TIMER_WAIT) + s.set_failed_icon.assert_called_with() + q.put("failed") + self.waitFor(seconds=self.TIMER_WAIT) + s.set_failed_icon.assert_called_with() + + def test_on_checks_validation_ready_called(self): + s = self.stepy + s.on_checks_validation_ready = mock.MagicMock() + + _do_checks = mock.Mock() + _do_checks.return_value = ( + (("test", 0), lambda: None),) + s._do_checks = _do_checks + s.do_checks() + + self.waitFor(seconds=self.TIMER_WAIT) + s.on_checks_validation_ready.assert_called_with() + + def test_fail(self): + s = self.stepy + + s.wizard = mock.Mock() + wizard = s.wizard.return_value + wizard.set_validation_error.return_value = True + s.completeChanged = mock.Mock() + s.completeChanged.emit.return_value = True + + self.assertFalse(s.fail(err="foo")) + self.waitFor(seconds=self.TIMER_WAIT) + wizard.set_validation_error.assert_called_with('testpage', 'foo') + s.completeChanged.emit.assert_called_with() + + # with no args + s.wizard = mock.Mock() + wizard = s.wizard.return_value + wizard.set_validation_error.return_value = True + s.completeChanged = mock.Mock() + s.completeChanged.emit.return_value = True + + self.assertFalse(s.fail()) + self.waitFor(seconds=self.TIMER_WAIT) + with self.assertRaises(AssertionError): + wizard.set_validation_error.assert_called_with() + s.completeChanged.emit.assert_called_with() + + def test_done(self): + s = self.stepy + s.done = False + + s.completeChanged = mock.Mock() + s.completeChanged.emit.return_value = True + + self.assertFalse(s.is_done()) + s.set_done() + self.assertTrue(s.is_done()) + s.completeChanged.emit.assert_called_with() + + s.completeChanged = mock.Mock() + s.completeChanged.emit.return_value = True + s.set_undone() + self.assertFalse(s.is_done()) + + def test_back_and_next(self): + s = self.stepy + s.wizard = mock.Mock() + wizard = s.wizard.return_value + wizard.back.return_value = True + wizard.next.return_value = True + s.go_back() + wizard.back.assert_called_with() + s.go_next() + wizard.next.assert_called_with() + + def test_on_step_statuschanged_slot(self): + s = self.stepy + s.onStepStatusChanged = progress.WithStepsMixIn.onStepStatusChanged + s.add_status_line = mock.Mock() + s.set_checked_icon = mock.Mock() + s.progress = mock.Mock() + s.progress.setValue.return_value = True + s.progress.update.return_value = True + + s.onStepStatusChanged(s, "end_sentinel") + s.set_checked_icon.assert_called_with() + + s.onStepStatusChanged(s, "foo") + s.add_status_line.assert_called_with("foo") + + s.onStepStatusChanged(s, "bar", 42) + s.progress.setValue.assert_called_with(42) + s.progress.update.assert_called_with() + + def test_steps_and_errors(self): + s = self.stepy + s.setupSteps() + self.assertTrue(isinstance(s.steps, progress.ProgressStepContainer)) + self.assertEqual(s.errors, {}) + s.set_error('fooerror', 'barerror') + self.assertEqual(s.errors, {'fooerror': 'barerror'}) + s.set_error('2', 42) + self.assertEqual(s.errors, {'fooerror': 'barerror', '2': 42}) + fe = s.pop_first_error() + self.assertEqual(fe, ('fooerror', 'barerror')) + self.assertEqual(s.errors, {'2': 42}) + s.clean_errors() + self.assertEqual(s.errors, {}) + + def test_launch_chechs_slot(self): + s = self.stepy + s.do_checks = mock.Mock() + s.launch_checks() + s.do_checks.assert_called_with() + + def test_clean_wizard_errors(self): + s = self.stepy + s.wizard = mock.Mock() + wizard = s.wizard.return_value + wizard.set_validation_error.return_value = True + s.clean_wizard_errors(pagename="foopage") + wizard.set_validation_error.assert_called_with("foopage", None) + + def test_clear_table(self): + s = self.stepy + s.stepsTableWidget = mock.Mock() + s.stepsTableWidget.clearContents.return_value = True + s.clearTable() + s.stepsTableWidget.clearContents.assert_called_with() + + def test_populate_steps_table(self): + s = self.stepy + Step = namedtuple('Step', ['name', 'done']) + + class Steps(object): + columns = ("name", "done") + _items = (Step('step1', False), Step('step2', False)) + + def __len__(self): + return 2 + + def __iter__(self): + for i in self._items: + yield i + + s.steps = Steps() + + s.stepsTableWidget = mock.Mock() + s.stepsTableWidget.setItem.return_value = True + s.resizeTable = mock.Mock() + s.update = mock.Mock() + s.populateStepsTable() + s.update.assert_called_with() + s.resizeTable.assert_called_with() + + # assert stepsTableWidget.setItem called ... + # we do not want to get into the actual + # <QTableWidgetItem object at 0x92a565c> + call_list = s.stepsTableWidget.setItem.call_args_list + indexes = [(y, z) for y, z, xx in [x[0] for x in call_list]] + self.assertEqual(indexes, + [(0, 0), (0, 1), (1, 0), (1, 1)]) + + def test_add_status_line(self): + s = self.stepy + s.steps = progress.ProgressStepContainer() + s.stepsTableWidget = mock.Mock() + s.stepsTableWidget.width.return_value = 100 + s.set_item = mock.Mock() + s.set_item_icon = mock.Mock() + s.add_status_line("new status") + s.set_item_icon.assert_called_with(current=False) + + def test_set_item_icon(self): + s = self.stepy + s.steps = progress.ProgressStepContainer() + s.stepsTableWidget = mock.Mock() + s.stepsTableWidget.setCellWidget.return_value = True + s.stepsTableWidget.width.return_value = 100 + #s.set_item = mock.Mock() + #s.set_item_icon = mock.Mock() + s.add_status_line("new status") + s.add_status_line("new 2 status") + s.add_status_line("new 3 status") + call_list = s.stepsTableWidget.setCellWidget.call_args_list + indexes = [(y, z) for y, z, xx in [x[0] for x in call_list]] + self.assertEqual( + indexes, + [(0, 1), (-1, 1), (1, 1), (0, 1), (2, 1), (1, 1)]) + + +class TestInlineValidationPage(progress.InlineValidationPage): + pass + + +class InlineValidationPageTestCase(unittest.TestCase): + + def setUp(self): + self.app = QtGui.QApplication(sys.argv) + QtGui.qApp = self.app + self.page = TestInlineValidationPage() + + def tearDown(self): + QtGui.qApp = None + self.app = None + + def test_defaults(self): + self.assertFalse(self.page.done) + # if setupProcessingQueue was called + self.assertTrue(isinstance(self.page.stepscheck_timer, QtCore.QTimer)) + self.assertTrue(isinstance(self.page.steps_queue, Queue.Queue)) + + def test_validation_frame(self): + # test frame creation + self.page.stepsTableWidget = progress.StepsTableWidget( + parent=self.page) + self.page.setupValidationFrame() + self.assertTrue(isinstance(self.page.valFrame, QtGui.QFrame)) + + # test show steps calls frame.show + self.page.valFrame = mock.Mock() + self.page.valFrame.show.return_value = True + self.page.showStepsFrame() + self.page.valFrame.show.assert_called_with() + + +class TestValidationPage(progress.ValidationPage): + pass + + +class ValidationPageTestCase(unittest.TestCase): + + def setUp(self): + self.app = QtGui.QApplication(sys.argv) + QtGui.qApp = self.app + self.page = TestValidationPage() + + def tearDown(self): + QtGui.qApp = None + self.app = None + + def test_defaults(self): + self.assertFalse(self.page.done) + # if setupProcessingQueue was called + self.assertTrue(isinstance(self.page.timer, QtCore.QTimer)) + self.assertTrue(isinstance(self.page.stepscheck_timer, QtCore.QTimer)) + self.assertTrue(isinstance(self.page.steps_queue, Queue.Queue)) + + def test_is_complete(self): + self.assertFalse(self.page.isComplete()) + self.page.done = True + self.assertTrue(self.page.isComplete()) + self.page.done = False + self.assertFalse(self.page.isComplete()) + + def test_show_hide_progress(self): + p = self.page + p.progress = mock.Mock() + p.progress.show.return_code = True + p.show_progress() + p.progress.show.assert_called_with() + p.progress.hide.return_code = True + p.hide_progress() + p.progress.hide.assert_called_with() + + def test_initialize_page(self): + p = self.page + p.timer = mock.Mock() + p.timer.singleShot.return_code = True + p.initializePage() + p.timer.singleShot.assert_called_with(0, p.do_checks) + + +if __name__ == "__main__": + unittest.main() diff --git a/src/leap/gui/tests/test_threads.py b/src/leap/gui/tests/test_threads.py new file mode 100644 index 00000000..06c19606 --- /dev/null +++ b/src/leap/gui/tests/test_threads.py @@ -0,0 +1,27 @@ +import unittest + +import mock +from leap.gui import threads + + +class FunThreadTestCase(unittest.TestCase): + + def setUp(self): + self.fun = mock.MagicMock() + self.fun.return_value = "foo" + self.t = threads.FunThread(fun=self.fun) + + def test_thread(self): + self.t.begin() + self.t.wait() + self.fun.assert_called() + del self.t + + def test_run(self): + # this is called by PyQt + self.t.run() + del self.t + self.fun.assert_called() + +if __name__ == "__main__": + unittest.main() diff --git a/src/leap/testing/pyqt.py b/src/leap/testing/pyqt.py new file mode 100644 index 00000000..6edaf059 --- /dev/null +++ b/src/leap/testing/pyqt.py @@ -0,0 +1,52 @@ +from PyQt4 import QtCore + +_oldConnect = QtCore.QObject.connect +_oldDisconnect = QtCore.QObject.disconnect +_oldEmit = QtCore.QObject.emit + + +def _wrapConnect(callableObject): + """ + Returns a wrapped call to the old version of QtCore.QObject.connect + """ + @staticmethod + def call(*args): + callableObject(*args) + _oldConnect(*args) + return call + + +def _wrapDisconnect(callableObject): + """ + Returns a wrapped call to the old version of QtCore.QObject.disconnect + """ + @staticmethod + def call(*args): + callableObject(*args) + _oldDisconnect(*args) + return call + + +def enableSignalDebugging(**kwargs): + """ + Call this to enable Qt Signal debugging. This will trap all + connect, and disconnect calls. + """ + + f = lambda *args: None + connectCall = kwargs.get('connectCall', f) + disconnectCall = kwargs.get('disconnectCall', f) + emitCall = kwargs.get('emitCall', f) + + def printIt(msg): + def call(*args): + print msg, args + return call + QtCore.QObject.connect = _wrapConnect(connectCall) + QtCore.QObject.disconnect = _wrapDisconnect(disconnectCall) + + def new_emit(self, *args): + emitCall(self, *args) + _oldEmit(self, *args) + + QtCore.QObject.emit = new_emit diff --git a/src/leap/testing/qunittest.py b/src/leap/testing/qunittest.py new file mode 100644 index 00000000..b89ccec3 --- /dev/null +++ b/src/leap/testing/qunittest.py @@ -0,0 +1,302 @@ +# -*- coding: utf-8 -*- + +# **qunittest** is an standard Python `unittest` enhancement for PyQt4, +# allowing +# you to test asynchronous code using standard synchronous testing facility. +# +# The source for `qunittest` is available on [GitHub][gh], and released under +# the MIT license. +# +# Slightly modified by The Leap Project. + +### Prerequisites + +# Import unittest2 or unittest +try: + import unittest2 as unittest +except ImportError: + import unittest + +# ... and some standard Python libraries +import sys +import functools +import contextlib +import re + +# ... and several PyQt classes +from PyQt4.QtCore import QTimer +from PyQt4.QtTest import QTest +from PyQt4 import QtGui + +### The code + + +# Override standard main method, by invoking it inside PyQt event loop + +def main(*args, **kwargs): + qapplication = QtGui.QApplication(sys.argv) + + QTimer.singleShot(0, unittest.main(*args, **kwargs)) + qapplication.exec_() + +""" +This main substitute does not integrate with unittest. + +Note about mixing the event loop and unittests: + +Unittest will fail if we keep more than one reference to a QApplication. +(pyqt expects to be and only one). +So, for the things that need a QApplication to exist, do something like: + + self.app = QApplication() + QtGui.qApp = self.app + +in the class setUp, and:: + + QtGui.qApp = None + self.app = None + +in the class tearDown. + +For some explanation about this, see + http://stuvel.eu/blog/127/multiple-instances-of-qapplication-in-one-process +and + http://www.riverbankcomputing.com/pipermail/pyqt/2010-September/027705.html +""" + + +# Helper returning the name of a given signal + +def _signal_name(signal): + s = repr(signal) + name_re = "signal (\w+) of (\w+)" + match = re.search(name_re, s, re.I) + if not match: + return "??" + return "%s#%s" % (match.group(2), match.group(1)) + + +class _SignalConnector(object): + """ Encapsulates signal assertion testing """ + def __init__(self, test, signal, callable_): + self.test = test + self.callable_ = callable_ + self.called_with = None + self.emited = False + self.signal = signal + self._asserted = False + + signal.connect(self.on_signal_emited) + + # Store given parameters and mark signal as `emited` + def on_signal_emited(self, *args, **kwargs): + self.called_with = (args, kwargs) + self.emited = True + + def assertEmission(self): + # Assert once wheter signal was emited or not + was_asserted = self._asserted + self._asserted = True + + if not was_asserted: + if not self.emited: + self.test.fail( + "signal %s not emited" % (_signal_name(self.signal))) + + # Call given callable is necessary + if self.callable_: + args, kwargs = self.called_with + self.callable_(*args, **kwargs) + + def __enter__(self): + # Assert emission when context is entered + self.assertEmission() + return self.called_with + + def __exit__(self, *_): + return False + +### Unit Testing + +# `qunittest` does not force much abould how test should look - it just adds +# several helpers for asynchronous code testing. +# +# Common test case may look like this: +# +# import qunittest +# from calculator import Calculator +# +# class TestCalculator(qunittest.TestCase): +# def setUp(self): +# self.calc = Calculator() +# +# def test_should_add_two_numbers_synchronously(self): +# # given +# a, b = 2, 3 +# +# # when +# r = self.calc.add(a, b) +# +# # then +# self.assertEqual(5, r) +# +# def test_should_calculate_factorial_in_background(self): +# # given +# +# # when +# self.calc.factorial(20) +# +# # then +# self.assertEmited(self.calc.done) with (args, kwargs): +# self.assertEqual([2432902008176640000], args) +# +# if __name__ == "__main__": +# main() +# +# Test can be run by typing: +# +# python test_calculator.py +# +# Automatic test discovery is not supported now, because testing PyQt needs +# an instance of `QApplication` and its `exec_` method is blocking. +# + + +### TestCase class + +class TestCase(unittest.TestCase): + """ + Extends standard `unittest.TestCase` with several PyQt4 testing features + useful for asynchronous testing. + """ + def __init__(self, *args, **kwargs): + super(TestCase, self).__init__(*args, **kwargs) + + self._clearSignalConnectors() + self._succeeded = False + self.addCleanup(self._clearSignalConnectors) + self.tearDown = self._decorateTearDown(self.tearDown) + + ### Protected methods + + def _clearSignalConnectors(self): + self._connectedSignals = [] + + def _decorateTearDown(self, tearDown): + @functools.wraps(tearDown) + def decorator(): + self._ensureEmitedSignals() + return tearDown() + return decorator + + def _ensureEmitedSignals(self): + """ + Checks if signals were acually emited. Raises AssertionError if no. + """ + # TODO: add information about line + for signal in self._connectedSignals: + signal.assertEmission() + + ### Assertions + + def assertEmited(self, signal, callable_=None, timeout=1): + """ + Asserts if given `signal` was emited. Waits 1 second by default, + before asserts signal emission. + + If `callable_` is given, it should be a function which takes two + arguments: `args` and `kwargs`. It will be called after blocking + operation or when assertion about signal emission is made and + signal was emited. + + When timeout is not `False`, method call is blocking, and ends + after `timeout` seconds. After that time, it validates wether + signal was emited. + + When timeout is `False`, method is non blocking, and test should wait + for signals afterwards. Otherwise, at the end of the test, all + signal emissions are checked if appeared. + + Function returns context, which yields to list of parameters given + to signal. It can be useful for testing given parameters. Following + code: + + with self.assertEmited(widget.signal) as (args, kwargs): + self.assertEqual(1, len(args)) + self.assertEqual("Hello World!", args[0]) + + will wait 1 second and test for correct parameters, is signal was + emtied. + + Note that code: + + with self.assertEmited(widget.signal, timeout=False) as (a, k): + # Will not be invoked + + will always fail since signal cannot be emited in the time of its + connection - code inside the context will not be invoked at all. + """ + + connector = _SignalConnector(self, signal, callable_) + self._connectedSignals.append(connector) + if timeout: + self.waitFor(timeout) + connector.assertEmission() + + return connector + + ### Helper methods + + @contextlib.contextmanager + def invokeAfter(self, seconds, callable_=None): + """ + Waits given amount of time and executes the context. + + If `callable_` is given, executes it, instead of context. + """ + self.waitFor(seconds) + if callable_: + callable_() + else: + yield + + def waitFor(self, seconds): + """ + Waits given amount of time. + + self.widget.loadImage(url) + self.waitFor(seconds=10) + """ + QTest.qWait(seconds * 1000) + + def succeed(self, bool_=True): + """ Marks test as suceeded for next `failAfter()` invocation. """ + self._succeeded = self._succeeded or bool_ + + def failAfter(self, seconds, message=None): + """ + Waits given amount of time, and fails the test if `succeed(bool)` + is not called - in most common case, `succeed(bool)` should be called + asynchronously (in signal handler): + + self.widget.signal.connect(lambda: self.succeed()) + self.failAfter(1, "signal not emited?") + + After invocation, test is no longer consider as succeeded. + """ + self.waitFor(seconds) + if not self._succeeded: + self.fail(message) + + self._succeeded = False + +### Credits +# +# * **Who is responsible:** [Dawid Fatyga][df] +# * **Source:** [GitHub][gh] +# * **Doc. generator:** [rocco][ro] +# +# [gh]: https://www.github.com/dejw/qunittest +# [df]: https://github.com/dejw +# [ro]: http://rtomayko.github.com/rocco/ +# diff --git a/src/leap/util/fileutil.py b/src/leap/util/fileutil.py index aef4cfe0..820ffe46 100644 --- a/src/leap/util/fileutil.py +++ b/src/leap/util/fileutil.py @@ -93,6 +93,11 @@ def mkdir_p(path): raise +def mkdir_f(path): + folder, fname = os.path.split(path) + mkdir_p(folder) + + def check_and_fix_urw_only(_file): """ test for 600 mode and try diff --git a/src/leap/util/misc.py b/src/leap/util/misc.py new file mode 100644 index 00000000..3c26892b --- /dev/null +++ b/src/leap/util/misc.py @@ -0,0 +1,16 @@ +""" +misc utils +""" + + +class ImproperlyConfigured(Exception): + """ + """ + + +def null_check(value, value_name): + try: + assert value is not None + except AssertionError: + raise ImproperlyConfigured( + "%s parameter cannot be None" % value_name) diff --git a/src/leap/util/web.py b/src/leap/util/web.py index b2aef058..15de0561 100644 --- a/src/leap/util/web.py +++ b/src/leap/util/web.py @@ -13,6 +13,7 @@ def get_https_domain_and_port(full_domain): from a full_domain string that can contain a colon """ + full_domain = unicode(full_domain) if full_domain is None: return None, None |