diff options
| -rw-r--r-- | src/leap/crypto/srpauth.py | 177 | ||||
| -rw-r--r-- | src/leap/crypto/tests/test_srpauth.py | 755 | ||||
| -rw-r--r-- | src/leap/crypto/tests/test_srpregister.py | 5 | ||||
| -rw-r--r-- | src/leap/util/request_helpers.py | 2 | 
4 files changed, 893 insertions, 46 deletions
| diff --git a/src/leap/crypto/srpauth.py b/src/leap/crypto/srpauth.py index 89fee80b..8e228e79 100644 --- a/src/leap/crypto/srpauth.py +++ b/src/leap/crypto/srpauth.py @@ -30,8 +30,8 @@ from PySide import QtCore  from twisted.internet import threads  from leap.common.check import leap_assert -from leap.util.request_helpers import get_content  from leap.util.constants import REQUEST_TIMEOUT +from leap.util import request_helpers as reqhelper  from leap.common.events import signal as events_signal  from leap.common.events import events_pb2 as proto @@ -45,6 +45,71 @@ class SRPAuthenticationError(Exception):      pass +class SRPAuthConnectionError(SRPAuthenticationError): +    """ +    Exception raised when there's a connection error +    """ +    pass + +class SRPAuthUnknownUser(SRPAuthenticationError): +    """ +    Exception raised when trying to authenticate an unknown user +    """ +    pass + +class SRPAuthBadStatusCode(SRPAuthenticationError): +    """ +    Exception raised when we received an unknown bad status code +    """ +    pass + +class SRPAuthNoSalt(SRPAuthenticationError): +    """ +    Exception raised when we don't receive the salt param at a +    specific point in the auth process +    """ +    pass + +class SRPAuthNoB(SRPAuthenticationError): +    """ +    Exception raised when we don't receive the B param at a specific +    point in the auth process +    """ +    pass + +class SRPAuthBadDataFromServer(SRPAuthenticationError): +    """ +    Generic exception when we receive bad data from the server. +    """ +    pass + +class SRPAuthJSONDecodeError(SRPAuthenticationError): +    """ +    Exception raised when there's a problem decoding the JSON content +    parsed as received from th e server. +    """ +    pass + +class SRPAuthBadPassword(SRPAuthenticationError): +    """ +    Exception raised when the user provided a bad password to auth. +    """ +    pass + +class SRPAuthVerificationFailed(SRPAuthenticationError): +    """ +    Exception raised when we can't verify the SRP data received from +    the server. +    """ +    pass + +class SRPAuthNoSessionId(SRPAuthenticationError): +    """ +    Exception raised when we don't receive a session id from the +    server. +    """ +    pass +  class SRPAuth(QtCore.QObject):      """      SRPAuth singleton @@ -126,19 +191,23 @@ class SRPAuth(QtCore.QObject):              self._srp_a = A -        def _start_authentication(self, _, username, password): +        def _start_authentication(self, _, username):              """              Sends the first request for authentication to retrieve the              salt and B parameter -            Might raise SRPAuthenticationError +            Might raise all SRPAuthenticationError based: +              SRPAuthenticationError +              SRPAuthConnectionError +              SRPAuthUnknownUser +              SRPAuthBadStatusCode +              SRPAuthNoSalt +              SRPAuthNoB              :param _: IGNORED, output from the previous callback (None)              :type _: IGNORED              :param username: username to login              :type username: str -            :param password: password for the username -            :type password: str              :return: salt and B parameters              :rtype: tuple @@ -158,24 +227,29 @@ class SRPAuth(QtCore.QObject):                                                    verify=self._provider_config.                                                    get_ca_cert_path(),                                                    timeout=REQUEST_TIMEOUT) +                # Clean up A value, we don't need it anymore +                self._srp_a = None              except requests.exceptions.ConnectionError as e:                  logger.error("No connection made (salt): %r" %                               (e,)) -                raise SRPAuthenticationError("Could not establish a " +                raise SRPAuthConnectionError("Could not establish a "                                               "connection")              except Exception as e:                  logger.error("Unknown error: %r" % (e,))                  raise SRPAuthenticationError("Unknown error: %r" %                                               (e,)) -            content, mtime = get_content(init_session) +            content, mtime = reqhelper.get_content(init_session)              if init_session.status_code not in (200,):                  logger.error("No valid response (salt): "                               "Status code = %r. Content: %r" %                               (init_session.status_code, content))                  if init_session.status_code == 422: -                    raise SRPAuthenticationError(self.tr("Unknown user")) +                    raise SRPAuthUnknownUser(self.tr("Unknown user")) + +                raise SRPAuthBadStatusCode(self.tr("There was a problem with" +                                                   " authentication"))              json_content = json.loads(content)              salt = json_content.get("salt", None) @@ -183,12 +257,12 @@ class SRPAuth(QtCore.QObject):              if salt is None:                  logger.error("No salt parameter sent") -                raise SRPAuthenticationError(self.tr("The server did not send " -                                                     "the salt parameter")) +                raise SRPAuthNoSalt(self.tr("The server did not send " +                                            "the salt parameter"))              if B is None:                  logger.error("No B parameter sent") -                raise SRPAuthenticationError(self.tr("The server did not send " -                                                     "the B parameter")) +                raise SRPAuthNoB(self.tr("The server did not send " +                                         "the B parameter"))              return salt, B @@ -197,7 +271,12 @@ class SRPAuth(QtCore.QObject):              Given the salt and B processes the auth challenge and              generates the M2 parameter -            Might throw SRPAuthenticationError +            Might raise SRPAuthenticationError based: +              SRPAuthenticationError +              SRPAuthBadDataFromServer +              SRPAuthConnectionError +              SRPAuthJSONDecodeError +              SRPAuthBadPassword              :param salt_B: salt and B parameters for the username              :type salt_B: tuple @@ -212,10 +291,10 @@ class SRPAuth(QtCore.QObject):                  salt, B = salt_B                  unhex_salt = self._safe_unhexlify(salt)                  unhex_B = self._safe_unhexlify(B) -            except TypeError as e: +            except (TypeError, ValueError) as e:                  logger.error("Bad data from server: %r" % (e,)) -                raise SRPAuthenticationError(self.tr("The data sent from " -                                                     "the server had errors")) +                raise SRPAuthBadDataFromServer( +                    self.tr("The data sent from the server had errors"))              M = self._srp_user.process_challenge(unhex_salt, unhex_B)              auth_url = "%s/%s/%s/%s" % (self._provider_config.get_api_uri(), @@ -236,13 +315,13 @@ class SRPAuth(QtCore.QObject):                                                  timeout=REQUEST_TIMEOUT)              except requests.exceptions.ConnectionError as e:                  logger.error("No connection made (HAMK): %r" % (e,)) -                raise SRPAuthenticationError(self.tr("Could not connect to " +                raise SRPAuthConnectionError(self.tr("Could not connect to "                                                       "the server"))              try: -                content, mtime = get_content(auth_result) +                content, mtime = reqhelper.get_content(auth_result)              except JSONDecodeError: -                raise SRPAuthenticationError("Bad JSON content in auth result") +                raise SRPAuthJSONDecodeError("Bad JSON content in auth result")              if auth_result.status_code == 422:                  error = "" @@ -256,35 +335,47 @@ class SRPAuth(QtCore.QObject):                                   "received: %s", (content,))                  logger.error("[%s] Wrong password (HAMK): [%s]" %                               (auth_result.status_code, error)) -                raise SRPAuthenticationError(self.tr("Wrong password")) +                raise SRPAuthBadPassword(self.tr("Wrong password"))              if auth_result.status_code not in (200,):                  logger.error("No valid response (HAMK): "                               "Status code = %s. Content = %r" %                               (auth_result.status_code, content)) -                raise SRPAuthenticationError(self.tr("Unknown error (%s)") % -                                             (auth_result.status_code,)) +                raise SRPAuthBadStatusCode(self.tr("Unknown error (%s)") % +                                           (auth_result.status_code,)) -            json_content = json.loads(content) +            return json.loads(content) + +        def _extract_data(self, json_content): +            """ +            Extracts the necessary parameters from json_content (M2, +            id, token) + +            Might raise SRPAuthenticationError based: +              SRPBadDataFromServer +            :param json_content: Data received from the server +            :type json_content: dict +            """              try:                  M2 = json_content.get("M2", None)                  uid = json_content.get("id", None)                  token = json_content.get("token", None)              except Exception as e:                  logger.error(e) -                raise Exception("Something went wrong with the login") - -            events_signal(proto.CLIENT_UID, content=uid) +                raise SRPAuthBadDataFromServer("Something went wrong with the " +                                               "login")              self.set_uid(uid)              self.set_token(token)              if M2 is None or self.get_uid() is None:                  logger.error("Something went wrong. Content = %r" % -                             (content,)) -                raise SRPAuthenticationError(self.tr("Problem getting data " -                                                     "from server")) +                             (json_content,)) +                raise SRPAuthBadDataFromServer(self.tr("Problem getting data " +                                                       "from server")) + +            events_signal(proto.CLIENT_UID, content=uid)              return M2 @@ -294,7 +385,9 @@ class SRPAuth(QtCore.QObject):              verification succeeds, it sets the session_id for this              session -            Might throw SRPAuthenticationError +            Might raise SRPAuthenticationError based: +              SRPAuthBadDataFromServer +              SRPAuthVerificationFailed              :param M2: M2 SRP parameter              :type M2: str @@ -304,22 +397,22 @@ class SRPAuth(QtCore.QObject):                  unhex_M2 = self._safe_unhexlify(M2)              except TypeError:                  logger.error("Bad data from server (HAWK)") -                raise SRPAuthenticationError(self.tr("Bad data from server")) +                raise SRPAuthBadDataFromServer(self.tr("Bad data from server"))              self._srp_user.verify_session(unhex_M2)              if not self._srp_user.authenticated():                  logger.error("Auth verification failed") -                raise SRPAuthenticationError(self.tr("Auth verification " -                                                     "failed")) +                raise SRPAuthVerificationFailed(self.tr("Auth verification " +                                                        "failed"))              logger.debug("Session verified.")              session_id = self._session.cookies.get(self.SESSION_ID_KEY, None)              if not session_id:                  logger.error("Bad cookie from server (missing _session_id)") -                raise SRPAuthenticationError(self.tr("Session cookie " -                                                     "verification " -                                                     "failed")) +                raise SRPAuthNoSessionId(self.tr("Session cookie " +                                                 "verification " +                                                 "failed"))              events_signal(proto.CLIENT_SESSION_ID, content=session_id) @@ -351,12 +444,15 @@ class SRPAuth(QtCore.QObject):              d.addCallback(                  partial(self._threader,                          self._start_authentication), -                username=username, -                password=password) +                username=username)              d.addCallback(                  partial(self._threader,                          self._process_challenge),                  username=username) +            d.addCallback( +                partial(self._threader, +                        self._extract_data), +                username=username)              d.addCallback(partial(self._threader,                                    self._verify_session)) @@ -435,14 +531,11 @@ class SRPAuth(QtCore.QObject):          # Store instance reference as the only member in the handle          self.__dict__['_SRPAuth__instance'] = SRPAuth.__instance -        self._username = None -        self._password = None -      def authenticate(self, username, password):          """          Executes the whole authentication process for a user -        Might raise SRPAuthenticationError +        Might raise SRPAuthenticationError based          :param username: username for this session          :type username: str diff --git a/src/leap/crypto/tests/test_srpauth.py b/src/leap/crypto/tests/test_srpauth.py new file mode 100644 index 00000000..9a684a8f --- /dev/null +++ b/src/leap/crypto/tests/test_srpauth.py @@ -0,0 +1,755 @@ +# -*- coding: utf-8 -*- +# test_srpauth.py +# Copyright (C) 2013 LEAP +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program.  If not, see <http://www.gnu.org/licenses/>. +""" +Tests for: +    * leap/crypto/srpauth.py +""" +try: +    import unittest2 as unittest +except ImportError: +    import unittest +import os +import sys +import binascii +import requests +import mock + +from mock import MagicMock +from nose.twistedtools import reactor, deferred +from twisted.python import log +from twisted.internet import threads +from functools import partial +from requests.models import Response +from simplejson.decoder import JSONDecodeError + +from leap.common.testing.https_server import where +from leap.config.providerconfig import ProviderConfig +from leap.crypto import srpregister, srpauth +from leap.crypto.tests import fake_provider + +log.startLogging(sys.stdout) + + +def _get_capath(): +    return where("cacert.pem") + +_here = os.path.split(__file__)[0] + + +class ImproperlyConfiguredError(Exception): +    """ +    Raised if the test provider is missing configuration +    """ + + +class SRPAuthTestCase(unittest.TestCase): +    """ +    Tests for the SRPAuth class +    """ +    __name__ = "SRPAuth tests" + +    def setUp(self): +        """ +        Sets up this TestCase with a simple and faked provider instance: + +        * runs a threaded reactor +        * loads a mocked ProviderConfig that points to the certs in the +          leap.common.testing module. +        """ +        factory = fake_provider.get_provider_factory() +        http = reactor.listenTCP(0, factory) +        https = reactor.listenSSL( +            0, factory, +            fake_provider.OpenSSLServerContextFactory()) +        get_port = lambda p: p.getHost().port +        self.http_port = get_port(http) +        self.https_port = get_port(https) + +        provider = ProviderConfig() +        provider.get_ca_cert_path = MagicMock() +        provider.get_ca_cert_path.return_value = _get_capath() + +        provider.get_api_uri = MagicMock() +        provider.get_api_uri.return_value = self._get_https_uri() + +        loaded = provider.load(path=os.path.join( +            _here, "test_provider.json")) +        if not loaded: +            raise ImproperlyConfiguredError( +                "Could not load test provider config") +        self.register = srpregister.SRPRegister(provider_config=provider) +        self.provider = provider +        self.TEST_USER = "register_test_auth" +        self.TEST_PASS = "pass" + +        # Reset the singleton +        srpauth.SRPAuth._SRPAuth__instance = None +        self.auth = srpauth.SRPAuth(self.provider) +        self.auth_backend = self.auth._SRPAuth__instance + +        self.old_post = self.auth_backend._session.post +        self.old_put = self.auth_backend._session.put +        self.old_delete = self.auth_backend._session.delete + +        self.old_start_auth = self.auth_backend._start_authentication +        self.old_proc_challenge = self.auth_backend._process_challenge +        self.old_extract_data = self.auth_backend._extract_data +        self.old_verify_session = self.auth_backend._verify_session +        self.old_auth_preproc = self.auth_backend._authentication_preprocessing +        self.old_get_sid = self.auth_backend.get_session_id +        self.old_cookie_get = self.auth_backend._session.cookies.get +        self.old_auth = self.auth_backend.authenticate + +    def tearDown(self): +        self.auth_backend._session.post = self.old_post +        self.auth_backend._session.put = self.old_put +        self.auth_backend._session.delete = self.old_delete + +        self.auth_backend._start_authentication = self.old_start_auth +        self.auth_backend._process_challenge = self.old_proc_challenge +        self.auth_backend._extract_data = self.old_extract_data +        self.auth_backend._verify_session = self.old_verify_session +        self.auth_backend._authentication_preprocessing = self.old_auth_preproc +        self.auth_backend.get_session_id = self.old_get_sid +        self.auth_backend._session.cookies.get = self.old_cookie_get +        self.auth_backend.authenticate = self.old_auth + +    # helper methods + +    def _get_https_uri(self): +        """ +        Returns a https uri with the right https port initialized +        """ +        return "https://localhost:%s" % (self.https_port,) + +    # Auth tests + +    def _prepare_auth_test(self, code=200, side_effect=None): +        """ +        Creates the needed defers to test several test situations. It +        adds up to the auth preprocessing step. + +        :param code: status code for the response of POST in requests +        :type code: int +        :param side_effect: side effect triggered by the POST method +                            in requests +        :type side_effect: some kind of Exception + +        :returns: the defer that is created +        :rtype: defer.Deferred +        """ +        res = Response() +        res.status_code = code +        self.auth_backend._session.post = MagicMock(return_value=res, +                                                    side_effect=side_effect) + +        d = threads.deferToThread(self.register.register_user, +                                  self.TEST_USER, +                                  self.TEST_PASS) + +        def wrapper_preproc(*args): +            return threads.deferToThread( +                self.auth_backend._authentication_preprocessing, +                self.TEST_USER, self.TEST_PASS) + +        d.addCallback(wrapper_preproc) + +        return d + +    def test_safe_unhexlify(self): +        input_value = "somestring" +        test_value = binascii.hexlify(input_value) +        self.assertEqual( +            self.auth_backend._safe_unhexlify(test_value), +            input_value) + +    def test_safe_unhexlify_not_raises(self): +        input_value = "somestring" +        test_value = binascii.hexlify(input_value)[:-1] + +        with self.assertRaises(TypeError): +            binascii.unhexlify(test_value) + +        self.auth_backend._safe_unhexlify(test_value) + +    def test_preprocessing_loads_a(self): +        self.assertEqual(self.auth_backend._srp_a, None) +        self.auth_backend._authentication_preprocessing("user", "pass") +        self.assertIsNotNone(self.auth_backend._srp_a) +        self.assertTrue(len(self.auth_backend._srp_a) > 0) + +    @deferred() +    def test_start_authentication(self): +        d = threads.deferToThread(self.register.register_user, self.TEST_USER, +                                  self.TEST_PASS) + +        def wrapper_preproc(*args): +            return threads.deferToThread( +                self.auth_backend._authentication_preprocessing, +                self.TEST_USER, self.TEST_PASS) + +        d.addCallback(wrapper_preproc) + +        def wrapper(_): +            return threads.deferToThread( +                self.auth_backend._start_authentication, +                None, self.TEST_USER) + +        d.addCallback(wrapper) +        return d + +    @deferred() +    def test_start_authentication_fails_connerror(self): +        d = self._prepare_auth_test( +            side_effect=requests.exceptions.ConnectionError()) + +        def wrapper(_): +            with self.assertRaises(srpauth.SRPAuthConnectionError): +                self.auth_backend._start_authentication(None, self.TEST_USER) + +        d.addCallback(partial(threads.deferToThread, wrapper)) +        return d + +    @deferred() +    def test_start_authentication_fails_any_error(self): +        d = self._prepare_auth_test(side_effect=Exception()) + +        def wrapper(_): +            with self.assertRaises(srpauth.SRPAuthenticationError): +                self.auth_backend._start_authentication(None, self.TEST_USER) + +        d.addCallback(partial(threads.deferToThread, wrapper)) +        return d + +    @deferred() +    def test_start_authentication_fails_unknown_user(self): +        d = self._prepare_auth_test(422) + +        def wrapper(_): +            with self.assertRaises(srpauth.SRPAuthUnknownUser): +                with mock.patch('leap.util.request_helpers.get_content', +                                new_callable=MagicMock()) as \ +                        content: +                    content.return_value = ("{}", 0) + +                    self.auth_backend._start_authentication( +                        None, self.TEST_USER) + +        d.addCallback(partial(threads.deferToThread, wrapper)) +        return d + +    @deferred() +    def test_start_authentication_fails_errorcode(self): +        d = self._prepare_auth_test(302) + +        def wrapper(_): +            with self.assertRaises(srpauth.SRPAuthBadStatusCode): +                with mock.patch('leap.util.request_helpers.get_content', +                                new_callable=MagicMock()) as \ +                        content: +                    content.return_value = ("{}", 0) + +                    self.auth_backend._start_authentication(None, +                                                            self.TEST_USER) + +        d.addCallback(partial(threads.deferToThread, wrapper)) +        return d + +    @deferred() +    def test_start_authentication_fails_no_salt(self): +        d = self._prepare_auth_test(200) + +        def wrapper(_): +            with self.assertRaises(srpauth.SRPAuthNoSalt): +                with mock.patch('leap.util.request_helpers.get_content', +                                new_callable=MagicMock()) as \ +                        content: +                    content.return_value = ("{}", 0) + +                    self.auth_backend._start_authentication(None, +                                                            self.TEST_USER) + +        d.addCallback(partial(threads.deferToThread, wrapper)) +        return d + +    @deferred() +    def test_start_authentication_fails_no_B(self): +        d = self._prepare_auth_test(200) + +        def wrapper(_): +            with self.assertRaises(srpauth.SRPAuthNoB): +                with mock.patch('leap.util.request_helpers.get_content', +                                new_callable=MagicMock()) as \ +                        content: +                    content.return_value = ('{"salt": ""}', 0) + +                    self.auth_backend._start_authentication(None, +                                                            self.TEST_USER) + +        d.addCallback(partial(threads.deferToThread, wrapper)) +        return d + +    @deferred() +    def test_start_authentication_correct_saltb(self): +        d = self._prepare_auth_test(200) + +        test_salt = "12345" +        test_B = "67890" + +        def wrapper(_): +            with mock.patch('leap.util.request_helpers.get_content', +                            new_callable=MagicMock()) as \ +                    content: +                content.return_value = ('{"salt":"%s", "B":"%s"}' % (test_salt, +                                                                     test_B), +                                        0) + +                salt, B = self.auth_backend._start_authentication( +                    None, +                    self.TEST_USER) +                self.assertEqual(salt, test_salt) +                self.assertEqual(B, test_B) + +        d.addCallback(partial(threads.deferToThread, wrapper)) +        return d + +    def _prepare_auth_challenge(self): +        """ +        Creates the needed defers to test several test situations. It +        adds up to the start authentication step. + +        :returns: the defer that is created +        :rtype: defer.Deferred +        """ +        d = threads.deferToThread(self.register.register_user, +                                  self.TEST_USER, +                                  self.TEST_PASS) + +        def wrapper_preproc(*args): +            return threads.deferToThread( +                self.auth_backend._authentication_preprocessing, +                self.TEST_USER, self.TEST_PASS) + +        d.addCallback(wrapper_preproc) + +        def wrapper_start(*args): +            return threads.deferToThread( +                self.auth_backend._start_authentication, +                None, self.TEST_USER) + +        d.addCallback(wrapper_start) + +        return d + +    @deferred() +    def test_process_challenge_wrong_saltb(self): +        d = self._prepare_auth_challenge() + +        def wrapper(salt_B): +            with self.assertRaises(srpauth.SRPAuthBadDataFromServer): +                self.auth_backend._process_challenge("", +                                                     username=self.TEST_USER) + +        d.addCallback(partial(threads.deferToThread, wrapper)) +        return d + +    @deferred() +    def test_process_challenge_requests_problem_raises(self): +        d = self._prepare_auth_challenge() + +        self.auth_backend._session.put = MagicMock( +            side_effect=requests.exceptions.ConnectionError()) + +        def wrapper(salt_B): +            with self.assertRaises(srpauth.SRPAuthConnectionError): +                self.auth_backend._process_challenge(salt_B, +                                                     username=self.TEST_USER) + +        d.addCallback(partial(threads.deferToThread, wrapper)) + +        return d + +    @deferred() +    def test_process_challenge_json_decode_error(self): +        d = self._prepare_auth_challenge() + +        def wrapper(salt_B): +            with mock.patch('leap.util.request_helpers.get_content', +                            new_callable=MagicMock()) as \ +                    content: +                content.return_value = ("{", 0) +                content.side_effect = JSONDecodeError("", "", 0) + +                with self.assertRaises(srpauth.SRPAuthJSONDecodeError): +                   self.auth_backend._process_challenge( +                       salt_B, +                       username=self.TEST_USER) + +        d.addCallback(partial(threads.deferToThread, wrapper)) + +        return d + +    @deferred() +    def test_process_challenge_bad_password(self): +        d = self._prepare_auth_challenge() + +        res = Response() +        res.status_code = 422 +        self.auth_backend._session.put = MagicMock(return_value=res) + +        def wrapper(salt_B): +            with mock.patch('leap.util.request_helpers.get_content', +                            new_callable=MagicMock()) as \ +                    content: +                content.return_value = ("", 0) +                with self.assertRaises(srpauth.SRPAuthBadPassword): +                    self.auth_backend._process_challenge( +                        salt_B, +                        username=self.TEST_USER) + +        d.addCallback(partial(threads.deferToThread, wrapper)) + +        return d + +    @deferred() +    def test_process_challenge_bad_password2(self): +        d = self._prepare_auth_challenge() + +        res = Response() +        res.status_code = 422 +        self.auth_backend._session.put = MagicMock(return_value=res) + +        def wrapper(salt_B): +            with mock.patch('leap.util.request_helpers.get_content', +                            new_callable=MagicMock()) as \ +                    content: +                content.return_value = ("[]", 0) +                with self.assertRaises(srpauth.SRPAuthBadPassword): +                    self.auth_backend._process_challenge( +                        salt_B, +                        username=self.TEST_USER) + +        d.addCallback(partial(threads.deferToThread, wrapper)) + +        return d + +    @deferred() +    def test_process_challenge_other_error_code(self): +        d = self._prepare_auth_challenge() + +        res = Response() +        res.status_code = 300 +        self.auth_backend._session.put = MagicMock(return_value=res) + +        def wrapper(salt_B): +            with mock.patch('leap.util.request_helpers.get_content', +                            new_callable=MagicMock()) as \ +                    content: +                content.return_value = ("{}", 0) +                with self.assertRaises(srpauth.SRPAuthBadStatusCode): +                    self.auth_backend._process_challenge( +                        salt_B, +                        username=self.TEST_USER) + +        d.addCallback(partial(threads.deferToThread, wrapper)) + +        return d + +    @deferred() +    def test_process_challenge(self): +        d = self._prepare_auth_challenge() + +        def wrapper(salt_B): +            self.auth_backend._process_challenge(salt_B, +                                                 username=self.TEST_USER) + +        d.addCallback(partial(threads.deferToThread, wrapper)) + +        return d + +    def test_extract_data_wrong_data(self): +        with self.assertRaises(srpauth.SRPAuthBadDataFromServer): +            self.auth_backend._extract_data(None) + +        with self.assertRaises(srpauth.SRPAuthBadDataFromServer): +            self.auth_backend._extract_data("") + +    def test_extract_data_fails_on_wrong_data_from_server(self): +        with self.assertRaises(srpauth.SRPAuthBadDataFromServer): +            self.auth_backend._extract_data({}) + +        with self.assertRaises(srpauth.SRPAuthBadDataFromServer): +            self.auth_backend._extract_data({"M2": ""}) + +    def test_extract_data_sets_uidtoken(self): +        test_uid = "someuid" +        test_m2 = "somem2" +        test_token = "sometoken" +        test_data = { +            "M2": test_m2, +            "id": test_uid, +            "token": test_token +        } +        m2 = self.auth_backend._extract_data(test_data) + +        self.assertEqual(m2, test_m2) +        self.assertEqual(self.auth_backend.get_uid(), test_uid) +        self.assertEqual(self.auth_backend.get_uid(), +                         self.auth.get_uid()) +        self.assertEqual(self.auth_backend.get_token(), test_token) +        self.assertEqual(self.auth_backend.get_token(), +                         self.auth.get_token()) + +    def _prepare_verify_session(self): +        """ +        Prepares the tests for verify session with needed steps +        before. It adds up to the extract_data step. + +        :returns: The defer to chain to +        :rtype: defer.Deferred +        """ +        d = self._prepare_auth_challenge() + +        def wrapper_proc_challenge(salt_B): +            return self.auth_backend._process_challenge( +                salt_B, +                username=self.TEST_USER) + +        def wrapper_extract_data(data): +            return self.auth_backend._extract_data(data) + +        d.addCallback(partial(threads.deferToThread, wrapper_proc_challenge)) +        d.addCallback(partial(threads.deferToThread, wrapper_extract_data)) + +        return d + +    @deferred() +    def test_verify_session_unhexlifiable_m2(self): +        d = self._prepare_verify_session() + +        def wrapper(M2): +            with self.assertRaises(srpauth.SRPAuthBadDataFromServer): +                self.auth_backend._verify_session("za")  # unhexlifiable value + +        d.addCallback(wrapper) + +        return d + +    @deferred() +    def test_verify_session_unverifiable_m2(self): +        d = self._prepare_verify_session() + +        def wrapper(M2): +            with self.assertRaises(srpauth.SRPAuthVerificationFailed): +                # Correctly unhelifiable value, but not for verifying the +                # session +                self.auth_backend._verify_session("abc12") + +        d.addCallback(wrapper) + +        return d + +    @deferred() +    def test_verify_session_fails_on_no_session_id(self): +        d = self._prepare_verify_session() + +        def wrapper(M2): +            self.auth_backend._session.cookies.get = MagicMock( +                return_value=None) +            with self.assertRaises(srpauth.SRPAuthNoSessionId): +                self.auth_backend._verify_session(M2) + +        d.addCallback(wrapper) + +        return d + +    @deferred() +    def test_verify_session_session_id(self): +        d = self._prepare_verify_session() + +        test_session_id = "12345" + +        def wrapper(M2): +            self.auth_backend._session.cookies.get = MagicMock( +                return_value=test_session_id) +            self.auth_backend._verify_session(M2) +            self.assertEqual(self.auth_backend.get_session_id(), +                             test_session_id) +            self.assertEqual(self.auth_backend.get_session_id(), +                             self.auth.get_session_id()) + +        d.addCallback(wrapper) + +        return d + +    @deferred() +    def test_verify_session(self): +        d = self._prepare_verify_session() + +        def wrapper(M2): +            self.auth_backend._verify_session(M2) + +        d.addCallback(wrapper) + +        return d + +    @deferred() +    def test_authenticate(self): +        self.auth_backend._authentication_preprocessing = MagicMock( +            return_value=None) +        self.auth_backend._start_authentication = MagicMock(return_value=None) +        self.auth_backend._process_challenge = MagicMock(return_value=None) +        self.auth_backend._extract_data = MagicMock(return_value=None) +        self.auth_backend._verify_session = MagicMock(return_value=None) + +        d = self.auth_backend.authenticate(self.TEST_USER, self.TEST_PASS) + +        def check(*args): +            self.auth_backend._authentication_preprocessing.\ +                assert_called_once_with( +                    username=self.TEST_USER, +                    password=self.TEST_PASS +                ) +            self.auth_backend._start_authentication.assert_called_once_with( +                None, +                username=self.TEST_USER) +            self.auth_backend._process_challenge.assert_called_once_with( +                None, +                username=self.TEST_USER) +            self.auth_backend._extract_data.assert_called_once_with( +                None, +                username=self.TEST_USER) +            self.auth_backend._verify_session.assert_called_once_with(None) + +        d.addCallback(check) + +        return d + +    @deferred() +    def test_logout_fails_if_not_logged_in(self): + +        def wrapper(*args): +            with self.assertRaises(AssertionError): +                self.auth_backend.logout() + +        d = threads.deferToThread(wrapper) +        return d + +    @deferred() +    def test_logout_traps_delete(self): +        self.auth_backend.get_session_id = MagicMock(return_value="1234") +        self.auth_backend._session.delete = MagicMock(side_effect=Exception()) + +        def wrapper(*args): +            self.auth_backend.logout() + +        d = threads.deferToThread(wrapper) +        return d + +    @deferred() +    def test_logout_clears(self): +        self.auth_backend._session_id = "1234" + +        def wrapper(*args): +            old_session = self.auth_backend._session +            self.auth_backend.logout() +            self.assertIsNone(self.auth_backend.get_session_id()) +            self.assertIsNone(self.auth_backend.get_uid()) +            self.assertNotEqual(old_session, self.auth_backend._session) + +        d = threads.deferToThread(wrapper) +        return d + + +class SRPAuthSingletonTestCase(unittest.TestCase): +    def setUp(self): +        self.old_auth = srpauth.SRPAuth._SRPAuth__impl.authenticate + +    def tearDown(self): +        srpauth.SRPAuth._SRPAuth__impl.authenticate = self.old_auth + +    def test_singleton(self): +        obj1 = srpauth.SRPAuth(ProviderConfig()) +        obj2 = srpauth.SRPAuth(ProviderConfig()) +        self.assertEqual(obj1._SRPAuth__instance, obj2._SRPAuth__instance) + +    @deferred() +    def test_authenticate_notifies_gui(self): +        auth = srpauth.SRPAuth(ProviderConfig()) +        auth._SRPAuth__instance.authenticate = MagicMock( +            return_value=threads.deferToThread(lambda: None)) +        auth._gui_notify = MagicMock() + +        d = auth.authenticate("", "") + +        def check(*args): +            auth._gui_notify.assert_called_once_with(None) + +        d.addCallback(check) +        return d + +    @deferred() +    def test_authenticate_errsback(self): +        auth = srpauth.SRPAuth(ProviderConfig()) +        auth._SRPAuth__instance.authenticate = MagicMock( +            return_value=threads.deferToThread(MagicMock( +                side_effect=Exception()))) +        auth._gui_notify = MagicMock() +        auth._errback = MagicMock() + +        d = auth.authenticate("", "") + +        def check(*args): +            self.assertFalse(auth._gui_notify.called) +            self.assertEqual(auth._errback.call_count, 1) + +        d.addCallback(check) +        return d + +    @deferred() +    def test_authenticate_runs_cleanly_when_raises(self): +        auth = srpauth.SRPAuth(ProviderConfig()) +        auth._SRPAuth__instance.authenticate = MagicMock( +            return_value=threads.deferToThread(MagicMock( +                side_effect=Exception()))) + +        d = auth.authenticate("", "") + +        return d + +    @deferred() +    def test_authenticate_runs_cleanly(self): +        auth = srpauth.SRPAuth(ProviderConfig()) +        auth._SRPAuth__instance.authenticate = MagicMock( +            return_value=threads.deferToThread(MagicMock())) + +        d = auth.authenticate("", "") + +        return d + +    def test_logout(self): +        auth = srpauth.SRPAuth(ProviderConfig()) +        auth._SRPAuth__instance.logout = MagicMock() + +        self.assertTrue(auth.logout()) + +    def test_logout_rets_false_when_raises(self): +        auth = srpauth.SRPAuth(ProviderConfig()) +        auth._SRPAuth__instance.logout = MagicMock( +            side_effect=Exception()) + +        self.assertFalse(auth.logout()) diff --git a/src/leap/crypto/tests/test_srpregister.py b/src/leap/crypto/tests/test_srpregister.py index 6d2b52e8..66b815f2 100644 --- a/src/leap/crypto/tests/test_srpregister.py +++ b/src/leap/crypto/tests/test_srpregister.py @@ -17,7 +17,6 @@  """  Tests for:      * leap/crypto/srpregister.py -    * leap/crypto/srpauth.py  """  try:      import unittest2 as unittest @@ -53,9 +52,9 @@ class ImproperlyConfiguredError(Exception):  class SRPTestCase(unittest.TestCase):      """ -    Tests for the SRP Register and Auth classes +    Tests for the SRPRegister class      """ -    __name__ = "SRPRegister and SRPAuth tests" +    __name__ = "SRPRegister tests"      @classmethod      def setUpClass(cls): diff --git a/src/leap/util/request_helpers.py b/src/leap/util/request_helpers.py index e06dabb8..350abfbd 100644 --- a/src/leap/util/request_helpers.py +++ b/src/leap/util/request_helpers.py @@ -41,7 +41,7 @@ def get_content(request):      contents = ""      mtime = None -    if request.json: +    if request.content and request.json:          if callable(request.json):              contents = json.dumps(request.json())          else: | 
