summaryrefslogtreecommitdiff
path: root/src/leap
diff options
context:
space:
mode:
Diffstat (limited to 'src/leap')
-rw-r--r--src/leap/crypto/srpauth.py177
-rw-r--r--src/leap/crypto/tests/test_srpauth.py755
-rw-r--r--src/leap/crypto/tests/test_srpregister.py5
-rw-r--r--src/leap/util/request_helpers.py2
4 files changed, 893 insertions, 46 deletions
diff --git a/src/leap/crypto/srpauth.py b/src/leap/crypto/srpauth.py
index 89fee80b..8e228e79 100644
--- a/src/leap/crypto/srpauth.py
+++ b/src/leap/crypto/srpauth.py
@@ -30,8 +30,8 @@ from PySide import QtCore
from twisted.internet import threads
from leap.common.check import leap_assert
-from leap.util.request_helpers import get_content
from leap.util.constants import REQUEST_TIMEOUT
+from leap.util import request_helpers as reqhelper
from leap.common.events import signal as events_signal
from leap.common.events import events_pb2 as proto
@@ -45,6 +45,71 @@ class SRPAuthenticationError(Exception):
pass
+class SRPAuthConnectionError(SRPAuthenticationError):
+ """
+ Exception raised when there's a connection error
+ """
+ pass
+
+class SRPAuthUnknownUser(SRPAuthenticationError):
+ """
+ Exception raised when trying to authenticate an unknown user
+ """
+ pass
+
+class SRPAuthBadStatusCode(SRPAuthenticationError):
+ """
+ Exception raised when we received an unknown bad status code
+ """
+ pass
+
+class SRPAuthNoSalt(SRPAuthenticationError):
+ """
+ Exception raised when we don't receive the salt param at a
+ specific point in the auth process
+ """
+ pass
+
+class SRPAuthNoB(SRPAuthenticationError):
+ """
+ Exception raised when we don't receive the B param at a specific
+ point in the auth process
+ """
+ pass
+
+class SRPAuthBadDataFromServer(SRPAuthenticationError):
+ """
+ Generic exception when we receive bad data from the server.
+ """
+ pass
+
+class SRPAuthJSONDecodeError(SRPAuthenticationError):
+ """
+ Exception raised when there's a problem decoding the JSON content
+ parsed as received from th e server.
+ """
+ pass
+
+class SRPAuthBadPassword(SRPAuthenticationError):
+ """
+ Exception raised when the user provided a bad password to auth.
+ """
+ pass
+
+class SRPAuthVerificationFailed(SRPAuthenticationError):
+ """
+ Exception raised when we can't verify the SRP data received from
+ the server.
+ """
+ pass
+
+class SRPAuthNoSessionId(SRPAuthenticationError):
+ """
+ Exception raised when we don't receive a session id from the
+ server.
+ """
+ pass
+
class SRPAuth(QtCore.QObject):
"""
SRPAuth singleton
@@ -126,19 +191,23 @@ class SRPAuth(QtCore.QObject):
self._srp_a = A
- def _start_authentication(self, _, username, password):
+ def _start_authentication(self, _, username):
"""
Sends the first request for authentication to retrieve the
salt and B parameter
- Might raise SRPAuthenticationError
+ Might raise all SRPAuthenticationError based:
+ SRPAuthenticationError
+ SRPAuthConnectionError
+ SRPAuthUnknownUser
+ SRPAuthBadStatusCode
+ SRPAuthNoSalt
+ SRPAuthNoB
:param _: IGNORED, output from the previous callback (None)
:type _: IGNORED
:param username: username to login
:type username: str
- :param password: password for the username
- :type password: str
:return: salt and B parameters
:rtype: tuple
@@ -158,24 +227,29 @@ class SRPAuth(QtCore.QObject):
verify=self._provider_config.
get_ca_cert_path(),
timeout=REQUEST_TIMEOUT)
+ # Clean up A value, we don't need it anymore
+ self._srp_a = None
except requests.exceptions.ConnectionError as e:
logger.error("No connection made (salt): %r" %
(e,))
- raise SRPAuthenticationError("Could not establish a "
+ raise SRPAuthConnectionError("Could not establish a "
"connection")
except Exception as e:
logger.error("Unknown error: %r" % (e,))
raise SRPAuthenticationError("Unknown error: %r" %
(e,))
- content, mtime = get_content(init_session)
+ content, mtime = reqhelper.get_content(init_session)
if init_session.status_code not in (200,):
logger.error("No valid response (salt): "
"Status code = %r. Content: %r" %
(init_session.status_code, content))
if init_session.status_code == 422:
- raise SRPAuthenticationError(self.tr("Unknown user"))
+ raise SRPAuthUnknownUser(self.tr("Unknown user"))
+
+ raise SRPAuthBadStatusCode(self.tr("There was a problem with"
+ " authentication"))
json_content = json.loads(content)
salt = json_content.get("salt", None)
@@ -183,12 +257,12 @@ class SRPAuth(QtCore.QObject):
if salt is None:
logger.error("No salt parameter sent")
- raise SRPAuthenticationError(self.tr("The server did not send "
- "the salt parameter"))
+ raise SRPAuthNoSalt(self.tr("The server did not send "
+ "the salt parameter"))
if B is None:
logger.error("No B parameter sent")
- raise SRPAuthenticationError(self.tr("The server did not send "
- "the B parameter"))
+ raise SRPAuthNoB(self.tr("The server did not send "
+ "the B parameter"))
return salt, B
@@ -197,7 +271,12 @@ class SRPAuth(QtCore.QObject):
Given the salt and B processes the auth challenge and
generates the M2 parameter
- Might throw SRPAuthenticationError
+ Might raise SRPAuthenticationError based:
+ SRPAuthenticationError
+ SRPAuthBadDataFromServer
+ SRPAuthConnectionError
+ SRPAuthJSONDecodeError
+ SRPAuthBadPassword
:param salt_B: salt and B parameters for the username
:type salt_B: tuple
@@ -212,10 +291,10 @@ class SRPAuth(QtCore.QObject):
salt, B = salt_B
unhex_salt = self._safe_unhexlify(salt)
unhex_B = self._safe_unhexlify(B)
- except TypeError as e:
+ except (TypeError, ValueError) as e:
logger.error("Bad data from server: %r" % (e,))
- raise SRPAuthenticationError(self.tr("The data sent from "
- "the server had errors"))
+ raise SRPAuthBadDataFromServer(
+ self.tr("The data sent from the server had errors"))
M = self._srp_user.process_challenge(unhex_salt, unhex_B)
auth_url = "%s/%s/%s/%s" % (self._provider_config.get_api_uri(),
@@ -236,13 +315,13 @@ class SRPAuth(QtCore.QObject):
timeout=REQUEST_TIMEOUT)
except requests.exceptions.ConnectionError as e:
logger.error("No connection made (HAMK): %r" % (e,))
- raise SRPAuthenticationError(self.tr("Could not connect to "
+ raise SRPAuthConnectionError(self.tr("Could not connect to "
"the server"))
try:
- content, mtime = get_content(auth_result)
+ content, mtime = reqhelper.get_content(auth_result)
except JSONDecodeError:
- raise SRPAuthenticationError("Bad JSON content in auth result")
+ raise SRPAuthJSONDecodeError("Bad JSON content in auth result")
if auth_result.status_code == 422:
error = ""
@@ -256,35 +335,47 @@ class SRPAuth(QtCore.QObject):
"received: %s", (content,))
logger.error("[%s] Wrong password (HAMK): [%s]" %
(auth_result.status_code, error))
- raise SRPAuthenticationError(self.tr("Wrong password"))
+ raise SRPAuthBadPassword(self.tr("Wrong password"))
if auth_result.status_code not in (200,):
logger.error("No valid response (HAMK): "
"Status code = %s. Content = %r" %
(auth_result.status_code, content))
- raise SRPAuthenticationError(self.tr("Unknown error (%s)") %
- (auth_result.status_code,))
+ raise SRPAuthBadStatusCode(self.tr("Unknown error (%s)") %
+ (auth_result.status_code,))
- json_content = json.loads(content)
+ return json.loads(content)
+
+ def _extract_data(self, json_content):
+ """
+ Extracts the necessary parameters from json_content (M2,
+ id, token)
+
+ Might raise SRPAuthenticationError based:
+ SRPBadDataFromServer
+ :param json_content: Data received from the server
+ :type json_content: dict
+ """
try:
M2 = json_content.get("M2", None)
uid = json_content.get("id", None)
token = json_content.get("token", None)
except Exception as e:
logger.error(e)
- raise Exception("Something went wrong with the login")
-
- events_signal(proto.CLIENT_UID, content=uid)
+ raise SRPAuthBadDataFromServer("Something went wrong with the "
+ "login")
self.set_uid(uid)
self.set_token(token)
if M2 is None or self.get_uid() is None:
logger.error("Something went wrong. Content = %r" %
- (content,))
- raise SRPAuthenticationError(self.tr("Problem getting data "
- "from server"))
+ (json_content,))
+ raise SRPAuthBadDataFromServer(self.tr("Problem getting data "
+ "from server"))
+
+ events_signal(proto.CLIENT_UID, content=uid)
return M2
@@ -294,7 +385,9 @@ class SRPAuth(QtCore.QObject):
verification succeeds, it sets the session_id for this
session
- Might throw SRPAuthenticationError
+ Might raise SRPAuthenticationError based:
+ SRPAuthBadDataFromServer
+ SRPAuthVerificationFailed
:param M2: M2 SRP parameter
:type M2: str
@@ -304,22 +397,22 @@ class SRPAuth(QtCore.QObject):
unhex_M2 = self._safe_unhexlify(M2)
except TypeError:
logger.error("Bad data from server (HAWK)")
- raise SRPAuthenticationError(self.tr("Bad data from server"))
+ raise SRPAuthBadDataFromServer(self.tr("Bad data from server"))
self._srp_user.verify_session(unhex_M2)
if not self._srp_user.authenticated():
logger.error("Auth verification failed")
- raise SRPAuthenticationError(self.tr("Auth verification "
- "failed"))
+ raise SRPAuthVerificationFailed(self.tr("Auth verification "
+ "failed"))
logger.debug("Session verified.")
session_id = self._session.cookies.get(self.SESSION_ID_KEY, None)
if not session_id:
logger.error("Bad cookie from server (missing _session_id)")
- raise SRPAuthenticationError(self.tr("Session cookie "
- "verification "
- "failed"))
+ raise SRPAuthNoSessionId(self.tr("Session cookie "
+ "verification "
+ "failed"))
events_signal(proto.CLIENT_SESSION_ID, content=session_id)
@@ -351,12 +444,15 @@ class SRPAuth(QtCore.QObject):
d.addCallback(
partial(self._threader,
self._start_authentication),
- username=username,
- password=password)
+ username=username)
d.addCallback(
partial(self._threader,
self._process_challenge),
username=username)
+ d.addCallback(
+ partial(self._threader,
+ self._extract_data),
+ username=username)
d.addCallback(partial(self._threader,
self._verify_session))
@@ -435,14 +531,11 @@ class SRPAuth(QtCore.QObject):
# Store instance reference as the only member in the handle
self.__dict__['_SRPAuth__instance'] = SRPAuth.__instance
- self._username = None
- self._password = None
-
def authenticate(self, username, password):
"""
Executes the whole authentication process for a user
- Might raise SRPAuthenticationError
+ Might raise SRPAuthenticationError based
:param username: username for this session
:type username: str
diff --git a/src/leap/crypto/tests/test_srpauth.py b/src/leap/crypto/tests/test_srpauth.py
new file mode 100644
index 00000000..9a684a8f
--- /dev/null
+++ b/src/leap/crypto/tests/test_srpauth.py
@@ -0,0 +1,755 @@
+# -*- coding: utf-8 -*-
+# test_srpauth.py
+# Copyright (C) 2013 LEAP
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see <http://www.gnu.org/licenses/>.
+"""
+Tests for:
+ * leap/crypto/srpauth.py
+"""
+try:
+ import unittest2 as unittest
+except ImportError:
+ import unittest
+import os
+import sys
+import binascii
+import requests
+import mock
+
+from mock import MagicMock
+from nose.twistedtools import reactor, deferred
+from twisted.python import log
+from twisted.internet import threads
+from functools import partial
+from requests.models import Response
+from simplejson.decoder import JSONDecodeError
+
+from leap.common.testing.https_server import where
+from leap.config.providerconfig import ProviderConfig
+from leap.crypto import srpregister, srpauth
+from leap.crypto.tests import fake_provider
+
+log.startLogging(sys.stdout)
+
+
+def _get_capath():
+ return where("cacert.pem")
+
+_here = os.path.split(__file__)[0]
+
+
+class ImproperlyConfiguredError(Exception):
+ """
+ Raised if the test provider is missing configuration
+ """
+
+
+class SRPAuthTestCase(unittest.TestCase):
+ """
+ Tests for the SRPAuth class
+ """
+ __name__ = "SRPAuth tests"
+
+ def setUp(self):
+ """
+ Sets up this TestCase with a simple and faked provider instance:
+
+ * runs a threaded reactor
+ * loads a mocked ProviderConfig that points to the certs in the
+ leap.common.testing module.
+ """
+ factory = fake_provider.get_provider_factory()
+ http = reactor.listenTCP(0, factory)
+ https = reactor.listenSSL(
+ 0, factory,
+ fake_provider.OpenSSLServerContextFactory())
+ get_port = lambda p: p.getHost().port
+ self.http_port = get_port(http)
+ self.https_port = get_port(https)
+
+ provider = ProviderConfig()
+ provider.get_ca_cert_path = MagicMock()
+ provider.get_ca_cert_path.return_value = _get_capath()
+
+ provider.get_api_uri = MagicMock()
+ provider.get_api_uri.return_value = self._get_https_uri()
+
+ loaded = provider.load(path=os.path.join(
+ _here, "test_provider.json"))
+ if not loaded:
+ raise ImproperlyConfiguredError(
+ "Could not load test provider config")
+ self.register = srpregister.SRPRegister(provider_config=provider)
+ self.provider = provider
+ self.TEST_USER = "register_test_auth"
+ self.TEST_PASS = "pass"
+
+ # Reset the singleton
+ srpauth.SRPAuth._SRPAuth__instance = None
+ self.auth = srpauth.SRPAuth(self.provider)
+ self.auth_backend = self.auth._SRPAuth__instance
+
+ self.old_post = self.auth_backend._session.post
+ self.old_put = self.auth_backend._session.put
+ self.old_delete = self.auth_backend._session.delete
+
+ self.old_start_auth = self.auth_backend._start_authentication
+ self.old_proc_challenge = self.auth_backend._process_challenge
+ self.old_extract_data = self.auth_backend._extract_data
+ self.old_verify_session = self.auth_backend._verify_session
+ self.old_auth_preproc = self.auth_backend._authentication_preprocessing
+ self.old_get_sid = self.auth_backend.get_session_id
+ self.old_cookie_get = self.auth_backend._session.cookies.get
+ self.old_auth = self.auth_backend.authenticate
+
+ def tearDown(self):
+ self.auth_backend._session.post = self.old_post
+ self.auth_backend._session.put = self.old_put
+ self.auth_backend._session.delete = self.old_delete
+
+ self.auth_backend._start_authentication = self.old_start_auth
+ self.auth_backend._process_challenge = self.old_proc_challenge
+ self.auth_backend._extract_data = self.old_extract_data
+ self.auth_backend._verify_session = self.old_verify_session
+ self.auth_backend._authentication_preprocessing = self.old_auth_preproc
+ self.auth_backend.get_session_id = self.old_get_sid
+ self.auth_backend._session.cookies.get = self.old_cookie_get
+ self.auth_backend.authenticate = self.old_auth
+
+ # helper methods
+
+ def _get_https_uri(self):
+ """
+ Returns a https uri with the right https port initialized
+ """
+ return "https://localhost:%s" % (self.https_port,)
+
+ # Auth tests
+
+ def _prepare_auth_test(self, code=200, side_effect=None):
+ """
+ Creates the needed defers to test several test situations. It
+ adds up to the auth preprocessing step.
+
+ :param code: status code for the response of POST in requests
+ :type code: int
+ :param side_effect: side effect triggered by the POST method
+ in requests
+ :type side_effect: some kind of Exception
+
+ :returns: the defer that is created
+ :rtype: defer.Deferred
+ """
+ res = Response()
+ res.status_code = code
+ self.auth_backend._session.post = MagicMock(return_value=res,
+ side_effect=side_effect)
+
+ d = threads.deferToThread(self.register.register_user,
+ self.TEST_USER,
+ self.TEST_PASS)
+
+ def wrapper_preproc(*args):
+ return threads.deferToThread(
+ self.auth_backend._authentication_preprocessing,
+ self.TEST_USER, self.TEST_PASS)
+
+ d.addCallback(wrapper_preproc)
+
+ return d
+
+ def test_safe_unhexlify(self):
+ input_value = "somestring"
+ test_value = binascii.hexlify(input_value)
+ self.assertEqual(
+ self.auth_backend._safe_unhexlify(test_value),
+ input_value)
+
+ def test_safe_unhexlify_not_raises(self):
+ input_value = "somestring"
+ test_value = binascii.hexlify(input_value)[:-1]
+
+ with self.assertRaises(TypeError):
+ binascii.unhexlify(test_value)
+
+ self.auth_backend._safe_unhexlify(test_value)
+
+ def test_preprocessing_loads_a(self):
+ self.assertEqual(self.auth_backend._srp_a, None)
+ self.auth_backend._authentication_preprocessing("user", "pass")
+ self.assertIsNotNone(self.auth_backend._srp_a)
+ self.assertTrue(len(self.auth_backend._srp_a) > 0)
+
+ @deferred()
+ def test_start_authentication(self):
+ d = threads.deferToThread(self.register.register_user, self.TEST_USER,
+ self.TEST_PASS)
+
+ def wrapper_preproc(*args):
+ return threads.deferToThread(
+ self.auth_backend._authentication_preprocessing,
+ self.TEST_USER, self.TEST_PASS)
+
+ d.addCallback(wrapper_preproc)
+
+ def wrapper(_):
+ return threads.deferToThread(
+ self.auth_backend._start_authentication,
+ None, self.TEST_USER)
+
+ d.addCallback(wrapper)
+ return d
+
+ @deferred()
+ def test_start_authentication_fails_connerror(self):
+ d = self._prepare_auth_test(
+ side_effect=requests.exceptions.ConnectionError())
+
+ def wrapper(_):
+ with self.assertRaises(srpauth.SRPAuthConnectionError):
+ self.auth_backend._start_authentication(None, self.TEST_USER)
+
+ d.addCallback(partial(threads.deferToThread, wrapper))
+ return d
+
+ @deferred()
+ def test_start_authentication_fails_any_error(self):
+ d = self._prepare_auth_test(side_effect=Exception())
+
+ def wrapper(_):
+ with self.assertRaises(srpauth.SRPAuthenticationError):
+ self.auth_backend._start_authentication(None, self.TEST_USER)
+
+ d.addCallback(partial(threads.deferToThread, wrapper))
+ return d
+
+ @deferred()
+ def test_start_authentication_fails_unknown_user(self):
+ d = self._prepare_auth_test(422)
+
+ def wrapper(_):
+ with self.assertRaises(srpauth.SRPAuthUnknownUser):
+ with mock.patch('leap.util.request_helpers.get_content',
+ new_callable=MagicMock()) as \
+ content:
+ content.return_value = ("{}", 0)
+
+ self.auth_backend._start_authentication(
+ None, self.TEST_USER)
+
+ d.addCallback(partial(threads.deferToThread, wrapper))
+ return d
+
+ @deferred()
+ def test_start_authentication_fails_errorcode(self):
+ d = self._prepare_auth_test(302)
+
+ def wrapper(_):
+ with self.assertRaises(srpauth.SRPAuthBadStatusCode):
+ with mock.patch('leap.util.request_helpers.get_content',
+ new_callable=MagicMock()) as \
+ content:
+ content.return_value = ("{}", 0)
+
+ self.auth_backend._start_authentication(None,
+ self.TEST_USER)
+
+ d.addCallback(partial(threads.deferToThread, wrapper))
+ return d
+
+ @deferred()
+ def test_start_authentication_fails_no_salt(self):
+ d = self._prepare_auth_test(200)
+
+ def wrapper(_):
+ with self.assertRaises(srpauth.SRPAuthNoSalt):
+ with mock.patch('leap.util.request_helpers.get_content',
+ new_callable=MagicMock()) as \
+ content:
+ content.return_value = ("{}", 0)
+
+ self.auth_backend._start_authentication(None,
+ self.TEST_USER)
+
+ d.addCallback(partial(threads.deferToThread, wrapper))
+ return d
+
+ @deferred()
+ def test_start_authentication_fails_no_B(self):
+ d = self._prepare_auth_test(200)
+
+ def wrapper(_):
+ with self.assertRaises(srpauth.SRPAuthNoB):
+ with mock.patch('leap.util.request_helpers.get_content',
+ new_callable=MagicMock()) as \
+ content:
+ content.return_value = ('{"salt": ""}', 0)
+
+ self.auth_backend._start_authentication(None,
+ self.TEST_USER)
+
+ d.addCallback(partial(threads.deferToThread, wrapper))
+ return d
+
+ @deferred()
+ def test_start_authentication_correct_saltb(self):
+ d = self._prepare_auth_test(200)
+
+ test_salt = "12345"
+ test_B = "67890"
+
+ def wrapper(_):
+ with mock.patch('leap.util.request_helpers.get_content',
+ new_callable=MagicMock()) as \
+ content:
+ content.return_value = ('{"salt":"%s", "B":"%s"}' % (test_salt,
+ test_B),
+ 0)
+
+ salt, B = self.auth_backend._start_authentication(
+ None,
+ self.TEST_USER)
+ self.assertEqual(salt, test_salt)
+ self.assertEqual(B, test_B)
+
+ d.addCallback(partial(threads.deferToThread, wrapper))
+ return d
+
+ def _prepare_auth_challenge(self):
+ """
+ Creates the needed defers to test several test situations. It
+ adds up to the start authentication step.
+
+ :returns: the defer that is created
+ :rtype: defer.Deferred
+ """
+ d = threads.deferToThread(self.register.register_user,
+ self.TEST_USER,
+ self.TEST_PASS)
+
+ def wrapper_preproc(*args):
+ return threads.deferToThread(
+ self.auth_backend._authentication_preprocessing,
+ self.TEST_USER, self.TEST_PASS)
+
+ d.addCallback(wrapper_preproc)
+
+ def wrapper_start(*args):
+ return threads.deferToThread(
+ self.auth_backend._start_authentication,
+ None, self.TEST_USER)
+
+ d.addCallback(wrapper_start)
+
+ return d
+
+ @deferred()
+ def test_process_challenge_wrong_saltb(self):
+ d = self._prepare_auth_challenge()
+
+ def wrapper(salt_B):
+ with self.assertRaises(srpauth.SRPAuthBadDataFromServer):
+ self.auth_backend._process_challenge("",
+ username=self.TEST_USER)
+
+ d.addCallback(partial(threads.deferToThread, wrapper))
+ return d
+
+ @deferred()
+ def test_process_challenge_requests_problem_raises(self):
+ d = self._prepare_auth_challenge()
+
+ self.auth_backend._session.put = MagicMock(
+ side_effect=requests.exceptions.ConnectionError())
+
+ def wrapper(salt_B):
+ with self.assertRaises(srpauth.SRPAuthConnectionError):
+ self.auth_backend._process_challenge(salt_B,
+ username=self.TEST_USER)
+
+ d.addCallback(partial(threads.deferToThread, wrapper))
+
+ return d
+
+ @deferred()
+ def test_process_challenge_json_decode_error(self):
+ d = self._prepare_auth_challenge()
+
+ def wrapper(salt_B):
+ with mock.patch('leap.util.request_helpers.get_content',
+ new_callable=MagicMock()) as \
+ content:
+ content.return_value = ("{", 0)
+ content.side_effect = JSONDecodeError("", "", 0)
+
+ with self.assertRaises(srpauth.SRPAuthJSONDecodeError):
+ self.auth_backend._process_challenge(
+ salt_B,
+ username=self.TEST_USER)
+
+ d.addCallback(partial(threads.deferToThread, wrapper))
+
+ return d
+
+ @deferred()
+ def test_process_challenge_bad_password(self):
+ d = self._prepare_auth_challenge()
+
+ res = Response()
+ res.status_code = 422
+ self.auth_backend._session.put = MagicMock(return_value=res)
+
+ def wrapper(salt_B):
+ with mock.patch('leap.util.request_helpers.get_content',
+ new_callable=MagicMock()) as \
+ content:
+ content.return_value = ("", 0)
+ with self.assertRaises(srpauth.SRPAuthBadPassword):
+ self.auth_backend._process_challenge(
+ salt_B,
+ username=self.TEST_USER)
+
+ d.addCallback(partial(threads.deferToThread, wrapper))
+
+ return d
+
+ @deferred()
+ def test_process_challenge_bad_password2(self):
+ d = self._prepare_auth_challenge()
+
+ res = Response()
+ res.status_code = 422
+ self.auth_backend._session.put = MagicMock(return_value=res)
+
+ def wrapper(salt_B):
+ with mock.patch('leap.util.request_helpers.get_content',
+ new_callable=MagicMock()) as \
+ content:
+ content.return_value = ("[]", 0)
+ with self.assertRaises(srpauth.SRPAuthBadPassword):
+ self.auth_backend._process_challenge(
+ salt_B,
+ username=self.TEST_USER)
+
+ d.addCallback(partial(threads.deferToThread, wrapper))
+
+ return d
+
+ @deferred()
+ def test_process_challenge_other_error_code(self):
+ d = self._prepare_auth_challenge()
+
+ res = Response()
+ res.status_code = 300
+ self.auth_backend._session.put = MagicMock(return_value=res)
+
+ def wrapper(salt_B):
+ with mock.patch('leap.util.request_helpers.get_content',
+ new_callable=MagicMock()) as \
+ content:
+ content.return_value = ("{}", 0)
+ with self.assertRaises(srpauth.SRPAuthBadStatusCode):
+ self.auth_backend._process_challenge(
+ salt_B,
+ username=self.TEST_USER)
+
+ d.addCallback(partial(threads.deferToThread, wrapper))
+
+ return d
+
+ @deferred()
+ def test_process_challenge(self):
+ d = self._prepare_auth_challenge()
+
+ def wrapper(salt_B):
+ self.auth_backend._process_challenge(salt_B,
+ username=self.TEST_USER)
+
+ d.addCallback(partial(threads.deferToThread, wrapper))
+
+ return d
+
+ def test_extract_data_wrong_data(self):
+ with self.assertRaises(srpauth.SRPAuthBadDataFromServer):
+ self.auth_backend._extract_data(None)
+
+ with self.assertRaises(srpauth.SRPAuthBadDataFromServer):
+ self.auth_backend._extract_data("")
+
+ def test_extract_data_fails_on_wrong_data_from_server(self):
+ with self.assertRaises(srpauth.SRPAuthBadDataFromServer):
+ self.auth_backend._extract_data({})
+
+ with self.assertRaises(srpauth.SRPAuthBadDataFromServer):
+ self.auth_backend._extract_data({"M2": ""})
+
+ def test_extract_data_sets_uidtoken(self):
+ test_uid = "someuid"
+ test_m2 = "somem2"
+ test_token = "sometoken"
+ test_data = {
+ "M2": test_m2,
+ "id": test_uid,
+ "token": test_token
+ }
+ m2 = self.auth_backend._extract_data(test_data)
+
+ self.assertEqual(m2, test_m2)
+ self.assertEqual(self.auth_backend.get_uid(), test_uid)
+ self.assertEqual(self.auth_backend.get_uid(),
+ self.auth.get_uid())
+ self.assertEqual(self.auth_backend.get_token(), test_token)
+ self.assertEqual(self.auth_backend.get_token(),
+ self.auth.get_token())
+
+ def _prepare_verify_session(self):
+ """
+ Prepares the tests for verify session with needed steps
+ before. It adds up to the extract_data step.
+
+ :returns: The defer to chain to
+ :rtype: defer.Deferred
+ """
+ d = self._prepare_auth_challenge()
+
+ def wrapper_proc_challenge(salt_B):
+ return self.auth_backend._process_challenge(
+ salt_B,
+ username=self.TEST_USER)
+
+ def wrapper_extract_data(data):
+ return self.auth_backend._extract_data(data)
+
+ d.addCallback(partial(threads.deferToThread, wrapper_proc_challenge))
+ d.addCallback(partial(threads.deferToThread, wrapper_extract_data))
+
+ return d
+
+ @deferred()
+ def test_verify_session_unhexlifiable_m2(self):
+ d = self._prepare_verify_session()
+
+ def wrapper(M2):
+ with self.assertRaises(srpauth.SRPAuthBadDataFromServer):
+ self.auth_backend._verify_session("za") # unhexlifiable value
+
+ d.addCallback(wrapper)
+
+ return d
+
+ @deferred()
+ def test_verify_session_unverifiable_m2(self):
+ d = self._prepare_verify_session()
+
+ def wrapper(M2):
+ with self.assertRaises(srpauth.SRPAuthVerificationFailed):
+ # Correctly unhelifiable value, but not for verifying the
+ # session
+ self.auth_backend._verify_session("abc12")
+
+ d.addCallback(wrapper)
+
+ return d
+
+ @deferred()
+ def test_verify_session_fails_on_no_session_id(self):
+ d = self._prepare_verify_session()
+
+ def wrapper(M2):
+ self.auth_backend._session.cookies.get = MagicMock(
+ return_value=None)
+ with self.assertRaises(srpauth.SRPAuthNoSessionId):
+ self.auth_backend._verify_session(M2)
+
+ d.addCallback(wrapper)
+
+ return d
+
+ @deferred()
+ def test_verify_session_session_id(self):
+ d = self._prepare_verify_session()
+
+ test_session_id = "12345"
+
+ def wrapper(M2):
+ self.auth_backend._session.cookies.get = MagicMock(
+ return_value=test_session_id)
+ self.auth_backend._verify_session(M2)
+ self.assertEqual(self.auth_backend.get_session_id(),
+ test_session_id)
+ self.assertEqual(self.auth_backend.get_session_id(),
+ self.auth.get_session_id())
+
+ d.addCallback(wrapper)
+
+ return d
+
+ @deferred()
+ def test_verify_session(self):
+ d = self._prepare_verify_session()
+
+ def wrapper(M2):
+ self.auth_backend._verify_session(M2)
+
+ d.addCallback(wrapper)
+
+ return d
+
+ @deferred()
+ def test_authenticate(self):
+ self.auth_backend._authentication_preprocessing = MagicMock(
+ return_value=None)
+ self.auth_backend._start_authentication = MagicMock(return_value=None)
+ self.auth_backend._process_challenge = MagicMock(return_value=None)
+ self.auth_backend._extract_data = MagicMock(return_value=None)
+ self.auth_backend._verify_session = MagicMock(return_value=None)
+
+ d = self.auth_backend.authenticate(self.TEST_USER, self.TEST_PASS)
+
+ def check(*args):
+ self.auth_backend._authentication_preprocessing.\
+ assert_called_once_with(
+ username=self.TEST_USER,
+ password=self.TEST_PASS
+ )
+ self.auth_backend._start_authentication.assert_called_once_with(
+ None,
+ username=self.TEST_USER)
+ self.auth_backend._process_challenge.assert_called_once_with(
+ None,
+ username=self.TEST_USER)
+ self.auth_backend._extract_data.assert_called_once_with(
+ None,
+ username=self.TEST_USER)
+ self.auth_backend._verify_session.assert_called_once_with(None)
+
+ d.addCallback(check)
+
+ return d
+
+ @deferred()
+ def test_logout_fails_if_not_logged_in(self):
+
+ def wrapper(*args):
+ with self.assertRaises(AssertionError):
+ self.auth_backend.logout()
+
+ d = threads.deferToThread(wrapper)
+ return d
+
+ @deferred()
+ def test_logout_traps_delete(self):
+ self.auth_backend.get_session_id = MagicMock(return_value="1234")
+ self.auth_backend._session.delete = MagicMock(side_effect=Exception())
+
+ def wrapper(*args):
+ self.auth_backend.logout()
+
+ d = threads.deferToThread(wrapper)
+ return d
+
+ @deferred()
+ def test_logout_clears(self):
+ self.auth_backend._session_id = "1234"
+
+ def wrapper(*args):
+ old_session = self.auth_backend._session
+ self.auth_backend.logout()
+ self.assertIsNone(self.auth_backend.get_session_id())
+ self.assertIsNone(self.auth_backend.get_uid())
+ self.assertNotEqual(old_session, self.auth_backend._session)
+
+ d = threads.deferToThread(wrapper)
+ return d
+
+
+class SRPAuthSingletonTestCase(unittest.TestCase):
+ def setUp(self):
+ self.old_auth = srpauth.SRPAuth._SRPAuth__impl.authenticate
+
+ def tearDown(self):
+ srpauth.SRPAuth._SRPAuth__impl.authenticate = self.old_auth
+
+ def test_singleton(self):
+ obj1 = srpauth.SRPAuth(ProviderConfig())
+ obj2 = srpauth.SRPAuth(ProviderConfig())
+ self.assertEqual(obj1._SRPAuth__instance, obj2._SRPAuth__instance)
+
+ @deferred()
+ def test_authenticate_notifies_gui(self):
+ auth = srpauth.SRPAuth(ProviderConfig())
+ auth._SRPAuth__instance.authenticate = MagicMock(
+ return_value=threads.deferToThread(lambda: None))
+ auth._gui_notify = MagicMock()
+
+ d = auth.authenticate("", "")
+
+ def check(*args):
+ auth._gui_notify.assert_called_once_with(None)
+
+ d.addCallback(check)
+ return d
+
+ @deferred()
+ def test_authenticate_errsback(self):
+ auth = srpauth.SRPAuth(ProviderConfig())
+ auth._SRPAuth__instance.authenticate = MagicMock(
+ return_value=threads.deferToThread(MagicMock(
+ side_effect=Exception())))
+ auth._gui_notify = MagicMock()
+ auth._errback = MagicMock()
+
+ d = auth.authenticate("", "")
+
+ def check(*args):
+ self.assertFalse(auth._gui_notify.called)
+ self.assertEqual(auth._errback.call_count, 1)
+
+ d.addCallback(check)
+ return d
+
+ @deferred()
+ def test_authenticate_runs_cleanly_when_raises(self):
+ auth = srpauth.SRPAuth(ProviderConfig())
+ auth._SRPAuth__instance.authenticate = MagicMock(
+ return_value=threads.deferToThread(MagicMock(
+ side_effect=Exception())))
+
+ d = auth.authenticate("", "")
+
+ return d
+
+ @deferred()
+ def test_authenticate_runs_cleanly(self):
+ auth = srpauth.SRPAuth(ProviderConfig())
+ auth._SRPAuth__instance.authenticate = MagicMock(
+ return_value=threads.deferToThread(MagicMock()))
+
+ d = auth.authenticate("", "")
+
+ return d
+
+ def test_logout(self):
+ auth = srpauth.SRPAuth(ProviderConfig())
+ auth._SRPAuth__instance.logout = MagicMock()
+
+ self.assertTrue(auth.logout())
+
+ def test_logout_rets_false_when_raises(self):
+ auth = srpauth.SRPAuth(ProviderConfig())
+ auth._SRPAuth__instance.logout = MagicMock(
+ side_effect=Exception())
+
+ self.assertFalse(auth.logout())
diff --git a/src/leap/crypto/tests/test_srpregister.py b/src/leap/crypto/tests/test_srpregister.py
index 6d2b52e8..66b815f2 100644
--- a/src/leap/crypto/tests/test_srpregister.py
+++ b/src/leap/crypto/tests/test_srpregister.py
@@ -17,7 +17,6 @@
"""
Tests for:
* leap/crypto/srpregister.py
- * leap/crypto/srpauth.py
"""
try:
import unittest2 as unittest
@@ -53,9 +52,9 @@ class ImproperlyConfiguredError(Exception):
class SRPTestCase(unittest.TestCase):
"""
- Tests for the SRP Register and Auth classes
+ Tests for the SRPRegister class
"""
- __name__ = "SRPRegister and SRPAuth tests"
+ __name__ = "SRPRegister tests"
@classmethod
def setUpClass(cls):
diff --git a/src/leap/util/request_helpers.py b/src/leap/util/request_helpers.py
index e06dabb8..350abfbd 100644
--- a/src/leap/util/request_helpers.py
+++ b/src/leap/util/request_helpers.py
@@ -41,7 +41,7 @@ def get_content(request):
contents = ""
mtime = None
- if request.json:
+ if request.content and request.json:
if callable(request.json):
contents = json.dumps(request.json())
else: