summaryrefslogtreecommitdiff
path: root/u1db/remote/http_client.py
diff options
context:
space:
mode:
Diffstat (limited to 'u1db/remote/http_client.py')
-rw-r--r--u1db/remote/http_client.py218
1 files changed, 218 insertions, 0 deletions
diff --git a/u1db/remote/http_client.py b/u1db/remote/http_client.py
new file mode 100644
index 00000000..decddda3
--- /dev/null
+++ b/u1db/remote/http_client.py
@@ -0,0 +1,218 @@
+# Copyright 2011-2012 Canonical Ltd.
+#
+# This file is part of u1db.
+#
+# u1db is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Lesser General Public License version 3
+# as published by the Free Software Foundation.
+#
+# u1db is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with u1db. If not, see <http://www.gnu.org/licenses/>.
+
+"""Base class to make requests to a remote HTTP server."""
+
+import httplib
+from oauth import oauth
+try:
+ import simplejson as json
+except ImportError:
+ import json # noqa
+import socket
+import ssl
+import sys
+import urlparse
+import urllib
+
+from time import sleep
+from u1db import (
+ errors,
+ )
+from u1db.remote import (
+ http_errors,
+ )
+
+from u1db.remote.ssl_match_hostname import ( # noqa
+ CertificateError,
+ match_hostname,
+ )
+
+# Ubuntu/debian
+# XXX other...
+CA_CERTS = "/etc/ssl/certs/ca-certificates.crt"
+
+
+def _encode_query_parameter(value):
+ """Encode query parameter."""
+ if isinstance(value, bool):
+ if value:
+ value = 'true'
+ else:
+ value = 'false'
+ return unicode(value).encode('utf-8')
+
+
+class _VerifiedHTTPSConnection(httplib.HTTPSConnection):
+ """HTTPSConnection verifying server side certificates."""
+ # derived from httplib.py
+
+ def connect(self):
+ "Connect to a host on a given (SSL) port."
+
+ sock = socket.create_connection((self.host, self.port),
+ self.timeout, self.source_address)
+ if self._tunnel_host:
+ self.sock = sock
+ self._tunnel()
+ if sys.platform.startswith('linux'):
+ cert_opts = {
+ 'cert_reqs': ssl.CERT_REQUIRED,
+ 'ca_certs': CA_CERTS
+ }
+ else:
+ # XXX no cert verification implemented elsewhere for now
+ cert_opts = {}
+ self.sock = ssl.wrap_socket(sock, self.key_file, self.cert_file,
+ ssl_version=ssl.PROTOCOL_SSLv3,
+ **cert_opts
+ )
+ if cert_opts:
+ match_hostname(self.sock.getpeercert(), self.host)
+
+
+class HTTPClientBase(object):
+ """Base class to make requests to a remote HTTP server."""
+
+ # by default use HMAC-SHA1 OAuth signature method to not disclose
+ # tokens
+ # NB: given that the content bodies are not covered by the
+ # signatures though, to achieve security (against man-in-the-middle
+ # attacks for example) one would need HTTPS
+ oauth_signature_method = oauth.OAuthSignatureMethod_HMAC_SHA1()
+
+ # Will use these delays to retry on 503 befor finally giving up. The final
+ # 0 is there to not wait after the final try fails.
+ _delays = (1, 1, 2, 4, 0)
+
+ def __init__(self, url, creds=None):
+ self._url = urlparse.urlsplit(url)
+ self._conn = None
+ self._creds = {}
+ if creds is not None:
+ if len(creds) != 1:
+ raise errors.UnknownAuthMethod()
+ auth_meth, credentials = creds.items()[0]
+ try:
+ set_creds = getattr(self, 'set_%s_credentials' % auth_meth)
+ except AttributeError:
+ raise errors.UnknownAuthMethod(auth_meth)
+ set_creds(**credentials)
+
+ def set_oauth_credentials(self, consumer_key, consumer_secret,
+ token_key, token_secret):
+ self._creds = {'oauth': (
+ oauth.OAuthConsumer(consumer_key, consumer_secret),
+ oauth.OAuthToken(token_key, token_secret))}
+
+ def _ensure_connection(self):
+ if self._conn is not None:
+ return
+ if self._url.scheme == 'https':
+ connClass = _VerifiedHTTPSConnection
+ else:
+ connClass = httplib.HTTPConnection
+ self._conn = connClass(self._url.hostname, self._url.port)
+
+ def close(self):
+ if self._conn:
+ self._conn.close()
+ self._conn = None
+
+ # xxx retry mechanism?
+
+ def _error(self, respdic):
+ descr = respdic.get("error")
+ exc_cls = errors.wire_description_to_exc.get(descr)
+ if exc_cls is not None:
+ message = respdic.get("message")
+ raise exc_cls(message)
+
+ def _response(self):
+ resp = self._conn.getresponse()
+ body = resp.read()
+ headers = dict(resp.getheaders())
+ if resp.status in (200, 201):
+ return body, headers
+ elif resp.status in http_errors.ERROR_STATUSES:
+ try:
+ respdic = json.loads(body)
+ except ValueError:
+ pass
+ else:
+ self._error(respdic)
+ # special case
+ if resp.status == 503:
+ raise errors.Unavailable(body, headers)
+ raise errors.HTTPError(resp.status, body, headers)
+
+ def _sign_request(self, method, url_query, params):
+ if 'oauth' in self._creds:
+ consumer, token = self._creds['oauth']
+ full_url = "%s://%s%s" % (self._url.scheme, self._url.netloc,
+ url_query)
+ oauth_req = oauth.OAuthRequest.from_consumer_and_token(
+ consumer, token,
+ http_method=method,
+ parameters=params,
+ http_url=full_url
+ )
+ oauth_req.sign_request(
+ self.oauth_signature_method, consumer, token)
+ # Authorization: OAuth ...
+ return oauth_req.to_header().items()
+ else:
+ return []
+
+ def _request(self, method, url_parts, params=None, body=None,
+ content_type=None):
+ self._ensure_connection()
+ unquoted_url = url_query = self._url.path
+ if url_parts:
+ if not url_query.endswith('/'):
+ url_query += '/'
+ unquoted_url = url_query
+ url_query += '/'.join(urllib.quote(part, safe='')
+ for part in url_parts)
+ # oauth performs its own quoting
+ unquoted_url += '/'.join(url_parts)
+ encoded_params = {}
+ if params:
+ for key, value in params.items():
+ key = unicode(key).encode('utf-8')
+ encoded_params[key] = _encode_query_parameter(value)
+ url_query += ('?' + urllib.urlencode(encoded_params))
+ if body is not None and not isinstance(body, basestring):
+ body = json.dumps(body)
+ content_type = 'application/json'
+ headers = {}
+ if content_type:
+ headers['content-type'] = content_type
+ headers.update(
+ self._sign_request(method, unquoted_url, encoded_params))
+ for delay in self._delays:
+ try:
+ self._conn.request(method, url_query, body, headers)
+ return self._response()
+ except errors.Unavailable, e:
+ sleep(delay)
+ raise e
+
+ def _request_json(self, method, url_parts, params=None, body=None,
+ content_type=None):
+ res, headers = self._request(method, url_parts, params, body,
+ content_type)
+ return json.loads(res), headers