diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/leap/crypto/certs.py | 86 | ||||
| -rw-r--r-- | src/leap/crypto/certs_gnutls.py | 112 | 
2 files changed, 150 insertions, 48 deletions
| diff --git a/src/leap/crypto/certs.py b/src/leap/crypto/certs.py index 78f49fb0..c2835878 100644 --- a/src/leap/crypto/certs.py +++ b/src/leap/crypto/certs.py @@ -1,44 +1,53 @@ -import ctypes +import logging +import os  from StringIO import StringIO -import socket +import ssl -import gnutls.connection -import gnutls.crypto -import gnutls.library +from OpenSSL import crypto  from leap.util.misc import null_check +logger = logging.getLogger(__name__) +  class BadCertError(Exception): -    """raised for malformed certs""" +    """ +    raised for malformed certs +    """ -def get_https_cert_from_domain(domain): +class NoCertError(Exception):      """ -    @param domain: a domain name to get a certificate from. +    raised for cert not found in given path      """ -    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) -    cred = gnutls.connection.X509Credentials() -    session = gnutls.connection.ClientSession(sock, cred) -    session.connect((domain, 443)) -    session.handshake() -    cert = session.peer_certificate -    return cert + +def get_https_cert_from_domain(domain, port=443): +    """ +    @param domain: a domain name to get a certificate from. +    """ +    cert = ssl.get_server_certificate((domain, port)) +    x509 = crypto.load_certificate(crypto.FILETYPE_PEM, cert) +    return x509  def get_cert_from_file(_file): -    getcert = lambda f: gnutls.crypto.X509Certificate(f.read()) +    null_check(_file, "pem file")      if isinstance(_file, str): +        if not os.path.isfile(_file): +            raise NoCertError          with open(_file) as f: -            cert = getcert(f) +            cert = f.read()      else: -        cert = getcert(_file) -    return cert +        cert = _file.read() +    x509 = crypto.load_certificate(crypto.FILETYPE_PEM, cert) +    return x509  def get_pkey_from_file(_file): -    getkey = lambda f: gnutls.crypto.X509PrivateKey(f.read()) +    getkey = lambda f: crypto.load_privatekey( +        crypto.FILETYPE_PEM, f.read()) +      if isinstance(_file, str):          with open(_file) as f:              key = getkey(f) @@ -48,6 +57,10 @@ def get_pkey_from_file(_file):  def can_load_cert_and_pkey(string): +    """ +    loads certificate and private key from +    a buffer +    """      try:          f = StringIO(string)          cert = get_cert_from_file(f) @@ -57,14 +70,14 @@ def can_load_cert_and_pkey(string):          null_check(cert, 'certificate')          null_check(key, 'private key') -    except: -        # XXX catch GNUTLSError? +    except Exception as exc: +        logger.error(type(exc), exc.message)          raise BadCertError      else:          return True -def get_cert_fingerprint(domain=None, filepath=None, +def get_cert_fingerprint(domain=None, port=443, filepath=None,                           hash_type="SHA256", sep=":"):      """      @param domain: a domain name to get a fingerprint from @@ -79,31 +92,8 @@ def get_cert_fingerprint(domain=None, filepath=None,      @rtype: string      """      if domain: -        cert = get_https_cert_from_domain(domain) +        cert = get_https_cert_from_domain(domain, port=port)      if filepath:          cert = get_cert_from_file(filepath) - -    _buffer = ctypes.create_string_buffer(64) -    buffer_length = ctypes.c_size_t(64) - -    SUPPORTED_DIGEST_FUN = ("SHA1", "SHA224", "SHA256", "SHA384", "SHA512") -    if hash_type in SUPPORTED_DIGEST_FUN: -        digestfunction = getattr( -            gnutls.library.constants, -            "GNUTLS_DIG_%s" % hash_type) -    else: -        # XXX improperlyconfigured or something -        raise Exception("digest function not supported") - -    gnutls.library.functions.gnutls_x509_crt_get_fingerprint( -        cert._c_object, digestfunction, -        ctypes.byref(_buffer), ctypes.byref(buffer_length)) - -    # deinit -    #server_cert._X509Certificate__deinit(server_cert._c_object) -    # needed? is segfaulting - -    fpr = ctypes.string_at(_buffer, buffer_length.value) -    hex_fpr = sep.join(u"%02X" % ord(char) for char in fpr) - +    hex_fpr = cert.digest(hash_type)      return hex_fpr diff --git a/src/leap/crypto/certs_gnutls.py b/src/leap/crypto/certs_gnutls.py new file mode 100644 index 00000000..20c0e043 --- /dev/null +++ b/src/leap/crypto/certs_gnutls.py @@ -0,0 +1,112 @@ +''' +We're using PyOpenSSL now + +import ctypes +from StringIO import StringIO +import socket + +import gnutls.connection +import gnutls.crypto +import gnutls.library + +from leap.util.misc import null_check + + +class BadCertError(Exception): +    """raised for malformed certs""" + + +def get_https_cert_from_domain(domain): +    """ +    @param domain: a domain name to get a certificate from. +    """ +    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) +    cred = gnutls.connection.X509Credentials() + +    session = gnutls.connection.ClientSession(sock, cred) +    session.connect((domain, 443)) +    session.handshake() +    cert = session.peer_certificate +    return cert + + +def get_cert_from_file(_file): +    getcert = lambda f: gnutls.crypto.X509Certificate(f.read()) +    if isinstance(_file, str): +        with open(_file) as f: +            cert = getcert(f) +    else: +        cert = getcert(_file) +    return cert + + +def get_pkey_from_file(_file): +    getkey = lambda f: gnutls.crypto.X509PrivateKey(f.read()) +    if isinstance(_file, str): +        with open(_file) as f: +            key = getkey(f) +    else: +        key = getkey(_file) +    return key + + +def can_load_cert_and_pkey(string): +    try: +        f = StringIO(string) +        cert = get_cert_from_file(f) + +        f = StringIO(string) +        key = get_pkey_from_file(f) + +        null_check(cert, 'certificate') +        null_check(key, 'private key') +    except: +        # XXX catch GNUTLSError? +        raise BadCertError +    else: +        return True + +def get_cert_fingerprint(domain=None, filepath=None, +                         hash_type="SHA256", sep=":"): +    """ +    @param domain: a domain name to get a fingerprint from +    @type domain: str +    @param filepath: path to a file containing a PEM file +    @type filepath: str +    @param hash_type: the hash function to be used in the fingerprint. +        must be one of SHA1, SHA224, SHA256, SHA384, SHA512 +    @type hash_type: str +    @rparam: hex_fpr, a hexadecimal representation of a bytestring +             containing the fingerprint. +    @rtype: string +    """ +    if domain: +        cert = get_https_cert_from_domain(domain) +    if filepath: +        cert = get_cert_from_file(filepath) + +    _buffer = ctypes.create_string_buffer(64) +    buffer_length = ctypes.c_size_t(64) + +    SUPPORTED_DIGEST_FUN = ("SHA1", "SHA224", "SHA256", "SHA384", "SHA512") +    if hash_type in SUPPORTED_DIGEST_FUN: +        digestfunction = getattr( +            gnutls.library.constants, +            "GNUTLS_DIG_%s" % hash_type) +    else: +        # XXX improperlyconfigured or something +        raise Exception("digest function not supported") + +    gnutls.library.functions.gnutls_x509_crt_get_fingerprint( +        cert._c_object, digestfunction, +        ctypes.byref(_buffer), ctypes.byref(buffer_length)) + +    # deinit +    #server_cert._X509Certificate__deinit(server_cert._c_object) +    # needed? is segfaulting + +    fpr = ctypes.string_at(_buffer, buffer_length.value) +    hex_fpr = sep.join(u"%02X" % ord(char) for char in fpr) + +    return hex_fpr +''' | 
