# -*- coding: utf-8 -*- # backend.py # Copyright (C) 2013, 2014 LEAP # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program. If not, see <http://www.gnu.org/licenses/>. # FIXME this is missing module documentation. It would be fine to say a couple # of lines about the whole backend architecture. import json import os import time import psutil from twisted.internet import defer, reactor, threads, task import txzmq import zmq try: from zmq.auth.thread import ThreadAuthenticator except ImportError: pass from leap.bitmask.backend.api import API, PING_REQUEST from leap.bitmask.backend.signaler import Signaler from leap.bitmask.backend.utils import get_backend_certificates from leap.bitmask.config import flags from leap.bitmask.logs.utils import get_logger logger = get_logger() class TxZmqREPConnection(object): """ A twisted based zmq rep connection. """ def __init__(self, server_address, process_request): """ Initialize the connection. :param server_address: The address of the backend zmq server. :type server: str :param process_request: A callable used to process incoming requests. :type process_request: callable(messageParts) """ self._server_address = server_address self._process_request = process_request self._zmq_factory = None self._zmq_connection = None self._init_txzmq() def _init_txzmq(self): """ Configure the txzmq components and connection. """ self._zmq_factory = txzmq.ZmqFactory() self._zmq_factory.registerForShutdown() self._zmq_connection = txzmq.ZmqREPConnection(self._zmq_factory) context = self._zmq_factory.context socket = self._zmq_connection.socket def _gotMessage(messageId, messageParts): self._zmq_connection.reply(messageId, "OK") self._process_request(messageParts) self._zmq_connection.gotMessage = _gotMessage if flags.ZMQ_HAS_CURVE: # Start an authenticator for this context. auth = ThreadAuthenticator(context) auth.start() # XXX do not hardcode this here. auth.allow('127.0.0.1') # Tell authenticator to use the certificate in a directory auth.configure_curve(domain='*', location=zmq.auth.CURVE_ALLOW_ANY) public, secret = get_backend_certificates() socket.curve_publickey = public socket.curve_secretkey = secret socket.curve_server = True # must come before bind proto, addr = self._server_address.split('://') # tcp/ipc, ip/socket socket.bind(self._server_address) if proto == 'ipc': os.chmod(addr, 0600) class Backend(object): """ Backend server. Receives signals from backend_proxy and emit signals if needed. """ # XXX we might want to make this configurable per-platform, # and use the most performant socket type on each one. if flags.ZMQ_HAS_CURVE: # XXX this should not be hardcoded. Make it configurable. PORT = '5556' BIND_ADDR = "tcp://127.0.0.1:%s" % PORT else: SOCKET_FILE = "/tmp/bitmask.socket.0" BIND_ADDR = "ipc://%s" % SOCKET_FILE PING_INTERVAL = 2 # secs def __init__(self, frontend_pid=None): """ Backend constructor, create needed instances. """ self._signaler = Signaler() self._frontend_pid = frontend_pid self._frontend_checker = None self._ongoing_defers = [] self._zmq_connection = TxZmqREPConnection( self.BIND_ADDR, self._process_request) def _check_frontend_alive(self): """ Check if the frontend is alive and stop the backend if it is not. """ pid = self._frontend_pid if pid is not None and not psutil.pid_exists(pid): logger.critical("The frontend is down!") self.stop() def _stop_reactor(self): """ Stop the Twisted reactor, but first wait a little for some threads to complete their work. Note: this method needs to be run in a different thread so the time.sleep() does not block and other threads can finish. i.e.: use threads.deferToThread(this_method) instead of this_method() """ wait_max = 3 # seconds wait_step = 0.5 wait = 0 while self._ongoing_defers and wait < wait_max: time.sleep(wait_step) wait += wait_step msg = "Waiting for running threads to finish... {0}/{1}" msg = msg.format(wait, wait_max) logger.debug(msg) # after a timeout we shut down the existing threads. for d in self._ongoing_defers: d.cancel() logger.debug("Stopping the Twisted reactor...") reactor.stop() def run(self): """ Start the ZMQ server and run the loop to handle requests. """ self._signaler.start() self._frontend_checker = task.LoopingCall(self._check_frontend_alive) self._frontend_checker.start(self.PING_INTERVAL) logger.debug("Starting Twisted reactor.") reactor.run() logger.debug("Finished Twisted reactor.") def stop(self): """ Stop the server and the zmq request parse loop. """ logger.debug("Stopping the backend...") self._signaler.stop() self._frontend_checker.stop() threads.deferToThread(self._stop_reactor) def _process_request(self, request_json): """ Process a request and call the according method with the given parameters. :param request_json: a json specification of a request. :type request_json: str """ if request_json == PING_REQUEST: # do not process request if it's just a ping return try: # request = zmq.utils.jsonapi.loads(request_json) # We use stdlib's json to ensure that we get unicode strings request = json.loads(request_json) api_method = request['api_method'] kwargs = request['arguments'] or None except Exception as e: msg = "Malformed JSON data in Backend request '{0}'. Exc: {1!r}" msg = msg.format(request_json, e) msg = msg.format(request_json) logger.critical(msg) raise if api_method not in API: logger.error("Invalid API call '{0}'".format(api_method)) return self._run_in_thread(api_method, kwargs) def _run_in_thread(self, api_method, kwargs): """ Run the method name in a thread with the given arguments. :param api_method: the callable name to run in a thread. :type api_method: str :param kwargs: the arguments dict that will be sent to the callable. :type kwargs: tuple """ func = getattr(self, api_method) method = func if kwargs is not None: method = lambda: func(**kwargs) # logger.debug("Running method: '{0}' " # "with args: '{1}' in a thread".format(api_method, kwargs)) # run the action in a thread and keep track of it d = threads.deferToThread(method) d.addCallback(self._done_action, d) d.addErrback(self._done_action, d) self._ongoing_defers.append(d) def _done_action(self, failure, d): """ Remove the defer from the ongoing list. :param failure: the failure that triggered the errback. None if no error. :type failure: twisted.python.failure.Failure :param d: defer to remove :type d: twisted.internet.defer.Deferred """ if failure is not None: if failure.check(defer.CancelledError): logger.debug("A defer was cancelled.") else: logger.error("There was a failure - {0!r}".format(failure)) logger.error(failure.getTraceback()) if d in self._ongoing_defers: self._ongoing_defers.remove(d)