diff options
author | kali <kali@leap.se> | 2013-01-24 07:59:35 +0900 |
---|---|---|
committer | kali <kali@leap.se> | 2013-01-24 07:59:35 +0900 |
commit | 26d1849415402a5aa826c57519d40a19cc67c059 (patch) | |
tree | 96011407b7e0e7d0eefa1f1d7f81d7b863ab6172 /src/leap/crypto/certs.py | |
parent | 88159d703e9b75d3cb0c192e7d7ae92d9d8c67bc (diff) | |
parent | 73b73793d524b795279a697cad12c22a808f5c36 (diff) |
Merge branch 'feature/deprecate-gnutls' into develop
Diffstat (limited to 'src/leap/crypto/certs.py')
-rw-r--r-- | src/leap/crypto/certs.py | 86 |
1 files changed, 38 insertions, 48 deletions
diff --git a/src/leap/crypto/certs.py b/src/leap/crypto/certs.py index 78f49fb0..c2835878 100644 --- a/src/leap/crypto/certs.py +++ b/src/leap/crypto/certs.py @@ -1,44 +1,53 @@ -import ctypes +import logging +import os from StringIO import StringIO -import socket +import ssl -import gnutls.connection -import gnutls.crypto -import gnutls.library +from OpenSSL import crypto from leap.util.misc import null_check +logger = logging.getLogger(__name__) + class BadCertError(Exception): - """raised for malformed certs""" + """ + raised for malformed certs + """ -def get_https_cert_from_domain(domain): +class NoCertError(Exception): """ - @param domain: a domain name to get a certificate from. + raised for cert not found in given path """ - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - cred = gnutls.connection.X509Credentials() - session = gnutls.connection.ClientSession(sock, cred) - session.connect((domain, 443)) - session.handshake() - cert = session.peer_certificate - return cert + +def get_https_cert_from_domain(domain, port=443): + """ + @param domain: a domain name to get a certificate from. + """ + cert = ssl.get_server_certificate((domain, port)) + x509 = crypto.load_certificate(crypto.FILETYPE_PEM, cert) + return x509 def get_cert_from_file(_file): - getcert = lambda f: gnutls.crypto.X509Certificate(f.read()) + null_check(_file, "pem file") if isinstance(_file, str): + if not os.path.isfile(_file): + raise NoCertError with open(_file) as f: - cert = getcert(f) + cert = f.read() else: - cert = getcert(_file) - return cert + cert = _file.read() + x509 = crypto.load_certificate(crypto.FILETYPE_PEM, cert) + return x509 def get_pkey_from_file(_file): - getkey = lambda f: gnutls.crypto.X509PrivateKey(f.read()) + getkey = lambda f: crypto.load_privatekey( + crypto.FILETYPE_PEM, f.read()) + if isinstance(_file, str): with open(_file) as f: key = getkey(f) @@ -48,6 +57,10 @@ def get_pkey_from_file(_file): def can_load_cert_and_pkey(string): + """ + loads certificate and private key from + a buffer + """ try: f = StringIO(string) cert = get_cert_from_file(f) @@ -57,14 +70,14 @@ def can_load_cert_and_pkey(string): null_check(cert, 'certificate') null_check(key, 'private key') - except: - # XXX catch GNUTLSError? + except Exception as exc: + logger.error(type(exc), exc.message) raise BadCertError else: return True -def get_cert_fingerprint(domain=None, filepath=None, +def get_cert_fingerprint(domain=None, port=443, filepath=None, hash_type="SHA256", sep=":"): """ @param domain: a domain name to get a fingerprint from @@ -79,31 +92,8 @@ def get_cert_fingerprint(domain=None, filepath=None, @rtype: string """ if domain: - cert = get_https_cert_from_domain(domain) + cert = get_https_cert_from_domain(domain, port=port) if filepath: cert = get_cert_from_file(filepath) - - _buffer = ctypes.create_string_buffer(64) - buffer_length = ctypes.c_size_t(64) - - SUPPORTED_DIGEST_FUN = ("SHA1", "SHA224", "SHA256", "SHA384", "SHA512") - if hash_type in SUPPORTED_DIGEST_FUN: - digestfunction = getattr( - gnutls.library.constants, - "GNUTLS_DIG_%s" % hash_type) - else: - # XXX improperlyconfigured or something - raise Exception("digest function not supported") - - gnutls.library.functions.gnutls_x509_crt_get_fingerprint( - cert._c_object, digestfunction, - ctypes.byref(_buffer), ctypes.byref(buffer_length)) - - # deinit - #server_cert._X509Certificate__deinit(server_cert._c_object) - # needed? is segfaulting - - fpr = ctypes.string_at(_buffer, buffer_length.value) - hex_fpr = sep.join(u"%02X" % ord(char) for char in fpr) - + hex_fpr = cert.digest(hash_type) return hex_fpr |