summaryrefslogtreecommitdiff
path: root/src/leap/bitmask/backend/backend.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/leap/bitmask/backend/backend.py')
-rw-r--r--src/leap/bitmask/backend/backend.py135
1 files changed, 68 insertions, 67 deletions
diff --git a/src/leap/bitmask/backend/backend.py b/src/leap/bitmask/backend/backend.py
index cff731ba..4a98d146 100644
--- a/src/leap/bitmask/backend/backend.py
+++ b/src/leap/bitmask/backend/backend.py
@@ -17,17 +17,15 @@
# FIXME this is missing module documentation. It would be fine to say a couple
# of lines about the whole backend architecture.
-# TODO use txzmq bindings instead.
-
import json
import os
-import threading
import time
import psutil
-from twisted.internet import defer, reactor, threads
+from twisted.internet import defer, reactor, threads, task
+import txzmq
import zmq
try:
from zmq.auth.thread import ThreadAuthenticator
@@ -35,51 +33,50 @@ except ImportError:
pass
from leap.bitmask.backend.api import API, PING_REQUEST
+from leap.bitmask.backend.signaler import Signaler
from leap.bitmask.backend.utils import get_backend_certificates
from leap.bitmask.config import flags
-from leap.bitmask.backend.signaler import Signaler
+from leap.bitmask.logs.utils import get_logger
-import logging
-logger = logging.getLogger(__name__)
+logger = get_logger()
-class Backend(object):
+class TxZmqREPConnection(object):
"""
- Backend server.
- Receives signals from backend_proxy and emit signals if needed.
+ A twisted based zmq rep connection.
"""
- # XXX we might want to make this configurable per-platform,
- # and use the most performant socket type on each one.
- if flags.ZMQ_HAS_CURVE:
- # XXX this should not be hardcoded. Make it configurable.
- PORT = '5556'
- BIND_ADDR = "tcp://127.0.0.1:%s" % PORT
- else:
- SOCKET_FILE = "/tmp/bitmask.socket.0"
- BIND_ADDR = "ipc://%s" % SOCKET_FILE
- PING_INTERVAL = 2 # secs
+ def __init__(self, server_address, process_request):
+ """
+ Initialize the connection.
- def __init__(self, frontend_pid=None):
+ :param server_address: The address of the backend zmq server.
+ :type server: str
+ :param process_request: A callable used to process incoming requests.
+ :type process_request: callable(messageParts)
"""
- Backend constructor, create needed instances.
+ self._server_address = server_address
+ self._process_request = process_request
+ self._zmq_factory = None
+ self._zmq_connection = None
+ self._init_txzmq()
+
+ def _init_txzmq(self):
"""
- self._signaler = Signaler()
+ Configure the txzmq components and connection.
+ """
+ self._zmq_factory = txzmq.ZmqFactory()
+ self._zmq_factory.registerForShutdown()
+ self._zmq_connection = txzmq.ZmqREPConnection(self._zmq_factory)
- self._frontend_pid = frontend_pid
+ context = self._zmq_factory.context
+ socket = self._zmq_connection.socket
- self._do_work = threading.Event() # used to stop the worker thread.
- self._zmq_socket = None
+ def _gotMessage(messageId, messageParts):
+ self._zmq_connection.reply(messageId, "OK")
+ self._process_request(messageParts)
- self._ongoing_defers = []
- self._init_zmq()
-
- def _init_zmq(self):
- """
- Configure the zmq components and connection.
- """
- context = zmq.Context()
- socket = context.socket(zmq.REP)
+ self._zmq_connection.gotMessage = _gotMessage
if flags.ZMQ_HAS_CURVE:
# Start an authenticator for this context.
@@ -95,37 +92,39 @@ class Backend(object):
socket.curve_secretkey = secret
socket.curve_server = True # must come before bind
- socket.bind(self.BIND_ADDR)
- if not flags.ZMQ_HAS_CURVE:
- os.chmod(self.SOCKET_FILE, 0600)
+ proto, addr = self._server_address.split('://') # tcp/ipc, ip/socket
+ socket.bind(self._server_address)
+ if proto == 'ipc':
+ os.chmod(addr, 0600)
- self._zmq_socket = socket
- def _worker(self):
- """
- Receive requests and send it to process.
+class Backend(object):
+ """
+ Backend server.
+ Receives signals from backend_proxy and emit signals if needed.
+ """
+ # XXX we might want to make this configurable per-platform,
+ # and use the most performant socket type on each one.
+ if flags.ZMQ_HAS_CURVE:
+ # XXX this should not be hardcoded. Make it configurable.
+ PORT = '5556'
+ BIND_ADDR = "tcp://127.0.0.1:%s" % PORT
+ else:
+ SOCKET_FILE = "/tmp/bitmask.socket.0"
+ BIND_ADDR = "ipc://%s" % SOCKET_FILE
- Note: we use a simple while since is less resource consuming than a
- Twisted's LoopingCall.
+ PING_INTERVAL = 2 # secs
+
+ def __init__(self, frontend_pid=None):
"""
- pid = self._frontend_pid
- check_wait = 0
- while self._do_work.is_set():
- # Wait for next request from client
- try:
- request = self._zmq_socket.recv(zmq.NOBLOCK)
- self._zmq_socket.send("OK")
- # logger.debug("Received request: '{0}'".format(request))
- self._process_request(request)
- except zmq.ZMQError as e:
- if e.errno != zmq.EAGAIN:
- raise
- time.sleep(0.01)
-
- check_wait += 0.01
- if pid is not None and check_wait > self.PING_INTERVAL:
- check_wait = 0
- self._check_frontend_alive()
+ Backend constructor, create needed instances.
+ """
+ self._signaler = Signaler()
+ self._frontend_pid = frontend_pid
+ self._frontend_checker = None
+ self._ongoing_defers = []
+ self._zmq_connection = TxZmqREPConnection(
+ self.BIND_ADDR, self._process_request)
def _check_frontend_alive(self):
"""
@@ -160,25 +159,27 @@ class Backend(object):
for d in self._ongoing_defers:
d.cancel()
+ logger.debug("Stopping the Twisted reactor...")
reactor.stop()
- logger.debug("Twisted reactor stopped.")
def run(self):
"""
Start the ZMQ server and run the loop to handle requests.
"""
self._signaler.start()
- self._do_work.set()
- threads.deferToThread(self._worker)
+ self._frontend_checker = task.LoopingCall(self._check_frontend_alive)
+ self._frontend_checker.start(self.PING_INTERVAL)
+ logger.debug("Starting Twisted reactor.")
reactor.run()
+ logger.debug("Finished Twisted reactor.")
def stop(self):
"""
Stop the server and the zmq request parse loop.
"""
- logger.debug("STOP received.")
+ logger.debug("Stopping the backend...")
self._signaler.stop()
- self._do_work.clear()
+ self._frontend_checker.stop()
threads.deferToThread(self._stop_reactor)
def _process_request(self, request_json):