# -*- coding: utf-8 -*-
# backend.py
# Copyright (C) 2013, 2014 LEAP
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
import json
import threading
import time

from twisted.internet import defer, reactor, threads

import zmq
from zmq.auth.thread import ThreadAuthenticator

from leap.bitmask.backend.api import API
from leap.bitmask.backend.utils import get_backend_certificates
from leap.bitmask.backend.signaler import Signaler

import logging
logger = logging.getLogger(__name__)


class Backend(object):
    """
    Backend server.
    Receives signals from backend_proxy and emit signals if needed.
    """
    PORT = '5556'
    BIND_ADDR = "tcp://127.0.0.1:%s" % PORT

    def __init__(self):
        """
        Backend constructor, create needed instances.
        """
        self._signaler = Signaler()

        self._do_work = threading.Event()  # used to stop the worker thread.
        self._zmq_socket = None

        self._ongoing_defers = []
        self._init_zmq()

    def _init_zmq(self):
        """
        Configure the zmq components and connection.
        """
        context = zmq.Context()
        socket = context.socket(zmq.REP)

        # Start an authenticator for this context.
        auth = ThreadAuthenticator(context)
        auth.start()
        auth.allow('127.0.0.1')

        # Tell authenticator to use the certificate in a directory
        auth.configure_curve(domain='*', location=zmq.auth.CURVE_ALLOW_ANY)
        public, secret = get_backend_certificates()
        socket.curve_publickey = public
        socket.curve_secretkey = secret
        socket.curve_server = True  # must come before bind

        socket.bind(self.BIND_ADDR)

        self._zmq_socket = socket

    def _worker(self):
        """
        Receive requests and send it to process.

        Note: we use a simple while since is less resource consuming than a
        Twisted's LoopingCall.
        """
        while self._do_work.is_set():
            # Wait for next request from client
            try:
                request = self._zmq_socket.recv(zmq.NOBLOCK)
                self._zmq_socket.send("OK")
                # logger.debug("Received request: '{0}'".format(request))
                self._process_request(request)
            except zmq.ZMQError as e:
                if e.errno != zmq.EAGAIN:
                    raise
            time.sleep(0.01)

    def _stop_reactor(self):
        """
        Stop the Twisted reactor, but first wait a little for some threads to
        complete their work.

        Note: this method needs to be run in a different thread so the
        time.sleep() does not block and other threads can finish.
        i.e.:
            use threads.deferToThread(this_method) instead of this_method()
        """
        wait_max = 5  # seconds
        wait_step = 0.5
        wait = 0
        while self._ongoing_defers and wait < wait_max:
            time.sleep(wait_step)
            wait += wait_step
            msg = "Waiting for running threads to finish... {0}/{1}"
            msg = msg.format(wait, wait_max)
            logger.debug(msg)

        # after a timeout we shut down the existing threads.
        for d in self._ongoing_defers:
            d.cancel()

        reactor.stop()
        logger.debug("Twisted reactor stopped.")

    def run(self):
        """
        Start the ZMQ server and run the loop to handle requests.
        """
        self._signaler.start()
        self._do_work.set()
        threads.deferToThread(self._worker)
        reactor.run()

    def stop(self):
        """
        Stop the server and the zmq request parse loop.
        """
        logger.debug("STOP received.")
        self._signaler.stop()
        self._do_work.clear()
        threads.deferToThread(self._stop_reactor)

    def _process_request(self, request_json):
        """
        Process a request and call the according method with the given
        parameters.

        :param request_json: a json specification of a request.
        :type request_json: str
        """
        try:
            # request = zmq.utils.jsonapi.loads(request_json)
            # We use stdlib's json to ensure that we get unicode strings
            request = json.loads(request_json)
            api_method = request['api_method']
            kwargs = request['arguments'] or None
        except Exception as e:
            msg = "Malformed JSON data in Backend request '{0}'. Exc: {1!r}"
            msg = msg.format(request_json, e)
            msg = msg.format(request_json)
            logger.critical(msg)
            raise

        if api_method not in API:
            logger.error("Invalid API call '{0}'".format(api_method))
            return

        self._run_in_thread(api_method, kwargs)

    def _run_in_thread(self, api_method, kwargs):
        """
        Run the method name in a thread with the given arguments.

        :param api_method: the callable name to run in a thread.
        :type api_method: str
        :param kwargs: the arguments dict that will be sent to the callable.
        :type kwargs: tuple
        """
        func = getattr(self, api_method)

        method = func
        if kwargs is not None:
            method = lambda: func(**kwargs)

        # logger.debug("Running method: '{0}' "
        #            "with args: '{1}' in a thread".format(api_method, kwargs))

        # run the action in a thread and keep track of it
        d = threads.deferToThread(method)
        d.addCallback(self._done_action, d)
        d.addErrback(self._done_action, d)
        self._ongoing_defers.append(d)

    def _done_action(self, failure, d):
        """
        Remove the defer from the ongoing list.

        :param failure: the failure that triggered the errback.
                        None if no error.
        :type failure: twisted.python.failure.Failure
        :param d: defer to remove
        :type d: twisted.internet.defer.Deferred
        """
        if failure is not None:
            if failure.check(defer.CancelledError):
                logger.debug("A defer was cancelled.")
            else:
                logger.error("There was a failure - {0!r}".format(failure))
                logger.error(failure.getTraceback())

        if d in self._ongoing_defers:
            self._ongoing_defers.remove(d)