diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/leap/crypto/srpauth.py | 177 | ||||
-rw-r--r-- | src/leap/crypto/tests/test_srpauth.py | 755 | ||||
-rw-r--r-- | src/leap/crypto/tests/test_srpregister.py | 5 | ||||
-rw-r--r-- | src/leap/util/request_helpers.py | 2 |
4 files changed, 893 insertions, 46 deletions
diff --git a/src/leap/crypto/srpauth.py b/src/leap/crypto/srpauth.py index 89fee80b..8e228e79 100644 --- a/src/leap/crypto/srpauth.py +++ b/src/leap/crypto/srpauth.py @@ -30,8 +30,8 @@ from PySide import QtCore from twisted.internet import threads from leap.common.check import leap_assert -from leap.util.request_helpers import get_content from leap.util.constants import REQUEST_TIMEOUT +from leap.util import request_helpers as reqhelper from leap.common.events import signal as events_signal from leap.common.events import events_pb2 as proto @@ -45,6 +45,71 @@ class SRPAuthenticationError(Exception): pass +class SRPAuthConnectionError(SRPAuthenticationError): + """ + Exception raised when there's a connection error + """ + pass + +class SRPAuthUnknownUser(SRPAuthenticationError): + """ + Exception raised when trying to authenticate an unknown user + """ + pass + +class SRPAuthBadStatusCode(SRPAuthenticationError): + """ + Exception raised when we received an unknown bad status code + """ + pass + +class SRPAuthNoSalt(SRPAuthenticationError): + """ + Exception raised when we don't receive the salt param at a + specific point in the auth process + """ + pass + +class SRPAuthNoB(SRPAuthenticationError): + """ + Exception raised when we don't receive the B param at a specific + point in the auth process + """ + pass + +class SRPAuthBadDataFromServer(SRPAuthenticationError): + """ + Generic exception when we receive bad data from the server. + """ + pass + +class SRPAuthJSONDecodeError(SRPAuthenticationError): + """ + Exception raised when there's a problem decoding the JSON content + parsed as received from th e server. + """ + pass + +class SRPAuthBadPassword(SRPAuthenticationError): + """ + Exception raised when the user provided a bad password to auth. + """ + pass + +class SRPAuthVerificationFailed(SRPAuthenticationError): + """ + Exception raised when we can't verify the SRP data received from + the server. + """ + pass + +class SRPAuthNoSessionId(SRPAuthenticationError): + """ + Exception raised when we don't receive a session id from the + server. + """ + pass + class SRPAuth(QtCore.QObject): """ SRPAuth singleton @@ -126,19 +191,23 @@ class SRPAuth(QtCore.QObject): self._srp_a = A - def _start_authentication(self, _, username, password): + def _start_authentication(self, _, username): """ Sends the first request for authentication to retrieve the salt and B parameter - Might raise SRPAuthenticationError + Might raise all SRPAuthenticationError based: + SRPAuthenticationError + SRPAuthConnectionError + SRPAuthUnknownUser + SRPAuthBadStatusCode + SRPAuthNoSalt + SRPAuthNoB :param _: IGNORED, output from the previous callback (None) :type _: IGNORED :param username: username to login :type username: str - :param password: password for the username - :type password: str :return: salt and B parameters :rtype: tuple @@ -158,24 +227,29 @@ class SRPAuth(QtCore.QObject): verify=self._provider_config. get_ca_cert_path(), timeout=REQUEST_TIMEOUT) + # Clean up A value, we don't need it anymore + self._srp_a = None except requests.exceptions.ConnectionError as e: logger.error("No connection made (salt): %r" % (e,)) - raise SRPAuthenticationError("Could not establish a " + raise SRPAuthConnectionError("Could not establish a " "connection") except Exception as e: logger.error("Unknown error: %r" % (e,)) raise SRPAuthenticationError("Unknown error: %r" % (e,)) - content, mtime = get_content(init_session) + content, mtime = reqhelper.get_content(init_session) if init_session.status_code not in (200,): logger.error("No valid response (salt): " "Status code = %r. Content: %r" % (init_session.status_code, content)) if init_session.status_code == 422: - raise SRPAuthenticationError(self.tr("Unknown user")) + raise SRPAuthUnknownUser(self.tr("Unknown user")) + + raise SRPAuthBadStatusCode(self.tr("There was a problem with" + " authentication")) json_content = json.loads(content) salt = json_content.get("salt", None) @@ -183,12 +257,12 @@ class SRPAuth(QtCore.QObject): if salt is None: logger.error("No salt parameter sent") - raise SRPAuthenticationError(self.tr("The server did not send " - "the salt parameter")) + raise SRPAuthNoSalt(self.tr("The server did not send " + "the salt parameter")) if B is None: logger.error("No B parameter sent") - raise SRPAuthenticationError(self.tr("The server did not send " - "the B parameter")) + raise SRPAuthNoB(self.tr("The server did not send " + "the B parameter")) return salt, B @@ -197,7 +271,12 @@ class SRPAuth(QtCore.QObject): Given the salt and B processes the auth challenge and generates the M2 parameter - Might throw SRPAuthenticationError + Might raise SRPAuthenticationError based: + SRPAuthenticationError + SRPAuthBadDataFromServer + SRPAuthConnectionError + SRPAuthJSONDecodeError + SRPAuthBadPassword :param salt_B: salt and B parameters for the username :type salt_B: tuple @@ -212,10 +291,10 @@ class SRPAuth(QtCore.QObject): salt, B = salt_B unhex_salt = self._safe_unhexlify(salt) unhex_B = self._safe_unhexlify(B) - except TypeError as e: + except (TypeError, ValueError) as e: logger.error("Bad data from server: %r" % (e,)) - raise SRPAuthenticationError(self.tr("The data sent from " - "the server had errors")) + raise SRPAuthBadDataFromServer( + self.tr("The data sent from the server had errors")) M = self._srp_user.process_challenge(unhex_salt, unhex_B) auth_url = "%s/%s/%s/%s" % (self._provider_config.get_api_uri(), @@ -236,13 +315,13 @@ class SRPAuth(QtCore.QObject): timeout=REQUEST_TIMEOUT) except requests.exceptions.ConnectionError as e: logger.error("No connection made (HAMK): %r" % (e,)) - raise SRPAuthenticationError(self.tr("Could not connect to " + raise SRPAuthConnectionError(self.tr("Could not connect to " "the server")) try: - content, mtime = get_content(auth_result) + content, mtime = reqhelper.get_content(auth_result) except JSONDecodeError: - raise SRPAuthenticationError("Bad JSON content in auth result") + raise SRPAuthJSONDecodeError("Bad JSON content in auth result") if auth_result.status_code == 422: error = "" @@ -256,35 +335,47 @@ class SRPAuth(QtCore.QObject): "received: %s", (content,)) logger.error("[%s] Wrong password (HAMK): [%s]" % (auth_result.status_code, error)) - raise SRPAuthenticationError(self.tr("Wrong password")) + raise SRPAuthBadPassword(self.tr("Wrong password")) if auth_result.status_code not in (200,): logger.error("No valid response (HAMK): " "Status code = %s. Content = %r" % (auth_result.status_code, content)) - raise SRPAuthenticationError(self.tr("Unknown error (%s)") % - (auth_result.status_code,)) + raise SRPAuthBadStatusCode(self.tr("Unknown error (%s)") % + (auth_result.status_code,)) - json_content = json.loads(content) + return json.loads(content) + + def _extract_data(self, json_content): + """ + Extracts the necessary parameters from json_content (M2, + id, token) + + Might raise SRPAuthenticationError based: + SRPBadDataFromServer + :param json_content: Data received from the server + :type json_content: dict + """ try: M2 = json_content.get("M2", None) uid = json_content.get("id", None) token = json_content.get("token", None) except Exception as e: logger.error(e) - raise Exception("Something went wrong with the login") - - events_signal(proto.CLIENT_UID, content=uid) + raise SRPAuthBadDataFromServer("Something went wrong with the " + "login") self.set_uid(uid) self.set_token(token) if M2 is None or self.get_uid() is None: logger.error("Something went wrong. Content = %r" % - (content,)) - raise SRPAuthenticationError(self.tr("Problem getting data " - "from server")) + (json_content,)) + raise SRPAuthBadDataFromServer(self.tr("Problem getting data " + "from server")) + + events_signal(proto.CLIENT_UID, content=uid) return M2 @@ -294,7 +385,9 @@ class SRPAuth(QtCore.QObject): verification succeeds, it sets the session_id for this session - Might throw SRPAuthenticationError + Might raise SRPAuthenticationError based: + SRPAuthBadDataFromServer + SRPAuthVerificationFailed :param M2: M2 SRP parameter :type M2: str @@ -304,22 +397,22 @@ class SRPAuth(QtCore.QObject): unhex_M2 = self._safe_unhexlify(M2) except TypeError: logger.error("Bad data from server (HAWK)") - raise SRPAuthenticationError(self.tr("Bad data from server")) + raise SRPAuthBadDataFromServer(self.tr("Bad data from server")) self._srp_user.verify_session(unhex_M2) if not self._srp_user.authenticated(): logger.error("Auth verification failed") - raise SRPAuthenticationError(self.tr("Auth verification " - "failed")) + raise SRPAuthVerificationFailed(self.tr("Auth verification " + "failed")) logger.debug("Session verified.") session_id = self._session.cookies.get(self.SESSION_ID_KEY, None) if not session_id: logger.error("Bad cookie from server (missing _session_id)") - raise SRPAuthenticationError(self.tr("Session cookie " - "verification " - "failed")) + raise SRPAuthNoSessionId(self.tr("Session cookie " + "verification " + "failed")) events_signal(proto.CLIENT_SESSION_ID, content=session_id) @@ -351,12 +444,15 @@ class SRPAuth(QtCore.QObject): d.addCallback( partial(self._threader, self._start_authentication), - username=username, - password=password) + username=username) d.addCallback( partial(self._threader, self._process_challenge), username=username) + d.addCallback( + partial(self._threader, + self._extract_data), + username=username) d.addCallback(partial(self._threader, self._verify_session)) @@ -435,14 +531,11 @@ class SRPAuth(QtCore.QObject): # Store instance reference as the only member in the handle self.__dict__['_SRPAuth__instance'] = SRPAuth.__instance - self._username = None - self._password = None - def authenticate(self, username, password): """ Executes the whole authentication process for a user - Might raise SRPAuthenticationError + Might raise SRPAuthenticationError based :param username: username for this session :type username: str diff --git a/src/leap/crypto/tests/test_srpauth.py b/src/leap/crypto/tests/test_srpauth.py new file mode 100644 index 00000000..9a684a8f --- /dev/null +++ b/src/leap/crypto/tests/test_srpauth.py @@ -0,0 +1,755 @@ +# -*- coding: utf-8 -*- +# test_srpauth.py +# Copyright (C) 2013 LEAP +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. +""" +Tests for: + * leap/crypto/srpauth.py +""" +try: + import unittest2 as unittest +except ImportError: + import unittest +import os +import sys +import binascii +import requests +import mock + +from mock import MagicMock +from nose.twistedtools import reactor, deferred +from twisted.python import log +from twisted.internet import threads +from functools import partial +from requests.models import Response +from simplejson.decoder import JSONDecodeError + +from leap.common.testing.https_server import where +from leap.config.providerconfig import ProviderConfig +from leap.crypto import srpregister, srpauth +from leap.crypto.tests import fake_provider + +log.startLogging(sys.stdout) + + +def _get_capath(): + return where("cacert.pem") + +_here = os.path.split(__file__)[0] + + +class ImproperlyConfiguredError(Exception): + """ + Raised if the test provider is missing configuration + """ + + +class SRPAuthTestCase(unittest.TestCase): + """ + Tests for the SRPAuth class + """ + __name__ = "SRPAuth tests" + + def setUp(self): + """ + Sets up this TestCase with a simple and faked provider instance: + + * runs a threaded reactor + * loads a mocked ProviderConfig that points to the certs in the + leap.common.testing module. + """ + factory = fake_provider.get_provider_factory() + http = reactor.listenTCP(0, factory) + https = reactor.listenSSL( + 0, factory, + fake_provider.OpenSSLServerContextFactory()) + get_port = lambda p: p.getHost().port + self.http_port = get_port(http) + self.https_port = get_port(https) + + provider = ProviderConfig() + provider.get_ca_cert_path = MagicMock() + provider.get_ca_cert_path.return_value = _get_capath() + + provider.get_api_uri = MagicMock() + provider.get_api_uri.return_value = self._get_https_uri() + + loaded = provider.load(path=os.path.join( + _here, "test_provider.json")) + if not loaded: + raise ImproperlyConfiguredError( + "Could not load test provider config") + self.register = srpregister.SRPRegister(provider_config=provider) + self.provider = provider + self.TEST_USER = "register_test_auth" + self.TEST_PASS = "pass" + + # Reset the singleton + srpauth.SRPAuth._SRPAuth__instance = None + self.auth = srpauth.SRPAuth(self.provider) + self.auth_backend = self.auth._SRPAuth__instance + + self.old_post = self.auth_backend._session.post + self.old_put = self.auth_backend._session.put + self.old_delete = self.auth_backend._session.delete + + self.old_start_auth = self.auth_backend._start_authentication + self.old_proc_challenge = self.auth_backend._process_challenge + self.old_extract_data = self.auth_backend._extract_data + self.old_verify_session = self.auth_backend._verify_session + self.old_auth_preproc = self.auth_backend._authentication_preprocessing + self.old_get_sid = self.auth_backend.get_session_id + self.old_cookie_get = self.auth_backend._session.cookies.get + self.old_auth = self.auth_backend.authenticate + + def tearDown(self): + self.auth_backend._session.post = self.old_post + self.auth_backend._session.put = self.old_put + self.auth_backend._session.delete = self.old_delete + + self.auth_backend._start_authentication = self.old_start_auth + self.auth_backend._process_challenge = self.old_proc_challenge + self.auth_backend._extract_data = self.old_extract_data + self.auth_backend._verify_session = self.old_verify_session + self.auth_backend._authentication_preprocessing = self.old_auth_preproc + self.auth_backend.get_session_id = self.old_get_sid + self.auth_backend._session.cookies.get = self.old_cookie_get + self.auth_backend.authenticate = self.old_auth + + # helper methods + + def _get_https_uri(self): + """ + Returns a https uri with the right https port initialized + """ + return "https://localhost:%s" % (self.https_port,) + + # Auth tests + + def _prepare_auth_test(self, code=200, side_effect=None): + """ + Creates the needed defers to test several test situations. It + adds up to the auth preprocessing step. + + :param code: status code for the response of POST in requests + :type code: int + :param side_effect: side effect triggered by the POST method + in requests + :type side_effect: some kind of Exception + + :returns: the defer that is created + :rtype: defer.Deferred + """ + res = Response() + res.status_code = code + self.auth_backend._session.post = MagicMock(return_value=res, + side_effect=side_effect) + + d = threads.deferToThread(self.register.register_user, + self.TEST_USER, + self.TEST_PASS) + + def wrapper_preproc(*args): + return threads.deferToThread( + self.auth_backend._authentication_preprocessing, + self.TEST_USER, self.TEST_PASS) + + d.addCallback(wrapper_preproc) + + return d + + def test_safe_unhexlify(self): + input_value = "somestring" + test_value = binascii.hexlify(input_value) + self.assertEqual( + self.auth_backend._safe_unhexlify(test_value), + input_value) + + def test_safe_unhexlify_not_raises(self): + input_value = "somestring" + test_value = binascii.hexlify(input_value)[:-1] + + with self.assertRaises(TypeError): + binascii.unhexlify(test_value) + + self.auth_backend._safe_unhexlify(test_value) + + def test_preprocessing_loads_a(self): + self.assertEqual(self.auth_backend._srp_a, None) + self.auth_backend._authentication_preprocessing("user", "pass") + self.assertIsNotNone(self.auth_backend._srp_a) + self.assertTrue(len(self.auth_backend._srp_a) > 0) + + @deferred() + def test_start_authentication(self): + d = threads.deferToThread(self.register.register_user, self.TEST_USER, + self.TEST_PASS) + + def wrapper_preproc(*args): + return threads.deferToThread( + self.auth_backend._authentication_preprocessing, + self.TEST_USER, self.TEST_PASS) + + d.addCallback(wrapper_preproc) + + def wrapper(_): + return threads.deferToThread( + self.auth_backend._start_authentication, + None, self.TEST_USER) + + d.addCallback(wrapper) + return d + + @deferred() + def test_start_authentication_fails_connerror(self): + d = self._prepare_auth_test( + side_effect=requests.exceptions.ConnectionError()) + + def wrapper(_): + with self.assertRaises(srpauth.SRPAuthConnectionError): + self.auth_backend._start_authentication(None, self.TEST_USER) + + d.addCallback(partial(threads.deferToThread, wrapper)) + return d + + @deferred() + def test_start_authentication_fails_any_error(self): + d = self._prepare_auth_test(side_effect=Exception()) + + def wrapper(_): + with self.assertRaises(srpauth.SRPAuthenticationError): + self.auth_backend._start_authentication(None, self.TEST_USER) + + d.addCallback(partial(threads.deferToThread, wrapper)) + return d + + @deferred() + def test_start_authentication_fails_unknown_user(self): + d = self._prepare_auth_test(422) + + def wrapper(_): + with self.assertRaises(srpauth.SRPAuthUnknownUser): + with mock.patch('leap.util.request_helpers.get_content', + new_callable=MagicMock()) as \ + content: + content.return_value = ("{}", 0) + + self.auth_backend._start_authentication( + None, self.TEST_USER) + + d.addCallback(partial(threads.deferToThread, wrapper)) + return d + + @deferred() + def test_start_authentication_fails_errorcode(self): + d = self._prepare_auth_test(302) + + def wrapper(_): + with self.assertRaises(srpauth.SRPAuthBadStatusCode): + with mock.patch('leap.util.request_helpers.get_content', + new_callable=MagicMock()) as \ + content: + content.return_value = ("{}", 0) + + self.auth_backend._start_authentication(None, + self.TEST_USER) + + d.addCallback(partial(threads.deferToThread, wrapper)) + return d + + @deferred() + def test_start_authentication_fails_no_salt(self): + d = self._prepare_auth_test(200) + + def wrapper(_): + with self.assertRaises(srpauth.SRPAuthNoSalt): + with mock.patch('leap.util.request_helpers.get_content', + new_callable=MagicMock()) as \ + content: + content.return_value = ("{}", 0) + + self.auth_backend._start_authentication(None, + self.TEST_USER) + + d.addCallback(partial(threads.deferToThread, wrapper)) + return d + + @deferred() + def test_start_authentication_fails_no_B(self): + d = self._prepare_auth_test(200) + + def wrapper(_): + with self.assertRaises(srpauth.SRPAuthNoB): + with mock.patch('leap.util.request_helpers.get_content', + new_callable=MagicMock()) as \ + content: + content.return_value = ('{"salt": ""}', 0) + + self.auth_backend._start_authentication(None, + self.TEST_USER) + + d.addCallback(partial(threads.deferToThread, wrapper)) + return d + + @deferred() + def test_start_authentication_correct_saltb(self): + d = self._prepare_auth_test(200) + + test_salt = "12345" + test_B = "67890" + + def wrapper(_): + with mock.patch('leap.util.request_helpers.get_content', + new_callable=MagicMock()) as \ + content: + content.return_value = ('{"salt":"%s", "B":"%s"}' % (test_salt, + test_B), + 0) + + salt, B = self.auth_backend._start_authentication( + None, + self.TEST_USER) + self.assertEqual(salt, test_salt) + self.assertEqual(B, test_B) + + d.addCallback(partial(threads.deferToThread, wrapper)) + return d + + def _prepare_auth_challenge(self): + """ + Creates the needed defers to test several test situations. It + adds up to the start authentication step. + + :returns: the defer that is created + :rtype: defer.Deferred + """ + d = threads.deferToThread(self.register.register_user, + self.TEST_USER, + self.TEST_PASS) + + def wrapper_preproc(*args): + return threads.deferToThread( + self.auth_backend._authentication_preprocessing, + self.TEST_USER, self.TEST_PASS) + + d.addCallback(wrapper_preproc) + + def wrapper_start(*args): + return threads.deferToThread( + self.auth_backend._start_authentication, + None, self.TEST_USER) + + d.addCallback(wrapper_start) + + return d + + @deferred() + def test_process_challenge_wrong_saltb(self): + d = self._prepare_auth_challenge() + + def wrapper(salt_B): + with self.assertRaises(srpauth.SRPAuthBadDataFromServer): + self.auth_backend._process_challenge("", + username=self.TEST_USER) + + d.addCallback(partial(threads.deferToThread, wrapper)) + return d + + @deferred() + def test_process_challenge_requests_problem_raises(self): + d = self._prepare_auth_challenge() + + self.auth_backend._session.put = MagicMock( + side_effect=requests.exceptions.ConnectionError()) + + def wrapper(salt_B): + with self.assertRaises(srpauth.SRPAuthConnectionError): + self.auth_backend._process_challenge(salt_B, + username=self.TEST_USER) + + d.addCallback(partial(threads.deferToThread, wrapper)) + + return d + + @deferred() + def test_process_challenge_json_decode_error(self): + d = self._prepare_auth_challenge() + + def wrapper(salt_B): + with mock.patch('leap.util.request_helpers.get_content', + new_callable=MagicMock()) as \ + content: + content.return_value = ("{", 0) + content.side_effect = JSONDecodeError("", "", 0) + + with self.assertRaises(srpauth.SRPAuthJSONDecodeError): + self.auth_backend._process_challenge( + salt_B, + username=self.TEST_USER) + + d.addCallback(partial(threads.deferToThread, wrapper)) + + return d + + @deferred() + def test_process_challenge_bad_password(self): + d = self._prepare_auth_challenge() + + res = Response() + res.status_code = 422 + self.auth_backend._session.put = MagicMock(return_value=res) + + def wrapper(salt_B): + with mock.patch('leap.util.request_helpers.get_content', + new_callable=MagicMock()) as \ + content: + content.return_value = ("", 0) + with self.assertRaises(srpauth.SRPAuthBadPassword): + self.auth_backend._process_challenge( + salt_B, + username=self.TEST_USER) + + d.addCallback(partial(threads.deferToThread, wrapper)) + + return d + + @deferred() + def test_process_challenge_bad_password2(self): + d = self._prepare_auth_challenge() + + res = Response() + res.status_code = 422 + self.auth_backend._session.put = MagicMock(return_value=res) + + def wrapper(salt_B): + with mock.patch('leap.util.request_helpers.get_content', + new_callable=MagicMock()) as \ + content: + content.return_value = ("[]", 0) + with self.assertRaises(srpauth.SRPAuthBadPassword): + self.auth_backend._process_challenge( + salt_B, + username=self.TEST_USER) + + d.addCallback(partial(threads.deferToThread, wrapper)) + + return d + + @deferred() + def test_process_challenge_other_error_code(self): + d = self._prepare_auth_challenge() + + res = Response() + res.status_code = 300 + self.auth_backend._session.put = MagicMock(return_value=res) + + def wrapper(salt_B): + with mock.patch('leap.util.request_helpers.get_content', + new_callable=MagicMock()) as \ + content: + content.return_value = ("{}", 0) + with self.assertRaises(srpauth.SRPAuthBadStatusCode): + self.auth_backend._process_challenge( + salt_B, + username=self.TEST_USER) + + d.addCallback(partial(threads.deferToThread, wrapper)) + + return d + + @deferred() + def test_process_challenge(self): + d = self._prepare_auth_challenge() + + def wrapper(salt_B): + self.auth_backend._process_challenge(salt_B, + username=self.TEST_USER) + + d.addCallback(partial(threads.deferToThread, wrapper)) + + return d + + def test_extract_data_wrong_data(self): + with self.assertRaises(srpauth.SRPAuthBadDataFromServer): + self.auth_backend._extract_data(None) + + with self.assertRaises(srpauth.SRPAuthBadDataFromServer): + self.auth_backend._extract_data("") + + def test_extract_data_fails_on_wrong_data_from_server(self): + with self.assertRaises(srpauth.SRPAuthBadDataFromServer): + self.auth_backend._extract_data({}) + + with self.assertRaises(srpauth.SRPAuthBadDataFromServer): + self.auth_backend._extract_data({"M2": ""}) + + def test_extract_data_sets_uidtoken(self): + test_uid = "someuid" + test_m2 = "somem2" + test_token = "sometoken" + test_data = { + "M2": test_m2, + "id": test_uid, + "token": test_token + } + m2 = self.auth_backend._extract_data(test_data) + + self.assertEqual(m2, test_m2) + self.assertEqual(self.auth_backend.get_uid(), test_uid) + self.assertEqual(self.auth_backend.get_uid(), + self.auth.get_uid()) + self.assertEqual(self.auth_backend.get_token(), test_token) + self.assertEqual(self.auth_backend.get_token(), + self.auth.get_token()) + + def _prepare_verify_session(self): + """ + Prepares the tests for verify session with needed steps + before. It adds up to the extract_data step. + + :returns: The defer to chain to + :rtype: defer.Deferred + """ + d = self._prepare_auth_challenge() + + def wrapper_proc_challenge(salt_B): + return self.auth_backend._process_challenge( + salt_B, + username=self.TEST_USER) + + def wrapper_extract_data(data): + return self.auth_backend._extract_data(data) + + d.addCallback(partial(threads.deferToThread, wrapper_proc_challenge)) + d.addCallback(partial(threads.deferToThread, wrapper_extract_data)) + + return d + + @deferred() + def test_verify_session_unhexlifiable_m2(self): + d = self._prepare_verify_session() + + def wrapper(M2): + with self.assertRaises(srpauth.SRPAuthBadDataFromServer): + self.auth_backend._verify_session("za") # unhexlifiable value + + d.addCallback(wrapper) + + return d + + @deferred() + def test_verify_session_unverifiable_m2(self): + d = self._prepare_verify_session() + + def wrapper(M2): + with self.assertRaises(srpauth.SRPAuthVerificationFailed): + # Correctly unhelifiable value, but not for verifying the + # session + self.auth_backend._verify_session("abc12") + + d.addCallback(wrapper) + + return d + + @deferred() + def test_verify_session_fails_on_no_session_id(self): + d = self._prepare_verify_session() + + def wrapper(M2): + self.auth_backend._session.cookies.get = MagicMock( + return_value=None) + with self.assertRaises(srpauth.SRPAuthNoSessionId): + self.auth_backend._verify_session(M2) + + d.addCallback(wrapper) + + return d + + @deferred() + def test_verify_session_session_id(self): + d = self._prepare_verify_session() + + test_session_id = "12345" + + def wrapper(M2): + self.auth_backend._session.cookies.get = MagicMock( + return_value=test_session_id) + self.auth_backend._verify_session(M2) + self.assertEqual(self.auth_backend.get_session_id(), + test_session_id) + self.assertEqual(self.auth_backend.get_session_id(), + self.auth.get_session_id()) + + d.addCallback(wrapper) + + return d + + @deferred() + def test_verify_session(self): + d = self._prepare_verify_session() + + def wrapper(M2): + self.auth_backend._verify_session(M2) + + d.addCallback(wrapper) + + return d + + @deferred() + def test_authenticate(self): + self.auth_backend._authentication_preprocessing = MagicMock( + return_value=None) + self.auth_backend._start_authentication = MagicMock(return_value=None) + self.auth_backend._process_challenge = MagicMock(return_value=None) + self.auth_backend._extract_data = MagicMock(return_value=None) + self.auth_backend._verify_session = MagicMock(return_value=None) + + d = self.auth_backend.authenticate(self.TEST_USER, self.TEST_PASS) + + def check(*args): + self.auth_backend._authentication_preprocessing.\ + assert_called_once_with( + username=self.TEST_USER, + password=self.TEST_PASS + ) + self.auth_backend._start_authentication.assert_called_once_with( + None, + username=self.TEST_USER) + self.auth_backend._process_challenge.assert_called_once_with( + None, + username=self.TEST_USER) + self.auth_backend._extract_data.assert_called_once_with( + None, + username=self.TEST_USER) + self.auth_backend._verify_session.assert_called_once_with(None) + + d.addCallback(check) + + return d + + @deferred() + def test_logout_fails_if_not_logged_in(self): + + def wrapper(*args): + with self.assertRaises(AssertionError): + self.auth_backend.logout() + + d = threads.deferToThread(wrapper) + return d + + @deferred() + def test_logout_traps_delete(self): + self.auth_backend.get_session_id = MagicMock(return_value="1234") + self.auth_backend._session.delete = MagicMock(side_effect=Exception()) + + def wrapper(*args): + self.auth_backend.logout() + + d = threads.deferToThread(wrapper) + return d + + @deferred() + def test_logout_clears(self): + self.auth_backend._session_id = "1234" + + def wrapper(*args): + old_session = self.auth_backend._session + self.auth_backend.logout() + self.assertIsNone(self.auth_backend.get_session_id()) + self.assertIsNone(self.auth_backend.get_uid()) + self.assertNotEqual(old_session, self.auth_backend._session) + + d = threads.deferToThread(wrapper) + return d + + +class SRPAuthSingletonTestCase(unittest.TestCase): + def setUp(self): + self.old_auth = srpauth.SRPAuth._SRPAuth__impl.authenticate + + def tearDown(self): + srpauth.SRPAuth._SRPAuth__impl.authenticate = self.old_auth + + def test_singleton(self): + obj1 = srpauth.SRPAuth(ProviderConfig()) + obj2 = srpauth.SRPAuth(ProviderConfig()) + self.assertEqual(obj1._SRPAuth__instance, obj2._SRPAuth__instance) + + @deferred() + def test_authenticate_notifies_gui(self): + auth = srpauth.SRPAuth(ProviderConfig()) + auth._SRPAuth__instance.authenticate = MagicMock( + return_value=threads.deferToThread(lambda: None)) + auth._gui_notify = MagicMock() + + d = auth.authenticate("", "") + + def check(*args): + auth._gui_notify.assert_called_once_with(None) + + d.addCallback(check) + return d + + @deferred() + def test_authenticate_errsback(self): + auth = srpauth.SRPAuth(ProviderConfig()) + auth._SRPAuth__instance.authenticate = MagicMock( + return_value=threads.deferToThread(MagicMock( + side_effect=Exception()))) + auth._gui_notify = MagicMock() + auth._errback = MagicMock() + + d = auth.authenticate("", "") + + def check(*args): + self.assertFalse(auth._gui_notify.called) + self.assertEqual(auth._errback.call_count, 1) + + d.addCallback(check) + return d + + @deferred() + def test_authenticate_runs_cleanly_when_raises(self): + auth = srpauth.SRPAuth(ProviderConfig()) + auth._SRPAuth__instance.authenticate = MagicMock( + return_value=threads.deferToThread(MagicMock( + side_effect=Exception()))) + + d = auth.authenticate("", "") + + return d + + @deferred() + def test_authenticate_runs_cleanly(self): + auth = srpauth.SRPAuth(ProviderConfig()) + auth._SRPAuth__instance.authenticate = MagicMock( + return_value=threads.deferToThread(MagicMock())) + + d = auth.authenticate("", "") + + return d + + def test_logout(self): + auth = srpauth.SRPAuth(ProviderConfig()) + auth._SRPAuth__instance.logout = MagicMock() + + self.assertTrue(auth.logout()) + + def test_logout_rets_false_when_raises(self): + auth = srpauth.SRPAuth(ProviderConfig()) + auth._SRPAuth__instance.logout = MagicMock( + side_effect=Exception()) + + self.assertFalse(auth.logout()) diff --git a/src/leap/crypto/tests/test_srpregister.py b/src/leap/crypto/tests/test_srpregister.py index 6d2b52e8..66b815f2 100644 --- a/src/leap/crypto/tests/test_srpregister.py +++ b/src/leap/crypto/tests/test_srpregister.py @@ -17,7 +17,6 @@ """ Tests for: * leap/crypto/srpregister.py - * leap/crypto/srpauth.py """ try: import unittest2 as unittest @@ -53,9 +52,9 @@ class ImproperlyConfiguredError(Exception): class SRPTestCase(unittest.TestCase): """ - Tests for the SRP Register and Auth classes + Tests for the SRPRegister class """ - __name__ = "SRPRegister and SRPAuth tests" + __name__ = "SRPRegister tests" @classmethod def setUpClass(cls): diff --git a/src/leap/util/request_helpers.py b/src/leap/util/request_helpers.py index e06dabb8..350abfbd 100644 --- a/src/leap/util/request_helpers.py +++ b/src/leap/util/request_helpers.py @@ -41,7 +41,7 @@ def get_content(request): contents = "" mtime = None - if request.json: + if request.content and request.json: if callable(request.json): contents = json.dumps(request.json()) else: |