# Copyright 2011-2012 Canonical Ltd.
# Copyright 2016 LEAP Encryption Access Project
#
# This file is part of u1db.
#
# u1db is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License version 3
# as published by the Free Software Foundation.
#
# u1db is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with u1db. If not, see .
"""
Tests for HTTPDatabase
"""
import json
from unittest import skip
from leap.soledad.common.l2db import errors
from leap.soledad.common.l2db.remote import http_client
from test_soledad import u1db_tests as tests
@skip("Skiping tests imported from U1DB.")
class TestEncoder(tests.TestCase):
def test_encode_string(self):
self.assertEqual("foo", http_client._encode_query_parameter("foo"))
def test_encode_true(self):
self.assertEqual("true", http_client._encode_query_parameter(True))
def test_encode_false(self):
self.assertEqual("false", http_client._encode_query_parameter(False))
@skip("Skiping tests imported from U1DB.")
class TestHTTPClientBase(tests.TestCaseWithServer):
def setUp(self):
super(TestHTTPClientBase, self).setUp()
self.errors = 0
def app(self, environ, start_response):
if environ['PATH_INFO'].endswith('echo'):
start_response("200 OK", [('Content-Type', 'application/json')])
ret = {}
for name in ('REQUEST_METHOD', 'PATH_INFO', 'QUERY_STRING'):
ret[name] = environ[name]
if environ['REQUEST_METHOD'] in ('PUT', 'POST'):
ret['CONTENT_TYPE'] = environ['CONTENT_TYPE']
content_length = int(environ['CONTENT_LENGTH'])
ret['body'] = environ['wsgi.input'].read(content_length)
return [json.dumps(ret)]
elif environ['PATH_INFO'].endswith('error_then_accept'):
if self.errors >= 3:
start_response(
"200 OK", [('Content-Type', 'application/json')])
ret = {}
for name in ('REQUEST_METHOD', 'PATH_INFO', 'QUERY_STRING'):
ret[name] = environ[name]
if environ['REQUEST_METHOD'] in ('PUT', 'POST'):
ret['CONTENT_TYPE'] = environ['CONTENT_TYPE']
content_length = int(environ['CONTENT_LENGTH'])
ret['body'] = '{"oki": "doki"}'
return [json.dumps(ret)]
self.errors += 1
content_length = int(environ['CONTENT_LENGTH'])
error = json.loads(
environ['wsgi.input'].read(content_length))
response = error['response']
# In debug mode, wsgiref has an assertion that the status parameter
# is a 'str' object. However error['status'] returns a unicode
# object.
status = str(error['status'])
if isinstance(response, unicode):
response = str(response)
if isinstance(response, str):
start_response(status, [('Content-Type', 'text/plain')])
return [str(response)]
else:
start_response(status, [('Content-Type', 'application/json')])
return [json.dumps(response)]
elif environ['PATH_INFO'].endswith('error'):
self.errors += 1
content_length = int(environ['CONTENT_LENGTH'])
error = json.loads(
environ['wsgi.input'].read(content_length))
response = error['response']
# In debug mode, wsgiref has an assertion that the status parameter
# is a 'str' object. However error['status'] returns a unicode
# object.
status = str(error['status'])
if isinstance(response, unicode):
response = str(response)
if isinstance(response, str):
start_response(status, [('Content-Type', 'text/plain')])
return [str(response)]
else:
start_response(status, [('Content-Type', 'application/json')])
return [json.dumps(response)]
def make_app(self):
return self.app
def getClient(self, **kwds):
self.startServer()
return http_client.HTTPClientBase(self.getURL('dbase'), **kwds)
def test_construct(self):
self.startServer()
url = self.getURL()
cli = http_client.HTTPClientBase(url)
self.assertEqual(url, cli._url.geturl())
self.assertIs(None, cli._conn)
def test_parse_url(self):
cli = http_client.HTTPClientBase(
'%s://127.0.0.1:12345/' % self.url_scheme)
self.assertEqual(self.url_scheme, cli._url.scheme)
self.assertEqual('127.0.0.1', cli._url.hostname)
self.assertEqual(12345, cli._url.port)
self.assertEqual('/', cli._url.path)
def test__ensure_connection(self):
cli = self.getClient()
self.assertIs(None, cli._conn)
cli._ensure_connection()
self.assertIsNot(None, cli._conn)
conn = cli._conn
cli._ensure_connection()
self.assertIs(conn, cli._conn)
def test_close(self):
cli = self.getClient()
cli._ensure_connection()
cli.close()
self.assertIs(None, cli._conn)
def test__request(self):
cli = self.getClient()
res, headers = cli._request('PUT', ['echo'], {}, {})
self.assertEqual({'CONTENT_TYPE': 'application/json',
'PATH_INFO': '/dbase/echo',
'QUERY_STRING': '',
'body': '{}',
'REQUEST_METHOD': 'PUT'}, json.loads(res))
res, headers = cli._request('GET', ['doc', 'echo'], {'a': 1})
self.assertEqual({'PATH_INFO': '/dbase/doc/echo',
'QUERY_STRING': 'a=1',
'REQUEST_METHOD': 'GET'}, json.loads(res))
res, headers = cli._request('GET', ['doc', '%FFFF', 'echo'], {'a': 1})
self.assertEqual({'PATH_INFO': '/dbase/doc/%FFFF/echo',
'QUERY_STRING': 'a=1',
'REQUEST_METHOD': 'GET'}, json.loads(res))
res, headers = cli._request('POST', ['echo'], {'b': 2}, 'Body',
'application/x-test')
self.assertEqual({'CONTENT_TYPE': 'application/x-test',
'PATH_INFO': '/dbase/echo',
'QUERY_STRING': 'b=2',
'body': 'Body',
'REQUEST_METHOD': 'POST'}, json.loads(res))
def test__request_json(self):
cli = self.getClient()
res, headers = cli._request_json(
'POST', ['echo'], {'b': 2}, {'a': 'x'})
self.assertEqual('application/json', headers['content-type'])
self.assertEqual({'CONTENT_TYPE': 'application/json',
'PATH_INFO': '/dbase/echo',
'QUERY_STRING': 'b=2',
'body': '{"a": "x"}',
'REQUEST_METHOD': 'POST'}, res)
def test_unspecified_http_error(self):
cli = self.getClient()
self.assertRaises(errors.HTTPError,
cli._request_json, 'POST', ['error'], {},
{'status': "500 Internal Error",
'response': "Crash."})
try:
cli._request_json('POST', ['error'], {},
{'status': "500 Internal Error",
'response': "Fail."})
except errors.HTTPError, e:
pass
self.assertEqual(500, e.status)
self.assertEqual("Fail.", e.message)
self.assertTrue("content-type" in e.headers)
def test_revision_conflict(self):
cli = self.getClient()
self.assertRaises(errors.RevisionConflict,
cli._request_json, 'POST', ['error'], {},
{'status': "409 Conflict",
'response': {"error": "revision conflict"}})
def test_unavailable_proper(self):
cli = self.getClient()
cli._delays = (0, 0, 0, 0, 0)
self.assertRaises(errors.Unavailable,
cli._request_json, 'POST', ['error'], {},
{'status': "503 Service Unavailable",
'response': {"error": "unavailable"}})
self.assertEqual(5, self.errors)
def test_unavailable_then_available(self):
cli = self.getClient()
cli._delays = (0, 0, 0, 0, 0)
res, headers = cli._request_json(
'POST', ['error_then_accept'], {'b': 2},
{'status': "503 Service Unavailable",
'response': {"error": "unavailable"}})
self.assertEqual('application/json', headers['content-type'])
self.assertEqual({'CONTENT_TYPE': 'application/json',
'PATH_INFO': '/dbase/error_then_accept',
'QUERY_STRING': 'b=2',
'body': '{"oki": "doki"}',
'REQUEST_METHOD': 'POST'}, res)
self.assertEqual(3, self.errors)
def test_unavailable_random_source(self):
cli = self.getClient()
cli._delays = (0, 0, 0, 0, 0)
try:
cli._request_json('POST', ['error'], {},
{'status': "503 Service Unavailable",
'response': "random unavailable."})
except errors.Unavailable, e:
pass
self.assertEqual(503, e.status)
self.assertEqual("random unavailable.", e.message)
self.assertTrue("content-type" in e.headers)
self.assertEqual(5, self.errors)
def test_document_too_big(self):
cli = self.getClient()
self.assertRaises(errors.DocumentTooBig,
cli._request_json, 'POST', ['error'], {},
{'status': "403 Forbidden",
'response': {"error": "document too big"}})
def test_user_quota_exceeded(self):
cli = self.getClient()
self.assertRaises(errors.UserQuotaExceeded,
cli._request_json, 'POST', ['error'], {},
{'status': "403 Forbidden",
'response': {"error": "user quota exceeded"}})
def test_user_needs_subscription(self):
cli = self.getClient()
self.assertRaises(errors.SubscriptionNeeded,
cli._request_json, 'POST', ['error'], {},
{'status': "403 Forbidden",
'response': {"error": "user needs subscription"}})
def test_generic_u1db_error(self):
cli = self.getClient()
self.assertRaises(errors.U1DBError,
cli._request_json, 'POST', ['error'], {},
{'status': "400 Bad Request",
'response': {"error": "error"}})
try:
cli._request_json('POST', ['error'], {},
{'status': "400 Bad Request",
'response': {"error": "error"}})
except errors.U1DBError, e:
pass
self.assertIs(e.__class__, errors.U1DBError)
def test_unspecified_bad_request(self):
cli = self.getClient()
self.assertRaises(errors.HTTPError,
cli._request_json, 'POST', ['error'], {},
{'status': "400 Bad Request",
'response': ""})
try:
cli._request_json('POST', ['error'], {},
{'status': "400 Bad Request",
'response': ""})
except errors.HTTPError, e:
pass
self.assertEqual(400, e.status)
self.assertEqual("", e.message)
self.assertTrue("content-type" in e.headers)
def test_unknown_creds(self):
self.assertRaises(errors.UnknownAuthMethod,
self.getClient, creds={'foo': {}})
self.assertRaises(errors.UnknownAuthMethod,
self.getClient, creds={})