# -*- coding: utf-8 -*-
# providerbootstrapper.py
# Copyright (C) 2013 LEAP
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
"""
Provider bootstrapping
"""
import logging
import socket
import os
import sys

import requests

from leap.bitmask.config import flags
from leap.bitmask.config.providerconfig import ProviderConfig, MissingCACert
from leap.bitmask.util.request_helpers import get_content
from leap.bitmask import util
from leap.bitmask.util.constants import REQUEST_TIMEOUT
from leap.bitmask.services.abstractbootstrapper import AbstractBootstrapper
from leap.bitmask import provider
from leap.common import ca_bundle
from leap.common.certs import get_digest
from leap.common.files import check_and_fix_urw_only, get_mtime, mkdir_p
from leap.common.check import leap_assert, leap_assert_type, leap_check

logger = logging.getLogger(__name__)


class UnsupportedProviderAPI(Exception):
    """
    Raised when attempting to use a provider with an incompatible API.
    """
    pass


class UnsupportedClientVersionError(Exception):
    """
    Raised when attempting to use a provider with an older
    client than supported.
    """
    pass


class WrongFingerprint(Exception):
    """
    Raised when a fingerprint comparison does not match.
    """
    pass


class ProviderBootstrapper(AbstractBootstrapper):
    """
    Given a provider URL performs a series of checks and emits signals
    after they are passed.
    If a check fails, the subsequent checks are not executed
    """

    MIN_CLIENT_VERSION = 'x-minimum-client-version'

    def __init__(self, signaler=None, bypass_checks=False):
        """
        Constructor for provider bootstrapper object

        :param signaler: Signaler object used to receive notifications
                         from the backend
        :type signaler: Signaler
        :param bypass_checks: Set to true if the app should bypass
                              first round of checks for CA
                              certificates at bootstrap
        :type bypass_checks: bool
        """
        AbstractBootstrapper.__init__(self, signaler, bypass_checks)

        self._domain = None
        self._provider_config = None
        self._download_if_needed = False

    @property
    def verify(self):
        """
        Verify parameter for requests.

        :returns: either False, if checks are skipped, or the
                  path to the ca bundle.
        :rtype: bool or str
        """
        if self._bypass_checks:
            verify = False
        else:
            verify = ca_bundle.where()
        return verify

    def _check_name_resolution(self):
        """
        Checks that the name resolution for the provider name works
        """
        leap_assert(self._domain, "Cannot check DNS without a domain")
        logger.debug("Checking name resolution for %r" % (self._domain))

        # We don't skip this check, since it's basic for the whole
        # system to work
        # err --- but we can do it after a failure, to diagnose what went
        # wrong. Right now we're just adding connection overhead. -- kali
        socket.gethostbyname(self._domain.encode('idna'))

    def _check_https(self, *args):
        """
        Checks that https is working and that the provided certificate
        checks out
        """
        leap_assert(self._domain, "Cannot check HTTPS without a domain")
        logger.debug("Checking https for %r" % (self._domain))

        # We don't skip this check, since it's basic for the whole
        # system to work.
        # err --- but we can do it after a failure, to diagnose what went
        # wrong. Right now we're just adding connection overhead. -- kali

        verify = self.verify
        if verify:
            verify = self.verify.encode(sys.getfilesystemencoding())

        try:
            uri = "https://{0}".format(self._domain.encode('idna'))
            res = self._session.get(uri, verify=verify,
                                    timeout=REQUEST_TIMEOUT)
            res.raise_for_status()
        except requests.exceptions.SSLError as exc:
            logger.exception(exc)
            self._err_msg = self.tr("Provider certificate could "
                                    "not be verified")
            raise
        except Exception as exc:
            # XXX careful!. The error might be also a SSL handshake
            # timeout error, in which case we should retry a couple of times
            # more, for cases where the ssl server gives high latencies.
            logger.exception(exc)
            self._err_msg = self.tr("Provider does not support HTTPS")
            raise

    def _download_provider_info(self, *args):
        """
        Downloads the provider.json defition
        """
        leap_assert(self._domain,
                    "Cannot download provider info without a domain")
        logger.debug("Downloading provider info for %r" % (self._domain))

        # --------------------------------------------------------------
        # TODO factor out with the download routines in services.
        # Watch out! We're handling the verify paramenter differently here.

        headers = {}
        domain = self._domain.encode(sys.getfilesystemencoding())
        provider_json = os.path.join(util.get_path_prefix(),
                                     "leap", "providers", domain,
                                     "provider.json")
        mtime = get_mtime(provider_json)

        if self._download_if_needed and mtime:
            headers['if-modified-since'] = mtime

        uri = "https://%s/%s" % (self._domain, "provider.json")
        verify = self.verify

        if mtime:  # the provider.json exists
        # So, we're getting it from the api.* and checking against
        # the provider ca.
            try:
                provider_config = ProviderConfig()
                provider_config.load(provider_json)
                uri = provider_config.get_api_uri() + '/provider.json'
                verify = provider_config.get_ca_cert_path()
            except MissingCACert:
                # no ca? then download from main domain again.
                pass

        if verify:
            verify = verify.encode(sys.getfilesystemencoding())
        logger.debug("Requesting for provider.json... "
                     "uri: {0}, verify: {1}, headers: {2}".format(
                         uri, verify, headers))
        res = self._session.get(uri.encode('idna'), verify=verify,
                                headers=headers, timeout=REQUEST_TIMEOUT)
        res.raise_for_status()
        logger.debug("Request status code: {0}".format(res.status_code))

        min_client_version = res.headers.get(self.MIN_CLIENT_VERSION, '0')

        # Not modified
        if res.status_code == 304:
            logger.debug("Provider definition has not been modified")
        # --------------------------------------------------------------
        # end refactor, more or less...
        # XXX Watch out, have to check the supported api yet.
        else:
            if flags.APP_VERSION_CHECK:
                # TODO split
                if not provider.supports_client(min_client_version):
                    self._signaler.signal(
                        self._signaler.PROV_UNSUPPORTED_CLIENT)
                    raise UnsupportedClientVersionError()

            provider_definition, mtime = get_content(res)

            provider_config = ProviderConfig()
            provider_config.load(data=provider_definition, mtime=mtime)
            provider_config.save(["leap", "providers",
                                  domain, "provider.json"])

            if flags.API_VERSION_CHECK:
                # TODO split
                api_version = provider_config.get_api_version()
                if provider.supports_api(api_version):
                    logger.debug("Provider definition has been modified")
                else:
                    api_supported = ', '.join(provider.SUPPORTED_APIS)
                    error = ('Unsupported provider API version. '
                             'Supported versions are: {0}. '
                             'Found: {1}.').format(api_supported, api_version)

                    logger.error(error)
                    self._signaler.signal(self._signaler.PROV_UNSUPPORTED_API)
                    raise UnsupportedProviderAPI(error)

    def run_provider_select_checks(self, domain, download_if_needed=False):
        """
        Populates the check queue.

        :param domain: domain to check
        :type domain: unicode

        :param download_if_needed: if True, makes the checks do not
                                   overwrite already downloaded data
        :type download_if_needed: bool
        """
        leap_assert(domain and len(domain) > 0, "We need a domain!")

        self._domain = ProviderConfig.sanitize_path_component(domain)
        self._download_if_needed = download_if_needed

        cb_chain = [
            (self._check_name_resolution,
             self._signaler.PROV_NAME_RESOLUTION_KEY),
            (self._check_https, self._signaler.PROV_HTTPS_CONNECTION_KEY),
            (self._download_provider_info,
             self._signaler.PROV_DOWNLOAD_PROVIDER_INFO_KEY)
        ]

        return self.addCallbackChain(cb_chain)

    def _should_proceed_cert(self):
        """
        Returns False if the certificate already exists for the given
        provider. True otherwise

        :rtype: bool
        """
        leap_assert(self._provider_config, "We need a provider config!")

        if not self._download_if_needed:
            return True

        return not os.path.exists(self._provider_config
                                  .get_ca_cert_path(about_to_download=True))

    def _download_ca_cert(self, *args):
        """
        Downloads the CA cert that is going to be used for the api URL
        """
        # XXX maybe we can skip this step if
        # we have a fresh one.
        leap_assert(self._provider_config, "Cannot download the ca cert "
                    "without a provider config!")

        logger.debug("Downloading ca cert for %r at %r" %
                     (self._domain, self._provider_config.get_ca_cert_uri()))

        if not self._should_proceed_cert():
            check_and_fix_urw_only(
                self._provider_config
                .get_ca_cert_path(about_to_download=True))
            return

        res = self._session.get(self._provider_config.get_ca_cert_uri(),
                                verify=self.verify,
                                timeout=REQUEST_TIMEOUT)
        res.raise_for_status()

        cert_path = self._provider_config.get_ca_cert_path(
            about_to_download=True)
        cert_dir = os.path.dirname(cert_path)
        mkdir_p(cert_dir)
        with open(cert_path, "w") as f:
            f.write(res.content)

        check_and_fix_urw_only(cert_path)

    def _check_ca_fingerprint(self, *args):
        """
        Checks the CA cert fingerprint against the one provided in the
        json definition
        """
        leap_assert(self._provider_config, "Cannot check the ca cert "
                    "without a provider config!")

        logger.debug("Checking ca fingerprint for %r and cert %r" %
                     (self._domain,
                      self._provider_config.get_ca_cert_path()))

        if not self._should_proceed_cert():
            return

        parts = self._provider_config.get_ca_cert_fingerprint().split(":")

        error_msg = "Wrong fingerprint format"
        leap_check(len(parts) == 2, error_msg, WrongFingerprint)

        method = parts[0].strip()
        fingerprint = parts[1].strip()
        cert_data = None
        with open(self._provider_config.get_ca_cert_path()) as f:
            cert_data = f.read()

        leap_assert(len(cert_data) > 0, "Could not read certificate data")
        digest = get_digest(cert_data, method)

        error_msg = "Downloaded certificate has a different fingerprint!"
        leap_check(digest == fingerprint, error_msg, WrongFingerprint)

    def _check_api_certificate(self, *args):
        """
        Tries to make an API call with the downloaded cert and checks
        if it validates against it
        """
        leap_assert(self._provider_config, "Cannot check the ca cert "
                    "without a provider config!")

        logger.debug("Checking api certificate for %s and cert %s" %
                     (self._provider_config.get_api_uri(),
                      self._provider_config.get_ca_cert_path()))

        if not self._should_proceed_cert():
            return

        test_uri = "%s/%s/cert" % (self._provider_config.get_api_uri(),
                                   self._provider_config.get_api_version())
        ca_cert_path = self._provider_config.get_ca_cert_path()
        ca_cert_path = ca_cert_path.encode(sys.getfilesystemencoding())
        res = self._session.get(test_uri, verify=ca_cert_path,
                                timeout=REQUEST_TIMEOUT)
        res.raise_for_status()

    def run_provider_setup_checks(self,
                                  provider_config,
                                  download_if_needed=False):
        """
        Starts the checks needed for a new provider setup.

        :param provider_config: Provider configuration
        :type provider_config: ProviderConfig

        :param download_if_needed: if True, makes the checks do not
                                   overwrite already downloaded data.
        :type download_if_needed: bool
        """
        leap_assert(provider_config, "We need a provider config!")
        leap_assert_type(provider_config, ProviderConfig)

        self._provider_config = provider_config
        self._download_if_needed = download_if_needed

        cb_chain = [
            (self._download_ca_cert, self._signaler.PROV_DOWNLOAD_CA_CERT_KEY),
            (self._check_ca_fingerprint,
             self._signaler.PROV_CHECK_CA_FINGERPRINT_KEY),
            (self._check_api_certificate,
             self._signaler.PROV_CHECK_API_CERTIFICATE_KEY)
        ]

        return self.addCallbackChain(cb_chain)