1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
|
# Copyright 2012 Canonical Ltd.
#
# This file is part of u1db.
#
# u1db is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License version 3
# as published by the Free Software Foundation.
#
# u1db is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with u1db. If not, see <http://www.gnu.org/licenses/>.
"""U1DB OAuth authorisation WSGI middleware."""
import httplib
from oauth import oauth
try:
import simplejson as json
except ImportError:
import json # noqa
from urllib import quote
from wsgiref.util import shift_path_info
sign_meth_HMAC_SHA1 = oauth.OAuthSignatureMethod_HMAC_SHA1()
sign_meth_PLAINTEXT = oauth.OAuthSignatureMethod_PLAINTEXT()
class OAuthMiddleware(object):
"""U1DB OAuth Authorisation WSGI middleware."""
# max seconds the request timestamp is allowed to be shifted
# from arrival time
timestamp_threshold = 300
def __init__(self, app, base_url, prefix='/~/'):
self.app = app
self.base_url = base_url
self.prefix = prefix
def get_oauth_data_store(self):
"""Provide a oauth.OAuthDataStore."""
raise NotImplementedError(self.get_oauth_data_store)
def _error(self, start_response, status, description, message=None):
start_response("%d %s" % (status, httplib.responses[status]),
[('content-type', 'application/json')])
err = {"error": description}
if message:
err['message'] = message
return [json.dumps(err)]
def __call__(self, environ, start_response):
if self.prefix and not environ['PATH_INFO'].startswith(self.prefix):
return self._error(start_response, 400, "bad request")
headers = {}
if 'HTTP_AUTHORIZATION' in environ:
headers['Authorization'] = environ['HTTP_AUTHORIZATION']
oauth_req = oauth.OAuthRequest.from_request(
http_method=environ['REQUEST_METHOD'],
http_url=self.base_url + environ['PATH_INFO'],
headers=headers,
query_string=environ['QUERY_STRING']
)
if oauth_req is None:
return self._error(start_response, 401, "unauthorized",
"Missing OAuth.")
try:
self.verify(environ, oauth_req)
except oauth.OAuthError, e:
return self._error(start_response, 401, "unauthorized",
e.message)
shift_path_info(environ)
return self.app(environ, start_response)
def verify(self, environ, oauth_req):
"""Verify OAuth request, put user_id in the environ."""
oauth_server = oauth.OAuthServer(self.get_oauth_data_store())
oauth_server.timestamp_threshold = self.timestamp_threshold
oauth_server.add_signature_method(sign_meth_HMAC_SHA1)
oauth_server.add_signature_method(sign_meth_PLAINTEXT)
consumer, token, parameters = oauth_server.verify_request(oauth_req)
# filter out oauth bits
environ['QUERY_STRING'] = '&'.join("%s=%s" % (quote(k, safe=''),
quote(v, safe=''))
for k, v in parameters.iteritems())
return consumer, token
|