summaryrefslogtreecommitdiff
path: root/zmq/auth
diff options
context:
space:
mode:
Diffstat (limited to 'zmq/auth')
-rw-r--r--zmq/auth/__init__.py10
-rw-r--r--zmq/auth/base.py272
-rw-r--r--zmq/auth/certs.py119
-rw-r--r--zmq/auth/ioloop.py34
-rw-r--r--zmq/auth/thread.py184
5 files changed, 619 insertions, 0 deletions
diff --git a/zmq/auth/__init__.py b/zmq/auth/__init__.py
new file mode 100644
index 0000000..11d3ad6
--- /dev/null
+++ b/zmq/auth/__init__.py
@@ -0,0 +1,10 @@
+"""Utilities for ZAP authentication.
+
+To run authentication in a background thread, see :mod:`zmq.auth.thread`.
+For integration with the tornado eventloop, see :mod:`zmq.auth.ioloop`.
+
+.. versionadded:: 14.1
+"""
+
+from .base import *
+from .certs import *
diff --git a/zmq/auth/base.py b/zmq/auth/base.py
new file mode 100644
index 0000000..9b4aaed
--- /dev/null
+++ b/zmq/auth/base.py
@@ -0,0 +1,272 @@
+"""Base implementation of 0MQ authentication."""
+
+# Copyright (C) PyZMQ Developers
+# Distributed under the terms of the Modified BSD License.
+
+import logging
+
+import zmq
+from zmq.utils import z85
+from zmq.utils.strtypes import bytes, unicode, b, u
+from zmq.error import _check_version
+
+from .certs import load_certificates
+
+
+CURVE_ALLOW_ANY = '*'
+VERSION = b'1.0'
+
+class Authenticator(object):
+ """Implementation of ZAP authentication for zmq connections.
+
+ Note:
+ - libzmq provides four levels of security: default NULL (which the Authenticator does
+ not see), and authenticated NULL, PLAIN, and CURVE, which the Authenticator can see.
+ - until you add policies, all incoming NULL connections are allowed
+ (classic ZeroMQ behavior), and all PLAIN and CURVE connections are denied.
+ """
+
+ def __init__(self, context=None, encoding='utf-8', log=None):
+ _check_version((4,0), "security")
+ self.context = context or zmq.Context.instance()
+ self.encoding = encoding
+ self.allow_any = False
+ self.zap_socket = None
+ self.whitelist = set()
+ self.blacklist = set()
+ # passwords is a dict keyed by domain and contains values
+ # of dicts with username:password pairs.
+ self.passwords = {}
+ # certs is dict keyed by domain and contains values
+ # of dicts keyed by the public keys from the specified location.
+ self.certs = {}
+ self.log = log or logging.getLogger('zmq.auth')
+
+ def start(self):
+ """Create and bind the ZAP socket"""
+ self.zap_socket = self.context.socket(zmq.REP)
+ self.zap_socket.linger = 1
+ self.zap_socket.bind("inproc://zeromq.zap.01")
+
+ def stop(self):
+ """Close the ZAP socket"""
+ if self.zap_socket:
+ self.zap_socket.close()
+ self.zap_socket = None
+
+ def allow(self, *addresses):
+ """Allow (whitelist) IP address(es).
+
+ Connections from addresses not in the whitelist will be rejected.
+
+ - For NULL, all clients from this address will be accepted.
+ - For PLAIN and CURVE, they will be allowed to continue with authentication.
+
+ whitelist is mutually exclusive with blacklist.
+ """
+ if self.blacklist:
+ raise ValueError("Only use a whitelist or a blacklist, not both")
+ self.whitelist.update(addresses)
+
+ def deny(self, *addresses):
+ """Deny (blacklist) IP address(es).
+
+ Addresses not in the blacklist will be allowed to continue with authentication.
+
+ Blacklist is mutually exclusive with whitelist.
+ """
+ if self.whitelist:
+ raise ValueError("Only use a whitelist or a blacklist, not both")
+ self.blacklist.update(addresses)
+
+ def configure_plain(self, domain='*', passwords=None):
+ """Configure PLAIN authentication for a given domain.
+
+ PLAIN authentication uses a plain-text password file.
+ To cover all domains, use "*".
+ You can modify the password file at any time; it is reloaded automatically.
+ """
+ if passwords:
+ self.passwords[domain] = passwords
+
+ def configure_curve(self, domain='*', location=None):
+ """Configure CURVE authentication for a given domain.
+
+ CURVE authentication uses a directory that holds all public client certificates,
+ i.e. their public keys.
+
+ To cover all domains, use "*".
+
+ You can add and remove certificates in that directory at any time.
+
+ To allow all client keys without checking, specify CURVE_ALLOW_ANY for the location.
+ """
+ # If location is CURVE_ALLOW_ANY then allow all clients. Otherwise
+ # treat location as a directory that holds the certificates.
+ if location == CURVE_ALLOW_ANY:
+ self.allow_any = True
+ else:
+ self.allow_any = False
+ try:
+ self.certs[domain] = load_certificates(location)
+ except Exception as e:
+ self.log.error("Failed to load CURVE certs from %s: %s", location, e)
+
+ def handle_zap_message(self, msg):
+ """Perform ZAP authentication"""
+ if len(msg) < 6:
+ self.log.error("Invalid ZAP message, not enough frames: %r", msg)
+ if len(msg) < 2:
+ self.log.error("Not enough information to reply")
+ else:
+ self._send_zap_reply(msg[1], b"400", b"Not enough frames")
+ return
+
+ version, request_id, domain, address, identity, mechanism = msg[:6]
+ credentials = msg[6:]
+
+ domain = u(domain, self.encoding, 'replace')
+ address = u(address, self.encoding, 'replace')
+
+ if (version != VERSION):
+ self.log.error("Invalid ZAP version: %r", msg)
+ self._send_zap_reply(request_id, b"400", b"Invalid version")
+ return
+
+ self.log.debug("version: %r, request_id: %r, domain: %r,"
+ " address: %r, identity: %r, mechanism: %r",
+ version, request_id, domain,
+ address, identity, mechanism,
+ )
+
+
+ # Is address is explicitly whitelisted or blacklisted?
+ allowed = False
+ denied = False
+ reason = b"NO ACCESS"
+
+ if self.whitelist:
+ if address in self.whitelist:
+ allowed = True
+ self.log.debug("PASSED (whitelist) address=%s", address)
+ else:
+ denied = True
+ reason = b"Address not in whitelist"
+ self.log.debug("DENIED (not in whitelist) address=%s", address)
+
+ elif self.blacklist:
+ if address in self.blacklist:
+ denied = True
+ reason = b"Address is blacklisted"
+ self.log.debug("DENIED (blacklist) address=%s", address)
+ else:
+ allowed = True
+ self.log.debug("PASSED (not in blacklist) address=%s", address)
+
+ # Perform authentication mechanism-specific checks if necessary
+ username = u("user")
+ if not denied:
+
+ if mechanism == b'NULL' and not allowed:
+ # For NULL, we allow if the address wasn't blacklisted
+ self.log.debug("ALLOWED (NULL)")
+ allowed = True
+
+ elif mechanism == b'PLAIN':
+ # For PLAIN, even a whitelisted address must authenticate
+ if len(credentials) != 2:
+ self.log.error("Invalid PLAIN credentials: %r", credentials)
+ self._send_zap_reply(request_id, b"400", b"Invalid credentials")
+ return
+ username, password = [ u(c, self.encoding, 'replace') for c in credentials ]
+ allowed, reason = self._authenticate_plain(domain, username, password)
+
+ elif mechanism == b'CURVE':
+ # For CURVE, even a whitelisted address must authenticate
+ if len(credentials) != 1:
+ self.log.error("Invalid CURVE credentials: %r", credentials)
+ self._send_zap_reply(request_id, b"400", b"Invalid credentials")
+ return
+ key = credentials[0]
+ allowed, reason = self._authenticate_curve(domain, key)
+
+ if allowed:
+ self._send_zap_reply(request_id, b"200", b"OK", username)
+ else:
+ self._send_zap_reply(request_id, b"400", reason)
+
+ def _authenticate_plain(self, domain, username, password):
+ """PLAIN ZAP authentication"""
+ allowed = False
+ reason = b""
+ if self.passwords:
+ # If no domain is not specified then use the default domain
+ if not domain:
+ domain = '*'
+
+ if domain in self.passwords:
+ if username in self.passwords[domain]:
+ if password == self.passwords[domain][username]:
+ allowed = True
+ else:
+ reason = b"Invalid password"
+ else:
+ reason = b"Invalid username"
+ else:
+ reason = b"Invalid domain"
+
+ if allowed:
+ self.log.debug("ALLOWED (PLAIN) domain=%s username=%s password=%s",
+ domain, username, password,
+ )
+ else:
+ self.log.debug("DENIED %s", reason)
+
+ else:
+ reason = b"No passwords defined"
+ self.log.debug("DENIED (PLAIN) %s", reason)
+
+ return allowed, reason
+
+ def _authenticate_curve(self, domain, client_key):
+ """CURVE ZAP authentication"""
+ allowed = False
+ reason = b""
+ if self.allow_any:
+ allowed = True
+ reason = b"OK"
+ self.log.debug("ALLOWED (CURVE allow any client)")
+ else:
+ # If no explicit domain is specified then use the default domain
+ if not domain:
+ domain = '*'
+
+ if domain in self.certs:
+ # The certs dict stores keys in z85 format, convert binary key to z85 bytes
+ z85_client_key = z85.encode(client_key)
+ if z85_client_key in self.certs[domain] or self.certs[domain] == b'OK':
+ allowed = True
+ reason = b"OK"
+ else:
+ reason = b"Unknown key"
+
+ status = "ALLOWED" if allowed else "DENIED"
+ self.log.debug("%s (CURVE) domain=%s client_key=%s",
+ status, domain, z85_client_key,
+ )
+ else:
+ reason = b"Unknown domain"
+
+ return allowed, reason
+
+ def _send_zap_reply(self, request_id, status_code, status_text, user_id='user'):
+ """Send a ZAP reply to finish the authentication."""
+ user_id = user_id if status_code == b'200' else b''
+ if isinstance(user_id, unicode):
+ user_id = user_id.encode(self.encoding, 'replace')
+ metadata = b'' # not currently used
+ self.log.debug("ZAP reply code=%s text=%s", status_code, status_text)
+ reply = [VERSION, request_id, status_code, status_text, user_id, metadata]
+ self.zap_socket.send_multipart(reply)
+
+__all__ = ['Authenticator', 'CURVE_ALLOW_ANY']
diff --git a/zmq/auth/certs.py b/zmq/auth/certs.py
new file mode 100644
index 0000000..4d26ad7
--- /dev/null
+++ b/zmq/auth/certs.py
@@ -0,0 +1,119 @@
+"""0MQ authentication related functions and classes."""
+
+# Copyright (C) PyZMQ Developers
+# Distributed under the terms of the Modified BSD License.
+
+
+import datetime
+import glob
+import io
+import os
+import zmq
+from zmq.utils.strtypes import bytes, unicode, b, u
+
+
+_cert_secret_banner = u("""# **** Generated on {0} by pyzmq ****
+# ZeroMQ CURVE **Secret** Certificate
+# DO NOT PROVIDE THIS FILE TO OTHER USERS nor change its permissions.
+
+""")
+
+_cert_public_banner = u("""# **** Generated on {0} by pyzmq ****
+# ZeroMQ CURVE Public Certificate
+# Exchange securely, or use a secure mechanism to verify the contents
+# of this file after exchange. Store public certificates in your home
+# directory, in the .curve subdirectory.
+
+""")
+
+def _write_key_file(key_filename, banner, public_key, secret_key=None, metadata=None, encoding='utf-8'):
+ """Create a certificate file"""
+ if isinstance(public_key, bytes):
+ public_key = public_key.decode(encoding)
+ if isinstance(secret_key, bytes):
+ secret_key = secret_key.decode(encoding)
+ with io.open(key_filename, 'w', encoding='utf8') as f:
+ f.write(banner.format(datetime.datetime.now()))
+
+ f.write(u('metadata\n'))
+ if metadata:
+ for k, v in metadata.items():
+ if isinstance(v, bytes):
+ v = v.decode(encoding)
+ f.write(u(" {0} = {1}\n").format(k, v))
+
+ f.write(u('curve\n'))
+ f.write(u(" public-key = \"{0}\"\n").format(public_key))
+
+ if secret_key:
+ f.write(u(" secret-key = \"{0}\"\n").format(secret_key))
+
+
+def create_certificates(key_dir, name, metadata=None):
+ """Create zmq certificates.
+
+ Returns the file paths to the public and secret certificate files.
+ """
+ public_key, secret_key = zmq.curve_keypair()
+ base_filename = os.path.join(key_dir, name)
+ secret_key_file = "{0}.key_secret".format(base_filename)
+ public_key_file = "{0}.key".format(base_filename)
+ now = datetime.datetime.now()
+
+ _write_key_file(public_key_file,
+ _cert_public_banner.format(now),
+ public_key)
+
+ _write_key_file(secret_key_file,
+ _cert_secret_banner.format(now),
+ public_key,
+ secret_key=secret_key,
+ metadata=metadata)
+
+ return public_key_file, secret_key_file
+
+
+def load_certificate(filename):
+ """Load public and secret key from a zmq certificate.
+
+ Returns (public_key, secret_key)
+
+ If the certificate file only contains the public key,
+ secret_key will be None.
+ """
+ public_key = None
+ secret_key = None
+ if not os.path.exists(filename):
+ raise IOError("Invalid certificate file: {0}".format(filename))
+
+ with open(filename, 'rb') as f:
+ for line in f:
+ line = line.strip()
+ if line.startswith(b'#'):
+ continue
+ if line.startswith(b'public-key'):
+ public_key = line.split(b"=", 1)[1].strip(b' \t\'"')
+ if line.startswith(b'secret-key'):
+ secret_key = line.split(b"=", 1)[1].strip(b' \t\'"')
+ if public_key and secret_key:
+ break
+
+ return public_key, secret_key
+
+
+def load_certificates(directory='.'):
+ """Load public keys from all certificates in a directory"""
+ certs = {}
+ if not os.path.isdir(directory):
+ raise IOError("Invalid certificate directory: {0}".format(directory))
+ # Follow czmq pattern of public keys stored in *.key files.
+ glob_string = os.path.join(directory, "*.key")
+
+ cert_files = glob.glob(glob_string)
+ for cert_file in cert_files:
+ public_key, _ = load_certificate(cert_file)
+ if public_key:
+ certs[public_key] = 'OK'
+ return certs
+
+__all__ = ['create_certificates', 'load_certificate', 'load_certificates']
diff --git a/zmq/auth/ioloop.py b/zmq/auth/ioloop.py
new file mode 100644
index 0000000..1f448b4
--- /dev/null
+++ b/zmq/auth/ioloop.py
@@ -0,0 +1,34 @@
+"""ZAP Authenticator integrated with the tornado IOLoop.
+
+.. versionadded:: 14.1
+"""
+
+# Copyright (C) PyZMQ Developers
+# Distributed under the terms of the Modified BSD License.
+
+from zmq.eventloop import ioloop, zmqstream
+from .base import Authenticator
+
+
+class IOLoopAuthenticator(Authenticator):
+ """ZAP authentication for use in the tornado IOLoop"""
+
+ def __init__(self, context=None, encoding='utf-8', log=None, io_loop=None):
+ super(IOLoopAuthenticator, self).__init__(context)
+ self.zap_stream = None
+ self.io_loop = io_loop or ioloop.IOLoop.instance()
+
+ def start(self):
+ """Start ZAP authentication"""
+ super(IOLoopAuthenticator, self).start()
+ self.zap_stream = zmqstream.ZMQStream(self.zap_socket, self.io_loop)
+ self.zap_stream.on_recv(self.handle_zap_message)
+
+ def stop(self):
+ """Stop ZAP authentication"""
+ if self.zap_stream:
+ self.zap_stream.close()
+ self.zap_stream = None
+ super(IOLoopAuthenticator, self).stop()
+
+__all__ = ['IOLoopAuthenticator']
diff --git a/zmq/auth/thread.py b/zmq/auth/thread.py
new file mode 100644
index 0000000..8c3355a
--- /dev/null
+++ b/zmq/auth/thread.py
@@ -0,0 +1,184 @@
+"""ZAP Authenticator in a Python Thread.
+
+.. versionadded:: 14.1
+"""
+
+# Copyright (C) PyZMQ Developers
+# Distributed under the terms of the Modified BSD License.
+
+import logging
+from threading import Thread
+
+import zmq
+from zmq.utils import jsonapi
+from zmq.utils.strtypes import bytes, unicode, b, u
+
+from .base import Authenticator
+
+class AuthenticationThread(Thread):
+ """A Thread for running a zmq Authenticator
+
+ This is run in the background by ThreadedAuthenticator
+ """
+
+ def __init__(self, context, endpoint, encoding='utf-8', log=None):
+ super(AuthenticationThread, self).__init__()
+ self.context = context or zmq.Context.instance()
+ self.encoding = encoding
+ self.log = log = log or logging.getLogger('zmq.auth')
+ self.authenticator = Authenticator(context, encoding=encoding, log=log)
+
+ # create a socket to communicate back to main thread.
+ self.pipe = context.socket(zmq.PAIR)
+ self.pipe.linger = 1
+ self.pipe.connect(endpoint)
+
+ def run(self):
+ """ Start the Authentication Agent thread task """
+ self.authenticator.start()
+ zap = self.authenticator.zap_socket
+ poller = zmq.Poller()
+ poller.register(self.pipe, zmq.POLLIN)
+ poller.register(zap, zmq.POLLIN)
+ while True:
+ try:
+ socks = dict(poller.poll())
+ except zmq.ZMQError:
+ break # interrupted
+
+ if self.pipe in socks and socks[self.pipe] == zmq.POLLIN:
+ terminate = self._handle_pipe()
+ if terminate:
+ break
+
+ if zap in socks and socks[zap] == zmq.POLLIN:
+ self._handle_zap()
+
+ self.pipe.close()
+ self.authenticator.stop()
+
+ def _handle_zap(self):
+ """
+ Handle a message from the ZAP socket.
+ """
+ msg = self.authenticator.zap_socket.recv_multipart()
+ if not msg: return
+ self.authenticator.handle_zap_message(msg)
+
+ def _handle_pipe(self):
+ """
+ Handle a message from front-end API.
+ """
+ terminate = False
+
+ # Get the whole message off the pipe in one go
+ msg = self.pipe.recv_multipart()
+
+ if msg is None:
+ terminate = True
+ return terminate
+
+ command = msg[0]
+ self.log.debug("auth received API command %r", command)
+
+ if command == b'ALLOW':
+ addresses = [u(m, self.encoding) for m in msg[1:]]
+ try:
+ self.authenticator.allow(*addresses)
+ except Exception as e:
+ self.log.exception("Failed to allow %s", addresses)
+
+ elif command == b'DENY':
+ addresses = [u(m, self.encoding) for m in msg[1:]]
+ try:
+ self.authenticator.deny(*addresses)
+ except Exception as e:
+ self.log.exception("Failed to deny %s", addresses)
+
+ elif command == b'PLAIN':
+ domain = u(msg[1], self.encoding)
+ json_passwords = msg[2]
+ self.authenticator.configure_plain(domain, jsonapi.loads(json_passwords))
+
+ elif command == b'CURVE':
+ # For now we don't do anything with domains
+ domain = u(msg[1], self.encoding)
+
+ # If location is CURVE_ALLOW_ANY, allow all clients. Otherwise
+ # treat location as a directory that holds the certificates.
+ location = u(msg[2], self.encoding)
+ self.authenticator.configure_curve(domain, location)
+
+ elif command == b'TERMINATE':
+ terminate = True
+
+ else:
+ self.log.error("Invalid auth command from API: %r", command)
+
+ return terminate
+
+def _inherit_docstrings(cls):
+ """inherit docstrings from Authenticator, so we don't duplicate them"""
+ for name, method in cls.__dict__.items():
+ if name.startswith('_'):
+ continue
+ upstream_method = getattr(Authenticator, name, None)
+ if not method.__doc__:
+ method.__doc__ = upstream_method.__doc__
+ return cls
+
+@_inherit_docstrings
+class ThreadAuthenticator(object):
+ """Run ZAP authentication in a background thread"""
+
+ def __init__(self, context=None, encoding='utf-8', log=None):
+ self.context = context or zmq.Context.instance()
+ self.log = log
+ self.encoding = encoding
+ self.pipe = None
+ self.pipe_endpoint = "inproc://{0}.inproc".format(id(self))
+ self.thread = None
+
+ def allow(self, *addresses):
+ self.pipe.send_multipart([b'ALLOW'] + [b(a, self.encoding) for a in addresses])
+
+ def deny(self, *addresses):
+ self.pipe.send_multipart([b'DENY'] + [b(a, self.encoding) for a in addresses])
+
+ def configure_plain(self, domain='*', passwords=None):
+ self.pipe.send_multipart([b'PLAIN', b(domain, self.encoding), jsonapi.dumps(passwords or {})])
+
+ def configure_curve(self, domain='*', location=''):
+ domain = b(domain, self.encoding)
+ location = b(location, self.encoding)
+ self.pipe.send_multipart([b'CURVE', domain, location])
+
+ def start(self):
+ """Start the authentication thread"""
+ # create a socket to communicate with auth thread.
+ self.pipe = self.context.socket(zmq.PAIR)
+ self.pipe.linger = 1
+ self.pipe.bind(self.pipe_endpoint)
+ self.thread = AuthenticationThread(self.context, self.pipe_endpoint, encoding=self.encoding, log=self.log)
+ self.thread.start()
+
+ def stop(self):
+ """Stop the authentication thread"""
+ if self.pipe:
+ self.pipe.send(b'TERMINATE')
+ if self.is_alive():
+ self.thread.join()
+ self.thread = None
+ self.pipe.close()
+ self.pipe = None
+
+ def is_alive(self):
+ """Is the ZAP thread currently running?"""
+ if self.thread and self.thread.is_alive():
+ return True
+ return False
+
+ def __del__(self):
+ self.stop()
+
+__all__ = ['ThreadAuthenticator']