import logging import os from StringIO import StringIO import ssl import time from dateutil.parser import parse from OpenSSL import crypto from leap.util.misc import null_check logger = logging.getLogger(__name__) class BadCertError(Exception): """ raised for malformed certs """ class NoCertError(Exception): """ raised for cert not found in given path """ def get_https_cert_from_domain(domain, port=443): """ @param domain: a domain name to get a certificate from. """ cert = ssl.get_server_certificate((domain, port)) x509 = crypto.load_certificate(crypto.FILETYPE_PEM, cert) return x509 def get_cert_from_file(_file): null_check(_file, "pem file") if isinstance(_file, (str, unicode)): if not os.path.isfile(_file): raise NoCertError with open(_file) as f: cert = f.read() else: cert = _file.read() x509 = crypto.load_certificate(crypto.FILETYPE_PEM, cert) return x509 def get_pkey_from_file(_file): getkey = lambda f: crypto.load_privatekey( crypto.FILETYPE_PEM, f.read()) if isinstance(_file, str): with open(_file) as f: key = getkey(f) else: key = getkey(_file) return key def can_load_cert_and_pkey(string): """ loads certificate and private key from a buffer """ try: f = StringIO(string) cert = get_cert_from_file(f) f = StringIO(string) key = get_pkey_from_file(f) null_check(cert, 'certificate') null_check(key, 'private key') except Exception as exc: logger.error(type(exc), exc.message) raise BadCertError else: return True def get_cert_fingerprint(domain=None, port=443, filepath=None, hash_type="SHA256", sep=":"): """ @param domain: a domain name to get a fingerprint from @type domain: str @param filepath: path to a file containing a PEM file @type filepath: str @param hash_type: the hash function to be used in the fingerprint. must be one of SHA1, SHA224, SHA256, SHA384, SHA512 @type hash_type: str @rparam: hex_fpr, a hexadecimal representation of a bytestring containing the fingerprint. @rtype: string """ if domain: cert = get_https_cert_from_domain(domain, port=port) if filepath: cert = get_cert_from_file(filepath) hex_fpr = cert.digest(hash_type) return hex_fpr def get_time_boundaries(certfile): cert = get_cert_from_file(certfile) null_check(cert, 'certificate') fromts, tots = (cert.get_notBefore(), cert.get_notAfter()) from_, to_ = map( lambda ts: time.gmtime(time.mktime(parse(ts).timetuple())), (fromts, tots)) return from_, to_