summaryrefslogtreecommitdiff
path: root/src/leap/config/providerconfig.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/leap/config/providerconfig.py')
-rw-r--r--src/leap/config/providerconfig.py144
1 files changed, 144 insertions, 0 deletions
diff --git a/src/leap/config/providerconfig.py b/src/leap/config/providerconfig.py
new file mode 100644
index 00000000..c3c2c298
--- /dev/null
+++ b/src/leap/config/providerconfig.py
@@ -0,0 +1,144 @@
+# -*- coding: utf-8 -*-
+# providerconfig.py
+# Copyright (C) 2013 LEAP
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see <http://www.gnu.org/licenses/>.
+
+"""
+Provider configuration
+"""
+import logging
+import os
+
+from leap.config.baseconfig import BaseConfig, LocalizedKey
+from leap.config.provider_spec import leap_provider_spec
+
+logger = logging.getLogger(__name__)
+
+
+class ProviderConfig(BaseConfig):
+ """
+ Provider configuration abstraction class
+ """
+
+ def __init__(self):
+ BaseConfig.__init__(self)
+
+ def _get_spec(self):
+ """
+ Returns the spec object for the specific configuration
+ """
+ return leap_provider_spec
+
+ def get_api_uri(self):
+ return self._safe_get_value("api_uri")
+
+ def get_api_version(self):
+ return self._safe_get_value("api_version")
+
+ def get_ca_cert_fingerprint(self):
+ return self._safe_get_value("ca_cert_fingerprint")
+
+ def get_ca_cert_uri(self):
+ return self._safe_get_value("ca_cert_uri")
+
+ def get_default_language(self):
+ return self._safe_get_value("default_language")
+
+ @LocalizedKey
+ def get_description(self):
+ return self._safe_get_value("description")
+
+ def get_domain(self):
+ return self._safe_get_value("domain")
+
+ def get_enrollment_policy(self):
+ return self._safe_get_value("enrollment_policy")
+
+ def get_languages(self):
+ return self._safe_get_value("languages")
+
+ @LocalizedKey
+ def get_name(self):
+ return self._safe_get_value("name")
+
+ def get_services(self):
+ return self._safe_get_value("services")
+
+ def get_ca_cert_path(self, about_to_download=False):
+ """
+ Returns the path to the certificate for the current provider
+
+ @param about_to_download: defines wether we want the path to
+ download the cert or not. This helps avoid checking if the
+ cert exists because we are about to write it.
+ @type about_to_download: bool
+ """
+
+ cert_path = os.path.join(self.get_path_prefix(),
+ "leap",
+ "providers",
+ self.get_domain(),
+ "keys",
+ "ca",
+ "cacert.pem")
+
+ if not about_to_download:
+ assert os.path.exists(cert_path), \
+ "You need to download the certificate first"
+ logger.debug("Going to verify SSL against %s" % (cert_path,))
+
+ return cert_path
+
+ def provides_eip(self):
+ """
+ Returns True if this particular provider has the EIP
+ service. False otherwise
+ """
+ return "openvpn" in self.get_services()
+
+
+if __name__ == "__main__":
+ logger = logging.getLogger(name='leap')
+ logger.setLevel(logging.DEBUG)
+ console = logging.StreamHandler()
+ console.setLevel(logging.DEBUG)
+ formatter = logging.Formatter(
+ '%(asctime)s '
+ '- %(name)s - %(levelname)s - %(message)s')
+ console.setFormatter(formatter)
+ logger.addHandler(console)
+
+ provider = ProviderConfig()
+
+ try:
+ provider.get_api_version()
+ except Exception as e:
+ assert isinstance(e, AssertionError), "Expected an assert"
+ print "Safe value getting is working"
+
+ # standalone minitest
+ #if provider.load("provider_bad.json"):
+ if provider.load("leap/providers/bitmask.net/provider.json"):
+ print provider.get_api_version()
+ print provider.get_ca_cert_fingerprint()
+ print provider.get_ca_cert_uri()
+ print provider.get_default_language()
+ print provider.get_description()
+ print provider.get_description(lang="asd")
+ print provider.get_domain()
+ print provider.get_enrollment_policy()
+ print provider.get_languages()
+ print provider.get_name()
+ print provider.get_services()