summaryrefslogtreecommitdiff
path: root/src/leap/crypto/certs.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/leap/crypto/certs.py')
-rw-r--r--src/leap/crypto/certs.py97
1 files changed, 50 insertions, 47 deletions
diff --git a/src/leap/crypto/certs.py b/src/leap/crypto/certs.py
index 78f49fb0..cbb5725a 100644
--- a/src/leap/crypto/certs.py
+++ b/src/leap/crypto/certs.py
@@ -1,44 +1,55 @@
-import ctypes
+import logging
+import os
from StringIO import StringIO
-import socket
+import ssl
+import time
-import gnutls.connection
-import gnutls.crypto
-import gnutls.library
+from dateutil.parser import parse
+from OpenSSL import crypto
from leap.util.misc import null_check
+logger = logging.getLogger(__name__)
+
class BadCertError(Exception):
- """raised for malformed certs"""
+ """
+ raised for malformed certs
+ """
-def get_https_cert_from_domain(domain):
+class NoCertError(Exception):
"""
- @param domain: a domain name to get a certificate from.
+ raised for cert not found in given path
"""
- sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- cred = gnutls.connection.X509Credentials()
- session = gnutls.connection.ClientSession(sock, cred)
- session.connect((domain, 443))
- session.handshake()
- cert = session.peer_certificate
- return cert
+
+def get_https_cert_from_domain(domain, port=443):
+ """
+ @param domain: a domain name to get a certificate from.
+ """
+ cert = ssl.get_server_certificate((domain, port))
+ x509 = crypto.load_certificate(crypto.FILETYPE_PEM, cert)
+ return x509
def get_cert_from_file(_file):
- getcert = lambda f: gnutls.crypto.X509Certificate(f.read())
- if isinstance(_file, str):
+ null_check(_file, "pem file")
+ if isinstance(_file, (str, unicode)):
+ if not os.path.isfile(_file):
+ raise NoCertError
with open(_file) as f:
- cert = getcert(f)
+ cert = f.read()
else:
- cert = getcert(_file)
- return cert
+ cert = _file.read()
+ x509 = crypto.load_certificate(crypto.FILETYPE_PEM, cert)
+ return x509
def get_pkey_from_file(_file):
- getkey = lambda f: gnutls.crypto.X509PrivateKey(f.read())
+ getkey = lambda f: crypto.load_privatekey(
+ crypto.FILETYPE_PEM, f.read())
+
if isinstance(_file, str):
with open(_file) as f:
key = getkey(f)
@@ -48,6 +59,10 @@ def get_pkey_from_file(_file):
def can_load_cert_and_pkey(string):
+ """
+ loads certificate and private key from
+ a buffer
+ """
try:
f = StringIO(string)
cert = get_cert_from_file(f)
@@ -57,14 +72,14 @@ def can_load_cert_and_pkey(string):
null_check(cert, 'certificate')
null_check(key, 'private key')
- except:
- # XXX catch GNUTLSError?
+ except Exception as exc:
+ logger.error(type(exc), exc.message)
raise BadCertError
else:
return True
-def get_cert_fingerprint(domain=None, filepath=None,
+def get_cert_fingerprint(domain=None, port=443, filepath=None,
hash_type="SHA256", sep=":"):
"""
@param domain: a domain name to get a fingerprint from
@@ -79,31 +94,19 @@ def get_cert_fingerprint(domain=None, filepath=None,
@rtype: string
"""
if domain:
- cert = get_https_cert_from_domain(domain)
+ cert = get_https_cert_from_domain(domain, port=port)
if filepath:
cert = get_cert_from_file(filepath)
+ hex_fpr = cert.digest(hash_type)
+ return hex_fpr
- _buffer = ctypes.create_string_buffer(64)
- buffer_length = ctypes.c_size_t(64)
-
- SUPPORTED_DIGEST_FUN = ("SHA1", "SHA224", "SHA256", "SHA384", "SHA512")
- if hash_type in SUPPORTED_DIGEST_FUN:
- digestfunction = getattr(
- gnutls.library.constants,
- "GNUTLS_DIG_%s" % hash_type)
- else:
- # XXX improperlyconfigured or something
- raise Exception("digest function not supported")
-
- gnutls.library.functions.gnutls_x509_crt_get_fingerprint(
- cert._c_object, digestfunction,
- ctypes.byref(_buffer), ctypes.byref(buffer_length))
-
- # deinit
- #server_cert._X509Certificate__deinit(server_cert._c_object)
- # needed? is segfaulting
- fpr = ctypes.string_at(_buffer, buffer_length.value)
- hex_fpr = sep.join(u"%02X" % ord(char) for char in fpr)
+def get_time_boundaries(certfile):
+ cert = get_cert_from_file(certfile)
+ null_check(cert, 'certificate')
- return hex_fpr
+ fromts, tots = (cert.get_notBefore(), cert.get_notAfter())
+ from_, to_ = map(
+ lambda ts: time.gmtime(time.mktime(parse(ts).timetuple())),
+ (fromts, tots))
+ return from_, to_