From 634030e5bba3fe7c2ea3632fff252a60b471487a Mon Sep 17 00:00:00 2001
From: kali <kali@leap.se>
Date: Fri, 19 Oct 2012 09:05:14 +0900
Subject: ca cert fingerprint check + api cert verification

---
 src/leap/crypto/certs.py       | 20 +++++++++++++-----
 src/leap/eip/checks.py         | 20 ++++++++++++++----
 src/leap/gui/firstrunwizard.py | 48 ++++++++++++++++++++++++++++++------------
 3 files changed, 66 insertions(+), 22 deletions(-)

(limited to 'src')

diff --git a/src/leap/crypto/certs.py b/src/leap/crypto/certs.py
index ac9bd357..8908865d 100644
--- a/src/leap/crypto/certs.py
+++ b/src/leap/crypto/certs.py
@@ -2,6 +2,7 @@ import ctypes
 import socket
 
 import gnutls.connection
+import gnutls.crypto
 import gnutls.library
 
 
@@ -19,10 +20,19 @@ def get_https_cert_from_domain(domain):
     return cert
 
 
-def get_https_cert_fingerprint(domain, hash_type="SHA256", sep=":"):
+def get_cert_from_file(filepath):
+    with open(filepath) as f:
+        cert = gnutls.crypto.X509Certificate(f.read())
+    return cert
+
+
+def get_cert_fingerprint(domain=None, filepath=None,
+                         hash_type="SHA256", sep=":"):
     """
     @param domain: a domain name to get a fingerprint from
     @type domain: str
+    @param filepath: path to a file containing a PEM file
+    @type filepath: str
     @param hash_type: the hash function to be used in the fingerprint.
         must be one of SHA1, SHA224, SHA256, SHA384, SHA512
     @type hash_type: str
@@ -30,7 +40,10 @@ def get_https_cert_fingerprint(domain, hash_type="SHA256", sep=":"):
              containing the fingerprint.
     @rtype: string
     """
-    cert = get_https_cert_from_domain(domain)
+    if domain:
+        cert = get_https_cert_from_domain(domain)
+    if filepath:
+        cert = get_cert_from_file(filepath)
 
     _buffer = ctypes.create_string_buffer(64)
     buffer_length = ctypes.c_size_t(64)
@@ -56,6 +69,3 @@ def get_https_cert_fingerprint(domain, hash_type="SHA256", sep=":"):
     hex_fpr = sep.join(u"%02X" % ord(char) for char in fpr)
 
     return hex_fpr
-
-#if __name__ == "__main__":
-    #print get_https_cert_fingerprint('springbok')
diff --git a/src/leap/eip/checks.py b/src/leap/eip/checks.py
index e925e11c..1c29dab1 100644
--- a/src/leap/eip/checks.py
+++ b/src/leap/eip/checks.py
@@ -10,10 +10,11 @@ import gnutls.crypto
 import requests
 
 from leap import __branding as BRANDING
-from leap import certs
+from leap import certs as leapcerts
 from leap.base import config as baseconfig
 from leap.base import constants as baseconstants
 from leap.base import providers
+from leap.crypto import certs
 from leap.eip import config as eipconfig
 from leap.eip import constants as eipconstants
 from leap.eip import exceptions as eipexceptions
@@ -46,7 +47,7 @@ reachable and testable as a whole.
 def get_ca_cert():
     ca_file = BRANDING.get('provider_ca_file')
     if ca_file:
-        return certs.where(ca_file)
+        return leapcerts.where(ca_file)
 
 
 class ProviderCertChecker(object):
@@ -97,7 +98,18 @@ class ProviderCertChecker(object):
     def check_ca_cert_fingerprint(
             self, hash_type="SHA256",
             fingerprint=None):
-        pass
+        ca_cert_path = self.ca_cert_path
+        ca_cert_fpr = certs.get_cert_fingerprint(
+            filepath=ca_cert_path)
+        return ca_cert_fpr == fingerprint
+
+    def verify_api_https(self, uri):
+        assert uri.startswith('https://')
+        cacert = self.ca_cert_path
+        verify = cacert and cacert or True
+        req = self.fetcher.get(uri, verify=verify)
+        req.raise_for_status()
+        return True
 
     def download_ca_signature(self):
         # MVS+
@@ -268,7 +280,7 @@ class ProviderCertChecker(object):
 
     @property
     def ca_cert_path(self):
-        return self._get_ca_cert_path()
+        return self._get_ca_cert_path(self.domain)
 
     def _get_root_uri(self):
         return u"https://%s/" % baseconstants.DEFAULT_PROVIDER
diff --git a/src/leap/gui/firstrunwizard.py b/src/leap/gui/firstrunwizard.py
index e4293cf6..55338090 100755
--- a/src/leap/gui/firstrunwizard.py
+++ b/src/leap/gui/firstrunwizard.py
@@ -3,6 +3,8 @@ import logging
 import json
 import socket
 
+import requests
+
 import sip
 sip.setapi('QString', 2)
 sip.setapi('QVariant', 2)
@@ -411,8 +413,8 @@ class SelectProviderPage(QtGui.QWizardPage):
                 pass
             else:
                 self.set_validation_status(exc.usermessage)
-                fingerprint = certs.get_https_cert_fingerprint(
-                    domain, sep=" ")
+                fingerprint = certs.get_cert_fingerprint(
+                    domain=domain, sep=" ")
                 self.add_cert_info(fingerprint)
                 self.did_cert_check = True
                 self.completeChanged.emit()
@@ -545,24 +547,44 @@ class ProviderSetupPage(QtGui.QWizardPage):
             verify=False)
 
         self.set_status('Checking CA fingerprint')
-        self.progress.setValue(40)
-        ca_cert_fingerprint = pconfig.get('ca_cert_fingerprint')
+        self.progress.setValue(66)
+        ca_cert_fingerprint = pconfig.get('ca_cert_fingerprint', None)
 
         # XXX get fingerprint dict (types)
-        certchecker.check_ca_cert_fingerprint(
-            fingerprint=ca_cert_fingerprint)
-        time.sleep(2)
-
-        self.set_status('Fetching api https certificate')
-        self.progress.setValue(60)
-        time.sleep(2)
+        sha256_fpr = ca_cert_fingerprint.split('=')[1]
+
+        validate_fpr = certchecker.check_ca_cert_fingerprint(
+            fingerprint=sha256_fpr)
+        time.sleep(0.5)
+        if not validate_fpr:
+            # XXX update validationMsg
+            # should catch exception
+            return False
 
         self.set_status('Validating api certificate')
-        self.progress.setValue(80)
-        time.sleep(2)
+        self.progress.setValue(90)
+
+        api_uri = pconfig.get('api_uri', None)
+        try:
+            api_cert_verified = certchecker.verify_api_https(api_uri)
+        except requests.exceptions.SSLError as exc:
+            logger.error('BUG #638. %s' % exc.message)
+            # XXX RAISE! See #638
+            # bypassing until the hostname is fixed.
+            # We probably should raise yet-another-warning
+            # here saying user that the hostname "XX.XX.XX.XX' does not
+            # match 'foo.bar.baz'
+            api_cert_verified = True
+
+        if not api_cert_verified:
+            # XXX update validationMsg
+            # should catch exception
+            return False
+        time.sleep(0.5)
         #ca_cert_path = checker.ca_cert_path
 
         self.progress.setValue(100)
+        time.sleep(0.2)
 
     # pagewizard methods
 
-- 
cgit v1.2.3