summaryrefslogtreecommitdiff
path: root/service/pixelated/bitmask_libraries/certs.py
blob: 3ca554697da9942cf5f815016cf5ce86ff2227ce (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
#
# Copyright (c) 2014 ThoughtWorks, Inc.
#
# Pixelated is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Pixelated is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with Pixelated. If not, see <http://www.gnu.org/licenses/>.
import os
import requests
import json
from leap.common import ca_bundle

from .config import AUTO_DETECT_CA_BUNDLE

LEAP_CERT = None
LEAP_FINGERPRINT = None
PACKAGED_CERTS_HOME = os.path.abspath(os.path.join(os.path.abspath(__file__), "..", "..", "certificates"))


def init_leap_cert(leap_provider_cert, leap_provider_cert_fingerprint):
    if leap_provider_cert_fingerprint is None:
        LEAP_CERT = leap_provider_cert or True
        LEAP_FINGERPRINT = None
    else:
        LEAP_FINGERPRINT = leap_provider_cert_fingerprint
        LEAP_CERT = False


def which_api_CA_bundle(provider):
    return str(LeapCertificate(provider).api_ca_bundle())


def which_bootstrap_cert_fingerprint():
    return LEAP_FINGERPRINT


def which_bootstrap_CA_bundle(provider):
    if LEAP_CERT is not None:
        return LEAP_CERT
    return str(LeapCertificate(provider).auto_detect_bootstrap_ca_bundle())


def refresh_ca_bundle(provider):
    LeapCertificate(provider).refresh_ca_bundle()


class LeapCertificate(object):
    def __init__(self, provider):
        self._config = provider.config
        self._server_name = provider.server_name
        self._provider = provider

    def auto_detect_bootstrap_ca_bundle(self):
        if self._config.bootstrap_ca_cert_bundle == AUTO_DETECT_CA_BUNDLE:
            local_cert = self._local_bootstrap_server_cert()
            if local_cert:
                return local_cert
            else:
                return ca_bundle.where()
        else:
            return self._config.bootstrap_ca_cert_bundle

    def api_ca_bundle(self):
        if self._provider.config.ca_cert_bundle:
            return self._provider.config.ca_cert_bundle

        cert_file = self._api_cert_file()

        if not os.path.isfile(cert_file):
            self._download_server_cert(cert_file)

        return cert_file

    def refresh_ca_bundle(self):
        cert_file = self._api_cert_file()
        self._download_server_cert(cert_file)

    def _api_cert_file(self):
        certs_root = self._api_certs_root_path()
        return os.path.join(certs_root, 'api.pem')

    def _api_certs_root_path(self):
        path = os.path.join(self._provider.config.leap_home, 'providers', self._server_name, 'keys', 'client')
        if not os.path.isdir(path):
            os.makedirs(path, 0700)
        return path

    def _local_bootstrap_server_cert(self):
        cert_file = self._bootstrap_certs_cert_file()
        if os.path.isfile(cert_file):
            return cert_file

        cert_file = os.path.join(PACKAGED_CERTS_HOME, '%s.ca.crt' % self._server_name)
        if os.path.exists(cert_file):
            return cert_file

        # else download the file
        cert_file = self._bootstrap_certs_cert_file()
        response = requests.get('https://%s/provider.json' % self._server_name)
        provider_data = json.loads(response.content)
        ca_cert_uri = str(provider_data['ca_cert_uri'])

        response = requests.get(ca_cert_uri)
        with open(cert_file, 'w') as file:
            file.write(response.content)

        return cert_file

    def _bootstrap_certs_cert_file(self):
        path = os.path.join(self._provider.config.leap_home, 'providers', self._server_name)
        if not os.path.isdir(path):
            os.makedirs(path, 0700)

        file_path = os.path.join(path, '%s.ca.crt' % self._server_name)

        return file_path

    def _download_server_cert(self, cert_file_name):
        cert = self._provider.fetch_valid_certificate()

        with open(cert_file_name, 'w') as file:
            file.write(cert)