diff options
| author | Micah Anderson <micah@riseup.net> | 2014-11-11 11:53:55 -0500 | 
|---|---|---|
| committer | Micah Anderson <micah@riseup.net> | 2014-11-11 11:53:55 -0500 | 
| commit | 7d5c3dcd969161322deed6c43f8a6a3cb92c3369 (patch) | |
| tree | 109b05c88c7252d7609ef324d62ef9dd7f06123f /zmq | |
| parent | 44be832c5708baadd146cb954befbc3dcad8d463 (diff) | |
upgrade to 14.4.1upstream/14.4.1
Diffstat (limited to 'zmq')
125 files changed, 15720 insertions, 0 deletions
diff --git a/zmq/__init__.py b/zmq/__init__.py new file mode 100644 index 0000000..cf4a1f7 --- /dev/null +++ b/zmq/__init__.py @@ -0,0 +1,63 @@ +"""Python bindings for 0MQ.""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +import os +import sys +import glob + +# load bundled libzmq, if there is one: + +here = os.path.dirname(__file__) + +bundled = [] +bundled_sodium = [] +for ext in ('pyd', 'so', 'dll', 'dylib'): +    bundled_sodium.extend(glob.glob(os.path.join(here, 'libsodium*.%s*' % ext))) +    bundled.extend(glob.glob(os.path.join(here, 'libzmq*.%s*' % ext))) + +if bundled: +    import ctypes +    if bundled_sodium: +        if bundled[0].endswith('.pyd'): +            # a Windows Extension +            _libsodium = ctypes.cdll.LoadLibrary(bundled_sodium[0]) +        else: +            _libsodium = ctypes.CDLL(bundled_sodium[0], mode=ctypes.RTLD_GLOBAL) +    if bundled[0].endswith('.pyd'): +        # a Windows Extension +        _libzmq = ctypes.cdll.LoadLibrary(bundled[0]) +    else: +        _libzmq = ctypes.CDLL(bundled[0], mode=ctypes.RTLD_GLOBAL) +    del ctypes +else: +    import zipimport +    try: +        if isinstance(__loader__, zipimport.zipimporter): +            # a zipped pyzmq egg +            from zmq import libzmq as _libzmq +    except (NameError, ImportError): +        pass +    finally: +        del zipimport + +del os, sys, glob, here, bundled, bundled_sodium, ext + +# zmq top-level imports + +from zmq.backend import * +from zmq import sugar +from zmq.sugar import * +from zmq import devices + +def get_includes(): +    """Return a list of directories to include for linking against pyzmq with cython.""" +    from os.path import join, dirname, abspath, pardir +    base = dirname(__file__) +    parent = abspath(join(base, pardir)) +    return [ parent ] + [ join(parent, base, subdir) for subdir in ('utils',) ] + + +__all__ = ['get_includes'] + sugar.__all__ + backend.__all__ + diff --git a/zmq/auth/__init__.py b/zmq/auth/__init__.py new file mode 100644 index 0000000..11d3ad6 --- /dev/null +++ b/zmq/auth/__init__.py @@ -0,0 +1,10 @@ +"""Utilities for ZAP authentication. + +To run authentication in a background thread, see :mod:`zmq.auth.thread`. +For integration with the tornado eventloop, see :mod:`zmq.auth.ioloop`. + +.. versionadded:: 14.1 +""" + +from .base import * +from .certs import * diff --git a/zmq/auth/base.py b/zmq/auth/base.py new file mode 100644 index 0000000..9b4aaed --- /dev/null +++ b/zmq/auth/base.py @@ -0,0 +1,272 @@ +"""Base implementation of 0MQ authentication.""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +import logging + +import zmq +from zmq.utils import z85 +from zmq.utils.strtypes import bytes, unicode, b, u +from zmq.error import _check_version + +from .certs import load_certificates + + +CURVE_ALLOW_ANY = '*' +VERSION = b'1.0' + +class Authenticator(object): +    """Implementation of ZAP authentication for zmq connections. + +    Note: +    - libzmq provides four levels of security: default NULL (which the Authenticator does +      not see), and authenticated NULL, PLAIN, and CURVE, which the Authenticator can see. +    - until you add policies, all incoming NULL connections are allowed +    (classic ZeroMQ behavior), and all PLAIN and CURVE connections are denied. +    """ + +    def __init__(self, context=None, encoding='utf-8', log=None): +        _check_version((4,0), "security") +        self.context = context or zmq.Context.instance() +        self.encoding = encoding +        self.allow_any = False +        self.zap_socket = None +        self.whitelist = set() +        self.blacklist = set() +        # passwords is a dict keyed by domain and contains values +        # of dicts with username:password pairs. +        self.passwords = {} +        # certs is dict keyed by domain and contains values +        # of dicts keyed by the public keys from the specified location. +        self.certs = {} +        self.log = log or logging.getLogger('zmq.auth') +     +    def start(self): +        """Create and bind the ZAP socket""" +        self.zap_socket = self.context.socket(zmq.REP) +        self.zap_socket.linger = 1 +        self.zap_socket.bind("inproc://zeromq.zap.01") + +    def stop(self): +        """Close the ZAP socket""" +        if self.zap_socket: +            self.zap_socket.close() +        self.zap_socket = None + +    def allow(self, *addresses): +        """Allow (whitelist) IP address(es). +         +        Connections from addresses not in the whitelist will be rejected. +         +        - For NULL, all clients from this address will be accepted. +        - For PLAIN and CURVE, they will be allowed to continue with authentication. +         +        whitelist is mutually exclusive with blacklist. +        """ +        if self.blacklist: +            raise ValueError("Only use a whitelist or a blacklist, not both") +        self.whitelist.update(addresses) + +    def deny(self, *addresses): +        """Deny (blacklist) IP address(es). +         +        Addresses not in the blacklist will be allowed to continue with authentication. +         +        Blacklist is mutually exclusive with whitelist. +        """ +        if self.whitelist: +            raise ValueError("Only use a whitelist or a blacklist, not both") +        self.blacklist.update(addresses) + +    def configure_plain(self, domain='*', passwords=None): +        """Configure PLAIN authentication for a given domain. +         +        PLAIN authentication uses a plain-text password file. +        To cover all domains, use "*". +        You can modify the password file at any time; it is reloaded automatically. +        """ +        if passwords: +            self.passwords[domain] = passwords + +    def configure_curve(self, domain='*', location=None): +        """Configure CURVE authentication for a given domain. +         +        CURVE authentication uses a directory that holds all public client certificates, +        i.e. their public keys. +         +        To cover all domains, use "*". +         +        You can add and remove certificates in that directory at any time. +         +        To allow all client keys without checking, specify CURVE_ALLOW_ANY for the location. +        """ +        # If location is CURVE_ALLOW_ANY then allow all clients. Otherwise +        # treat location as a directory that holds the certificates. +        if location == CURVE_ALLOW_ANY: +            self.allow_any = True +        else: +            self.allow_any = False +            try: +                self.certs[domain] = load_certificates(location) +            except Exception as e: +                self.log.error("Failed to load CURVE certs from %s: %s", location, e) + +    def handle_zap_message(self, msg): +        """Perform ZAP authentication""" +        if len(msg) < 6: +            self.log.error("Invalid ZAP message, not enough frames: %r", msg) +            if len(msg) < 2: +                self.log.error("Not enough information to reply") +            else: +                self._send_zap_reply(msg[1], b"400", b"Not enough frames") +            return +         +        version, request_id, domain, address, identity, mechanism = msg[:6] +        credentials = msg[6:] +         +        domain = u(domain, self.encoding, 'replace') +        address = u(address, self.encoding, 'replace') + +        if (version != VERSION): +            self.log.error("Invalid ZAP version: %r", msg) +            self._send_zap_reply(request_id, b"400", b"Invalid version") +            return + +        self.log.debug("version: %r, request_id: %r, domain: %r," +                      " address: %r, identity: %r, mechanism: %r", +                      version, request_id, domain, +                      address, identity, mechanism, +        ) + + +        # Is address is explicitly whitelisted or blacklisted? +        allowed = False +        denied = False +        reason = b"NO ACCESS" + +        if self.whitelist: +            if address in self.whitelist: +                allowed = True +                self.log.debug("PASSED (whitelist) address=%s", address) +            else: +                denied = True +                reason = b"Address not in whitelist" +                self.log.debug("DENIED (not in whitelist) address=%s", address) + +        elif self.blacklist: +            if address in self.blacklist: +                denied = True +                reason = b"Address is blacklisted" +                self.log.debug("DENIED (blacklist) address=%s", address) +            else: +                allowed = True +                self.log.debug("PASSED (not in blacklist) address=%s", address) + +        # Perform authentication mechanism-specific checks if necessary +        username = u("user") +        if not denied: + +            if mechanism == b'NULL' and not allowed: +                # For NULL, we allow if the address wasn't blacklisted +                self.log.debug("ALLOWED (NULL)") +                allowed = True + +            elif mechanism == b'PLAIN': +                # For PLAIN, even a whitelisted address must authenticate +                if len(credentials) != 2: +                    self.log.error("Invalid PLAIN credentials: %r", credentials) +                    self._send_zap_reply(request_id, b"400", b"Invalid credentials") +                    return +                username, password = [ u(c, self.encoding, 'replace') for c in credentials ] +                allowed, reason = self._authenticate_plain(domain, username, password) + +            elif mechanism == b'CURVE': +                # For CURVE, even a whitelisted address must authenticate +                if len(credentials) != 1: +                    self.log.error("Invalid CURVE credentials: %r", credentials) +                    self._send_zap_reply(request_id, b"400", b"Invalid credentials") +                    return +                key = credentials[0] +                allowed, reason = self._authenticate_curve(domain, key) + +        if allowed: +            self._send_zap_reply(request_id, b"200", b"OK", username) +        else: +            self._send_zap_reply(request_id, b"400", reason) + +    def _authenticate_plain(self, domain, username, password): +        """PLAIN ZAP authentication""" +        allowed = False +        reason = b"" +        if self.passwords: +            # If no domain is not specified then use the default domain +            if not domain: +                domain = '*' + +            if domain in self.passwords: +                if username in self.passwords[domain]: +                    if password == self.passwords[domain][username]: +                        allowed = True +                    else: +                        reason = b"Invalid password" +                else: +                    reason = b"Invalid username" +            else: +                reason = b"Invalid domain" + +            if allowed: +                self.log.debug("ALLOWED (PLAIN) domain=%s username=%s password=%s", +                    domain, username, password, +                ) +            else: +                self.log.debug("DENIED %s", reason) + +        else: +            reason = b"No passwords defined" +            self.log.debug("DENIED (PLAIN) %s", reason) + +        return allowed, reason + +    def _authenticate_curve(self, domain, client_key): +        """CURVE ZAP authentication""" +        allowed = False +        reason = b"" +        if self.allow_any: +            allowed = True +            reason = b"OK" +            self.log.debug("ALLOWED (CURVE allow any client)") +        else: +            # If no explicit domain is specified then use the default domain +            if not domain: +                domain = '*' + +            if domain in self.certs: +                # The certs dict stores keys in z85 format, convert binary key to z85 bytes +                z85_client_key = z85.encode(client_key) +                if z85_client_key in self.certs[domain] or self.certs[domain] == b'OK': +                    allowed = True +                    reason = b"OK" +                else: +                    reason = b"Unknown key" + +                status = "ALLOWED" if allowed else "DENIED" +                self.log.debug("%s (CURVE) domain=%s client_key=%s", +                    status, domain, z85_client_key, +                ) +            else: +                reason = b"Unknown domain" + +        return allowed, reason + +    def _send_zap_reply(self, request_id, status_code, status_text, user_id='user'): +        """Send a ZAP reply to finish the authentication.""" +        user_id = user_id if status_code == b'200' else b'' +        if isinstance(user_id, unicode): +            user_id = user_id.encode(self.encoding, 'replace') +        metadata = b''  # not currently used +        self.log.debug("ZAP reply code=%s text=%s", status_code, status_text) +        reply = [VERSION, request_id, status_code, status_text, user_id, metadata] +        self.zap_socket.send_multipart(reply) + +__all__ = ['Authenticator', 'CURVE_ALLOW_ANY'] diff --git a/zmq/auth/certs.py b/zmq/auth/certs.py new file mode 100644 index 0000000..4d26ad7 --- /dev/null +++ b/zmq/auth/certs.py @@ -0,0 +1,119 @@ +"""0MQ authentication related functions and classes.""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + + +import datetime +import glob +import io +import os +import zmq +from zmq.utils.strtypes import bytes, unicode, b, u + + +_cert_secret_banner = u("""#   ****  Generated on {0} by pyzmq  **** +#   ZeroMQ CURVE **Secret** Certificate +#   DO NOT PROVIDE THIS FILE TO OTHER USERS nor change its permissions. + +""") + +_cert_public_banner = u("""#   ****  Generated on {0} by pyzmq  **** +#   ZeroMQ CURVE Public Certificate +#   Exchange securely, or use a secure mechanism to verify the contents +#   of this file after exchange. Store public certificates in your home +#   directory, in the .curve subdirectory. + +""") + +def _write_key_file(key_filename, banner, public_key, secret_key=None, metadata=None, encoding='utf-8'): +    """Create a certificate file""" +    if isinstance(public_key, bytes): +        public_key = public_key.decode(encoding) +    if isinstance(secret_key, bytes): +        secret_key = secret_key.decode(encoding) +    with io.open(key_filename, 'w', encoding='utf8') as f: +        f.write(banner.format(datetime.datetime.now())) + +        f.write(u('metadata\n')) +        if metadata: +            for k, v in metadata.items(): +                if isinstance(v, bytes): +                    v = v.decode(encoding) +                f.write(u("    {0} = {1}\n").format(k, v)) + +        f.write(u('curve\n')) +        f.write(u("    public-key = \"{0}\"\n").format(public_key)) + +        if secret_key: +            f.write(u("    secret-key = \"{0}\"\n").format(secret_key)) + + +def create_certificates(key_dir, name, metadata=None): +    """Create zmq certificates. +     +    Returns the file paths to the public and secret certificate files. +    """ +    public_key, secret_key = zmq.curve_keypair() +    base_filename = os.path.join(key_dir, name) +    secret_key_file = "{0}.key_secret".format(base_filename) +    public_key_file = "{0}.key".format(base_filename) +    now = datetime.datetime.now() + +    _write_key_file(public_key_file, +                    _cert_public_banner.format(now), +                    public_key) + +    _write_key_file(secret_key_file, +                    _cert_secret_banner.format(now), +                    public_key, +                    secret_key=secret_key, +                    metadata=metadata) + +    return public_key_file, secret_key_file + + +def load_certificate(filename): +    """Load public and secret key from a zmq certificate. +     +    Returns (public_key, secret_key) +     +    If the certificate file only contains the public key, +    secret_key will be None. +    """ +    public_key = None +    secret_key = None +    if not os.path.exists(filename): +        raise IOError("Invalid certificate file: {0}".format(filename)) + +    with open(filename, 'rb') as f: +        for line in f: +            line = line.strip() +            if line.startswith(b'#'): +                continue +            if line.startswith(b'public-key'): +                public_key = line.split(b"=", 1)[1].strip(b' \t\'"') +            if line.startswith(b'secret-key'): +                secret_key = line.split(b"=", 1)[1].strip(b' \t\'"') +            if public_key and secret_key: +                break +     +    return public_key, secret_key + + +def load_certificates(directory='.'): +    """Load public keys from all certificates in a directory""" +    certs = {} +    if not os.path.isdir(directory): +        raise IOError("Invalid certificate directory: {0}".format(directory)) +    # Follow czmq pattern of public keys stored in *.key files. +    glob_string = os.path.join(directory, "*.key") +     +    cert_files = glob.glob(glob_string) +    for cert_file in cert_files: +        public_key, _ = load_certificate(cert_file) +        if public_key: +            certs[public_key] = 'OK' +    return certs + +__all__ = ['create_certificates', 'load_certificate', 'load_certificates'] diff --git a/zmq/auth/ioloop.py b/zmq/auth/ioloop.py new file mode 100644 index 0000000..1f448b4 --- /dev/null +++ b/zmq/auth/ioloop.py @@ -0,0 +1,34 @@ +"""ZAP Authenticator integrated with the tornado IOLoop. + +.. versionadded:: 14.1 +""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +from zmq.eventloop import ioloop, zmqstream +from .base import Authenticator + + +class IOLoopAuthenticator(Authenticator): +    """ZAP authentication for use in the tornado IOLoop""" + +    def __init__(self, context=None, encoding='utf-8', log=None, io_loop=None): +        super(IOLoopAuthenticator, self).__init__(context) +        self.zap_stream = None +        self.io_loop = io_loop or ioloop.IOLoop.instance() + +    def start(self): +        """Start ZAP authentication""" +        super(IOLoopAuthenticator, self).start() +        self.zap_stream = zmqstream.ZMQStream(self.zap_socket, self.io_loop) +        self.zap_stream.on_recv(self.handle_zap_message) + +    def stop(self): +        """Stop ZAP authentication""" +        if self.zap_stream: +            self.zap_stream.close() +            self.zap_stream = None +        super(IOLoopAuthenticator, self).stop() + +__all__ = ['IOLoopAuthenticator'] diff --git a/zmq/auth/thread.py b/zmq/auth/thread.py new file mode 100644 index 0000000..8c3355a --- /dev/null +++ b/zmq/auth/thread.py @@ -0,0 +1,184 @@ +"""ZAP Authenticator in a Python Thread. + +.. versionadded:: 14.1 +""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +import logging +from threading import Thread + +import zmq +from zmq.utils import jsonapi +from zmq.utils.strtypes import bytes, unicode, b, u + +from .base import Authenticator + +class AuthenticationThread(Thread): +    """A Thread for running a zmq Authenticator +     +    This is run in the background by ThreadedAuthenticator +    """ + +    def __init__(self, context, endpoint, encoding='utf-8', log=None): +        super(AuthenticationThread, self).__init__() +        self.context = context or zmq.Context.instance() +        self.encoding = encoding +        self.log = log = log or logging.getLogger('zmq.auth') +        self.authenticator = Authenticator(context, encoding=encoding, log=log) + +        # create a socket to communicate back to main thread. +        self.pipe = context.socket(zmq.PAIR) +        self.pipe.linger = 1 +        self.pipe.connect(endpoint) + +    def run(self): +        """ Start the Authentication Agent thread task """ +        self.authenticator.start() +        zap = self.authenticator.zap_socket +        poller = zmq.Poller() +        poller.register(self.pipe, zmq.POLLIN) +        poller.register(zap, zmq.POLLIN) +        while True: +            try: +                socks = dict(poller.poll()) +            except zmq.ZMQError: +                break  # interrupted + +            if self.pipe in socks and socks[self.pipe] == zmq.POLLIN: +                terminate = self._handle_pipe() +                if terminate: +                    break + +            if zap in socks and socks[zap] == zmq.POLLIN: +                self._handle_zap() + +        self.pipe.close() +        self.authenticator.stop() + +    def _handle_zap(self): +        """ +        Handle a message from the ZAP socket. +        """ +        msg = self.authenticator.zap_socket.recv_multipart() +        if not msg: return +        self.authenticator.handle_zap_message(msg) + +    def _handle_pipe(self): +        """ +        Handle a message from front-end API. +        """ +        terminate = False + +        # Get the whole message off the pipe in one go +        msg = self.pipe.recv_multipart() + +        if msg is None: +            terminate = True +            return terminate + +        command = msg[0] +        self.log.debug("auth received API command %r", command) + +        if command == b'ALLOW': +            addresses = [u(m, self.encoding) for m in msg[1:]] +            try: +                self.authenticator.allow(*addresses) +            except Exception as e: +                self.log.exception("Failed to allow %s", addresses) + +        elif command == b'DENY': +            addresses = [u(m, self.encoding) for m in msg[1:]] +            try: +                self.authenticator.deny(*addresses) +            except Exception as e: +                self.log.exception("Failed to deny %s", addresses) + +        elif command == b'PLAIN': +            domain = u(msg[1], self.encoding) +            json_passwords = msg[2] +            self.authenticator.configure_plain(domain, jsonapi.loads(json_passwords)) + +        elif command == b'CURVE': +            # For now we don't do anything with domains +            domain = u(msg[1], self.encoding) + +            # If location is CURVE_ALLOW_ANY, allow all clients. Otherwise +            # treat location as a directory that holds the certificates. +            location = u(msg[2], self.encoding) +            self.authenticator.configure_curve(domain, location) + +        elif command == b'TERMINATE': +            terminate = True + +        else: +            self.log.error("Invalid auth command from API: %r", command) + +        return terminate + +def _inherit_docstrings(cls): +    """inherit docstrings from Authenticator, so we don't duplicate them""" +    for name, method in cls.__dict__.items(): +        if name.startswith('_'): +            continue +        upstream_method = getattr(Authenticator, name, None) +        if not method.__doc__: +            method.__doc__ = upstream_method.__doc__ +    return cls + +@_inherit_docstrings +class ThreadAuthenticator(object): +    """Run ZAP authentication in a background thread""" + +    def __init__(self, context=None, encoding='utf-8', log=None): +        self.context = context or zmq.Context.instance() +        self.log = log +        self.encoding = encoding +        self.pipe = None +        self.pipe_endpoint = "inproc://{0}.inproc".format(id(self)) +        self.thread = None + +    def allow(self, *addresses): +        self.pipe.send_multipart([b'ALLOW'] + [b(a, self.encoding) for a in addresses]) + +    def deny(self, *addresses): +        self.pipe.send_multipart([b'DENY'] + [b(a, self.encoding) for a in addresses]) + +    def configure_plain(self, domain='*', passwords=None): +        self.pipe.send_multipart([b'PLAIN', b(domain, self.encoding), jsonapi.dumps(passwords or {})]) + +    def configure_curve(self, domain='*', location=''): +        domain = b(domain, self.encoding) +        location = b(location, self.encoding) +        self.pipe.send_multipart([b'CURVE', domain, location]) + +    def start(self): +        """Start the authentication thread""" +        # create a socket to communicate with auth thread. +        self.pipe = self.context.socket(zmq.PAIR) +        self.pipe.linger = 1 +        self.pipe.bind(self.pipe_endpoint) +        self.thread = AuthenticationThread(self.context, self.pipe_endpoint, encoding=self.encoding, log=self.log) +        self.thread.start() + +    def stop(self): +        """Stop the authentication thread""" +        if self.pipe: +            self.pipe.send(b'TERMINATE') +            if self.is_alive(): +                self.thread.join() +            self.thread = None +            self.pipe.close() +            self.pipe = None + +    def is_alive(self): +        """Is the ZAP thread currently running?""" +        if self.thread and self.thread.is_alive(): +            return True +        return False + +    def __del__(self): +        self.stop() + +__all__ = ['ThreadAuthenticator'] diff --git a/zmq/backend/__init__.py b/zmq/backend/__init__.py new file mode 100644 index 0000000..7cac725 --- /dev/null +++ b/zmq/backend/__init__.py @@ -0,0 +1,45 @@ +"""Import basic exposure of libzmq C API as a backend""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + + +import os +import platform +import sys + +from zmq.utils.sixcerpt import reraise + +from .select import public_api, select_backend + +if 'PYZMQ_BACKEND' in os.environ: +    backend = os.environ['PYZMQ_BACKEND'] +    if backend in ('cython', 'cffi'): +        backend = 'zmq.backend.%s' % backend +    _ns = select_backend(backend) +else: +    # default to cython, fallback to cffi +    # (reverse on PyPy) +    if platform.python_implementation() == 'PyPy': +        first, second = ('zmq.backend.cffi', 'zmq.backend.cython') +    else: +        first, second = ('zmq.backend.cython', 'zmq.backend.cffi') + +    try: +        _ns = select_backend(first) +    except Exception: +        exc_info = sys.exc_info() +        exc = exc_info[1] +        try: +            _ns = select_backend(second) +        except ImportError: +            # prevent 'During handling of the above exception...' on py3 +            # can't use `raise ... from` on Python 2 +            if hasattr(exc, '__cause__'): +                exc.__cause__ = None +            # raise the *first* error, not the fallback +            reraise(*exc_info) + +globals().update(_ns) + +__all__ = public_api diff --git a/zmq/backend/cffi/__init__.py b/zmq/backend/cffi/__init__.py new file mode 100644 index 0000000..ca3164d --- /dev/null +++ b/zmq/backend/cffi/__init__.py @@ -0,0 +1,22 @@ +"""CFFI backend (for PyPY)""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +from zmq.backend.cffi import (constants, error, message, context, socket, +                           _poll, devices, utils) + +__all__ = [] +for submod in (constants, error, message, context, socket, +               _poll, devices, utils): +    __all__.extend(submod.__all__) + +from .constants import * +from .error import * +from .message import * +from .context import * +from .socket import * +from .devices import * +from ._poll import * +from ._cffi import zmq_version_info, ffi +from .utils import * diff --git a/zmq/backend/cffi/_cdefs.h b/zmq/backend/cffi/_cdefs.h new file mode 100644 index 0000000..d330057 --- /dev/null +++ b/zmq/backend/cffi/_cdefs.h @@ -0,0 +1,68 @@ +void zmq_version(int *major, int *minor, int *patch); + +void* zmq_socket(void *context, int type); +int zmq_close(void *socket); + +int zmq_bind(void *socket, const char *endpoint); +int zmq_connect(void *socket, const char *endpoint); + +int zmq_errno(void); +const char * zmq_strerror(int errnum); + +void* zmq_stopwatch_start(void); +unsigned long zmq_stopwatch_stop(void *watch); +void zmq_sleep(int seconds_); +int zmq_device(int device, void *frontend, void *backend); + +int zmq_unbind(void *socket, const char *endpoint); +int zmq_disconnect(void *socket, const char *endpoint); +void* zmq_ctx_new(); +int zmq_ctx_destroy(void *context); +int zmq_ctx_get(void *context, int opt); +int zmq_ctx_set(void *context, int opt, int optval); +int zmq_proxy(void *frontend, void *backend, void *capture); +int zmq_socket_monitor(void *socket, const char *addr, int events); + +int zmq_curve_keypair (char *z85_public_key, char *z85_secret_key); +int zmq_has (const char *capability); + +typedef struct { ...; } zmq_msg_t; +typedef ... zmq_free_fn; + +int zmq_msg_init(zmq_msg_t *msg); +int zmq_msg_init_size(zmq_msg_t *msg, size_t size); +int zmq_msg_init_data(zmq_msg_t *msg, +                      void *data, +                      size_t size, +                      zmq_free_fn *ffn, +                      void *hint); + +size_t zmq_msg_size(zmq_msg_t *msg); +void *zmq_msg_data(zmq_msg_t *msg); +int zmq_msg_close(zmq_msg_t *msg); + +int zmq_msg_send(zmq_msg_t *msg, void *socket, int flags); +int zmq_msg_recv(zmq_msg_t *msg, void *socket, int flags); + +int zmq_getsockopt(void *socket, +                   int option_name, +                   void *option_value, +                   size_t *option_len); + +int zmq_setsockopt(void *socket, +                   int option_name, +                   const void *option_value, +                   size_t option_len); +typedef struct +{ +    void *socket; +    int fd; +    short events; +    short revents; +} zmq_pollitem_t; + +int zmq_poll(zmq_pollitem_t *items, int nitems, long timeout); + +// miscellany +void * memcpy(void *restrict s1, const void *restrict s2, size_t n); +int get_ipc_path_max_len(void); diff --git a/zmq/backend/cffi/_cffi.py b/zmq/backend/cffi/_cffi.py new file mode 100644 index 0000000..c73ebf8 --- /dev/null +++ b/zmq/backend/cffi/_cffi.py @@ -0,0 +1,127 @@ +# coding: utf-8 +"""The main CFFI wrapping of libzmq""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + + +import json +import os +from os.path import dirname, join +from cffi import FFI + +from zmq.utils.constant_names import all_names, no_prefix + + +base_zmq_version = (3,2,2) + +def load_compiler_config(): +    """load pyzmq compiler arguments""" +    import zmq +    zmq_dir = dirname(zmq.__file__) +    zmq_parent = dirname(zmq_dir) +     +    fname = join(zmq_dir, 'utils', 'compiler.json') +    if os.path.exists(fname): +        with open(fname) as f: +            cfg = json.load(f) +    else: +        cfg = {} +     +    cfg.setdefault("include_dirs", []) +    cfg.setdefault("library_dirs", []) +    cfg.setdefault("runtime_library_dirs", []) +    cfg.setdefault("libraries", ["zmq"]) +     +    # cast to str, because cffi can't handle unicode paths (?!) +    cfg['libraries'] = [str(lib) for lib in cfg['libraries']] +    for key in ("include_dirs", "library_dirs", "runtime_library_dirs"): +        # interpret paths relative to parent of zmq (like source tree) +        abs_paths = [] +        for p in cfg[key]: +            if p.startswith('zmq'): +                p = join(zmq_parent, p) +            abs_paths.append(str(p)) +        cfg[key] = abs_paths +    return cfg + + +def zmq_version_info(): +    """Get libzmq version as tuple of ints""" +    major = ffi.new('int*') +    minor = ffi.new('int*') +    patch = ffi.new('int*') + +    C.zmq_version(major, minor, patch) + +    return (int(major[0]), int(minor[0]), int(patch[0])) + + +cfg = load_compiler_config() +ffi = FFI() + +def _make_defines(names): +    _names = [] +    for name in names: +        define_line = "#define %s ..." % (name) +        _names.append(define_line) + +    return "\n".join(_names) + +c_constant_names = [] +for name in all_names: +    if no_prefix(name): +        c_constant_names.append(name) +    else: +        c_constant_names.append("ZMQ_" + name) + +# load ffi definitions +here = os.path.dirname(__file__) +with open(os.path.join(here, '_cdefs.h')) as f: +    _cdefs = f.read() + +with open(os.path.join(here, '_verify.c')) as f: +    _verify = f.read() + +ffi.cdef(_cdefs) +ffi.cdef(_make_defines(c_constant_names)) + +try: +    C = ffi.verify(_verify, +        modulename='_cffi_ext', +        libraries=cfg['libraries'], +        include_dirs=cfg['include_dirs'], +        library_dirs=cfg['library_dirs'], +        runtime_library_dirs=cfg['runtime_library_dirs'], +    ) +    _version_info = zmq_version_info() +except Exception as e: +    raise ImportError("PyZMQ CFFI backend couldn't find zeromq: %s\n" +    "Please check that you have zeromq headers and libraries." % e) + +if _version_info < (3,2,2): +    raise ImportError("PyZMQ CFFI backend requires zeromq >= 3.2.2," +        " but found %i.%i.%i" % _version_info +    ) + +nsp = new_sizet_pointer = lambda length: ffi.new('size_t*', length) + +new_uint64_pointer = lambda: (ffi.new('uint64_t*'), +                              nsp(ffi.sizeof('uint64_t'))) +new_int64_pointer = lambda: (ffi.new('int64_t*'), +                             nsp(ffi.sizeof('int64_t'))) +new_int_pointer = lambda: (ffi.new('int*'), +                           nsp(ffi.sizeof('int'))) +new_binary_data = lambda length: (ffi.new('char[%d]' % (length)), +                                  nsp(ffi.sizeof('char') * length)) + +value_uint64_pointer = lambda val : (ffi.new('uint64_t*', val), +                                     ffi.sizeof('uint64_t')) +value_int64_pointer = lambda val: (ffi.new('int64_t*', val), +                                   ffi.sizeof('int64_t')) +value_int_pointer = lambda val: (ffi.new('int*', val), +                                 ffi.sizeof('int')) +value_binary_data = lambda val, length: (ffi.new('char[%d]' % (length + 1), val), +                                         ffi.sizeof('char') * length) + +IPC_PATH_MAX_LEN = C.get_ipc_path_max_len() diff --git a/zmq/backend/cffi/_poll.py b/zmq/backend/cffi/_poll.py new file mode 100644 index 0000000..9bca34c --- /dev/null +++ b/zmq/backend/cffi/_poll.py @@ -0,0 +1,56 @@ +# coding: utf-8 +"""zmq poll function""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +from ._cffi import C, ffi, zmq_version_info + +from .constants import * + +from zmq.error import _check_rc + + +def _make_zmq_pollitem(socket, flags): +    zmq_socket = socket._zmq_socket +    zmq_pollitem = ffi.new('zmq_pollitem_t*') +    zmq_pollitem.socket = zmq_socket +    zmq_pollitem.fd = 0 +    zmq_pollitem.events = flags +    zmq_pollitem.revents = 0 +    return zmq_pollitem[0] + +def _make_zmq_pollitem_fromfd(socket_fd, flags): +    zmq_pollitem = ffi.new('zmq_pollitem_t*') +    zmq_pollitem.socket = ffi.NULL +    zmq_pollitem.fd = socket_fd +    zmq_pollitem.events = flags +    zmq_pollitem.revents = 0 +    return zmq_pollitem[0] + +def zmq_poll(sockets, timeout): +    cffi_pollitem_list = [] +    low_level_to_socket_obj = {} +    for item in sockets: +        if isinstance(item[0], int): +            low_level_to_socket_obj[item[0]] = item +            cffi_pollitem_list.append(_make_zmq_pollitem_fromfd(item[0], item[1])) +        else: +            low_level_to_socket_obj[item[0]._zmq_socket] = item +            cffi_pollitem_list.append(_make_zmq_pollitem(item[0], item[1])) +    items = ffi.new('zmq_pollitem_t[]', cffi_pollitem_list) +    list_length = ffi.cast('int', len(cffi_pollitem_list)) +    c_timeout = ffi.cast('long', timeout) +    rc = C.zmq_poll(items, list_length, c_timeout) +    _check_rc(rc) +    result = [] +    for index in range(len(items)): +        if not items[index].socket == ffi.NULL: +            if items[index].revents > 0: +                result.append((low_level_to_socket_obj[items[index].socket][0], +                            items[index].revents)) +        else: +            result.append((items[index].fd, items[index].revents)) +    return result + +__all__ = ['zmq_poll'] diff --git a/zmq/backend/cffi/_verify.c b/zmq/backend/cffi/_verify.c new file mode 100644 index 0000000..547840e --- /dev/null +++ b/zmq/backend/cffi/_verify.c @@ -0,0 +1,12 @@ +#include <stdio.h> +#include <sys/un.h> +#include <string.h> + +#include <zmq.h> +#include <zmq_utils.h> +#include "zmq_compat.h" + +int get_ipc_path_max_len(void) { +    struct sockaddr_un *dummy; +    return sizeof(dummy->sun_path) - 1; +} diff --git a/zmq/backend/cffi/constants.py b/zmq/backend/cffi/constants.py new file mode 100644 index 0000000..ee293e7 --- /dev/null +++ b/zmq/backend/cffi/constants.py @@ -0,0 +1,15 @@ +# coding: utf-8 +"""zmq constants""" + +from ._cffi import C, c_constant_names +from zmq.utils.constant_names import all_names + +g = globals() +for cname in c_constant_names: +    if cname.startswith("ZMQ_"): +        name = cname[4:] +    else: +        name = cname +    g[name] = getattr(C, cname) + +__all__ = all_names diff --git a/zmq/backend/cffi/context.py b/zmq/backend/cffi/context.py new file mode 100644 index 0000000..16a7b25 --- /dev/null +++ b/zmq/backend/cffi/context.py @@ -0,0 +1,100 @@ +# coding: utf-8 +"""zmq Context class""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +import weakref + +from ._cffi import C, ffi + +from .socket import * +from .constants import * + +from zmq.error import ZMQError, _check_rc + +class Context(object): +    _zmq_ctx = None +    _iothreads = None +    _closed = None +    _sockets = None +    _shadow = False + +    def __init__(self, io_threads=1, shadow=None): +         +        if shadow: +            self._zmq_ctx = ffi.cast("void *", shadow) +            self._shadow = True +        else: +            self._shadow = False +            if not io_threads >= 0: +                raise ZMQError(EINVAL) +         +            self._zmq_ctx = C.zmq_ctx_new() +        if self._zmq_ctx == ffi.NULL: +            raise ZMQError(C.zmq_errno()) +        if not shadow: +            C.zmq_ctx_set(self._zmq_ctx, IO_THREADS, io_threads) +        self._closed = False +        self._sockets = set() +     +    @property +    def underlying(self): +        """The address of the underlying libzmq context""" +        return int(ffi.cast('size_t', self._zmq_ctx)) +     +    @property +    def closed(self): +        return self._closed + +    def _add_socket(self, socket): +        ref = weakref.ref(socket) +        self._sockets.add(ref) +        return ref + +    def _rm_socket(self, ref): +        if ref in self._sockets: +            self._sockets.remove(ref) + +    def set(self, option, value): +        """set a context option +         +        see zmq_ctx_set +        """ +        rc = C.zmq_ctx_set(self._zmq_ctx, option, value) +        _check_rc(rc) + +    def get(self, option): +        """get context option +         +        see zmq_ctx_get +        """ +        rc = C.zmq_ctx_get(self._zmq_ctx, option) +        _check_rc(rc) +        return rc + +    def term(self): +        if self.closed: +            return + +        C.zmq_ctx_destroy(self._zmq_ctx) + +        self._zmq_ctx = None +        self._closed = True + +    def destroy(self, linger=None): +        if self.closed: +            return + +        sockets = self._sockets +        self._sockets = set() +        for s in sockets: +            s = s() +            if s and not s.closed: +                if linger: +                    s.setsockopt(LINGER, linger) +                s.close() +         +        self.term() + +__all__ = ['Context'] diff --git a/zmq/backend/cffi/devices.py b/zmq/backend/cffi/devices.py new file mode 100644 index 0000000..c7a514a --- /dev/null +++ b/zmq/backend/cffi/devices.py @@ -0,0 +1,24 @@ +# coding: utf-8 +"""zmq device functions""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +from ._cffi import C, ffi, zmq_version_info +from .socket import Socket +from zmq.error import ZMQError, _check_rc + +def device(device_type, frontend, backend): +    rc = C.zmq_proxy(frontend._zmq_socket, backend._zmq_socket, ffi.NULL) +    _check_rc(rc) + +def proxy(frontend, backend, capture=None): +    if isinstance(capture, Socket): +        capture = capture._zmq_socket +    else: +        capture = ffi.NULL + +    rc = C.zmq_proxy(frontend._zmq_socket, backend._zmq_socket, capture) +    _check_rc(rc) + +__all__ = ['device', 'proxy'] diff --git a/zmq/backend/cffi/error.py b/zmq/backend/cffi/error.py new file mode 100644 index 0000000..3bb64de --- /dev/null +++ b/zmq/backend/cffi/error.py @@ -0,0 +1,13 @@ +"""zmq error functions""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +from ._cffi import C, ffi + +def strerror(errno): +    return ffi.string(C.zmq_strerror(errno)) + +zmq_errno = C.zmq_errno + +__all__ = ['strerror', 'zmq_errno'] diff --git a/zmq/backend/cffi/message.py b/zmq/backend/cffi/message.py new file mode 100644 index 0000000..c35decb --- /dev/null +++ b/zmq/backend/cffi/message.py @@ -0,0 +1,69 @@ +"""Dummy Frame object""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +from ._cffi import ffi, C + +import zmq +from zmq.utils.strtypes import unicode + +try: +    view = memoryview +except NameError: +    view = buffer + +_content = lambda x: x.tobytes() if type(x) == memoryview else x + +class Frame(object): +    _data = None +    tracker = None +    closed = False +    more = False +    buffer = None + + +    def __init__(self, data, track=False): +        try: +            view(data) +        except TypeError: +            raise + +        self._data = data + +        if isinstance(data, unicode): +            raise TypeError("Unicode objects not allowed. Only: str/bytes, " + +                            "buffer interfaces.") + +        self.more = False +        self.tracker = None +        self.closed = False +        if track: +            self.tracker = zmq.MessageTracker() + +        self.buffer = view(self.bytes) + +    @property +    def bytes(self): +        data = _content(self._data) +        return data + +    def __len__(self): +        return len(self.bytes) + +    def __eq__(self, other): +        return self.bytes == _content(other) + +    def __str__(self): +        if str is unicode: +            return self.bytes.decode() +        else: +            return self.bytes + +    @property +    def done(self): +        return True + +Message = Frame + +__all__ = ['Frame', 'Message'] diff --git a/zmq/backend/cffi/socket.py b/zmq/backend/cffi/socket.py new file mode 100644 index 0000000..3c42773 --- /dev/null +++ b/zmq/backend/cffi/socket.py @@ -0,0 +1,244 @@ +# coding: utf-8 +"""zmq Socket class""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +import random +import codecs + +import errno as errno_mod + +from ._cffi import (C, ffi, new_uint64_pointer, new_int64_pointer, +                    new_int_pointer, new_binary_data, value_uint64_pointer, +                    value_int64_pointer, value_int_pointer, value_binary_data, +                    IPC_PATH_MAX_LEN) + +from .message import Frame +from .constants import * + +import zmq +from zmq.error import ZMQError, _check_rc, _check_version +from zmq.utils.strtypes import unicode + + +def new_pointer_from_opt(option, length=0): +    from zmq.sugar.constants import ( +        int64_sockopts, bytes_sockopts, +    ) +    if option in int64_sockopts: +        return new_int64_pointer() +    elif option in bytes_sockopts: +        return new_binary_data(length) +    else: +        # default +        return new_int_pointer() + +def value_from_opt_pointer(option, opt_pointer, length=0): +    from zmq.sugar.constants import ( +        int64_sockopts, bytes_sockopts, +    ) +    if option in int64_sockopts: +        return int(opt_pointer[0]) +    elif option in bytes_sockopts: +        return ffi.buffer(opt_pointer, length)[:] +    else: +        return int(opt_pointer[0]) + +def initialize_opt_pointer(option, value, length=0): +    from zmq.sugar.constants import ( +        int64_sockopts, bytes_sockopts, +    ) +    if option in int64_sockopts: +        return value_int64_pointer(value) +    elif option in bytes_sockopts: +        return value_binary_data(value, length) +    else: +        return value_int_pointer(value) + + +class Socket(object): +    context = None +    socket_type = None +    _zmq_socket = None +    _closed = None +    _ref = None +    _shadow = False + +    def __init__(self, context=None, socket_type=None, shadow=None): +        self.context = context +        if shadow is not None: +            self._zmq_socket = ffi.cast("void *", shadow) +            self._shadow = True +        else: +            self._shadow = False +            self._zmq_socket = C.zmq_socket(context._zmq_ctx, socket_type) +        if self._zmq_socket == ffi.NULL: +            raise ZMQError() +        self._closed = False +        if context: +            self._ref = context._add_socket(self) +     +    @property +    def underlying(self): +        """The address of the underlying libzmq socket""" +        return int(ffi.cast('size_t', self._zmq_socket)) +     +    @property +    def closed(self): +        return self._closed + +    def close(self, linger=None): +        rc = 0 +        if not self._closed and hasattr(self, '_zmq_socket'): +            if self._zmq_socket is not None: +                rc = C.zmq_close(self._zmq_socket) +            self._closed = True +            if self.context: +                self.context._rm_socket(self._ref) +        return rc + +    def bind(self, address): +        if isinstance(address, unicode): +            address = address.encode('utf8') +        rc = C.zmq_bind(self._zmq_socket, address) +        if rc < 0: +            if IPC_PATH_MAX_LEN and C.zmq_errno() == errno_mod.ENAMETOOLONG: +                # py3compat: address is bytes, but msg wants str +                if str is unicode: +                    address = address.decode('utf-8', 'replace') +                path = address.split('://', 1)[-1] +                msg = ('ipc path "{0}" is longer than {1} ' +                                'characters (sizeof(sockaddr_un.sun_path)).' +                                .format(path, IPC_PATH_MAX_LEN)) +                raise ZMQError(C.zmq_errno(), msg=msg) +            else: +                _check_rc(rc) + +    def unbind(self, address): +        _check_version((3,2), "unbind") +        if isinstance(address, unicode): +            address = address.encode('utf8') +        rc = C.zmq_unbind(self._zmq_socket, address) +        _check_rc(rc) + +    def connect(self, address): +        if isinstance(address, unicode): +            address = address.encode('utf8') +        rc = C.zmq_connect(self._zmq_socket, address) +        _check_rc(rc) + +    def disconnect(self, address): +        _check_version((3,2), "disconnect") +        if isinstance(address, unicode): +            address = address.encode('utf8') +        rc = C.zmq_disconnect(self._zmq_socket, address) +        _check_rc(rc) + +    def set(self, option, value): +        length = None +        if isinstance(value, unicode): +            raise TypeError("unicode not allowed, use bytes") +         +        if isinstance(value, bytes): +            if option not in zmq.constants.bytes_sockopts: +                raise TypeError("not a bytes sockopt: %s" % option) +            length = len(value) +         +        c_data = initialize_opt_pointer(option, value, length) + +        c_value_pointer = c_data[0] +        c_sizet = c_data[1] + +        rc = C.zmq_setsockopt(self._zmq_socket, +                               option, +                               ffi.cast('void*', c_value_pointer), +                               c_sizet) +        _check_rc(rc) + +    def get(self, option): +        c_data = new_pointer_from_opt(option, length=255) + +        c_value_pointer = c_data[0] +        c_sizet_pointer = c_data[1] + +        rc = C.zmq_getsockopt(self._zmq_socket, +                               option, +                               c_value_pointer, +                               c_sizet_pointer) +        _check_rc(rc) +         +        sz = c_sizet_pointer[0] +        v = value_from_opt_pointer(option, c_value_pointer, sz) +        if option != zmq.IDENTITY and option in zmq.constants.bytes_sockopts and v.endswith(b'\0'): +            v = v[:-1] +        return v + +    def send(self, message, flags=0, copy=False, track=False): +        if isinstance(message, unicode): +            raise TypeError("Message must be in bytes, not an unicode Object") + +        if isinstance(message, Frame): +            message = message.bytes + +        zmq_msg = ffi.new('zmq_msg_t*') +        c_message = ffi.new('char[]', message) +        rc = C.zmq_msg_init_size(zmq_msg, len(message)) +        C.memcpy(C.zmq_msg_data(zmq_msg), c_message, len(message)) + +        rc = C.zmq_msg_send(zmq_msg, self._zmq_socket, flags) +        C.zmq_msg_close(zmq_msg) +        _check_rc(rc) + +        if track: +            return zmq.MessageTracker() + +    def recv(self, flags=0, copy=True, track=False): +        zmq_msg = ffi.new('zmq_msg_t*') +        C.zmq_msg_init(zmq_msg) + +        rc = C.zmq_msg_recv(zmq_msg, self._zmq_socket, flags) + +        if rc < 0: +            C.zmq_msg_close(zmq_msg) +            _check_rc(rc) + +        _buffer = ffi.buffer(C.zmq_msg_data(zmq_msg), C.zmq_msg_size(zmq_msg)) +        value = _buffer[:] +        C.zmq_msg_close(zmq_msg) + +        frame = Frame(value, track=track) +        frame.more = self.getsockopt(RCVMORE) + +        if copy: +            return frame.bytes +        else: +            return frame +     +    def monitor(self, addr, events=-1): +        """s.monitor(addr, flags) + +        Start publishing socket events on inproc. +        See libzmq docs for zmq_monitor for details. +         +        Note: requires libzmq >= 3.2 +         +        Parameters +        ---------- +        addr : str +            The inproc url used for monitoring. Passing None as +            the addr will cause an existing socket monitor to be +            deregistered. +        events : int [default: zmq.EVENT_ALL] +            The zmq event bitmask for which events will be sent to the monitor. +        """ +         +        _check_version((3,2), "monitor") +        if events < 0: +            events = zmq.EVENT_ALL +        if addr is None: +            addr = ffi.NULL +        rc = C.zmq_socket_monitor(self._zmq_socket, addr, events) + + +__all__ = ['Socket', 'IPC_PATH_MAX_LEN'] diff --git a/zmq/backend/cffi/utils.py b/zmq/backend/cffi/utils.py new file mode 100644 index 0000000..fde7827 --- /dev/null +++ b/zmq/backend/cffi/utils.py @@ -0,0 +1,62 @@ +# coding: utf-8 +"""miscellaneous zmq_utils wrapping""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +from ._cffi import ffi, C + +from zmq.error import ZMQError, _check_rc, _check_version +from zmq.utils.strtypes import unicode + +def has(capability): +    """Check for zmq capability by name (e.g. 'ipc', 'curve') +     +    .. versionadded:: libzmq-4.1 +    .. versionadded:: 14.1 +    """ +    _check_version((4,1), 'zmq.has') +    if isinstance(capability, unicode): +        capability = capability.encode('utf8') +    return bool(C.zmq_has(capability)) +     +def curve_keypair(): +    """generate a Z85 keypair for use with zmq.CURVE security +     +    Requires libzmq (≥ 4.0) to have been linked with libsodium. +     +    Returns +    ------- +    (public, secret) : two bytestrings +        The public and private keypair as 40 byte z85-encoded bytestrings. +    """ +    _check_version((3,2), "monitor") +    public = ffi.new('char[64]') +    private = ffi.new('char[64]') +    rc = C.zmq_curve_keypair(public, private) +    _check_rc(rc) +    return ffi.buffer(public)[:40], ffi.buffer(private)[:40] + + +class Stopwatch(object): +    def __init__(self): +        self.watch = ffi.NULL + +    def start(self): +        if self.watch == ffi.NULL: +            self.watch = C.zmq_stopwatch_start() +        else: +            raise ZMQError('Stopwatch is already runing.') + +    def stop(self): +        if self.watch == ffi.NULL: +            raise ZMQError('Must start the Stopwatch before calling stop.') +        else: +            time = C.zmq_stopwatch_stop(self.watch) +            self.watch = ffi.NULL +            return time + +    def sleep(self, seconds): +        C.zmq_sleep(seconds) + +__all__ = ['has', 'curve_keypair', 'Stopwatch'] diff --git a/zmq/backend/cython/__init__.py b/zmq/backend/cython/__init__.py new file mode 100644 index 0000000..e535818 --- /dev/null +++ b/zmq/backend/cython/__init__.py @@ -0,0 +1,23 @@ +"""Python bindings for core 0MQ objects.""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Lesser GNU Public License (LGPL). + +from . import (constants, error, message, context, +                      socket, utils, _poll, _version, _device ) + +__all__ = [] +for submod in (constants, error, message, context, +               socket, utils, _poll, _version, _device): +    __all__.extend(submod.__all__) + +from .constants import * +from .error import * +from .message import * +from .context import * +from .socket import * +from ._poll import * +from .utils import * +from ._device import * +from ._version import * + diff --git a/zmq/backend/cython/_device.pyx b/zmq/backend/cython/_device.pyx new file mode 100644 index 0000000..eea0a00 --- /dev/null +++ b/zmq/backend/cython/_device.pyx @@ -0,0 +1,89 @@ +"""Python binding for 0MQ device function.""" + +# +#    Copyright (c) 2010-2011 Brian E. Granger & Min Ragan-Kelley +# +#    This file is part of pyzmq. +# +#    pyzmq is free software; you can redistribute it and/or modify it under +#    the terms of the Lesser GNU General Public License as published by +#    the Free Software Foundation; either version 3 of the License, or +#    (at your option) any later version. +# +#    pyzmq is distributed in the hope that it will be useful, +#    but WITHOUT ANY WARRANTY; without even the implied warranty of +#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +#    Lesser GNU General Public License for more details. +# +#    You should have received a copy of the Lesser GNU General Public License +#    along with this program.  If not, see <http://www.gnu.org/licenses/>. +# + +#----------------------------------------------------------------------------- +# Imports +#----------------------------------------------------------------------------- + +from libzmq cimport zmq_device, zmq_proxy, ZMQ_VERSION_MAJOR +from zmq.backend.cython.socket cimport Socket as cSocket +from zmq.backend.cython.checkrc cimport _check_rc + +#----------------------------------------------------------------------------- +# Basic device API +#----------------------------------------------------------------------------- + +def device(int device_type, cSocket frontend, cSocket backend=None): +    """device(device_type, frontend, backend) + +    Start a zeromq device. +     +    .. deprecated:: libzmq-3.2 +        Use zmq.proxy + +    Parameters +    ---------- +    device_type : (QUEUE, FORWARDER, STREAMER) +        The type of device to start. +    frontend : Socket +        The Socket instance for the incoming traffic. +    backend : Socket +        The Socket instance for the outbound traffic. +    """ +    if ZMQ_VERSION_MAJOR >= 3: +        return proxy(frontend, backend) + +    cdef int rc = 0 +    with nogil: +        rc = zmq_device(device_type, frontend.handle, backend.handle) +    _check_rc(rc) +    return rc + +def proxy(cSocket frontend, cSocket backend, cSocket capture=None): +    """proxy(frontend, backend, capture) +     +    Start a zeromq proxy (replacement for device). +     +    .. versionadded:: libzmq-3.2 +    .. versionadded:: 13.0 +     +    Parameters +    ---------- +    frontend : Socket +        The Socket instance for the incoming traffic. +    backend : Socket +        The Socket instance for the outbound traffic. +    capture : Socket (optional) +        The Socket instance for capturing traffic. +    """ +    cdef int rc = 0 +    cdef void* capture_handle +    if isinstance(capture, cSocket): +        capture_handle = capture.handle +    else: +        capture_handle = NULL +    with nogil: +        rc = zmq_proxy(frontend.handle, backend.handle, capture_handle) +    _check_rc(rc) +    return rc + +__all__ = ['device', 'proxy'] + diff --git a/zmq/backend/cython/_poll.pyx b/zmq/backend/cython/_poll.pyx new file mode 100644 index 0000000..5bed46b --- /dev/null +++ b/zmq/backend/cython/_poll.pyx @@ -0,0 +1,137 @@ +"""0MQ polling related functions and classes.""" + +# +#    Copyright (c) 2010-2011 Brian E. Granger & Min Ragan-Kelley +# +#    This file is part of pyzmq. +# +#    pyzmq is free software; you can redistribute it and/or modify it under +#    the terms of the Lesser GNU General Public License as published by +#    the Free Software Foundation; either version 3 of the License, or +#    (at your option) any later version. +# +#    pyzmq is distributed in the hope that it will be useful, +#    but WITHOUT ANY WARRANTY; without even the implied warranty of +#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +#    Lesser GNU General Public License for more details. +# +#    You should have received a copy of the Lesser GNU General Public License +#    along with this program.  If not, see <http://www.gnu.org/licenses/>. +# + +#----------------------------------------------------------------------------- +# Imports +#----------------------------------------------------------------------------- + +from libc.stdlib cimport free, malloc + +from libzmq cimport zmq_pollitem_t, ZMQ_VERSION_MAJOR +from libzmq cimport zmq_poll as zmq_poll_c +from socket cimport Socket + +import sys + +from zmq.backend.cython.checkrc cimport _check_rc + +#----------------------------------------------------------------------------- +# Polling related methods +#----------------------------------------------------------------------------- + +# version-independent typecheck for int/long +if sys.version_info[0] >= 3: +    int_t = int +else: +    int_t = (int,long) + + +def zmq_poll(sockets, long timeout=-1): +    """zmq_poll(sockets, timeout=-1) + +    Poll a set of 0MQ sockets, native file descs. or sockets. + +    Parameters +    ---------- +    sockets : list of tuples of (socket, flags) +        Each element of this list is a two-tuple containing a socket +        and a flags. The socket may be a 0MQ socket or any object with +        a ``fileno()`` method. The flags can be zmq.POLLIN (for detecting +        for incoming messages), zmq.POLLOUT (for detecting that send is OK) +        or zmq.POLLIN|zmq.POLLOUT for detecting both. +    timeout : int +        The number of milliseconds to poll for. Negative means no timeout. +    """ +    cdef int rc, i +    cdef zmq_pollitem_t *pollitems = NULL +    cdef int nsockets = <int>len(sockets) +    cdef Socket current_socket +     +    if nsockets == 0: +        return [] +     +    pollitems = <zmq_pollitem_t *>malloc(nsockets*sizeof(zmq_pollitem_t)) +    if pollitems == NULL: +        raise MemoryError("Could not allocate poll items") +         +    if ZMQ_VERSION_MAJOR < 3: +        # timeout is us in 2.x, ms in 3.x +        # expected input is ms (matches 3.x) +        timeout = 1000*timeout +     +    for i in range(nsockets): +        s, events = sockets[i] +        if isinstance(s, Socket): +            pollitems[i].socket = (<Socket>s).handle +            pollitems[i].events = events +            pollitems[i].revents = 0 +        elif isinstance(s, int_t): +            pollitems[i].socket = NULL +            pollitems[i].fd = s +            pollitems[i].events = events +            pollitems[i].revents = 0 +        elif hasattr(s, 'fileno'): +            try: +                fileno = int(s.fileno()) +            except: +                free(pollitems) +                raise ValueError('fileno() must return a valid integer fd') +            else: +                pollitems[i].socket = NULL +                pollitems[i].fd = fileno +                pollitems[i].events = events +                pollitems[i].revents = 0 +        else: +            free(pollitems) +            raise TypeError( +                "Socket must be a 0MQ socket, an integer fd or have " +                "a fileno() method: %r" % s +            ) +     + +    with nogil: +        rc = zmq_poll_c(pollitems, nsockets, timeout) +     +    if rc < 0: +        free(pollitems) +        _check_rc(rc) +     +    results = [] +    for i in range(nsockets): +        revents = pollitems[i].revents +        # for compatibility with select.poll: +        # - only return sockets with non-zero status +        # - return the fd for plain sockets +        if revents > 0: +            if pollitems[i].socket != NULL: +                s = sockets[i][0] +            else: +                s = pollitems[i].fd +            results.append((s, revents)) + +    free(pollitems) +    return results + +#----------------------------------------------------------------------------- +# Symbols to export +#----------------------------------------------------------------------------- + +__all__ = [ 'zmq_poll' ] diff --git a/zmq/backend/cython/_version.pyx b/zmq/backend/cython/_version.pyx new file mode 100644 index 0000000..02cf6fc --- /dev/null +++ b/zmq/backend/cython/_version.pyx @@ -0,0 +1,43 @@ +"""PyZMQ and 0MQ version functions.""" + +# +#    Copyright (c) 2010-2011 Brian E. Granger & Min Ragan-Kelley +# +#    This file is part of pyzmq. +# +#    pyzmq is free software; you can redistribute it and/or modify it under +#    the terms of the Lesser GNU General Public License as published by +#    the Free Software Foundation; either version 3 of the License, or +#    (at your option) any later version. +# +#    pyzmq is distributed in the hope that it will be useful, +#    but WITHOUT ANY WARRANTY; without even the implied warranty of +#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +#    Lesser GNU General Public License for more details. +# +#    You should have received a copy of the Lesser GNU General Public License +#    along with this program.  If not, see <http://www.gnu.org/licenses/>. +# + +#----------------------------------------------------------------------------- +# Imports +#----------------------------------------------------------------------------- + +from libzmq cimport _zmq_version + +#----------------------------------------------------------------------------- +# Code +#----------------------------------------------------------------------------- + +def zmq_version_info(): +    """zmq_version_info() + +    Return the version of ZeroMQ itself as a 3-tuple of ints. +    """ +    cdef int major, minor, patch +    _zmq_version(&major, &minor, &patch) +    return (major, minor, patch) + + +__all__ = ['zmq_version_info'] + diff --git a/zmq/backend/cython/checkrc.pxd b/zmq/backend/cython/checkrc.pxd new file mode 100644 index 0000000..3bf69fc --- /dev/null +++ b/zmq/backend/cython/checkrc.pxd @@ -0,0 +1,23 @@ +from libc.errno cimport EINTR, EAGAIN +from cpython cimport PyErr_CheckSignals +from libzmq cimport zmq_errno, ZMQ_ETERM + +cdef inline int _check_rc(int rc) except -1: +    """internal utility for checking zmq return condition +     +    and raising the appropriate Exception class +    """ +    cdef int errno = zmq_errno() +    PyErr_CheckSignals() +    if rc < 0: +        if errno == EAGAIN: +            from zmq.error import Again +            raise Again(errno) +        elif errno == ZMQ_ETERM: +            from zmq.error import ContextTerminated +            raise ContextTerminated(errno) +        else: +            from zmq.error import ZMQError +            raise ZMQError(errno) +        # return -1 +    return 0 diff --git a/zmq/backend/cython/constant_enums.pxi b/zmq/backend/cython/constant_enums.pxi new file mode 100644 index 0000000..3d0efd9 --- /dev/null +++ b/zmq/backend/cython/constant_enums.pxi @@ -0,0 +1,156 @@ +cdef extern from "zmq.h" nogil: + +    enum: ZMQ_VERSION +    enum: ZMQ_VERSION_MAJOR +    enum: ZMQ_VERSION_MINOR +    enum: ZMQ_VERSION_PATCH +    enum: ZMQ_NOBLOCK +    enum: ZMQ_DONTWAIT +    enum: ZMQ_POLLIN +    enum: ZMQ_POLLOUT +    enum: ZMQ_POLLERR +    enum: ZMQ_SNDMORE +    enum: ZMQ_STREAMER +    enum: ZMQ_FORWARDER +    enum: ZMQ_QUEUE +    enum: ZMQ_IO_THREADS_DFLT +    enum: ZMQ_MAX_SOCKETS_DFLT +    enum: ZMQ_POLLITEMS_DFLT +    enum: ZMQ_THREAD_PRIORITY_DFLT +    enum: ZMQ_THREAD_SCHED_POLICY_DFLT +    enum: ZMQ_PAIR +    enum: ZMQ_PUB +    enum: ZMQ_SUB +    enum: ZMQ_REQ +    enum: ZMQ_REP +    enum: ZMQ_DEALER +    enum: ZMQ_ROUTER +    enum: ZMQ_XREQ +    enum: ZMQ_XREP +    enum: ZMQ_PULL +    enum: ZMQ_PUSH +    enum: ZMQ_XPUB +    enum: ZMQ_XSUB +    enum: ZMQ_UPSTREAM +    enum: ZMQ_DOWNSTREAM +    enum: ZMQ_STREAM +    enum: ZMQ_EVENT_CONNECTED +    enum: ZMQ_EVENT_CONNECT_DELAYED +    enum: ZMQ_EVENT_CONNECT_RETRIED +    enum: ZMQ_EVENT_LISTENING +    enum: ZMQ_EVENT_BIND_FAILED +    enum: ZMQ_EVENT_ACCEPTED +    enum: ZMQ_EVENT_ACCEPT_FAILED +    enum: ZMQ_EVENT_CLOSED +    enum: ZMQ_EVENT_CLOSE_FAILED +    enum: ZMQ_EVENT_DISCONNECTED +    enum: ZMQ_EVENT_ALL +    enum: ZMQ_EVENT_MONITOR_STOPPED +    enum: ZMQ_NULL +    enum: ZMQ_PLAIN +    enum: ZMQ_CURVE +    enum: ZMQ_GSSAPI +    enum: ZMQ_EAGAIN "EAGAIN" +    enum: ZMQ_EINVAL "EINVAL" +    enum: ZMQ_EFAULT "EFAULT" +    enum: ZMQ_ENOMEM "ENOMEM" +    enum: ZMQ_ENODEV "ENODEV" +    enum: ZMQ_EMSGSIZE "EMSGSIZE" +    enum: ZMQ_EAFNOSUPPORT "EAFNOSUPPORT" +    enum: ZMQ_ENETUNREACH "ENETUNREACH" +    enum: ZMQ_ECONNABORTED "ECONNABORTED" +    enum: ZMQ_ECONNRESET "ECONNRESET" +    enum: ZMQ_ENOTCONN "ENOTCONN" +    enum: ZMQ_ETIMEDOUT "ETIMEDOUT" +    enum: ZMQ_EHOSTUNREACH "EHOSTUNREACH" +    enum: ZMQ_ENETRESET "ENETRESET" +    enum: ZMQ_HAUSNUMERO +    enum: ZMQ_ENOTSUP "ENOTSUP" +    enum: ZMQ_EPROTONOSUPPORT "EPROTONOSUPPORT" +    enum: ZMQ_ENOBUFS "ENOBUFS" +    enum: ZMQ_ENETDOWN "ENETDOWN" +    enum: ZMQ_EADDRINUSE "EADDRINUSE" +    enum: ZMQ_EADDRNOTAVAIL "EADDRNOTAVAIL" +    enum: ZMQ_ECONNREFUSED "ECONNREFUSED" +    enum: ZMQ_EINPROGRESS "EINPROGRESS" +    enum: ZMQ_ENOTSOCK "ENOTSOCK" +    enum: ZMQ_EFSM "EFSM" +    enum: ZMQ_ENOCOMPATPROTO "ENOCOMPATPROTO" +    enum: ZMQ_ETERM "ETERM" +    enum: ZMQ_EMTHREAD "EMTHREAD" +    enum: ZMQ_IO_THREADS +    enum: ZMQ_MAX_SOCKETS +    enum: ZMQ_SOCKET_LIMIT +    enum: ZMQ_THREAD_PRIORITY +    enum: ZMQ_THREAD_SCHED_POLICY +    enum: ZMQ_IDENTITY +    enum: ZMQ_SUBSCRIBE +    enum: ZMQ_UNSUBSCRIBE +    enum: ZMQ_LAST_ENDPOINT +    enum: ZMQ_TCP_ACCEPT_FILTER +    enum: ZMQ_PLAIN_USERNAME +    enum: ZMQ_PLAIN_PASSWORD +    enum: ZMQ_CURVE_PUBLICKEY +    enum: ZMQ_CURVE_SECRETKEY +    enum: ZMQ_CURVE_SERVERKEY +    enum: ZMQ_ZAP_DOMAIN +    enum: ZMQ_CONNECT_RID +    enum: ZMQ_GSSAPI_PRINCIPAL +    enum: ZMQ_GSSAPI_SERVICE_PRINCIPAL +    enum: ZMQ_SOCKS_PROXY +    enum: ZMQ_FD +    enum: ZMQ_IDENTITY_FD +    enum: ZMQ_RECONNECT_IVL_MAX +    enum: ZMQ_SNDTIMEO +    enum: ZMQ_RCVTIMEO +    enum: ZMQ_SNDHWM +    enum: ZMQ_RCVHWM +    enum: ZMQ_MULTICAST_HOPS +    enum: ZMQ_IPV4ONLY +    enum: ZMQ_ROUTER_BEHAVIOR +    enum: ZMQ_TCP_KEEPALIVE +    enum: ZMQ_TCP_KEEPALIVE_CNT +    enum: ZMQ_TCP_KEEPALIVE_IDLE +    enum: ZMQ_TCP_KEEPALIVE_INTVL +    enum: ZMQ_DELAY_ATTACH_ON_CONNECT +    enum: ZMQ_XPUB_VERBOSE +    enum: ZMQ_EVENTS +    enum: ZMQ_TYPE +    enum: ZMQ_LINGER +    enum: ZMQ_RECONNECT_IVL +    enum: ZMQ_BACKLOG +    enum: ZMQ_ROUTER_MANDATORY +    enum: ZMQ_FAIL_UNROUTABLE +    enum: ZMQ_ROUTER_RAW +    enum: ZMQ_IMMEDIATE +    enum: ZMQ_IPV6 +    enum: ZMQ_MECHANISM +    enum: ZMQ_PLAIN_SERVER +    enum: ZMQ_CURVE_SERVER +    enum: ZMQ_PROBE_ROUTER +    enum: ZMQ_REQ_RELAXED +    enum: ZMQ_REQ_CORRELATE +    enum: ZMQ_CONFLATE +    enum: ZMQ_ROUTER_HANDOVER +    enum: ZMQ_TOS +    enum: ZMQ_IPC_FILTER_PID +    enum: ZMQ_IPC_FILTER_UID +    enum: ZMQ_IPC_FILTER_GID +    enum: ZMQ_GSSAPI_SERVER +    enum: ZMQ_GSSAPI_PLAINTEXT +    enum: ZMQ_HANDSHAKE_IVL +    enum: ZMQ_XPUB_NODROP +    enum: ZMQ_AFFINITY +    enum: ZMQ_MAXMSGSIZE +    enum: ZMQ_HWM +    enum: ZMQ_SWAP +    enum: ZMQ_MCAST_LOOP +    enum: ZMQ_RECOVERY_IVL_MSEC +    enum: ZMQ_RATE +    enum: ZMQ_RECOVERY_IVL +    enum: ZMQ_SNDBUF +    enum: ZMQ_RCVBUF +    enum: ZMQ_RCVMORE +    enum: ZMQ_MORE +    enum: ZMQ_SRCFD +    enum: ZMQ_SHARED diff --git a/zmq/backend/cython/constants.pxi b/zmq/backend/cython/constants.pxi new file mode 100644 index 0000000..606e6cb --- /dev/null +++ b/zmq/backend/cython/constants.pxi @@ -0,0 +1,318 @@ +#----------------------------------------------------------------------------- +# Python module level constants +#----------------------------------------------------------------------------- + +VERSION = ZMQ_VERSION +VERSION_MAJOR = ZMQ_VERSION_MAJOR +VERSION_MINOR = ZMQ_VERSION_MINOR +VERSION_PATCH = ZMQ_VERSION_PATCH +NOBLOCK = ZMQ_NOBLOCK +DONTWAIT = ZMQ_DONTWAIT +POLLIN = ZMQ_POLLIN +POLLOUT = ZMQ_POLLOUT +POLLERR = ZMQ_POLLERR +SNDMORE = ZMQ_SNDMORE +STREAMER = ZMQ_STREAMER +FORWARDER = ZMQ_FORWARDER +QUEUE = ZMQ_QUEUE +IO_THREADS_DFLT = ZMQ_IO_THREADS_DFLT +MAX_SOCKETS_DFLT = ZMQ_MAX_SOCKETS_DFLT +POLLITEMS_DFLT = ZMQ_POLLITEMS_DFLT +THREAD_PRIORITY_DFLT = ZMQ_THREAD_PRIORITY_DFLT +THREAD_SCHED_POLICY_DFLT = ZMQ_THREAD_SCHED_POLICY_DFLT +PAIR = ZMQ_PAIR +PUB = ZMQ_PUB +SUB = ZMQ_SUB +REQ = ZMQ_REQ +REP = ZMQ_REP +DEALER = ZMQ_DEALER +ROUTER = ZMQ_ROUTER +XREQ = ZMQ_XREQ +XREP = ZMQ_XREP +PULL = ZMQ_PULL +PUSH = ZMQ_PUSH +XPUB = ZMQ_XPUB +XSUB = ZMQ_XSUB +UPSTREAM = ZMQ_UPSTREAM +DOWNSTREAM = ZMQ_DOWNSTREAM +STREAM = ZMQ_STREAM +EVENT_CONNECTED = ZMQ_EVENT_CONNECTED +EVENT_CONNECT_DELAYED = ZMQ_EVENT_CONNECT_DELAYED +EVENT_CONNECT_RETRIED = ZMQ_EVENT_CONNECT_RETRIED +EVENT_LISTENING = ZMQ_EVENT_LISTENING +EVENT_BIND_FAILED = ZMQ_EVENT_BIND_FAILED +EVENT_ACCEPTED = ZMQ_EVENT_ACCEPTED +EVENT_ACCEPT_FAILED = ZMQ_EVENT_ACCEPT_FAILED +EVENT_CLOSED = ZMQ_EVENT_CLOSED +EVENT_CLOSE_FAILED = ZMQ_EVENT_CLOSE_FAILED +EVENT_DISCONNECTED = ZMQ_EVENT_DISCONNECTED +EVENT_ALL = ZMQ_EVENT_ALL +EVENT_MONITOR_STOPPED = ZMQ_EVENT_MONITOR_STOPPED +globals()['NULL'] = ZMQ_NULL +PLAIN = ZMQ_PLAIN +CURVE = ZMQ_CURVE +GSSAPI = ZMQ_GSSAPI +EAGAIN = ZMQ_EAGAIN +EINVAL = ZMQ_EINVAL +EFAULT = ZMQ_EFAULT +ENOMEM = ZMQ_ENOMEM +ENODEV = ZMQ_ENODEV +EMSGSIZE = ZMQ_EMSGSIZE +EAFNOSUPPORT = ZMQ_EAFNOSUPPORT +ENETUNREACH = ZMQ_ENETUNREACH +ECONNABORTED = ZMQ_ECONNABORTED +ECONNRESET = ZMQ_ECONNRESET +ENOTCONN = ZMQ_ENOTCONN +ETIMEDOUT = ZMQ_ETIMEDOUT +EHOSTUNREACH = ZMQ_EHOSTUNREACH +ENETRESET = ZMQ_ENETRESET +HAUSNUMERO = ZMQ_HAUSNUMERO +ENOTSUP = ZMQ_ENOTSUP +EPROTONOSUPPORT = ZMQ_EPROTONOSUPPORT +ENOBUFS = ZMQ_ENOBUFS +ENETDOWN = ZMQ_ENETDOWN +EADDRINUSE = ZMQ_EADDRINUSE +EADDRNOTAVAIL = ZMQ_EADDRNOTAVAIL +ECONNREFUSED = ZMQ_ECONNREFUSED +EINPROGRESS = ZMQ_EINPROGRESS +ENOTSOCK = ZMQ_ENOTSOCK +EFSM = ZMQ_EFSM +ENOCOMPATPROTO = ZMQ_ENOCOMPATPROTO +ETERM = ZMQ_ETERM +EMTHREAD = ZMQ_EMTHREAD +IO_THREADS = ZMQ_IO_THREADS +MAX_SOCKETS = ZMQ_MAX_SOCKETS +SOCKET_LIMIT = ZMQ_SOCKET_LIMIT +THREAD_PRIORITY = ZMQ_THREAD_PRIORITY +THREAD_SCHED_POLICY = ZMQ_THREAD_SCHED_POLICY +IDENTITY = ZMQ_IDENTITY +SUBSCRIBE = ZMQ_SUBSCRIBE +UNSUBSCRIBE = ZMQ_UNSUBSCRIBE +LAST_ENDPOINT = ZMQ_LAST_ENDPOINT +TCP_ACCEPT_FILTER = ZMQ_TCP_ACCEPT_FILTER +PLAIN_USERNAME = ZMQ_PLAIN_USERNAME +PLAIN_PASSWORD = ZMQ_PLAIN_PASSWORD +CURVE_PUBLICKEY = ZMQ_CURVE_PUBLICKEY +CURVE_SECRETKEY = ZMQ_CURVE_SECRETKEY +CURVE_SERVERKEY = ZMQ_CURVE_SERVERKEY +ZAP_DOMAIN = ZMQ_ZAP_DOMAIN +CONNECT_RID = ZMQ_CONNECT_RID +GSSAPI_PRINCIPAL = ZMQ_GSSAPI_PRINCIPAL +GSSAPI_SERVICE_PRINCIPAL = ZMQ_GSSAPI_SERVICE_PRINCIPAL +SOCKS_PROXY = ZMQ_SOCKS_PROXY +FD = ZMQ_FD +IDENTITY_FD = ZMQ_IDENTITY_FD +RECONNECT_IVL_MAX = ZMQ_RECONNECT_IVL_MAX +SNDTIMEO = ZMQ_SNDTIMEO +RCVTIMEO = ZMQ_RCVTIMEO +SNDHWM = ZMQ_SNDHWM +RCVHWM = ZMQ_RCVHWM +MULTICAST_HOPS = ZMQ_MULTICAST_HOPS +IPV4ONLY = ZMQ_IPV4ONLY +ROUTER_BEHAVIOR = ZMQ_ROUTER_BEHAVIOR +TCP_KEEPALIVE = ZMQ_TCP_KEEPALIVE +TCP_KEEPALIVE_CNT = ZMQ_TCP_KEEPALIVE_CNT +TCP_KEEPALIVE_IDLE = ZMQ_TCP_KEEPALIVE_IDLE +TCP_KEEPALIVE_INTVL = ZMQ_TCP_KEEPALIVE_INTVL +DELAY_ATTACH_ON_CONNECT = ZMQ_DELAY_ATTACH_ON_CONNECT +XPUB_VERBOSE = ZMQ_XPUB_VERBOSE +EVENTS = ZMQ_EVENTS +TYPE = ZMQ_TYPE +LINGER = ZMQ_LINGER +RECONNECT_IVL = ZMQ_RECONNECT_IVL +BACKLOG = ZMQ_BACKLOG +ROUTER_MANDATORY = ZMQ_ROUTER_MANDATORY +FAIL_UNROUTABLE = ZMQ_FAIL_UNROUTABLE +ROUTER_RAW = ZMQ_ROUTER_RAW +IMMEDIATE = ZMQ_IMMEDIATE +IPV6 = ZMQ_IPV6 +MECHANISM = ZMQ_MECHANISM +PLAIN_SERVER = ZMQ_PLAIN_SERVER +CURVE_SERVER = ZMQ_CURVE_SERVER +PROBE_ROUTER = ZMQ_PROBE_ROUTER +REQ_RELAXED = ZMQ_REQ_RELAXED +REQ_CORRELATE = ZMQ_REQ_CORRELATE +CONFLATE = ZMQ_CONFLATE +ROUTER_HANDOVER = ZMQ_ROUTER_HANDOVER +TOS = ZMQ_TOS +IPC_FILTER_PID = ZMQ_IPC_FILTER_PID +IPC_FILTER_UID = ZMQ_IPC_FILTER_UID +IPC_FILTER_GID = ZMQ_IPC_FILTER_GID +GSSAPI_SERVER = ZMQ_GSSAPI_SERVER +GSSAPI_PLAINTEXT = ZMQ_GSSAPI_PLAINTEXT +HANDSHAKE_IVL = ZMQ_HANDSHAKE_IVL +XPUB_NODROP = ZMQ_XPUB_NODROP +AFFINITY = ZMQ_AFFINITY +MAXMSGSIZE = ZMQ_MAXMSGSIZE +HWM = ZMQ_HWM +SWAP = ZMQ_SWAP +MCAST_LOOP = ZMQ_MCAST_LOOP +RECOVERY_IVL_MSEC = ZMQ_RECOVERY_IVL_MSEC +RATE = ZMQ_RATE +RECOVERY_IVL = ZMQ_RECOVERY_IVL +SNDBUF = ZMQ_SNDBUF +RCVBUF = ZMQ_RCVBUF +RCVMORE = ZMQ_RCVMORE +MORE = ZMQ_MORE +SRCFD = ZMQ_SRCFD +SHARED = ZMQ_SHARED + +#----------------------------------------------------------------------------- +# Symbols to export +#----------------------------------------------------------------------------- +__all__ = [ +  "VERSION", +  "VERSION_MAJOR", +  "VERSION_MINOR", +  "VERSION_PATCH", +  "NOBLOCK", +  "DONTWAIT", +  "POLLIN", +  "POLLOUT", +  "POLLERR", +  "SNDMORE", +  "STREAMER", +  "FORWARDER", +  "QUEUE", +  "IO_THREADS_DFLT", +  "MAX_SOCKETS_DFLT", +  "POLLITEMS_DFLT", +  "THREAD_PRIORITY_DFLT", +  "THREAD_SCHED_POLICY_DFLT", +  "PAIR", +  "PUB", +  "SUB", +  "REQ", +  "REP", +  "DEALER", +  "ROUTER", +  "XREQ", +  "XREP", +  "PULL", +  "PUSH", +  "XPUB", +  "XSUB", +  "UPSTREAM", +  "DOWNSTREAM", +  "STREAM", +  "EVENT_CONNECTED", +  "EVENT_CONNECT_DELAYED", +  "EVENT_CONNECT_RETRIED", +  "EVENT_LISTENING", +  "EVENT_BIND_FAILED", +  "EVENT_ACCEPTED", +  "EVENT_ACCEPT_FAILED", +  "EVENT_CLOSED", +  "EVENT_CLOSE_FAILED", +  "EVENT_DISCONNECTED", +  "EVENT_ALL", +  "EVENT_MONITOR_STOPPED", +  "NULL", +  "PLAIN", +  "CURVE", +  "GSSAPI", +  "EAGAIN", +  "EINVAL", +  "EFAULT", +  "ENOMEM", +  "ENODEV", +  "EMSGSIZE", +  "EAFNOSUPPORT", +  "ENETUNREACH", +  "ECONNABORTED", +  "ECONNRESET", +  "ENOTCONN", +  "ETIMEDOUT", +  "EHOSTUNREACH", +  "ENETRESET", +  "HAUSNUMERO", +  "ENOTSUP", +  "EPROTONOSUPPORT", +  "ENOBUFS", +  "ENETDOWN", +  "EADDRINUSE", +  "EADDRNOTAVAIL", +  "ECONNREFUSED", +  "EINPROGRESS", +  "ENOTSOCK", +  "EFSM", +  "ENOCOMPATPROTO", +  "ETERM", +  "EMTHREAD", +  "IO_THREADS", +  "MAX_SOCKETS", +  "SOCKET_LIMIT", +  "THREAD_PRIORITY", +  "THREAD_SCHED_POLICY", +  "IDENTITY", +  "SUBSCRIBE", +  "UNSUBSCRIBE", +  "LAST_ENDPOINT", +  "TCP_ACCEPT_FILTER", +  "PLAIN_USERNAME", +  "PLAIN_PASSWORD", +  "CURVE_PUBLICKEY", +  "CURVE_SECRETKEY", +  "CURVE_SERVERKEY", +  "ZAP_DOMAIN", +  "CONNECT_RID", +  "GSSAPI_PRINCIPAL", +  "GSSAPI_SERVICE_PRINCIPAL", +  "SOCKS_PROXY", +  "FD", +  "IDENTITY_FD", +  "RECONNECT_IVL_MAX", +  "SNDTIMEO", +  "RCVTIMEO", +  "SNDHWM", +  "RCVHWM", +  "MULTICAST_HOPS", +  "IPV4ONLY", +  "ROUTER_BEHAVIOR", +  "TCP_KEEPALIVE", +  "TCP_KEEPALIVE_CNT", +  "TCP_KEEPALIVE_IDLE", +  "TCP_KEEPALIVE_INTVL", +  "DELAY_ATTACH_ON_CONNECT", +  "XPUB_VERBOSE", +  "EVENTS", +  "TYPE", +  "LINGER", +  "RECONNECT_IVL", +  "BACKLOG", +  "ROUTER_MANDATORY", +  "FAIL_UNROUTABLE", +  "ROUTER_RAW", +  "IMMEDIATE", +  "IPV6", +  "MECHANISM", +  "PLAIN_SERVER", +  "CURVE_SERVER", +  "PROBE_ROUTER", +  "REQ_RELAXED", +  "REQ_CORRELATE", +  "CONFLATE", +  "ROUTER_HANDOVER", +  "TOS", +  "IPC_FILTER_PID", +  "IPC_FILTER_UID", +  "IPC_FILTER_GID", +  "GSSAPI_SERVER", +  "GSSAPI_PLAINTEXT", +  "HANDSHAKE_IVL", +  "XPUB_NODROP", +  "AFFINITY", +  "MAXMSGSIZE", +  "HWM", +  "SWAP", +  "MCAST_LOOP", +  "RECOVERY_IVL_MSEC", +  "RATE", +  "RECOVERY_IVL", +  "SNDBUF", +  "RCVBUF", +  "RCVMORE", +  "MORE", +  "SRCFD", +  "SHARED", +] diff --git a/zmq/backend/cython/constants.pyx b/zmq/backend/cython/constants.pyx new file mode 100644 index 0000000..f924f03 --- /dev/null +++ b/zmq/backend/cython/constants.pyx @@ -0,0 +1,32 @@ +"""0MQ Constants.""" + +# +#    Copyright (c) 2010 Brian E. Granger & Min Ragan-Kelley +# +#    This file is part of pyzmq. +# +#    pyzmq is free software; you can redistribute it and/or modify it under +#    the terms of the Lesser GNU General Public License as published by +#    the Free Software Foundation; either version 3 of the License, or +#    (at your option) any later version. +# +#    pyzmq is distributed in the hope that it will be useful, +#    but WITHOUT ANY WARRANTY; without even the implied warranty of +#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +#    Lesser GNU General Public License for more details. +# +#    You should have received a copy of the Lesser GNU General Public License +#    along with this program.  If not, see <http://www.gnu.org/licenses/>. +# + +#----------------------------------------------------------------------------- +# Imports +#----------------------------------------------------------------------------- + +from libzmq cimport * + +#----------------------------------------------------------------------------- +# Python module level constants +#----------------------------------------------------------------------------- + +include "constants.pxi" diff --git a/zmq/backend/cython/context.pxd b/zmq/backend/cython/context.pxd new file mode 100644 index 0000000..9c9267a --- /dev/null +++ b/zmq/backend/cython/context.pxd @@ -0,0 +1,41 @@ +"""0MQ Context class declaration.""" + +# +#    Copyright (c) 2010-2011 Brian E. Granger & Min Ragan-Kelley +# +#    This file is part of pyzmq. +# +#    pyzmq is free software; you can redistribute it and/or modify it under +#    the terms of the Lesser GNU General Public License as published by +#    the Free Software Foundation; either version 3 of the License, or +#    (at your option) any later version. +# +#    pyzmq is distributed in the hope that it will be useful, +#    but WITHOUT ANY WARRANTY; without even the implied warranty of +#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +#    Lesser GNU General Public License for more details. +# +#    You should have received a copy of the Lesser GNU General Public License +#    along with this program.  If not, see <http://www.gnu.org/licenses/>. +# + +#----------------------------------------------------------------------------- +# Code +#----------------------------------------------------------------------------- + +cdef class Context: + +    cdef object __weakref__     # enable weakref +    cdef void *handle           # The C handle for the underlying zmq object. +    cdef bint _shadow           # whether the Context is a shadow wrapper of another +    cdef void **_sockets        # A C-array containg socket handles +    cdef size_t _n_sockets      # the number of sockets +    cdef size_t _max_sockets    # the size of the _sockets array +    cdef int _pid               # the pid of the process which created me (for fork safety) +     +    cdef public bint closed   # bool property for a closed context. +    cdef inline int _term(self) +    # helpers for events on _sockets in Socket.__cinit__()/close() +    cdef inline void _add_socket(self, void* handle) +    cdef inline void _remove_socket(self, void* handle) + diff --git a/zmq/backend/cython/context.pyx b/zmq/backend/cython/context.pyx new file mode 100644 index 0000000..b527e5d --- /dev/null +++ b/zmq/backend/cython/context.pyx @@ -0,0 +1,243 @@ +"""0MQ Context class.""" +# coding: utf-8 + +# Copyright (c) PyZMQ Developers. +# Distributed under the terms of the Lesser GNU Public License (LGPL). + +from libc.stdlib cimport free, malloc, realloc + +from libzmq cimport * + +cdef extern from "getpid_compat.h": +    int getpid() + +from zmq.error import ZMQError +from zmq.backend.cython.checkrc cimport _check_rc + + +_instance = None + +cdef class Context: +    """Context(io_threads=1) + +    Manage the lifecycle of a 0MQ context. + +    Parameters +    ---------- +    io_threads : int +        The number of IO threads. +    """ +     +    # no-op for the signature +    def __init__(self, io_threads=1, shadow=0): +        pass +     +    def __cinit__(self, int io_threads=1, size_t shadow=0, **kwargs): +        self.handle = NULL +        self._sockets = NULL +        if shadow: +            self.handle = <void *>shadow +            self._shadow = True +        else: +            self._shadow = False +            if ZMQ_VERSION_MAJOR >= 3: +                self.handle = zmq_ctx_new() +            else: +                self.handle = zmq_init(io_threads) +         +        if self.handle == NULL: +            raise ZMQError() +         +        cdef int rc = 0 +        if ZMQ_VERSION_MAJOR >= 3 and not self._shadow: +            rc = zmq_ctx_set(self.handle, ZMQ_IO_THREADS, io_threads) +            _check_rc(rc) +         +        self.closed = False +        self._n_sockets = 0 +        self._max_sockets = 32 +         +        self._sockets = <void **>malloc(self._max_sockets*sizeof(void *)) +        if self._sockets == NULL: +            raise MemoryError("Could not allocate _sockets array") +         +        self._pid = getpid() +     +    def __dealloc__(self): +        """don't touch members in dealloc, just cleanup allocations""" +        cdef int rc +        if self._sockets != NULL: +            free(self._sockets) +            self._sockets = NULL +            self._n_sockets = 0 + +        # we can't call object methods in dealloc as it +        # might already be partially deleted +        if not self._shadow: +            self._term() +     +    cdef inline void _add_socket(self, void* handle): +        """Add a socket handle to be closed when Context terminates. +         +        This is to be called in the Socket constructor. +        """ +        if self._n_sockets >= self._max_sockets: +            self._max_sockets *= 2 +            self._sockets = <void **>realloc(self._sockets, self._max_sockets*sizeof(void *)) +            if self._sockets == NULL: +                raise MemoryError("Could not reallocate _sockets array") +         +        self._sockets[self._n_sockets] = handle +        self._n_sockets += 1 + +    cdef inline void _remove_socket(self, void* handle): +        """Remove a socket from the collected handles. +         +        This should be called by Socket.close, to prevent trying to +        close a socket a second time. +        """ +        cdef bint found = False +         +        for idx in range(self._n_sockets): +            if self._sockets[idx] == handle: +                found=True +                break +         +        if found: +            self._n_sockets -= 1 +            if self._n_sockets: +                # move last handle to closed socket's index +                self._sockets[idx] = self._sockets[self._n_sockets] +     +     +    @property +    def underlying(self): +        """The address of the underlying libzmq context""" +        return <size_t> self.handle +     +    # backward-compat, though nobody is using it +    _handle = underlying +     +    cdef inline int _term(self): +        cdef int rc=0 +        if self.handle != NULL and not self.closed and getpid() == self._pid: +            with nogil: +                rc = zmq_ctx_destroy(self.handle) +        self.handle = NULL +        return rc +     +    def term(self): +        """ctx.term() + +        Close or terminate the context. +         +        This can be called to close the context by hand. If this is not called, +        the context will automatically be closed when it is garbage collected. +        """ +        cdef int rc +        rc = self._term() +        self.closed = True +     +    def set(self, int option, optval): +        """ctx.set(option, optval) + +        Set a context option. + +        See the 0MQ API documentation for zmq_ctx_set +        for details on specific options. +         +        .. versionadded:: libzmq-3.2 +        .. versionadded:: 13.0 + +        Parameters +        ---------- +        option : int +            The option to set.  Available values will depend on your +            version of libzmq.  Examples include:: +             +                zmq.IO_THREADS, zmq.MAX_SOCKETS +         +        optval : int +            The value of the option to set. +        """ +        cdef int optval_int_c +        cdef int rc +        cdef char* optval_c + +        if self.closed: +            raise RuntimeError("Context has been destroyed") +         +        if not isinstance(optval, int): +            raise TypeError('expected int, got: %r' % optval) +        optval_int_c = optval +        rc = zmq_ctx_set(self.handle, option, optval_int_c) +        _check_rc(rc) + +    def get(self, int option): +        """ctx.get(option) + +        Get the value of a context option. + +        See the 0MQ API documentation for zmq_ctx_get +        for details on specific options. +         +        .. versionadded:: libzmq-3.2 +        .. versionadded:: 13.0 + +        Parameters +        ---------- +        option : int +            The option to get.  Available values will depend on your +            version of libzmq.  Examples include:: +             +                zmq.IO_THREADS, zmq.MAX_SOCKETS +             +        Returns +        ------- +        optval : int +            The value of the option as an integer. +        """ +        cdef int optval_int_c +        cdef size_t sz +        cdef int rc + +        if self.closed: +            raise RuntimeError("Context has been destroyed") + +        rc = zmq_ctx_get(self.handle, option) +        _check_rc(rc) + +        return rc + +    def destroy(self, linger=None): +        """ctx.destroy(linger=None) +         +        Close all sockets associated with this context, and then terminate +        the context. If linger is specified, +        the LINGER sockopt of the sockets will be set prior to closing. +         +        .. warning:: +         +            destroy involves calling ``zmq_close()``, which is **NOT** threadsafe. +            If there are active sockets in other threads, this must not be called. +        """ +         +        cdef int linger_c +        cdef bint setlinger=False +         +        if linger is not None: +            linger_c = linger +            setlinger=True + +        if self.handle != NULL and not self.closed and self._n_sockets: +            while self._n_sockets: +                if setlinger: +                    zmq_setsockopt(self._sockets[0], ZMQ_LINGER, &linger_c, sizeof(int)) +                rc = zmq_close(self._sockets[0]) +                if rc < 0 and zmq_errno() != ZMQ_ENOTSOCK: +                    raise ZMQError() +                self._n_sockets -= 1 +                self._sockets[0] = self._sockets[self._n_sockets] +        self.term() +     +__all__ = ['Context'] diff --git a/zmq/backend/cython/error.pyx b/zmq/backend/cython/error.pyx new file mode 100644 index 0000000..85e785f --- /dev/null +++ b/zmq/backend/cython/error.pyx @@ -0,0 +1,56 @@ +"""0MQ Error classes and functions.""" + +# +#    Copyright (c) 2010-2011 Brian E. Granger & Min Ragan-Kelley +# +#    This file is part of pyzmq. +# +#    pyzmq is free software; you can redistribute it and/or modify it under +#    the terms of the Lesser GNU General Public License as published by +#    the Free Software Foundation; either version 3 of the License, or +#    (at your option) any later version. +# +#    pyzmq is distributed in the hope that it will be useful, +#    but WITHOUT ANY WARRANTY; without even the implied warranty of +#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +#    Lesser GNU General Public License for more details. +# +#    You should have received a copy of the Lesser GNU General Public License +#    along with this program.  If not, see <http://www.gnu.org/licenses/>. +# + +#----------------------------------------------------------------------------- +# Imports +#----------------------------------------------------------------------------- + +# allow const char* +cdef extern from *: +    ctypedef char* const_char_ptr "const char*" + +from libzmq cimport zmq_strerror, zmq_errno as zmq_errno_c + +from zmq.utils.strtypes import bytes + +def strerror(int errno): +    """strerror(errno) + +    Return the error string given the error number. +    """ +    cdef const_char_ptr str_e +    # char * will be a bytes object: +    str_e = zmq_strerror(errno) +    if str is bytes: +        # Python 2: str is bytes, so we already have the right type +        return str_e +    else: +        # Python 3: decode bytes to unicode str +        return str_e.decode() + +def zmq_errno(): +    """zmq_errno() +     +    Return the integer errno of the most recent zmq error. +    """ +    return zmq_errno_c() + +__all__ = ['strerror', 'zmq_errno'] diff --git a/zmq/backend/cython/libzmq.pxd b/zmq/backend/cython/libzmq.pxd new file mode 100644 index 0000000..e42f6d6 --- /dev/null +++ b/zmq/backend/cython/libzmq.pxd @@ -0,0 +1,110 @@ +"""All the C imports for 0MQ""" + +# +#    Copyright (c) 2010 Brian E. Granger & Min Ragan-Kelley +# +#    This file is part of pyzmq. +# +#    pyzmq is free software; you can redistribute it and/or modify it under +#    the terms of the Lesser GNU General Public License as published by +#    the Free Software Foundation; either version 3 of the License, or +#    (at your option) any later version. +# +#    pyzmq is distributed in the hope that it will be useful, +#    but WITHOUT ANY WARRANTY; without even the implied warranty of +#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +#    Lesser GNU General Public License for more details. +# +#    You should have received a copy of the Lesser GNU General Public License +#    along with this program.  If not, see <http://www.gnu.org/licenses/>. +# + +#----------------------------------------------------------------------------- +# Imports +#----------------------------------------------------------------------------- + +#----------------------------------------------------------------------------- +# Import the C header files +#----------------------------------------------------------------------------- + +cdef extern from *: +    ctypedef void* const_void_ptr "const void *" +    ctypedef char* const_char_ptr "const char *" + +cdef extern from "zmq_compat.h": +    ctypedef signed long long int64_t "pyzmq_int64_t" + +include "constant_enums.pxi" + +cdef extern from "zmq.h" nogil: + +    void _zmq_version "zmq_version"(int *major, int *minor, int *patch) +     +    ctypedef int fd_t "ZMQ_FD_T" +     +    enum: errno +    char *zmq_strerror (int errnum) +    int zmq_errno() + +    void *zmq_ctx_new () +    int zmq_ctx_destroy (void *context) +    int zmq_ctx_set (void *context, int option, int optval) +    int zmq_ctx_get (void *context, int option) +    void *zmq_init (int io_threads) +    int zmq_term (void *context) +     +    # blackbox def for zmq_msg_t +    ctypedef void * zmq_msg_t "zmq_msg_t" +     +    ctypedef void zmq_free_fn(void *data, void *hint) +     +    int zmq_msg_init (zmq_msg_t *msg) +    int zmq_msg_init_size (zmq_msg_t *msg, size_t size) +    int zmq_msg_init_data (zmq_msg_t *msg, void *data, +        size_t size, zmq_free_fn *ffn, void *hint) +    int zmq_msg_send (zmq_msg_t *msg, void *s, int flags) +    int zmq_msg_recv (zmq_msg_t *msg, void *s, int flags) +    int zmq_msg_close (zmq_msg_t *msg) +    int zmq_msg_move (zmq_msg_t *dest, zmq_msg_t *src) +    int zmq_msg_copy (zmq_msg_t *dest, zmq_msg_t *src) +    void *zmq_msg_data (zmq_msg_t *msg) +    size_t zmq_msg_size (zmq_msg_t *msg) +    int zmq_msg_more (zmq_msg_t *msg) +    int zmq_msg_get (zmq_msg_t *msg, int option) +    int zmq_msg_set (zmq_msg_t *msg, int option, int optval) +    const_char_ptr zmq_msg_gets (zmq_msg_t *msg, const_char_ptr property) +    int zmq_has (const_char_ptr capability) + +    void *zmq_socket (void *context, int type) +    int zmq_close (void *s) +    int zmq_setsockopt (void *s, int option, void *optval, size_t optvallen) +    int zmq_getsockopt (void *s, int option, void *optval, size_t *optvallen) +    int zmq_bind (void *s, char *addr) +    int zmq_connect (void *s, char *addr) +    int zmq_unbind (void *s, char *addr) +    int zmq_disconnect (void *s, char *addr) + +    int zmq_socket_monitor (void *s, char *addr, int flags) +     +    # send/recv +    int zmq_sendbuf (void *s, const_void_ptr buf, size_t n, int flags) +    int zmq_recvbuf (void *s, void *buf, size_t n, int flags) + +    ctypedef struct zmq_pollitem_t: +        void *socket +        int fd +        short events +        short revents + +    int zmq_poll (zmq_pollitem_t *items, int nitems, long timeout) + +    int zmq_device (int device_, void *insocket_, void *outsocket_) +    int zmq_proxy (void *frontend, void *backend, void *capture) + +cdef extern from "zmq_utils.h" nogil: + +    void *zmq_stopwatch_start () +    unsigned long zmq_stopwatch_stop (void *watch_) +    void zmq_sleep (int seconds_) +    int zmq_curve_keypair (char *z85_public_key, char *z85_secret_key) + diff --git a/zmq/backend/cython/message.pxd b/zmq/backend/cython/message.pxd new file mode 100644 index 0000000..4781195 --- /dev/null +++ b/zmq/backend/cython/message.pxd @@ -0,0 +1,63 @@ +"""0MQ Message related class declarations.""" + +# +#    Copyright (c) 2010-2011 Brian E. Granger & Min Ragan-Kelley +# +#    This file is part of pyzmq. +# +#    pyzmq is free software; you can redistribute it and/or modify it under +#    the terms of the Lesser GNU General Public License as published by +#    the Free Software Foundation; either version 3 of the License, or +#    (at your option) any later version. +# +#    pyzmq is distributed in the hope that it will be useful, +#    but WITHOUT ANY WARRANTY; without even the implied warranty of +#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +#    Lesser GNU General Public License for more details. +# +#    You should have received a copy of the Lesser GNU General Public License +#    along with this program.  If not, see <http://www.gnu.org/licenses/>. +# + +#----------------------------------------------------------------------------- +# Imports +#----------------------------------------------------------------------------- + +from cpython cimport PyBytes_FromStringAndSize + +from libzmq cimport zmq_msg_t, zmq_msg_data, zmq_msg_size + +#----------------------------------------------------------------------------- +# Code +#----------------------------------------------------------------------------- + +cdef class MessageTracker(object): + +    cdef set events  # Message Event objects to track. +    cdef set peers   # Other Message or MessageTracker objects. + + +cdef class Frame: + +    cdef zmq_msg_t zmq_msg +    cdef object _data      # The actual message data as a Python object. +    cdef object _buffer    # A Python Buffer/View of the message contents +    cdef object _bytes     # A bytes/str copy of the message. +    cdef bint _failed_init # Flag to handle failed zmq_msg_init +    cdef public object tracker_event  # Event for use with zmq_free_fn. +    cdef public object tracker        # MessageTracker object. +    cdef public bint more             # whether RCVMORE was set + +    cdef Frame fast_copy(self) # Create shallow copy of Message object. +    cdef object _getbuffer(self) # Construct self._buffer. + + +cdef inline object copy_zmq_msg_bytes(zmq_msg_t *zmq_msg): +    """ Copy the data from a zmq_msg_t """ +    cdef char *data_c = NULL +    cdef Py_ssize_t data_len_c +    data_c = <char *>zmq_msg_data(zmq_msg) +    data_len_c = zmq_msg_size(zmq_msg) +    return PyBytes_FromStringAndSize(data_c, data_len_c) + + diff --git a/zmq/backend/cython/message.pyx b/zmq/backend/cython/message.pyx new file mode 100644 index 0000000..312ae12 --- /dev/null +++ b/zmq/backend/cython/message.pyx @@ -0,0 +1,381 @@ +"""0MQ Message related classes.""" + +# +#    Copyright (c) 2013 Brian E. Granger & Min Ragan-Kelley +# +#    This file is part of pyzmq. +# +#    pyzmq is free software; you can redistribute it and/or modify it under +#    the terms of the Lesser GNU General Public License as published by +#    the Free Software Foundation; either version 3 of the License, or +#    (at your option) any later version. +# +#    pyzmq is distributed in the hope that it will be useful, +#    but WITHOUT ANY WARRANTY; without even the implied warranty of +#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +#    Lesser GNU General Public License for more details. +# +#    You should have received a copy of the Lesser GNU General Public License +#    along with this program.  If not, see <http://www.gnu.org/licenses/>. +# + +#----------------------------------------------------------------------------- +# Imports +#----------------------------------------------------------------------------- + +# get version-independent aliases: +cdef extern from "pyversion_compat.h": +    pass + +from cpython cimport Py_DECREF, Py_INCREF + +from buffers cimport asbuffer_r, viewfromobject_r + +cdef extern from "Python.h": +    ctypedef int Py_ssize_t + +from libzmq cimport * + +from libc.stdio cimport fprintf, stderr as cstderr +from libc.stdlib cimport malloc, free +from libc.string cimport memcpy + +import time + +try: +    # below 3.3 +    from threading import _Event as Event +except (ImportError, AttributeError): +    # python throws ImportError, cython throws AttributeError +    from threading import Event + +import zmq +from zmq.error import _check_version +from zmq.backend.cython.checkrc cimport _check_rc +from zmq.utils.strtypes import bytes,unicode,basestring + +#----------------------------------------------------------------------------- +# Code +#----------------------------------------------------------------------------- + +ctypedef struct zhint: +    void *ctx +    size_t id + +cdef void free_python_msg(void *data, void *vhint) nogil: +    """A pure-C function for DECREF'ing Python-owned message data. +     +    Sends a message on a PUSH socket +     +    The hint is a `zhint` struct with two values: +     +    ctx (void *): pointer to the Garbage Collector's context +    id (size_t): the id to be used to construct a zmq_msg_t that should be sent on a PUSH socket, +       signaling the Garbage Collector to remove its reference to the object. +     +    - A PUSH socket is created in the context, +    - it is connected to the garbage collector inproc channel, +    - it sends the gc message +    - the PUSH socket is closed +     +    When the Garbage Collector's PULL socket receives the message, +    it deletes its reference to the object, +    allowing Python to free the memory. +    """ +    cdef void *push +    cdef zmq_msg_t msg +    cdef zhint *hint = <zhint *> vhint +    if hint != NULL: +        zmq_msg_init_size(&msg, sizeof(size_t)) +        memcpy(zmq_msg_data(&msg), &hint.id, sizeof(size_t)) +         +        push = zmq_socket(hint.ctx, ZMQ_PUSH) +        if push == NULL: +            # this will happen if the context has been terminated +            return +        rc = zmq_connect(push, "inproc://pyzmq.gc.01") +        if rc < 0: +            fprintf(cstderr, "pyzmq-gc connect failed: %s\n", zmq_strerror(zmq_errno())) +            return +         +        rc = zmq_msg_send(&msg, push, 0) +        if rc < 0: +            fprintf(cstderr, "pyzmq-gc send failed: %s\n", zmq_strerror(zmq_errno())) +         +        zmq_msg_close(&msg) +        zmq_close(push) +        free(hint) + +gc = None + +cdef class Frame: +    """Frame(data=None, track=False) + +    A zmq message Frame class for non-copy send/recvs. + +    This class is only needed if you want to do non-copying send and recvs. +    When you pass a string to this class, like ``Frame(s)``, the  +    ref-count of `s` is increased by two: once because the Frame saves `s` as  +    an instance attribute and another because a ZMQ message is created that +    points to the buffer of `s`. This second ref-count increase makes sure +    that `s` lives until all messages that use it have been sent. Once 0MQ +    sends all the messages and it doesn't need the buffer of s, 0MQ will call +    ``Py_DECREF(s)``. + +    Parameters +    ---------- + +    data : object, optional +        any object that provides the buffer interface will be used to +        construct the 0MQ message data. +    track : bool [default: False] +        whether a MessageTracker_ should be created to track this object. +        Tracking a message has a cost at creation, because it creates a threadsafe +        Event object. +     +    """ + +    def __cinit__(self, object data=None, track=False, **kwargs): +        cdef int rc +        cdef char *data_c = NULL +        cdef Py_ssize_t data_len_c=0 +        cdef zhint *hint + +        # init more as False +        self.more = False + +        # Save the data object in case the user wants the the data as a str. +        self._data = data +        self._failed_init = True  # bool switch for dealloc +        self._buffer = None       # buffer view of data +        self._bytes = None        # bytes copy of data + +        # Event and MessageTracker for monitoring when zmq is done with data: +        if track: +            evt = Event() +            self.tracker_event = evt +            self.tracker = zmq.MessageTracker(evt) +        else: +            self.tracker_event = None +            self.tracker = None + +        if isinstance(data, unicode): +            raise TypeError("Unicode objects not allowed. Only: str/bytes, buffer interfaces.") + +        if data is None: +            rc = zmq_msg_init(&self.zmq_msg) +            _check_rc(rc) +            self._failed_init = False +            return +        else: +            asbuffer_r(data, <void **>&data_c, &data_len_c) +         +        # create the hint for zmq_free_fn +        # two pointers: the gc context and a message to be sent to the gc PULL socket +        # allows libzmq to signal to Python when it is done with Python-owned memory. +        global gc +        if gc is None: +            from zmq.utils.garbage import gc +         +        hint = <zhint *> malloc(sizeof(zhint)) +        hint.id = gc.store(data, self.tracker_event) +        hint.ctx = <void *> <size_t> gc._context.underlying +         +        rc = zmq_msg_init_data( +                &self.zmq_msg, <void *>data_c, data_len_c,  +                <zmq_free_fn *>free_python_msg, <void *>hint +            ) +        if rc != 0: +            free(hint) +            _check_rc(rc) +        self._failed_init = False +     +    def __init__(self, object data=None, track=False): +        """Enforce signature""" +        pass + +    def __dealloc__(self): +        cdef int rc +        if self._failed_init: +            return +        # This simply decreases the 0MQ ref-count of zmq_msg. +        with nogil: +            rc = zmq_msg_close(&self.zmq_msg) +        _check_rc(rc) +     +    # buffer interface code adapted from petsc4py by Lisandro Dalcin, a BSD project +     +    def __getbuffer__(self, Py_buffer* buffer, int flags): +        # new-style (memoryview) buffer interface +        buffer.buf = zmq_msg_data(&self.zmq_msg) +        buffer.len = zmq_msg_size(&self.zmq_msg) +         +        buffer.obj = self +        buffer.readonly = 1 +        buffer.format = "B" +        buffer.ndim = 0 +        buffer.shape = NULL +        buffer.strides = NULL +        buffer.suboffsets = NULL +        buffer.itemsize = 1 +        buffer.internal = NULL +     +    def __getsegcount__(self, Py_ssize_t *lenp): +        # required for getreadbuffer +        if lenp != NULL: +            lenp[0] = zmq_msg_size(&self.zmq_msg) +        return 1 +     +    def __getreadbuffer__(self, Py_ssize_t idx, void **p): +        # old-style (buffer) interface +        cdef char *data_c = NULL +        cdef Py_ssize_t data_len_c +        if idx != 0: +            raise SystemError("accessing non-existent buffer segment") +        # read-only, because we don't want to allow +        # editing of the message in-place +        data_c = <char *>zmq_msg_data(&self.zmq_msg) +        data_len_c = zmq_msg_size(&self.zmq_msg) +        if p != NULL: +            p[0] = <void*>data_c +        return data_len_c +     +    # end buffer interface +     +    def __copy__(self): +        """Create a shallow copy of the message. + +        This does not copy the contents of the Frame, just the pointer. +        This will increment the 0MQ ref count of the message, but not +        the ref count of the Python object. That is only done once when +        the Python is first turned into a 0MQ message. +        """ +        return self.fast_copy() + +    cdef Frame fast_copy(self): +        """Fast, cdef'd version of shallow copy of the Frame.""" +        cdef Frame new_msg +        new_msg = Frame() +        # This does not copy the contents, but just increases the ref-count  +        # of the zmq_msg by one. +        zmq_msg_copy(&new_msg.zmq_msg, &self.zmq_msg) +        # Copy the ref to data so the copy won't create a copy when str is +        # called. +        if self._data is not None: +            new_msg._data = self._data +        if self._buffer is not None: +            new_msg._buffer = self._buffer +        if self._bytes is not None: +            new_msg._bytes = self._bytes + +        # Frame copies share the tracker and tracker_event +        new_msg.tracker_event = self.tracker_event +        new_msg.tracker = self.tracker + +        return new_msg + +    def __len__(self): +        """Return the length of the message in bytes.""" +        cdef size_t sz +        sz = zmq_msg_size(&self.zmq_msg) +        return sz +        # return <int>zmq_msg_size(&self.zmq_msg) + +    def __str__(self): +        """Return the str form of the message.""" +        if isinstance(self._data, bytes): +            b = self._data +        else: +            b = self.bytes +        if str is unicode: +            return b.decode() +        else: +            return b + +    cdef inline object _getbuffer(self): +        """Create a Python buffer/view of the message data. + +        This will be called only once, the first time the `buffer` property +        is accessed. Subsequent calls use a cached copy. +        """ +        if self._data is None: +            return viewfromobject_r(self) +        else: +            return viewfromobject_r(self._data) +     +    @property +    def buffer(self): +        """A read-only buffer view of the message contents.""" +        if self._buffer is None: +            self._buffer = self._getbuffer() +        return self._buffer + +    @property +    def bytes(self): +        """The message content as a Python bytes object. + +        The first time this property is accessed, a copy of the message  +        contents is made. From then on that same copy of the message is +        returned. +        """ +        if self._bytes is None: +            self._bytes = copy_zmq_msg_bytes(&self.zmq_msg) +        return self._bytes +     +    def set(self, int option, int value): +        """Frame.set(option, value) +         +        Set a Frame option. +         +        See the 0MQ API documentation for zmq_msg_set +        for details on specific options. +         +        .. versionadded:: libzmq-3.2 +        .. versionadded:: 13.0 +        """ +        cdef int rc = zmq_msg_set(&self.zmq_msg, option, value) +        _check_rc(rc) + +    def get(self, option): +        """Frame.get(option) + +        Get a Frame option or property. + +        See the 0MQ API documentation for zmq_msg_get and zmq_msg_gets +        for details on specific options. + +        .. versionadded:: libzmq-3.2 +        .. versionadded:: 13.0 +         +        .. versionchanged:: 14.3 +            add support for zmq_msg_gets (requires libzmq-4.1) +        """ +        cdef int rc = 0 +        cdef char *property_c = NULL +        cdef Py_ssize_t property_len_c = 0 + +        # zmq_msg_get +        if isinstance(option, int): +            rc = zmq_msg_get(&self.zmq_msg, option) +            _check_rc(rc) +            return rc + +        # zmq_msg_gets +        _check_version((4,1), "get string properties") +        if isinstance(option, unicode): +            option = option.encode('utf8') +         +        if not isinstance(option, bytes): +            raise TypeError("expected str, got: %r" % option) +         +        property_c = option +         +        cdef const char *result = <char *>zmq_msg_gets(&self.zmq_msg, property_c) +        if result == NULL: +            _check_rc(-1) +        return result.decode('utf8') + +# legacy Message name +Message = Frame + +__all__ = ['Frame', 'Message'] diff --git a/zmq/backend/cython/rebuffer.pyx b/zmq/backend/cython/rebuffer.pyx new file mode 100644 index 0000000..402e3b6 --- /dev/null +++ b/zmq/backend/cython/rebuffer.pyx @@ -0,0 +1,104 @@ +""" +Utility for changing itemsize of memoryviews, and getting +numpy arrays from byte-arrays that should be interpreted with a different +itemsize. + +Authors +------- +* MinRK +""" + +#----------------------------------------------------------------------------- +#  Copyright (c) 2010-2012 Brian Granger, Min Ragan-Kelley +# +#  This file is part of pyzmq +# +#  Distributed under the terms of the New BSD License.  The full license is in +#  the file COPYING.BSD, distributed as part of this software. +#----------------------------------------------------------------------------- + +from libc.stdlib cimport malloc +from buffers cimport * + +cdef inline object _rebuffer(object obj, char * format, int itemsize): +    """clobber the format & itemsize of a 1-D + +    This is the Python 3 model, but will work on Python >= 2.6. Currently, +    we use it only on >= 3.0. +    """ +    cdef Py_buffer view +    cdef int flags = PyBUF_SIMPLE +    cdef int mode = 0 +    # cdef Py_ssize_t *shape, *strides, *suboffsets +     +    mode = check_buffer(obj) +    if mode == 0: +        raise TypeError("%r does not provide a buffer interface."%obj) + +    if mode == 3: +        flags = PyBUF_ANY_CONTIGUOUS +        if format: +            flags |= PyBUF_FORMAT +        PyObject_GetBuffer(obj, &view, flags) +        assert view.ndim <= 1, "Can only reinterpret 1-D memoryviews" +        assert view.len % itemsize == 0, "Buffer of length %i not divisible into items of size %i"%(view.len, itemsize) +        # hack the format +        view.ndim = 1 +        view.format = format +        view.itemsize = itemsize +        view.strides = <Py_ssize_t *>malloc(sizeof(Py_ssize_t)) +        view.strides[0] = itemsize +        view.shape = <Py_ssize_t *>malloc(sizeof(Py_ssize_t)) +        view.shape[0] = view.len/itemsize +        view.suboffsets = <Py_ssize_t *>malloc(sizeof(Py_ssize_t)) +        view.suboffsets[0] = 0 +        # for debug: make buffer writable, for zero-copy testing +        # view.readonly = 0 +         +        return PyMemoryView_FromBuffer(&view) +    else: +        raise TypeError("This funciton is only for new-style buffer objects.") + +def rebuffer(obj, format, itemsize): +    """Change the itemsize of a memoryview. +     +    Only for 1D contiguous buffers. +    """ +    return _rebuffer(obj, format, itemsize) + +def array_from_buffer(view, dtype, shape): +    """Get a numpy array from a memoryview, regardless of the itemsize of the original +    memoryview.  This is important, because pyzmq does not send memoryview shape data +    over the wire, so we need to change the memoryview itemsize before calling +    asarray. +    """ +    import numpy +    A = numpy.array([],dtype=dtype) +    ref = viewfromobject(A,0) +    fmt = ref.format.encode() +    buf = viewfromobject(view, 0) +    buf = _rebuffer(view, fmt, ref.itemsize) +    return numpy.asarray(buf, dtype=dtype).reshape(shape) + +def print_view_info(obj): +    """simple utility for printing info on a new-style buffer object""" +    cdef Py_buffer view +    cdef int flags = PyBUF_ANY_CONTIGUOUS|PyBUF_FORMAT +    cdef int mode = 0 +     +    mode = check_buffer(obj) +    if mode == 0: +        raise TypeError("%r does not provide a buffer interface."%obj) + +    if mode == 3: +        PyObject_GetBuffer(obj, &view, flags) +        print <size_t>view.buf, view.len, view.format, view.ndim, +        if view.ndim: +            if view.shape: +                print view.shape[0], +            if view.strides: +                print view.strides[0], +            if view.suboffsets: +                print view.suboffsets[0], +        print +    
\ No newline at end of file diff --git a/zmq/backend/cython/socket.pxd b/zmq/backend/cython/socket.pxd new file mode 100644 index 0000000..b8a331e --- /dev/null +++ b/zmq/backend/cython/socket.pxd @@ -0,0 +1,47 @@ +"""0MQ Socket class declaration.""" + +# +#    Copyright (c) 2010-2011 Brian E. Granger & Min Ragan-Kelley +# +#    This file is part of pyzmq. +# +#    pyzmq is free software; you can redistribute it and/or modify it under +#    the terms of the Lesser GNU General Public License as published by +#    the Free Software Foundation; either version 3 of the License, or +#    (at your option) any later version. +# +#    pyzmq is distributed in the hope that it will be useful, +#    but WITHOUT ANY WARRANTY; without even the implied warranty of +#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +#    Lesser GNU General Public License for more details. +# +#    You should have received a copy of the Lesser GNU General Public License +#    along with this program.  If not, see <http://www.gnu.org/licenses/>. +# + +#----------------------------------------------------------------------------- +# Imports +#----------------------------------------------------------------------------- + +from context cimport Context + +#----------------------------------------------------------------------------- +# Code +#----------------------------------------------------------------------------- + + +cdef class Socket: + +    cdef object __weakref__     # enable weakref +    cdef void *handle           # The C handle for the underlying zmq object. +    cdef bint _shadow           # whether the Socket is a shadow wrapper of another +    # Hold on to a reference to the context to make sure it is not garbage +    # collected until the socket it done with it. +    cdef public Context context # The zmq Context object that owns this. +    cdef public bint _closed    # bool property for a closed socket. +    cdef int _pid               # the pid of the process which created me (for fork safety) + +    # cpdef methods for direct-cython access: +    cpdef object send(self, object data, int flags=*, copy=*, track=*) +    cpdef object recv(self, int flags=*, copy=*, track=*) + diff --git a/zmq/backend/cython/socket.pyx b/zmq/backend/cython/socket.pyx new file mode 100644 index 0000000..9b9ec36 --- /dev/null +++ b/zmq/backend/cython/socket.pyx @@ -0,0 +1,672 @@ +"""0MQ Socket class.""" + +# +#    Copyright (c) 2010-2011 Brian E. Granger & Min Ragan-Kelley +# +#    This file is part of pyzmq. +# +#    pyzmq is free software; you can redistribute it and/or modify it under +#    the terms of the Lesser GNU General Public License as published by +#    the Free Software Foundation; either version 3 of the License, or +#    (at your option) any later version. +# +#    pyzmq is distributed in the hope that it will be useful, +#    but WITHOUT ANY WARRANTY; without even the implied warranty of +#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +#    Lesser GNU General Public License for more details. +# +#    You should have received a copy of the Lesser GNU General Public License +#    along with this program.  If not, see <http://www.gnu.org/licenses/>. +# + +#----------------------------------------------------------------------------- +# Cython Imports +#----------------------------------------------------------------------------- + +# get version-independent aliases: +cdef extern from "pyversion_compat.h": +    pass + +from libc.errno cimport ENAMETOOLONG +from libc.string cimport memcpy + +from cpython cimport PyBytes_FromStringAndSize +from cpython cimport PyBytes_AsString, PyBytes_Size +from cpython cimport Py_DECREF, Py_INCREF + +from buffers cimport asbuffer_r, viewfromobject_r + +from libzmq cimport * +from message cimport Frame, copy_zmq_msg_bytes + +from context cimport Context + +cdef extern from "Python.h": +    ctypedef int Py_ssize_t + +cdef extern from "ipcmaxlen.h": +    int get_ipc_path_max_len() + +cdef extern from "getpid_compat.h": +    int getpid() + + +#----------------------------------------------------------------------------- +# Python Imports +#----------------------------------------------------------------------------- + +import copy as copy_mod +import time +import sys +import random +import struct +import codecs + +from zmq.utils import jsonapi + +try: +    import cPickle +    pickle = cPickle +except: +    cPickle = None +    import pickle + +import zmq +from zmq.backend.cython import constants +from zmq.backend.cython.constants import * +from zmq.backend.cython.checkrc cimport _check_rc +from zmq.error import ZMQError, ZMQBindError, _check_version +from zmq.utils.strtypes import bytes,unicode,basestring + +#----------------------------------------------------------------------------- +# Code +#----------------------------------------------------------------------------- + +IPC_PATH_MAX_LEN = get_ipc_path_max_len() + +# inline some small socket submethods: +# true methods frequently cannot be inlined, acc. Cython docs + +cdef inline _check_closed(Socket s): +    """raise ENOTSUP if socket is closed +     +    Does not do a deep check +    """ +    if s._closed: +        raise ZMQError(ENOTSOCK) + +cdef inline _check_closed_deep(Socket s): +    """thorough check of whether the socket has been closed, +    even if by another entity (e.g. ctx.destroy). +     +    Only used by the `closed` property. +     +    returns True if closed, False otherwise +    """ +    cdef int rc +    cdef int errno +    cdef int stype +    cdef size_t sz=sizeof(int) +    if s._closed: +        return True +    else: +        rc = zmq_getsockopt(s.handle, ZMQ_TYPE, <void *>&stype, &sz) +        if rc < 0 and zmq_errno() == ENOTSOCK: +            s._closed = True +            return True +        else: +            _check_rc(rc) +    return False + +cdef inline Frame _recv_frame(void *handle, int flags=0, track=False): +    """Receive a message in a non-copying manner and return a Frame.""" +    cdef int rc +    msg = zmq.Frame(track=track) +    cdef Frame cmsg = msg + +    with nogil: +        rc = zmq_msg_recv(&cmsg.zmq_msg, handle, flags) +     +    _check_rc(rc) +    return msg + +cdef inline object _recv_copy(void *handle, int flags=0): +    """Receive a message and return a copy""" +    cdef zmq_msg_t zmq_msg +    with nogil: +        zmq_msg_init (&zmq_msg) +        rc = zmq_msg_recv(&zmq_msg, handle, flags) +    _check_rc(rc) +    msg_bytes = copy_zmq_msg_bytes(&zmq_msg) +    zmq_msg_close(&zmq_msg) +    return msg_bytes + +cdef inline object _send_frame(void *handle, Frame msg, int flags=0): +    """Send a Frame on this socket in a non-copy manner.""" +    cdef int rc +    cdef Frame msg_copy + +    # Always copy so the original message isn't garbage collected. +    # This doesn't do a real copy, just a reference. +    msg_copy = msg.fast_copy() + +    with nogil: +        rc = zmq_msg_send(&msg_copy.zmq_msg, handle, flags) + +    _check_rc(rc) +    return msg.tracker + + +cdef inline object _send_copy(void *handle, object msg, int flags=0): +    """Send a message on this socket by copying its content.""" +    cdef int rc, rc2 +    cdef zmq_msg_t data +    cdef char *msg_c +    cdef Py_ssize_t msg_c_len=0 + +    # copy to c array: +    asbuffer_r(msg, <void **>&msg_c, &msg_c_len) + +    # Copy the msg before sending. This avoids any complications with +    # the GIL, etc. +    # If zmq_msg_init_* fails we must not call zmq_msg_close (Bus Error) +    rc = zmq_msg_init_size(&data, msg_c_len) + +    _check_rc(rc) + +    with nogil: +        memcpy(zmq_msg_data(&data), msg_c, zmq_msg_size(&data)) +        rc = zmq_msg_send(&data, handle, flags) +        rc2 = zmq_msg_close(&data) +    _check_rc(rc) +    _check_rc(rc2) + + +cdef class Socket: +    """Socket(context, socket_type) + +    A 0MQ socket. + +    These objects will generally be constructed via the socket() method of a Context object. +     +    Note: 0MQ Sockets are *not* threadsafe. **DO NOT** share them across threads. +     +    Parameters +    ---------- +    context : Context +        The 0MQ Context this Socket belongs to. +    socket_type : int +        The socket type, which can be any of the 0MQ socket types:  +        REQ, REP, PUB, SUB, PAIR, DEALER, ROUTER, PULL, PUSH, XPUB, XSUB. +     +    See Also +    -------- +    .Context.socket : method for creating a socket bound to a Context. +    """ +     +    # no-op for the signature +    def __init__(self, context=None, socket_type=-1, shadow=0): +        pass +     +    def __cinit__(self, Context context=None, int socket_type=-1, size_t shadow=0, *args, **kwargs): +        cdef Py_ssize_t c_handle + +        self.handle = NULL +        self.context = context +        if shadow: +            self._shadow = True +            self.handle = <void *>shadow +        else: +            if context is None: +                raise TypeError("context must be specified") +            if socket_type < 0: +                raise TypeError("socket_type must be specified") +            self._shadow = False +            self.handle = zmq_socket(context.handle, socket_type) +        if self.handle == NULL: +            raise ZMQError() +        self._closed = False +        self._pid = getpid() +        if context: +            context._add_socket(self.handle) + +    def __dealloc__(self): +        """remove from context's list +         +        But be careful that context might not exist if called during gc +        """ +        if self.handle != NULL and not self._shadow and getpid() == self._pid: +            # during gc, self.context might be NULL +            if self.context and not self.context.closed: +                self.context._remove_socket(self.handle) +     +    @property +    def underlying(self): +        """The address of the underlying libzmq socket""" +        return <size_t> self.handle +     +    @property +    def closed(self): +        return _check_closed_deep(self) +     +    def close(self, linger=None): +        """s.close(linger=None) + +        Close the socket. +         +        If linger is specified, LINGER sockopt will be set prior to closing. + +        This can be called to close the socket by hand. If this is not +        called, the socket will automatically be closed when it is +        garbage collected. +        """ +        cdef int rc=0 +        cdef int linger_c +        cdef bint setlinger=False +         +        if linger is not None: +            linger_c = linger +            setlinger=True +         +        if self.handle != NULL and not self._closed and getpid() == self._pid: +            if setlinger: +                zmq_setsockopt(self.handle, ZMQ_LINGER, &linger_c, sizeof(int)) +            rc = zmq_close(self.handle) +            if rc != 0 and zmq_errno() != ENOTSOCK: +                # ignore ENOTSOCK (closed by Context) +                _check_rc(rc) +            self._closed = True +            # during gc, self.context might be NULL +            if self.context: +                self.context._remove_socket(self.handle) +            self.handle = NULL + +    def set(self, int option, optval): +        """s.set(option, optval) + +        Set socket options. + +        See the 0MQ API documentation for details on specific options. + +        Parameters +        ---------- +        option : int +            The option to set.  Available values will depend on your +            version of libzmq.  Examples include:: +             +                zmq.SUBSCRIBE, UNSUBSCRIBE, IDENTITY, HWM, LINGER, FD +         +        optval : int or bytes +            The value of the option to set. +        """ +        cdef int64_t optval_int64_c +        cdef int optval_int_c +        cdef int rc +        cdef char* optval_c +        cdef Py_ssize_t sz + +        _check_closed(self) +        if isinstance(optval, unicode): +            raise TypeError("unicode not allowed, use setsockopt_string") + +        if option in zmq.constants.bytes_sockopts: +            if not isinstance(optval, bytes): +                raise TypeError('expected bytes, got: %r' % optval) +            optval_c = PyBytes_AsString(optval) +            sz = PyBytes_Size(optval) +            rc = zmq_setsockopt( +                    self.handle, option, +                    optval_c, sz +                ) +        elif option in zmq.constants.int64_sockopts: +            if not isinstance(optval, int): +                raise TypeError('expected int, got: %r' % optval) +            optval_int64_c = optval +            rc = zmq_setsockopt( +                    self.handle, option, +                    &optval_int64_c, sizeof(int64_t) +                ) +        else: +            # default is to assume int, which is what most new sockopts will be +            # this lets pyzmq work with newer libzmq which may add constants +            # pyzmq has not yet added, rather than artificially raising. Invalid +            # sockopts will still raise just the same, but it will be libzmq doing +            # the raising. +            if not isinstance(optval, int): +                raise TypeError('expected int, got: %r' % optval) +            optval_int_c = optval +            rc = zmq_setsockopt( +                    self.handle, option, +                    &optval_int_c, sizeof(int) +                ) + +        _check_rc(rc) + +    def get(self, int option): +        """s.get(option) + +        Get the value of a socket option. + +        See the 0MQ API documentation for details on specific options. + +        Parameters +        ---------- +        option : int +            The option to get.  Available values will depend on your +            version of libzmq.  Examples include:: +             +                zmq.IDENTITY, HWM, LINGER, FD, EVENTS + +        Returns +        ------- +        optval : int or bytes +            The value of the option as a bytestring or int. +        """ +        cdef int64_t optval_int64_c +        cdef int optval_int_c +        cdef fd_t optval_fd_c +        cdef char identity_str_c [255] +        cdef size_t sz +        cdef int rc + +        _check_closed(self) + +        if option in zmq.constants.bytes_sockopts: +            sz = 255 +            rc = zmq_getsockopt(self.handle, option, <void *>identity_str_c, &sz) +            _check_rc(rc) +            # strip null-terminated strings *except* identity +            if option != ZMQ_IDENTITY and sz > 0 and (<char *>identity_str_c)[sz-1] == b'\0': +                sz -= 1 +            result = PyBytes_FromStringAndSize(<char *>identity_str_c, sz) +        elif option in zmq.constants.int64_sockopts: +            sz = sizeof(int64_t) +            rc = zmq_getsockopt(self.handle, option, <void *>&optval_int64_c, &sz) +            _check_rc(rc) +            result = optval_int64_c +        elif option in zmq.constants.fd_sockopts: +            sz = sizeof(fd_t) +            rc = zmq_getsockopt(self.handle, option, <void *>&optval_fd_c, &sz) +            _check_rc(rc) +            result = optval_fd_c +        else: +            # default is to assume int, which is what most new sockopts will be +            # this lets pyzmq work with newer libzmq which may add constants +            # pyzmq has not yet added, rather than artificially raising. Invalid +            # sockopts will still raise just the same, but it will be libzmq doing +            # the raising. +            sz = sizeof(int) +            rc = zmq_getsockopt(self.handle, option, <void *>&optval_int_c, &sz) +            _check_rc(rc) +            result = optval_int_c + +        return result +     +    def bind(self, addr): +        """s.bind(addr) + +        Bind the socket to an address. + +        This causes the socket to listen on a network port. Sockets on the +        other side of this connection will use ``Socket.connect(addr)`` to +        connect to this socket. + +        Parameters +        ---------- +        addr : str +            The address string. This has the form 'protocol://interface:port', +            for example 'tcp://127.0.0.1:5555'. Protocols supported include +            tcp, udp, pgm, epgm, inproc and ipc. If the address is unicode, it is +            encoded to utf-8 first. +        """ +        cdef int rc +        cdef char* c_addr + +        _check_closed(self) +        if isinstance(addr, unicode): +            addr = addr.encode('utf-8') +        if not isinstance(addr, bytes): +            raise TypeError('expected str, got: %r' % addr) +        c_addr = addr +        rc = zmq_bind(self.handle, c_addr) +        if rc != 0: +            if IPC_PATH_MAX_LEN and zmq_errno() == ENAMETOOLONG: +                # py3compat: addr is bytes, but msg wants str +                if str is unicode: +                    addr = addr.decode('utf-8', 'replace') +                path = addr.split('://', 1)[-1] +                msg = ('ipc path "{0}" is longer than {1} ' +                                'characters (sizeof(sockaddr_un.sun_path)). ' +                                'zmq.IPC_PATH_MAX_LEN constant can be used ' +                                'to check addr length (if it is defined).' +                                .format(path, IPC_PATH_MAX_LEN)) +                raise ZMQError(msg=msg) +        _check_rc(rc) + +    def connect(self, addr): +        """s.connect(addr) + +        Connect to a remote 0MQ socket. + +        Parameters +        ---------- +        addr : str +            The address string. This has the form 'protocol://interface:port', +            for example 'tcp://127.0.0.1:5555'. Protocols supported are +            tcp, upd, pgm, inproc and ipc. If the address is unicode, it is +            encoded to utf-8 first. +        """ +        cdef int rc +        cdef char* c_addr + +        _check_closed(self) +        if isinstance(addr, unicode): +            addr = addr.encode('utf-8') +        if not isinstance(addr, bytes): +            raise TypeError('expected str, got: %r' % addr) +        c_addr = addr +         +        rc = zmq_connect(self.handle, c_addr) +        if rc != 0: +            raise ZMQError() + +    def unbind(self, addr): +        """s.unbind(addr) +         +        Unbind from an address (undoes a call to bind). +         +        .. versionadded:: libzmq-3.2 +        .. versionadded:: 13.0 + +        Parameters +        ---------- +        addr : str +            The address string. This has the form 'protocol://interface:port', +            for example 'tcp://127.0.0.1:5555'. Protocols supported are +            tcp, upd, pgm, inproc and ipc. If the address is unicode, it is +            encoded to utf-8 first. +        """ +        cdef int rc +        cdef char* c_addr + +        _check_version((3,2), "unbind") +        _check_closed(self) +        if isinstance(addr, unicode): +            addr = addr.encode('utf-8') +        if not isinstance(addr, bytes): +            raise TypeError('expected str, got: %r' % addr) +        c_addr = addr +         +        rc = zmq_unbind(self.handle, c_addr) +        if rc != 0: +            raise ZMQError() + +    def disconnect(self, addr): +        """s.disconnect(addr) + +        Disconnect from a remote 0MQ socket (undoes a call to connect). +         +        .. versionadded:: libzmq-3.2 +        .. versionadded:: 13.0 + +        Parameters +        ---------- +        addr : str +            The address string. This has the form 'protocol://interface:port', +            for example 'tcp://127.0.0.1:5555'. Protocols supported are +            tcp, upd, pgm, inproc and ipc. If the address is unicode, it is +            encoded to utf-8 first. +        """ +        cdef int rc +        cdef char* c_addr +         +        _check_version((3,2), "disconnect") +        _check_closed(self) +        if isinstance(addr, unicode): +            addr = addr.encode('utf-8') +        if not isinstance(addr, bytes): +            raise TypeError('expected str, got: %r' % addr) +        c_addr = addr +         +        rc = zmq_disconnect(self.handle, c_addr) +        if rc != 0: +            raise ZMQError() + +    def monitor(self, addr, int events=ZMQ_EVENT_ALL): +        """s.monitor(addr, flags) + +        Start publishing socket events on inproc. +        See libzmq docs for zmq_monitor for details. +         +        While this function is available from libzmq 3.2, +        pyzmq cannot parse monitor messages from libzmq prior to 4.0. +         +        .. versionadded: libzmq-3.2 +        .. versionadded: 14.0 +         +        Parameters +        ---------- +        addr : str +            The inproc url used for monitoring. Passing None as +            the addr will cause an existing socket monitor to be +            deregistered. +        events : int [default: zmq.EVENT_ALL] +            The zmq event bitmask for which events will be sent to the monitor. +        """ +        cdef int rc, c_flags +        cdef char* c_addr = NULL +         +        _check_version((3,2), "monitor") +        if addr is not None: +            if isinstance(addr, unicode): +                addr = addr.encode('utf-8') +            if not isinstance(addr, bytes): +                raise TypeError('expected str, got: %r' % addr) +            c_addr = addr +        c_flags = events +        rc = zmq_socket_monitor(self.handle, c_addr, c_flags) +        _check_rc(rc) + +    #------------------------------------------------------------------------- +    # Sending and receiving messages +    #------------------------------------------------------------------------- + +    cpdef object send(self, object data, int flags=0, copy=True, track=False): +        """s.send(data, flags=0, copy=True, track=False) + +        Send a message on this socket. + +        This queues the message to be sent by the IO thread at a later time. + +        Parameters +        ---------- +        data : object, str, Frame +            The content of the message. +        flags : int +            Any supported flag: NOBLOCK, SNDMORE. +        copy : bool +            Should the message be sent in a copying or non-copying manner. +        track : bool +            Should the message be tracked for notification that ZMQ has +            finished with it? (ignored if copy=True) + +        Returns +        ------- +        None : if `copy` or not track +            None if message was sent, raises an exception otherwise. +        MessageTracker : if track and not copy +            a MessageTracker object, whose `pending` property will +            be True until the send is completed. +         +        Raises +        ------ +        TypeError +            If a unicode object is passed +        ValueError +            If `track=True`, but an untracked Frame is passed. +        ZMQError +            If the send does not succeed for any reason. +         +        """ +        _check_closed(self) +         +        if isinstance(data, unicode): +            raise TypeError("unicode not allowed, use send_string") +         +        if copy: +            # msg.bytes never returns the input data object +            # it is always a copy, but always the same copy +            if isinstance(data, Frame): +                data = data.buffer +            return _send_copy(self.handle, data, flags) +        else: +            if isinstance(data, Frame): +                if track and not data.tracker: +                    raise ValueError('Not a tracked message') +                msg = data +            else: +                msg = Frame(data, track=track) +            return _send_frame(self.handle, msg, flags) + +    cpdef object recv(self, int flags=0, copy=True, track=False): +        """s.recv(flags=0, copy=True, track=False) + +        Receive a message. + +        Parameters +        ---------- +        flags : int +            Any supported flag: NOBLOCK. If NOBLOCK is set, this method +            will raise a ZMQError with EAGAIN if a message is not ready. +            If NOBLOCK is not set, then this method will block until a +            message arrives. +        copy : bool +            Should the message be received in a copying or non-copying manner? +            If False a Frame object is returned, if True a string copy of +            message is returned. +        track : bool +            Should the message be tracked for notification that ZMQ has +            finished with it? (ignored if copy=True) + +        Returns +        ------- +        msg : bytes, Frame +            The received message frame.  If `copy` is False, then it will be a Frame, +            otherwise it will be bytes. +             +        Raises +        ------ +        ZMQError +            for any of the reasons zmq_msg_recv might fail. +        """ +        _check_closed(self) +         +        if copy: +            return _recv_copy(self.handle, flags) +        else: +            frame = _recv_frame(self.handle, flags, track) +            frame.more = self.getsockopt(zmq.RCVMORE) +            return frame +     + +__all__ = ['Socket', 'IPC_PATH_MAX_LEN'] diff --git a/zmq/backend/cython/utils.pxd b/zmq/backend/cython/utils.pxd new file mode 100644 index 0000000..1d7117f --- /dev/null +++ b/zmq/backend/cython/utils.pxd @@ -0,0 +1,29 @@ +"""Wrap zmq_utils.h""" + +# +#    Copyright (c) 2010 Brian E. Granger & Min Ragan-Kelley +# +#    This file is part of pyzmq. +# +#    pyzmq is free software; you can redistribute it and/or modify it under +#    the terms of the Lesser GNU General Public License as published by +#    the Free Software Foundation; either version 3 of the License, or +#    (at your option) any later version. +# +#    pyzmq is distributed in the hope that it will be useful, +#    but WITHOUT ANY WARRANTY; without even the implied warranty of +#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +#    Lesser GNU General Public License for more details. +# +#    You should have received a copy of the Lesser GNU General Public License +#    along with this program.  If not, see <http://www.gnu.org/licenses/>. +# + +#----------------------------------------------------------------------------- +# Code +#----------------------------------------------------------------------------- + + +cdef class Stopwatch: +    cdef void *watch # The C handle for the underlying zmq object + diff --git a/zmq/backend/cython/utils.pyx b/zmq/backend/cython/utils.pyx new file mode 100644 index 0000000..68976e3 --- /dev/null +++ b/zmq/backend/cython/utils.pyx @@ -0,0 +1,119 @@ +"""0MQ utils.""" + +# +#    Copyright (c) 2010-2011 Brian E. Granger & Min Ragan-Kelley +# +#    This file is part of pyzmq. +# +#    pyzmq is free software; you can redistribute it and/or modify it under +#    the terms of the Lesser GNU General Public License as published by +#    the Free Software Foundation; either version 3 of the License, or +#    (at your option) any later version. +# +#    pyzmq is distributed in the hope that it will be useful, +#    but WITHOUT ANY WARRANTY; without even the implied warranty of +#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +#    Lesser GNU General Public License for more details. +# +#    You should have received a copy of the Lesser GNU General Public License +#    along with this program.  If not, see <http://www.gnu.org/licenses/>. +# + +from libzmq cimport ( +    zmq_stopwatch_start, zmq_stopwatch_stop, zmq_sleep, zmq_curve_keypair, +    zmq_has, const_char_ptr +) +from zmq.error import ZMQError, _check_rc, _check_version +from zmq.utils.strtypes import unicode + +def has(capability): +    """Check for zmq capability by name (e.g. 'ipc', 'curve') +     +    .. versionadded:: libzmq-4.1 +    .. versionadded:: 14.1 +    """ +    _check_version((4,1), 'zmq.has') +    cdef bytes ccap +    if isinstance(capability, unicode): +        capability = capability.encode('utf8') +    ccap = capability +    return bool(zmq_has(ccap)) + +def curve_keypair(): +    """generate a Z85 keypair for use with zmq.CURVE security +     +    Requires libzmq (≥ 4.0) to have been linked with libsodium. +     +    .. versionadded:: libzmq-4.0 +    .. versionadded:: 14.0 +     +    Returns +    ------- +    (public, secret) : two bytestrings +        The public and private keypair as 40 byte z85-encoded bytestrings. +    """ +    cdef int rc +    cdef char[64] public_key +    cdef char[64] secret_key +    _check_version((4,0), "curve_keypair") +    rc = zmq_curve_keypair (public_key, secret_key) +    _check_rc(rc) +    return public_key, secret_key + + +cdef class Stopwatch: +    """Stopwatch() + +    A simple stopwatch based on zmq_stopwatch_start/stop. + +    This class should be used for benchmarking and timing 0MQ code. +    """ + +    def __cinit__(self): +        self.watch = NULL + +    def __dealloc__(self): +        # copy of self.stop() we can't call object methods in dealloc as it +        # might already be partially deleted +        if self.watch: +            zmq_stopwatch_stop(self.watch) +            self.watch = NULL + +    def start(self): +        """s.start() + +        Start the stopwatch. +        """ +        if self.watch == NULL: +            self.watch = zmq_stopwatch_start() +        else: +            raise ZMQError('Stopwatch is already running.') + +    def stop(self): +        """s.stop() + +        Stop the stopwatch. +         +        Returns +        ------- +        t : unsigned long int +            the number of microseconds since ``start()`` was called. +        """ +        cdef unsigned long time +        if self.watch == NULL: +            raise ZMQError('Must start the Stopwatch before calling stop.') +        else: +            time = zmq_stopwatch_stop(self.watch) +            self.watch = NULL +            return time + +    def sleep(self, int seconds): +        """s.sleep(seconds) + +        Sleep for an integer number of seconds. +        """ +        with nogil: +            zmq_sleep(seconds) + + +__all__ = ['has', 'curve_keypair', 'Stopwatch'] diff --git a/zmq/backend/select.py b/zmq/backend/select.py new file mode 100644 index 0000000..0a2e09a --- /dev/null +++ b/zmq/backend/select.py @@ -0,0 +1,39 @@ +"""Import basic exposure of libzmq C API as a backend""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +public_api = [ +    'Context', +    'Socket', +    'Frame', +    'Message', +    'Stopwatch', +    'device', +    'proxy', +    'zmq_poll', +    'strerror', +    'zmq_errno', +    'has', +    'curve_keypair', +    'constants', +    'zmq_version_info', +    'IPC_PATH_MAX_LEN', +] + +def select_backend(name): +    """Select the pyzmq backend""" +    try: +        mod = __import__(name, fromlist=public_api) +    except ImportError: +        raise +    except Exception as e: +        import sys +        from zmq.utils.sixcerpt import reraise +        exc_info = sys.exc_info() +        reraise(ImportError, ImportError("Importing %s failed with %s" % (name, e)), exc_info[2]) +     +    ns = {} +    for key in public_api: +        ns[key] = getattr(mod, key) +    return ns diff --git a/zmq/devices/__init__.py b/zmq/devices/__init__.py new file mode 100644 index 0000000..2371596 --- /dev/null +++ b/zmq/devices/__init__.py @@ -0,0 +1,16 @@ +"""0MQ Device classes for running in background threads or processes.""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +from zmq import device +from zmq.devices import basedevice, proxydevice, monitoredqueue, monitoredqueuedevice + +from zmq.devices.basedevice import * +from zmq.devices.proxydevice import * +from zmq.devices.monitoredqueue import * +from zmq.devices.monitoredqueuedevice import * + +__all__ = ['device'] +for submod in (basedevice, proxydevice, monitoredqueue, monitoredqueuedevice): +    __all__.extend(submod.__all__) diff --git a/zmq/devices/basedevice.py b/zmq/devices/basedevice.py new file mode 100644 index 0000000..7ba1b7a --- /dev/null +++ b/zmq/devices/basedevice.py @@ -0,0 +1,229 @@ +"""Classes for running 0MQ Devices in the background.""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + + +import time +from threading import Thread +from multiprocessing import Process + +from zmq import device, QUEUE, Context, ETERM, ZMQError + + +class Device: +    """A 0MQ Device to be run in the background. +     +    You do not pass Socket instances to this, but rather Socket types:: + +        Device(device_type, in_socket_type, out_socket_type) + +    For instance:: + +        dev = Device(zmq.QUEUE, zmq.DEALER, zmq.ROUTER) + +    Similar to zmq.device, but socket types instead of sockets themselves are +    passed, and the sockets are created in the work thread, to avoid issues +    with thread safety. As a result, additional bind_{in|out} and +    connect_{in|out} methods and setsockopt_{in|out} allow users to specify +    connections for the sockets. +     +    Parameters +    ---------- +    device_type : int +        The 0MQ Device type +    {in|out}_type : int +        zmq socket types, to be passed later to context.socket(). e.g. +        zmq.PUB, zmq.SUB, zmq.REQ. If out_type is < 0, then in_socket is used +        for both in_socket and out_socket. +         +    Methods +    ------- +    bind_{in_out}(iface) +        passthrough for ``{in|out}_socket.bind(iface)``, to be called in the thread +    connect_{in_out}(iface) +        passthrough for ``{in|out}_socket.connect(iface)``, to be called in the +        thread +    setsockopt_{in_out}(opt,value) +        passthrough for ``{in|out}_socket.setsockopt(opt, value)``, to be called in +        the thread +     +    Attributes +    ---------- +    daemon : int +        sets whether the thread should be run as a daemon +        Default is true, because if it is false, the thread will not +        exit unless it is killed +    context_factory : callable (class attribute) +        Function for creating the Context. This will be Context.instance +        in ThreadDevices, and Context in ProcessDevices.  The only reason +        it is not instance() in ProcessDevices is that there may be a stale +        Context instance already initialized, and the forked environment +        should *never* try to use it. +    """ +     +    context_factory = Context.instance +    """Callable that returns a context. Typically either Context.instance or Context, +    depending on whether the device should share the global instance or not. +    """ + +    def __init__(self, device_type=QUEUE, in_type=None, out_type=None): +        self.device_type = device_type +        if in_type is None: +            raise TypeError("in_type must be specified") +        if out_type is None: +            raise TypeError("out_type must be specified") +        self.in_type = in_type +        self.out_type = out_type +        self._in_binds = [] +        self._in_connects = [] +        self._in_sockopts = [] +        self._out_binds = [] +        self._out_connects = [] +        self._out_sockopts = [] +        self.daemon = True +        self.done = False +     +    def bind_in(self, addr): +        """Enqueue ZMQ address for binding on in_socket. + +        See zmq.Socket.bind for details. +        """ +        self._in_binds.append(addr) +     +    def connect_in(self, addr): +        """Enqueue ZMQ address for connecting on in_socket. + +        See zmq.Socket.connect for details. +        """ +        self._in_connects.append(addr) +     +    def setsockopt_in(self, opt, value): +        """Enqueue setsockopt(opt, value) for in_socket + +        See zmq.Socket.setsockopt for details. +        """ +        self._in_sockopts.append((opt, value)) +     +    def bind_out(self, addr): +        """Enqueue ZMQ address for binding on out_socket. + +        See zmq.Socket.bind for details. +        """ +        self._out_binds.append(addr) +     +    def connect_out(self, addr): +        """Enqueue ZMQ address for connecting on out_socket. + +        See zmq.Socket.connect for details. +        """ +        self._out_connects.append(addr) +     +    def setsockopt_out(self, opt, value): +        """Enqueue setsockopt(opt, value) for out_socket + +        See zmq.Socket.setsockopt for details. +        """ +        self._out_sockopts.append((opt, value)) +     +    def _setup_sockets(self): +        ctx = self.context_factory() +         +        self._context = ctx +         +        # create the sockets +        ins = ctx.socket(self.in_type) +        if self.out_type < 0: +            outs = ins +        else: +            outs = ctx.socket(self.out_type) +         +        # set sockopts (must be done first, in case of zmq.IDENTITY) +        for opt,value in self._in_sockopts: +            ins.setsockopt(opt, value) +        for opt,value in self._out_sockopts: +            outs.setsockopt(opt, value) +         +        for iface in self._in_binds: +            ins.bind(iface) +        for iface in self._out_binds: +            outs.bind(iface) +         +        for iface in self._in_connects: +            ins.connect(iface) +        for iface in self._out_connects: +            outs.connect(iface) +         +        return ins,outs +     +    def run_device(self): +        """The runner method. + +        Do not call me directly, instead call ``self.start()``, just like a Thread. +        """ +        ins,outs = self._setup_sockets() +        device(self.device_type, ins, outs) +     +    def run(self): +        """wrap run_device in try/catch ETERM""" +        try: +            self.run_device() +        except ZMQError as e: +            if e.errno == ETERM: +                # silence TERM errors, because this should be a clean shutdown +                pass +            else: +                raise +        finally: +            self.done = True +     +    def start(self): +        """Start the device. Override me in subclass for other launchers.""" +        return self.run() + +    def join(self,timeout=None): +        """wait for me to finish, like Thread.join. +         +        Reimplemented appropriately by subclasses.""" +        tic = time.time() +        toc = tic +        while not self.done and not (timeout is not None and toc-tic > timeout): +            time.sleep(.001) +            toc = time.time() + + +class BackgroundDevice(Device): +    """Base class for launching Devices in background processes and threads.""" + +    launcher=None +    _launch_class=None + +    def start(self): +        self.launcher = self._launch_class(target=self.run) +        self.launcher.daemon = self.daemon +        return self.launcher.start() + +    def join(self, timeout=None): +        return self.launcher.join(timeout=timeout) + + +class ThreadDevice(BackgroundDevice): +    """A Device that will be run in a background Thread. + +    See Device for details. +    """ +    _launch_class=Thread + +class ProcessDevice(BackgroundDevice): +    """A Device that will be run in a background Process. + +    See Device for details. +    """ +    _launch_class=Process +    context_factory = Context +    """Callable that returns a context. Typically either Context.instance or Context, +    depending on whether the device should share the global instance or not. +    """ + + +__all__ = ['Device', 'ThreadDevice', 'ProcessDevice'] diff --git a/zmq/devices/monitoredqueue.pxd b/zmq/devices/monitoredqueue.pxd new file mode 100644 index 0000000..1e26ed8 --- /dev/null +++ b/zmq/devices/monitoredqueue.pxd @@ -0,0 +1,177 @@ +"""MonitoredQueue class declarations. + +Authors +------- +* MinRK +* Brian Granger +""" + +# +#    Copyright (c) 2010 Min Ragan-Kelley, Brian Granger +# +#    This file is part of pyzmq, but is derived and adapted from zmq_queue.cpp +#    originally from libzmq-2.1.6, used under LGPLv3 +# +#    pyzmq is free software; you can redistribute it and/or modify it under +#    the terms of the Lesser GNU General Public License as published by +#    the Free Software Foundation; either version 3 of the License, or +#    (at your option) any later version. +# +#    pyzmq is distributed in the hope that it will be useful, +#    but WITHOUT ANY WARRANTY; without even the implied warranty of +#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +#    Lesser GNU General Public License for more details. +# +#    You should have received a copy of the Lesser GNU General Public License +#    along with this program.  If not, see <http://www.gnu.org/licenses/>. +# + +#----------------------------------------------------------------------------- +# Imports +#----------------------------------------------------------------------------- + +from libzmq cimport * + +#----------------------------------------------------------------------------- +# MonitoredQueue C functions +#----------------------------------------------------------------------------- + +cdef inline int _relay(void *insocket_, void *outsocket_, void *sidesocket_,  +                zmq_msg_t msg, zmq_msg_t side_msg, zmq_msg_t id_msg, +                bint swap_ids) nogil: +    cdef int rc +    cdef int64_t flag_2 +    cdef int flag_3 +    cdef int flags +    cdef bint more +    cdef size_t flagsz +    cdef void * flag_ptr +     +    if ZMQ_VERSION_MAJOR < 3: +        flagsz = sizeof (int64_t) +        flag_ptr = &flag_2 +    else: +        flagsz = sizeof (int) +        flag_ptr = &flag_3 +     +    if swap_ids:# both router, must send second identity first +        # recv two ids into msg, id_msg +        rc = zmq_msg_recv(&msg, insocket_, 0) +        if rc < 0: return rc +         +        rc = zmq_msg_recv(&id_msg, insocket_, 0) +        if rc < 0: return rc + +        # send second id (id_msg) first +        #!!!! always send a copy before the original !!!! +        rc = zmq_msg_copy(&side_msg, &id_msg) +        if rc < 0: return rc +        rc = zmq_msg_send(&side_msg, outsocket_, ZMQ_SNDMORE) +        if rc < 0: return rc +        rc = zmq_msg_send(&id_msg, sidesocket_, ZMQ_SNDMORE) +        if rc < 0: return rc +        # send first id (msg) second +        rc = zmq_msg_copy(&side_msg, &msg) +        if rc < 0: return rc +        rc = zmq_msg_send(&side_msg, outsocket_, ZMQ_SNDMORE) +        if rc < 0: return rc +        rc = zmq_msg_send(&msg, sidesocket_, ZMQ_SNDMORE) +        if rc < 0: return rc +    while (True): +        rc = zmq_msg_recv(&msg, insocket_, 0) +        if rc < 0: return rc +        # assert (rc == 0) +        rc = zmq_getsockopt (insocket_, ZMQ_RCVMORE, flag_ptr, &flagsz) +        if rc < 0: return rc +        flags = 0 +        if ZMQ_VERSION_MAJOR < 3: +            if flag_2: +                flags |= ZMQ_SNDMORE +        else: +            if flag_3: +                flags |= ZMQ_SNDMORE +            # LABEL has been removed: +            # rc = zmq_getsockopt (insocket_, ZMQ_RCVLABEL, flag_ptr, &flagsz) +            # if flag_3: +            #     flags |= ZMQ_SNDLABEL +        # assert (rc == 0) + +        rc = zmq_msg_copy(&side_msg, &msg) +        if rc < 0: return rc +        if flags: +            rc = zmq_msg_send(&side_msg, outsocket_, flags) +            if rc < 0: return rc +            # only SNDMORE for side-socket +            rc = zmq_msg_send(&msg, sidesocket_, ZMQ_SNDMORE) +            if rc < 0: return rc +        else: +            rc = zmq_msg_send(&side_msg, outsocket_, 0) +            if rc < 0: return rc +            rc = zmq_msg_send(&msg, sidesocket_, 0) +            if rc < 0: return rc +            break +    return rc + +# the MonitoredQueue C function, adapted from zmq::queue.cpp : +cdef inline int c_monitored_queue (void *insocket_, void *outsocket_, +                        void *sidesocket_, zmq_msg_t *in_msg_ptr,  +                        zmq_msg_t *out_msg_ptr, int swap_ids) nogil: +    """The actual C function for a monitored queue device.  + +    See ``monitored_queue()`` for details. +    """ +     +    cdef zmq_msg_t msg +    cdef int rc = zmq_msg_init (&msg) +    cdef zmq_msg_t id_msg +    rc = zmq_msg_init (&id_msg) +    if rc < 0: return rc +    cdef zmq_msg_t side_msg +    rc = zmq_msg_init (&side_msg) +    if rc < 0: return rc +     +    cdef zmq_pollitem_t items [2] +    items [0].socket = insocket_ +    items [0].fd = 0 +    items [0].events = ZMQ_POLLIN +    items [0].revents = 0 +    items [1].socket = outsocket_ +    items [1].fd = 0 +    items [1].events = ZMQ_POLLIN +    items [1].revents = 0 +    # I don't think sidesocket should be polled? +    # items [2].socket = sidesocket_ +    # items [2].fd = 0 +    # items [2].events = ZMQ_POLLIN +    # items [2].revents = 0 +     +    while (True): +     +        # //  Wait while there are either requests or replies to process. +        rc = zmq_poll (&items [0], 2, -1) +        if rc < 0: return rc +        # //  The algorithm below asumes ratio of request and replies processed +        # //  under full load to be 1:1. Although processing requests replies +        # //  first is tempting it is suspectible to DoS attacks (overloading +        # //  the system with unsolicited replies). +        #  +        # //  Process a request. +        if (items [0].revents & ZMQ_POLLIN): +            # send in_prefix to side socket +            rc = zmq_msg_copy(&side_msg, in_msg_ptr) +            if rc < 0: return rc +            rc = zmq_msg_send(&side_msg, sidesocket_, ZMQ_SNDMORE) +            if rc < 0: return rc +            # relay the rest of the message +            rc = _relay(insocket_, outsocket_, sidesocket_, msg, side_msg, id_msg, swap_ids) +            if rc < 0: return rc +        if (items [1].revents & ZMQ_POLLIN): +            # send out_prefix to side socket +            rc = zmq_msg_copy(&side_msg, out_msg_ptr) +            if rc < 0: return rc +            rc = zmq_msg_send(&side_msg, sidesocket_, ZMQ_SNDMORE) +            if rc < 0: return rc +            # relay the rest of the message +            rc = _relay(outsocket_, insocket_, sidesocket_, msg, side_msg, id_msg, swap_ids) +            if rc < 0: return rc +    return rc diff --git a/zmq/devices/monitoredqueue.py b/zmq/devices/monitoredqueue.py new file mode 100644 index 0000000..c6d9142 --- /dev/null +++ b/zmq/devices/monitoredqueue.py @@ -0,0 +1,37 @@ +"""pure Python monitored_queue function + +For use when Cython extension is unavailable (PyPy). + +Authors +------- +* MinRK +""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +import zmq + +def _relay(ins, outs, sides, prefix, swap_ids): +    msg = ins.recv_multipart() +    if swap_ids: +        msg[:2] = msg[:2][::-1] +    outs.send_multipart(msg) +    sides.send_multipart([prefix] + msg) + +def monitored_queue(in_socket, out_socket, mon_socket, +                    in_prefix=b'in', out_prefix=b'out'): +     +    swap_ids = in_socket.type == zmq.ROUTER and out_socket.type == zmq.ROUTER +     +    poller = zmq.Poller() +    poller.register(in_socket, zmq.POLLIN) +    poller.register(out_socket, zmq.POLLIN) +    while True: +        events = dict(poller.poll()) +        if in_socket in events: +            _relay(in_socket, out_socket, mon_socket, in_prefix, swap_ids) +        if out_socket in events: +            _relay(out_socket, in_socket, mon_socket, out_prefix, swap_ids) + +__all__ = ['monitored_queue'] diff --git a/zmq/devices/monitoredqueue.pyx b/zmq/devices/monitoredqueue.pyx new file mode 100644 index 0000000..d5fec64 --- /dev/null +++ b/zmq/devices/monitoredqueue.pyx @@ -0,0 +1,103 @@ +"""MonitoredQueue classes and functions. + +Authors +------- +* MinRK +* Brian Granger +""" + +#----------------------------------------------------------------------------- +#  Copyright (c) 2010-2012 Brian Granger, Min Ragan-Kelley +# +#  This file is part of pyzmq +# +#  Distributed under the terms of the New BSD License.  The full license is in +#  the file COPYING.BSD, distributed as part of this software. +#----------------------------------------------------------------------------- + +#----------------------------------------------------------------------------- +# Imports +#----------------------------------------------------------------------------- + +cdef extern from "Python.h": +    ctypedef int Py_ssize_t + +from libc.string cimport memcpy + +from buffers cimport asbuffer_r +from libzmq cimport * + +from zmq.backend.cython.socket cimport Socket +from zmq.backend.cython.checkrc cimport _check_rc + +from zmq import ROUTER, ZMQError + +#----------------------------------------------------------------------------- +# MonitoredQueue functions +#----------------------------------------------------------------------------- + + +def monitored_queue(Socket in_socket, Socket out_socket, Socket mon_socket, +                    bytes in_prefix=b'in', bytes out_prefix=b'out'): +    """monitored_queue(in_socket, out_socket, mon_socket, +                       in_prefix=b'in', out_prefix=b'out') +     +    Start a monitored queue device. +     +    A monitored queue is very similar to the zmq.proxy device (monitored queue came first). +     +    Differences from zmq.proxy: +     +    - monitored_queue supports both in and out being ROUTER sockets +      (via swapping IDENTITY prefixes). +    - monitor messages are prefixed, making in and out messages distinguishable. +     +    Parameters +    ---------- +    in_socket : Socket +        One of the sockets to the Queue. Its messages will be prefixed with +        'in'. +    out_socket : Socket +        One of the sockets to the Queue. Its messages will be prefixed with +        'out'. The only difference between in/out socket is this prefix. +    mon_socket : Socket +        This socket sends out every message received by each of the others +        with an in/out prefix specifying which one it was. +    in_prefix : str +        Prefix added to broadcast messages from in_socket. +    out_prefix : str +        Prefix added to broadcast messages from out_socket. +    """ +     +    cdef void *ins=in_socket.handle +    cdef void *outs=out_socket.handle +    cdef void *mons=mon_socket.handle +    cdef zmq_msg_t in_msg +    cdef zmq_msg_t out_msg +    cdef bint swap_ids +    cdef char *msg_c = NULL +    cdef Py_ssize_t msg_c_len +    cdef int rc + +    # force swap_ids if both ROUTERs +    swap_ids = (in_socket.type == ROUTER and out_socket.type == ROUTER) +     +    # build zmq_msg objects from str prefixes +    asbuffer_r(in_prefix, <void **>&msg_c, &msg_c_len) +    rc = zmq_msg_init_size(&in_msg, msg_c_len) +    _check_rc(rc) +     +    memcpy(zmq_msg_data(&in_msg), msg_c, zmq_msg_size(&in_msg)) +     +    asbuffer_r(out_prefix, <void **>&msg_c, &msg_c_len) +     +    rc = zmq_msg_init_size(&out_msg, msg_c_len) +    _check_rc(rc) +     +    with nogil: +        memcpy(zmq_msg_data(&out_msg), msg_c, zmq_msg_size(&out_msg)) +        rc = c_monitored_queue(ins, outs, mons, &in_msg, &out_msg, swap_ids) +    _check_rc(rc) +    return rc + +__all__ = ['monitored_queue'] diff --git a/zmq/devices/monitoredqueuedevice.py b/zmq/devices/monitoredqueuedevice.py new file mode 100644 index 0000000..9723f86 --- /dev/null +++ b/zmq/devices/monitoredqueuedevice.py @@ -0,0 +1,66 @@ +"""MonitoredQueue classes and functions.""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + + +from zmq import ZMQError, PUB +from zmq.devices.proxydevice import ProxyBase, Proxy, ThreadProxy, ProcessProxy +from zmq.devices.monitoredqueue import monitored_queue + + +class MonitoredQueueBase(ProxyBase): +    """Base class for overriding methods.""" +     +    _in_prefix = b'' +    _out_prefix = b'' +     +    def __init__(self, in_type, out_type, mon_type=PUB, in_prefix=b'in', out_prefix=b'out'): +         +        ProxyBase.__init__(self, in_type=in_type, out_type=out_type, mon_type=mon_type) +         +        self._in_prefix = in_prefix +        self._out_prefix = out_prefix + +    def run_device(self): +        ins,outs,mons = self._setup_sockets() +        monitored_queue(ins, outs, mons, self._in_prefix, self._out_prefix) + + +class MonitoredQueue(MonitoredQueueBase, Proxy): +    """Class for running monitored_queue in the background. + +    See zmq.devices.Device for most of the spec. MonitoredQueue differs from Proxy, +    only in that it adds a ``prefix`` to messages sent on the monitor socket, +    with a different prefix for each direction. +     +    MQ also supports ROUTER on both sides, which zmq.proxy does not. + +    If a message arrives on `in_sock`, it will be prefixed with `in_prefix` on the monitor socket. +    If it arrives on out_sock, it will be prefixed with `out_prefix`. + +    A PUB socket is the most logical choice for the mon_socket, but it is not required. +    """ +    pass + + +class ThreadMonitoredQueue(MonitoredQueueBase, ThreadProxy): +    """Run zmq.monitored_queue in a background thread. +     +    See MonitoredQueue and Proxy for details. +    """ +    pass + + +class ProcessMonitoredQueue(MonitoredQueueBase, ProcessProxy): +    """Run zmq.monitored_queue in a background thread. +     +    See MonitoredQueue and Proxy for details. +    """ + + +__all__ = [ +    'MonitoredQueue', +    'ThreadMonitoredQueue', +    'ProcessMonitoredQueue' +] diff --git a/zmq/devices/proxydevice.py b/zmq/devices/proxydevice.py new file mode 100644 index 0000000..68be3f1 --- /dev/null +++ b/zmq/devices/proxydevice.py @@ -0,0 +1,90 @@ +"""Proxy classes and functions.""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +import zmq +from zmq.devices.basedevice import Device, ThreadDevice, ProcessDevice + + +class ProxyBase(object): +    """Base class for overriding methods.""" +     +    def __init__(self, in_type, out_type, mon_type=zmq.PUB): +         +        Device.__init__(self, in_type=in_type, out_type=out_type) +        self.mon_type = mon_type +        self._mon_binds = [] +        self._mon_connects = [] +        self._mon_sockopts = [] + +    def bind_mon(self, addr): +        """Enqueue ZMQ address for binding on mon_socket. + +        See zmq.Socket.bind for details. +        """ +        self._mon_binds.append(addr) + +    def connect_mon(self, addr): +        """Enqueue ZMQ address for connecting on mon_socket. + +        See zmq.Socket.bind for details. +        """ +        self._mon_connects.append(addr) + +    def setsockopt_mon(self, opt, value): +        """Enqueue setsockopt(opt, value) for mon_socket + +        See zmq.Socket.setsockopt for details. +        """ +        self._mon_sockopts.append((opt, value)) + +    def _setup_sockets(self): +        ins,outs = Device._setup_sockets(self) +        ctx = self._context +        mons = ctx.socket(self.mon_type) +         +        # set sockopts (must be done first, in case of zmq.IDENTITY) +        for opt,value in self._mon_sockopts: +            mons.setsockopt(opt, value) +         +        for iface in self._mon_binds: +            mons.bind(iface) +         +        for iface in self._mon_connects: +            mons.connect(iface) +         +        return ins,outs,mons +     +    def run_device(self): +        ins,outs,mons = self._setup_sockets() +        zmq.proxy(ins, outs, mons) + +class Proxy(ProxyBase, Device): +    """Threadsafe Proxy object. + +    See zmq.devices.Device for most of the spec. This subclass adds a +    <method>_mon version of each <method>_{in|out} method, for configuring the +    monitor socket. + +    A Proxy is a 3-socket ZMQ Device that functions just like a +    QUEUE, except each message is also sent out on the monitor socket. + +    A PUB socket is the most logical choice for the mon_socket, but it is not required. +    """ +    pass + +class ThreadProxy(ProxyBase, ThreadDevice): +    """Proxy in a Thread. See Proxy for more.""" +    pass + +class ProcessProxy(ProxyBase, ProcessDevice): +    """Proxy in a Process. See Proxy for more.""" +    pass + + +__all__ = [ +    'Proxy', +    'ThreadProxy', +    'ProcessProxy', +] diff --git a/zmq/error.py b/zmq/error.py new file mode 100644 index 0000000..48cdaaf --- /dev/null +++ b/zmq/error.py @@ -0,0 +1,164 @@ +"""0MQ Error classes and functions.""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + + +class ZMQBaseError(Exception): +    """Base exception class for 0MQ errors in Python.""" +    pass + +class ZMQError(ZMQBaseError): +    """Wrap an errno style error. + +    Parameters +    ---------- +    errno : int +        The ZMQ errno or None.  If None, then ``zmq_errno()`` is called and +        used. +    msg : string +        Description of the error or None. +    """ +    errno = None + +    def __init__(self, errno=None, msg=None): +        """Wrap an errno style error. + +        Parameters +        ---------- +        errno : int +            The ZMQ errno or None.  If None, then ``zmq_errno()`` is called and +            used. +        msg : string +            Description of the error or None. +        """ +        from zmq.backend import strerror, zmq_errno +        if errno is None: +            errno = zmq_errno() +        if isinstance(errno, int): +            self.errno = errno +            if msg is None: +                self.strerror = strerror(errno) +            else: +                self.strerror = msg +        else: +            if msg is None: +                self.strerror = str(errno) +            else: +                self.strerror = msg +        # flush signals, because there could be a SIGINT +        # waiting to pounce, resulting in uncaught exceptions. +        # Doing this here means getting SIGINT during a blocking +        # libzmq call will raise a *catchable* KeyboardInterrupt +        # PyErr_CheckSignals() + +    def __str__(self): +        return self.strerror +     +    def __repr__(self): +        return "ZMQError('%s')"%self.strerror + + +class ZMQBindError(ZMQBaseError): +    """An error for ``Socket.bind_to_random_port()``. +     +    See Also +    -------- +    .Socket.bind_to_random_port +    """ +    pass + + +class NotDone(ZMQBaseError): +    """Raised when timeout is reached while waiting for 0MQ to finish with a Message +     +    See Also +    -------- +    .MessageTracker.wait : object for tracking when ZeroMQ is done +    """ +    pass + + +class ContextTerminated(ZMQError): +    """Wrapper for zmq.ETERM +     +    .. versionadded:: 13.0 +    """ +    pass + + +class Again(ZMQError): +    """Wrapper for zmq.EAGAIN +     +    .. versionadded:: 13.0 +    """ +    pass + + +def _check_rc(rc, errno=None): +    """internal utility for checking zmq return condition +     +    and raising the appropriate Exception class +    """ +    if rc < 0: +        from zmq.backend import zmq_errno +        if errno is None: +            errno = zmq_errno() +        from zmq import EAGAIN, ETERM +        if errno == EAGAIN: +            raise Again(errno) +        elif errno == ETERM: +            raise ContextTerminated(errno) +        else: +            raise ZMQError(errno) + +_zmq_version_info = None +_zmq_version = None + +class ZMQVersionError(NotImplementedError): +    """Raised when a feature is not provided by the linked version of libzmq. +     +    .. versionadded:: 14.2 +    """ +    min_version = None +    def __init__(self, min_version, msg='Feature'): +        global _zmq_version +        if _zmq_version is None: +            from zmq import zmq_version +            _zmq_version = zmq_version() +        self.msg = msg +        self.min_version = min_version +        self.version = _zmq_version +     +    def __repr__(self): +        return "ZMQVersionError('%s')" % str(self) +     +    def __str__(self): +        return "%s requires libzmq >= %s, have %s" % (self.msg, self.min_version, self.version) + + +def _check_version(min_version_info, msg='Feature'): +    """Check for libzmq +     +    raises ZMQVersionError if current zmq version is not at least min_version +     +    min_version_info is a tuple of integers, and will be compared against zmq.zmq_version_info(). +    """ +    global _zmq_version_info +    if _zmq_version_info is None: +        from zmq import zmq_version_info +        _zmq_version_info = zmq_version_info() +    if _zmq_version_info < min_version_info: +        min_version = '.'.join(str(v) for v in min_version_info) +        raise ZMQVersionError(min_version, msg) + + +__all__ = [ +    'ZMQBaseError', +    'ZMQBindError', +    'ZMQError', +    'NotDone', +    'ContextTerminated', +    'Again', +    'ZMQVersionError', +] diff --git a/zmq/eventloop/__init__.py b/zmq/eventloop/__init__.py new file mode 100644 index 0000000..568e8e8 --- /dev/null +++ b/zmq/eventloop/__init__.py @@ -0,0 +1,5 @@ +"""A Tornado based event loop for PyZMQ.""" + +from zmq.eventloop.ioloop import IOLoop + +__all__ = ['IOLoop']
\ No newline at end of file diff --git a/zmq/eventloop/ioloop.py b/zmq/eventloop/ioloop.py new file mode 100644 index 0000000..35f4c41 --- /dev/null +++ b/zmq/eventloop/ioloop.py @@ -0,0 +1,193 @@ +# coding: utf-8 +"""tornado IOLoop API with zmq compatibility + +If you have tornado ≥ 3.0, this is a subclass of tornado's IOLoop, +otherwise we ship a minimal subset of tornado in zmq.eventloop.minitornado. + +The minimal shipped version of tornado's IOLoop does not include +support for concurrent futures - this will only be available if you +have tornado ≥ 3.0. +""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +from __future__ import absolute_import, division, with_statement + +import os +import time +import warnings + +from zmq import ( +    Poller, +    POLLIN, POLLOUT, POLLERR, +    ZMQError, ETERM, +) + +try: +    import tornado +    tornado_version = tornado.version_info +except (ImportError, AttributeError): +    tornado_version = () + +try: +    # tornado ≥ 3 +    from tornado.ioloop import PollIOLoop, PeriodicCallback +    from tornado.log import gen_log +except ImportError: +    from .minitornado.ioloop import PollIOLoop, PeriodicCallback +    from .minitornado.log import gen_log + + +class DelayedCallback(PeriodicCallback): +    """Schedules the given callback to be called once. + +    The callback is called once, after callback_time milliseconds. + +    `start` must be called after the DelayedCallback is created. +     +    The timeout is calculated from when `start` is called. +    """ +    def __init__(self, callback, callback_time, io_loop=None): +        # PeriodicCallback require callback_time to be positive +        warnings.warn("""DelayedCallback is deprecated. +        Use loop.add_timeout instead.""", DeprecationWarning) +        callback_time = max(callback_time, 1e-3) +        super(DelayedCallback, self).__init__(callback, callback_time, io_loop) +     +    def start(self): +        """Starts the timer.""" +        self._running = True +        self._firstrun = True +        self._next_timeout = time.time() + self.callback_time / 1000.0 +        self.io_loop.add_timeout(self._next_timeout, self._run) +     +    def _run(self): +        if not self._running: return +        self._running = False +        try: +            self.callback() +        except Exception: +            gen_log.error("Error in delayed callback", exc_info=True) + + +class ZMQPoller(object): +    """A poller that can be used in the tornado IOLoop. +     +    This simply wraps a regular zmq.Poller, scaling the timeout +    by 1000, so that it is in seconds rather than milliseconds. +    """ +     +    def __init__(self): +        self._poller = Poller() +     +    @staticmethod +    def _map_events(events): +        """translate IOLoop.READ/WRITE/ERROR event masks into zmq.POLLIN/OUT/ERR""" +        z_events = 0 +        if events & IOLoop.READ: +            z_events |= POLLIN +        if events & IOLoop.WRITE: +            z_events |= POLLOUT +        if events & IOLoop.ERROR: +            z_events |= POLLERR +        return z_events +     +    @staticmethod +    def _remap_events(z_events): +        """translate zmq.POLLIN/OUT/ERR event masks into IOLoop.READ/WRITE/ERROR""" +        events = 0 +        if z_events & POLLIN: +            events |= IOLoop.READ +        if z_events & POLLOUT: +            events |= IOLoop.WRITE +        if z_events & POLLERR: +            events |= IOLoop.ERROR +        return events +     +    def register(self, fd, events): +        return self._poller.register(fd, self._map_events(events)) +     +    def modify(self, fd, events): +        return self._poller.modify(fd, self._map_events(events)) +     +    def unregister(self, fd): +        return self._poller.unregister(fd) +     +    def poll(self, timeout): +        """poll in seconds rather than milliseconds. +         +        Event masks will be IOLoop.READ/WRITE/ERROR +        """ +        z_events = self._poller.poll(1000*timeout) +        return [ (fd,self._remap_events(evt)) for (fd,evt) in z_events ] +     +    def close(self): +        pass + + +class ZMQIOLoop(PollIOLoop): +    """ZMQ subclass of tornado's IOLoop""" +    def initialize(self, impl=None, **kwargs): +        impl = ZMQPoller() if impl is None else impl +        super(ZMQIOLoop, self).initialize(impl=impl, **kwargs) +     +    @staticmethod +    def instance(): +        """Returns a global `IOLoop` instance. +         +        Most applications have a single, global `IOLoop` running on the +        main thread.  Use this method to get this instance from +        another thread.  To get the current thread's `IOLoop`, use `current()`. +        """ +        # install ZMQIOLoop as the active IOLoop implementation +        # when using tornado 3 +        if tornado_version >= (3,): +            PollIOLoop.configure(ZMQIOLoop) +        return PollIOLoop.instance() +     +    def start(self): +        try: +            super(ZMQIOLoop, self).start() +        except ZMQError as e: +            if e.errno == ETERM: +                # quietly return on ETERM +                pass +            else: +                raise e + + +if tornado_version >= (3,0) and tornado_version < (3,1): +    def backport_close(self, all_fds=False): +        """backport IOLoop.close to 3.0 from 3.1 (supports fd.close() method)""" +        from zmq.eventloop.minitornado.ioloop import PollIOLoop as mini_loop +        return mini_loop.close.__get__(self)(all_fds) +    ZMQIOLoop.close = backport_close + + +# public API name +IOLoop = ZMQIOLoop + + +def install(): +    """set the tornado IOLoop instance with the pyzmq IOLoop. +     +    After calling this function, tornado's IOLoop.instance() and pyzmq's +    IOLoop.instance() will return the same object. +     +    An assertion error will be raised if tornado's IOLoop has been initialized +    prior to calling this function. +    """ +    from tornado import ioloop +    # check if tornado's IOLoop is already initialized to something other +    # than the pyzmq IOLoop instance: +    assert (not ioloop.IOLoop.initialized()) or \ +        ioloop.IOLoop.instance() is IOLoop.instance(), "tornado IOLoop already initialized" +     +    if tornado_version >= (3,): +        # tornado 3 has an official API for registering new defaults, yay! +        ioloop.IOLoop.configure(ZMQIOLoop) +    else: +        # we have to set the global instance explicitly +        ioloop.IOLoop._instance = IOLoop.instance() + diff --git a/zmq/eventloop/minitornado/__init__.py b/zmq/eventloop/minitornado/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/zmq/eventloop/minitornado/__init__.py diff --git a/zmq/eventloop/minitornado/concurrent.py b/zmq/eventloop/minitornado/concurrent.py new file mode 100644 index 0000000..519b23d --- /dev/null +++ b/zmq/eventloop/minitornado/concurrent.py @@ -0,0 +1,11 @@ +"""pyzmq does not ship tornado's futures, +this just raises informative NotImplementedErrors to avoid having to change too much code. +""" + +class NotImplementedFuture(object): +    def __init__(self, *args, **kwargs): +        raise NotImplementedError("pyzmq does not ship tornado's Futures, " +            "install tornado >= 3.0 for future support." +        ) + +Future = TracebackFuture = NotImplementedFuture diff --git a/zmq/eventloop/minitornado/ioloop.py b/zmq/eventloop/minitornado/ioloop.py new file mode 100644 index 0000000..710a3ec --- /dev/null +++ b/zmq/eventloop/minitornado/ioloop.py @@ -0,0 +1,829 @@ +#!/usr/bin/env python +# +# Copyright 2009 Facebook +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +#     http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +"""An I/O event loop for non-blocking sockets. + +Typical applications will use a single `IOLoop` object, in the +`IOLoop.instance` singleton.  The `IOLoop.start` method should usually +be called at the end of the ``main()`` function.  Atypical applications may +use more than one `IOLoop`, such as one `IOLoop` per thread, or per `unittest` +case. + +In addition to I/O events, the `IOLoop` can also schedule time-based events. +`IOLoop.add_timeout` is a non-blocking alternative to `time.sleep`. +""" + +from __future__ import absolute_import, division, print_function, with_statement + +import datetime +import errno +import functools +import heapq +import logging +import numbers +import os +import select +import sys +import threading +import time +import traceback + +from .concurrent import Future, TracebackFuture +from .log import app_log, gen_log +from . import stack_context +from .util import Configurable + +try: +    import signal +except ImportError: +    signal = None + +try: +    import thread  # py2 +except ImportError: +    import _thread as thread  # py3 + +from .platform.auto import set_close_exec, Waker + + +class TimeoutError(Exception): +    pass + + +class IOLoop(Configurable): +    """A level-triggered I/O loop. + +    We use ``epoll`` (Linux) or ``kqueue`` (BSD and Mac OS X) if they +    are available, or else we fall back on select(). If you are +    implementing a system that needs to handle thousands of +    simultaneous connections, you should use a system that supports +    either ``epoll`` or ``kqueue``. + +    Example usage for a simple TCP server:: + +        import errno +        import functools +        import ioloop +        import socket + +        def connection_ready(sock, fd, events): +            while True: +                try: +                    connection, address = sock.accept() +                except socket.error, e: +                    if e.args[0] not in (errno.EWOULDBLOCK, errno.EAGAIN): +                        raise +                    return +                connection.setblocking(0) +                handle_connection(connection, address) + +        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0) +        sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) +        sock.setblocking(0) +        sock.bind(("", port)) +        sock.listen(128) + +        io_loop = ioloop.IOLoop.instance() +        callback = functools.partial(connection_ready, sock) +        io_loop.add_handler(sock.fileno(), callback, io_loop.READ) +        io_loop.start() + +    """ +    # Constants from the epoll module +    _EPOLLIN = 0x001 +    _EPOLLPRI = 0x002 +    _EPOLLOUT = 0x004 +    _EPOLLERR = 0x008 +    _EPOLLHUP = 0x010 +    _EPOLLRDHUP = 0x2000 +    _EPOLLONESHOT = (1 << 30) +    _EPOLLET = (1 << 31) + +    # Our events map exactly to the epoll events +    NONE = 0 +    READ = _EPOLLIN +    WRITE = _EPOLLOUT +    ERROR = _EPOLLERR | _EPOLLHUP + +    # Global lock for creating global IOLoop instance +    _instance_lock = threading.Lock() + +    _current = threading.local() + +    @staticmethod +    def instance(): +        """Returns a global `IOLoop` instance. + +        Most applications have a single, global `IOLoop` running on the +        main thread.  Use this method to get this instance from +        another thread.  To get the current thread's `IOLoop`, use `current()`. +        """ +        if not hasattr(IOLoop, "_instance"): +            with IOLoop._instance_lock: +                if not hasattr(IOLoop, "_instance"): +                    # New instance after double check +                    IOLoop._instance = IOLoop() +        return IOLoop._instance + +    @staticmethod +    def initialized(): +        """Returns true if the singleton instance has been created.""" +        return hasattr(IOLoop, "_instance") + +    def install(self): +        """Installs this `IOLoop` object as the singleton instance. + +        This is normally not necessary as `instance()` will create +        an `IOLoop` on demand, but you may want to call `install` to use +        a custom subclass of `IOLoop`. +        """ +        assert not IOLoop.initialized() +        IOLoop._instance = self + +    @staticmethod +    def current(): +        """Returns the current thread's `IOLoop`. + +        If an `IOLoop` is currently running or has been marked as current +        by `make_current`, returns that instance.  Otherwise returns +        `IOLoop.instance()`, i.e. the main thread's `IOLoop`. + +        A common pattern for classes that depend on ``IOLoops`` is to use +        a default argument to enable programs with multiple ``IOLoops`` +        but not require the argument for simpler applications:: + +            class MyClass(object): +                def __init__(self, io_loop=None): +                    self.io_loop = io_loop or IOLoop.current() + +        In general you should use `IOLoop.current` as the default when +        constructing an asynchronous object, and use `IOLoop.instance` +        when you mean to communicate to the main thread from a different +        one. +        """ +        current = getattr(IOLoop._current, "instance", None) +        if current is None: +            return IOLoop.instance() +        return current + +    def make_current(self): +        """Makes this the `IOLoop` for the current thread. + +        An `IOLoop` automatically becomes current for its thread +        when it is started, but it is sometimes useful to call +        `make_current` explictly before starting the `IOLoop`, +        so that code run at startup time can find the right +        instance. +        """ +        IOLoop._current.instance = self + +    @staticmethod +    def clear_current(): +        IOLoop._current.instance = None + +    @classmethod +    def configurable_base(cls): +        return IOLoop + +    @classmethod +    def configurable_default(cls): +        # this is the only patch to IOLoop: +        from zmq.eventloop.ioloop import ZMQIOLoop +        return ZMQIOLoop +        # the remainder of this method is unused, +        # but left for preservation reasons +        if hasattr(select, "epoll"): +            from tornado.platform.epoll import EPollIOLoop +            return EPollIOLoop +        if hasattr(select, "kqueue"): +            # Python 2.6+ on BSD or Mac +            from tornado.platform.kqueue import KQueueIOLoop +            return KQueueIOLoop +        from tornado.platform.select import SelectIOLoop +        return SelectIOLoop + +    def initialize(self): +        pass + +    def close(self, all_fds=False): +        """Closes the `IOLoop`, freeing any resources used. + +        If ``all_fds`` is true, all file descriptors registered on the +        IOLoop will be closed (not just the ones created by the +        `IOLoop` itself). + +        Many applications will only use a single `IOLoop` that runs for the +        entire lifetime of the process.  In that case closing the `IOLoop` +        is not necessary since everything will be cleaned up when the +        process exits.  `IOLoop.close` is provided mainly for scenarios +        such as unit tests, which create and destroy a large number of +        ``IOLoops``. + +        An `IOLoop` must be completely stopped before it can be closed.  This +        means that `IOLoop.stop()` must be called *and* `IOLoop.start()` must +        be allowed to return before attempting to call `IOLoop.close()`. +        Therefore the call to `close` will usually appear just after +        the call to `start` rather than near the call to `stop`. + +        .. versionchanged:: 3.1 +           If the `IOLoop` implementation supports non-integer objects +           for "file descriptors", those objects will have their +           ``close`` method when ``all_fds`` is true. +        """ +        raise NotImplementedError() + +    def add_handler(self, fd, handler, events): +        """Registers the given handler to receive the given events for fd. + +        The ``events`` argument is a bitwise or of the constants +        ``IOLoop.READ``, ``IOLoop.WRITE``, and ``IOLoop.ERROR``. + +        When an event occurs, ``handler(fd, events)`` will be run. +        """ +        raise NotImplementedError() + +    def update_handler(self, fd, events): +        """Changes the events we listen for fd.""" +        raise NotImplementedError() + +    def remove_handler(self, fd): +        """Stop listening for events on fd.""" +        raise NotImplementedError() + +    def set_blocking_signal_threshold(self, seconds, action): +        """Sends a signal if the `IOLoop` is blocked for more than +        ``s`` seconds. + +        Pass ``seconds=None`` to disable.  Requires Python 2.6 on a unixy +        platform. + +        The action parameter is a Python signal handler.  Read the +        documentation for the `signal` module for more information. +        If ``action`` is None, the process will be killed if it is +        blocked for too long. +        """ +        raise NotImplementedError() + +    def set_blocking_log_threshold(self, seconds): +        """Logs a stack trace if the `IOLoop` is blocked for more than +        ``s`` seconds. + +        Equivalent to ``set_blocking_signal_threshold(seconds, +        self.log_stack)`` +        """ +        self.set_blocking_signal_threshold(seconds, self.log_stack) + +    def log_stack(self, signal, frame): +        """Signal handler to log the stack trace of the current thread. + +        For use with `set_blocking_signal_threshold`. +        """ +        gen_log.warning('IOLoop blocked for %f seconds in\n%s', +                        self._blocking_signal_threshold, +                        ''.join(traceback.format_stack(frame))) + +    def start(self): +        """Starts the I/O loop. + +        The loop will run until one of the callbacks calls `stop()`, which +        will make the loop stop after the current event iteration completes. +        """ +        raise NotImplementedError() + +    def stop(self): +        """Stop the I/O loop. + +        If the event loop is not currently running, the next call to `start()` +        will return immediately. + +        To use asynchronous methods from otherwise-synchronous code (such as +        unit tests), you can start and stop the event loop like this:: + +          ioloop = IOLoop() +          async_method(ioloop=ioloop, callback=ioloop.stop) +          ioloop.start() + +        ``ioloop.start()`` will return after ``async_method`` has run +        its callback, whether that callback was invoked before or +        after ``ioloop.start``. + +        Note that even after `stop` has been called, the `IOLoop` is not +        completely stopped until `IOLoop.start` has also returned. +        Some work that was scheduled before the call to `stop` may still +        be run before the `IOLoop` shuts down. +        """ +        raise NotImplementedError() + +    def run_sync(self, func, timeout=None): +        """Starts the `IOLoop`, runs the given function, and stops the loop. + +        If the function returns a `.Future`, the `IOLoop` will run +        until the future is resolved.  If it raises an exception, the +        `IOLoop` will stop and the exception will be re-raised to the +        caller. + +        The keyword-only argument ``timeout`` may be used to set +        a maximum duration for the function.  If the timeout expires, +        a `TimeoutError` is raised. + +        This method is useful in conjunction with `tornado.gen.coroutine` +        to allow asynchronous calls in a ``main()`` function:: + +            @gen.coroutine +            def main(): +                # do stuff... + +            if __name__ == '__main__': +                IOLoop.instance().run_sync(main) +        """ +        future_cell = [None] + +        def run(): +            try: +                result = func() +            except Exception: +                future_cell[0] = TracebackFuture() +                future_cell[0].set_exc_info(sys.exc_info()) +            else: +                if isinstance(result, Future): +                    future_cell[0] = result +                else: +                    future_cell[0] = Future() +                    future_cell[0].set_result(result) +            self.add_future(future_cell[0], lambda future: self.stop()) +        self.add_callback(run) +        if timeout is not None: +            timeout_handle = self.add_timeout(self.time() + timeout, self.stop) +        self.start() +        if timeout is not None: +            self.remove_timeout(timeout_handle) +        if not future_cell[0].done(): +            raise TimeoutError('Operation timed out after %s seconds' % timeout) +        return future_cell[0].result() + +    def time(self): +        """Returns the current time according to the `IOLoop`'s clock. + +        The return value is a floating-point number relative to an +        unspecified time in the past. + +        By default, the `IOLoop`'s time function is `time.time`.  However, +        it may be configured to use e.g. `time.monotonic` instead. +        Calls to `add_timeout` that pass a number instead of a +        `datetime.timedelta` should use this function to compute the +        appropriate time, so they can work no matter what time function +        is chosen. +        """ +        return time.time() + +    def add_timeout(self, deadline, callback): +        """Runs the ``callback`` at the time ``deadline`` from the I/O loop. + +        Returns an opaque handle that may be passed to +        `remove_timeout` to cancel. + +        ``deadline`` may be a number denoting a time (on the same +        scale as `IOLoop.time`, normally `time.time`), or a +        `datetime.timedelta` object for a deadline relative to the +        current time. + +        Note that it is not safe to call `add_timeout` from other threads. +        Instead, you must use `add_callback` to transfer control to the +        `IOLoop`'s thread, and then call `add_timeout` from there. +        """ +        raise NotImplementedError() + +    def remove_timeout(self, timeout): +        """Cancels a pending timeout. + +        The argument is a handle as returned by `add_timeout`.  It is +        safe to call `remove_timeout` even if the callback has already +        been run. +        """ +        raise NotImplementedError() + +    def add_callback(self, callback, *args, **kwargs): +        """Calls the given callback on the next I/O loop iteration. + +        It is safe to call this method from any thread at any time, +        except from a signal handler.  Note that this is the **only** +        method in `IOLoop` that makes this thread-safety guarantee; all +        other interaction with the `IOLoop` must be done from that +        `IOLoop`'s thread.  `add_callback()` may be used to transfer +        control from other threads to the `IOLoop`'s thread. + +        To add a callback from a signal handler, see +        `add_callback_from_signal`. +        """ +        raise NotImplementedError() + +    def add_callback_from_signal(self, callback, *args, **kwargs): +        """Calls the given callback on the next I/O loop iteration. + +        Safe for use from a Python signal handler; should not be used +        otherwise. + +        Callbacks added with this method will be run without any +        `.stack_context`, to avoid picking up the context of the function +        that was interrupted by the signal. +        """ +        raise NotImplementedError() + +    def add_future(self, future, callback): +        """Schedules a callback on the ``IOLoop`` when the given +        `.Future` is finished. + +        The callback is invoked with one argument, the +        `.Future`. +        """ +        assert isinstance(future, Future) +        callback = stack_context.wrap(callback) +        future.add_done_callback( +            lambda future: self.add_callback(callback, future)) + +    def _run_callback(self, callback): +        """Runs a callback with error handling. + +        For use in subclasses. +        """ +        try: +            callback() +        except Exception: +            self.handle_callback_exception(callback) + +    def handle_callback_exception(self, callback): +        """This method is called whenever a callback run by the `IOLoop` +        throws an exception. + +        By default simply logs the exception as an error.  Subclasses +        may override this method to customize reporting of exceptions. + +        The exception itself is not passed explicitly, but is available +        in `sys.exc_info`. +        """ +        app_log.error("Exception in callback %r", callback, exc_info=True) + + +class PollIOLoop(IOLoop): +    """Base class for IOLoops built around a select-like function. + +    For concrete implementations, see `tornado.platform.epoll.EPollIOLoop` +    (Linux), `tornado.platform.kqueue.KQueueIOLoop` (BSD and Mac), or +    `tornado.platform.select.SelectIOLoop` (all platforms). +    """ +    def initialize(self, impl, time_func=None): +        super(PollIOLoop, self).initialize() +        self._impl = impl +        if hasattr(self._impl, 'fileno'): +            set_close_exec(self._impl.fileno()) +        self.time_func = time_func or time.time +        self._handlers = {} +        self._events = {} +        self._callbacks = [] +        self._callback_lock = threading.Lock() +        self._timeouts = [] +        self._cancellations = 0 +        self._running = False +        self._stopped = False +        self._closing = False +        self._thread_ident = None +        self._blocking_signal_threshold = None + +        # Create a pipe that we send bogus data to when we want to wake +        # the I/O loop when it is idle +        self._waker = Waker() +        self.add_handler(self._waker.fileno(), +                         lambda fd, events: self._waker.consume(), +                         self.READ) + +    def close(self, all_fds=False): +        with self._callback_lock: +            self._closing = True +        self.remove_handler(self._waker.fileno()) +        if all_fds: +            for fd in self._handlers.keys(): +                try: +                    close_method = getattr(fd, 'close', None) +                    if close_method is not None: +                        close_method() +                    else: +                        os.close(fd) +                except Exception: +                    gen_log.debug("error closing fd %s", fd, exc_info=True) +        self._waker.close() +        self._impl.close() + +    def add_handler(self, fd, handler, events): +        self._handlers[fd] = stack_context.wrap(handler) +        self._impl.register(fd, events | self.ERROR) + +    def update_handler(self, fd, events): +        self._impl.modify(fd, events | self.ERROR) + +    def remove_handler(self, fd): +        self._handlers.pop(fd, None) +        self._events.pop(fd, None) +        try: +            self._impl.unregister(fd) +        except Exception: +            gen_log.debug("Error deleting fd from IOLoop", exc_info=True) + +    def set_blocking_signal_threshold(self, seconds, action): +        if not hasattr(signal, "setitimer"): +            gen_log.error("set_blocking_signal_threshold requires a signal module " +                          "with the setitimer method") +            return +        self._blocking_signal_threshold = seconds +        if seconds is not None: +            signal.signal(signal.SIGALRM, +                          action if action is not None else signal.SIG_DFL) + +    def start(self): +        if not logging.getLogger().handlers: +            # The IOLoop catches and logs exceptions, so it's +            # important that log output be visible.  However, python's +            # default behavior for non-root loggers (prior to python +            # 3.2) is to print an unhelpful "no handlers could be +            # found" message rather than the actual log entry, so we +            # must explicitly configure logging if we've made it this +            # far without anything. +            logging.basicConfig() +        if self._stopped: +            self._stopped = False +            return +        old_current = getattr(IOLoop._current, "instance", None) +        IOLoop._current.instance = self +        self._thread_ident = thread.get_ident() +        self._running = True + +        # signal.set_wakeup_fd closes a race condition in event loops: +        # a signal may arrive at the beginning of select/poll/etc +        # before it goes into its interruptible sleep, so the signal +        # will be consumed without waking the select.  The solution is +        # for the (C, synchronous) signal handler to write to a pipe, +        # which will then be seen by select. +        # +        # In python's signal handling semantics, this only matters on the +        # main thread (fortunately, set_wakeup_fd only works on the main +        # thread and will raise a ValueError otherwise). +        # +        # If someone has already set a wakeup fd, we don't want to +        # disturb it.  This is an issue for twisted, which does its +        # SIGCHILD processing in response to its own wakeup fd being +        # written to.  As long as the wakeup fd is registered on the IOLoop, +        # the loop will still wake up and everything should work. +        old_wakeup_fd = None +        if hasattr(signal, 'set_wakeup_fd') and os.name == 'posix': +            # requires python 2.6+, unix.  set_wakeup_fd exists but crashes +            # the python process on windows. +            try: +                old_wakeup_fd = signal.set_wakeup_fd(self._waker.write_fileno()) +                if old_wakeup_fd != -1: +                    # Already set, restore previous value.  This is a little racy, +                    # but there's no clean get_wakeup_fd and in real use the +                    # IOLoop is just started once at the beginning. +                    signal.set_wakeup_fd(old_wakeup_fd) +                    old_wakeup_fd = None +            except ValueError:  # non-main thread +                pass + +        while True: +            poll_timeout = 3600.0 + +            # Prevent IO event starvation by delaying new callbacks +            # to the next iteration of the event loop. +            with self._callback_lock: +                callbacks = self._callbacks +                self._callbacks = [] +            for callback in callbacks: +                self._run_callback(callback) + +            if self._timeouts: +                now = self.time() +                while self._timeouts: +                    if self._timeouts[0].callback is None: +                        # the timeout was cancelled +                        heapq.heappop(self._timeouts) +                        self._cancellations -= 1 +                    elif self._timeouts[0].deadline <= now: +                        timeout = heapq.heappop(self._timeouts) +                        self._run_callback(timeout.callback) +                    else: +                        seconds = self._timeouts[0].deadline - now +                        poll_timeout = min(seconds, poll_timeout) +                        break +                if (self._cancellations > 512 +                        and self._cancellations > (len(self._timeouts) >> 1)): +                    # Clean up the timeout queue when it gets large and it's +                    # more than half cancellations. +                    self._cancellations = 0 +                    self._timeouts = [x for x in self._timeouts +                                      if x.callback is not None] +                    heapq.heapify(self._timeouts) + +            if self._callbacks: +                # If any callbacks or timeouts called add_callback, +                # we don't want to wait in poll() before we run them. +                poll_timeout = 0.0 + +            if not self._running: +                break + +            if self._blocking_signal_threshold is not None: +                # clear alarm so it doesn't fire while poll is waiting for +                # events. +                signal.setitimer(signal.ITIMER_REAL, 0, 0) + +            try: +                event_pairs = self._impl.poll(poll_timeout) +            except Exception as e: +                # Depending on python version and IOLoop implementation, +                # different exception types may be thrown and there are +                # two ways EINTR might be signaled: +                # * e.errno == errno.EINTR +                # * e.args is like (errno.EINTR, 'Interrupted system call') +                if (getattr(e, 'errno', None) == errno.EINTR or +                    (isinstance(getattr(e, 'args', None), tuple) and +                     len(e.args) == 2 and e.args[0] == errno.EINTR)): +                    continue +                else: +                    raise + +            if self._blocking_signal_threshold is not None: +                signal.setitimer(signal.ITIMER_REAL, +                                 self._blocking_signal_threshold, 0) + +            # Pop one fd at a time from the set of pending fds and run +            # its handler. Since that handler may perform actions on +            # other file descriptors, there may be reentrant calls to +            # this IOLoop that update self._events +            self._events.update(event_pairs) +            while self._events: +                fd, events = self._events.popitem() +                try: +                    self._handlers[fd](fd, events) +                except (OSError, IOError) as e: +                    if e.args[0] == errno.EPIPE: +                        # Happens when the client closes the connection +                        pass +                    else: +                        app_log.error("Exception in I/O handler for fd %s", +                                      fd, exc_info=True) +                except Exception: +                    app_log.error("Exception in I/O handler for fd %s", +                                  fd, exc_info=True) +        # reset the stopped flag so another start/stop pair can be issued +        self._stopped = False +        if self._blocking_signal_threshold is not None: +            signal.setitimer(signal.ITIMER_REAL, 0, 0) +        IOLoop._current.instance = old_current +        if old_wakeup_fd is not None: +            signal.set_wakeup_fd(old_wakeup_fd) + +    def stop(self): +        self._running = False +        self._stopped = True +        self._waker.wake() + +    def time(self): +        return self.time_func() + +    def add_timeout(self, deadline, callback): +        timeout = _Timeout(deadline, stack_context.wrap(callback), self) +        heapq.heappush(self._timeouts, timeout) +        return timeout + +    def remove_timeout(self, timeout): +        # Removing from a heap is complicated, so just leave the defunct +        # timeout object in the queue (see discussion in +        # http://docs.python.org/library/heapq.html). +        # If this turns out to be a problem, we could add a garbage +        # collection pass whenever there are too many dead timeouts. +        timeout.callback = None +        self._cancellations += 1 + +    def add_callback(self, callback, *args, **kwargs): +        with self._callback_lock: +            if self._closing: +                raise RuntimeError("IOLoop is closing") +            list_empty = not self._callbacks +            self._callbacks.append(functools.partial( +                stack_context.wrap(callback), *args, **kwargs)) +        if list_empty and thread.get_ident() != self._thread_ident: +            # If we're in the IOLoop's thread, we know it's not currently +            # polling.  If we're not, and we added the first callback to an +            # empty list, we may need to wake it up (it may wake up on its +            # own, but an occasional extra wake is harmless).  Waking +            # up a polling IOLoop is relatively expensive, so we try to +            # avoid it when we can. +            self._waker.wake() + +    def add_callback_from_signal(self, callback, *args, **kwargs): +        with stack_context.NullContext(): +            if thread.get_ident() != self._thread_ident: +                # if the signal is handled on another thread, we can add +                # it normally (modulo the NullContext) +                self.add_callback(callback, *args, **kwargs) +            else: +                # If we're on the IOLoop's thread, we cannot use +                # the regular add_callback because it may deadlock on +                # _callback_lock.  Blindly insert into self._callbacks. +                # This is safe because the GIL makes list.append atomic. +                # One subtlety is that if the signal interrupted the +                # _callback_lock block in IOLoop.start, we may modify +                # either the old or new version of self._callbacks, +                # but either way will work. +                self._callbacks.append(functools.partial( +                    stack_context.wrap(callback), *args, **kwargs)) + + +class _Timeout(object): +    """An IOLoop timeout, a UNIX timestamp and a callback""" + +    # Reduce memory overhead when there are lots of pending callbacks +    __slots__ = ['deadline', 'callback'] + +    def __init__(self, deadline, callback, io_loop): +        if isinstance(deadline, numbers.Real): +            self.deadline = deadline +        elif isinstance(deadline, datetime.timedelta): +            self.deadline = io_loop.time() + _Timeout.timedelta_to_seconds(deadline) +        else: +            raise TypeError("Unsupported deadline %r" % deadline) +        self.callback = callback + +    @staticmethod +    def timedelta_to_seconds(td): +        """Equivalent to td.total_seconds() (introduced in python 2.7).""" +        return (td.microseconds + (td.seconds + td.days * 24 * 3600) * 10 ** 6) / float(10 ** 6) + +    # Comparison methods to sort by deadline, with object id as a tiebreaker +    # to guarantee a consistent ordering.  The heapq module uses __le__ +    # in python2.5, and __lt__ in 2.6+ (sort() and most other comparisons +    # use __lt__). +    def __lt__(self, other): +        return ((self.deadline, id(self)) < +                (other.deadline, id(other))) + +    def __le__(self, other): +        return ((self.deadline, id(self)) <= +                (other.deadline, id(other))) + + +class PeriodicCallback(object): +    """Schedules the given callback to be called periodically. + +    The callback is called every ``callback_time`` milliseconds. + +    `start` must be called after the `PeriodicCallback` is created. +    """ +    def __init__(self, callback, callback_time, io_loop=None): +        self.callback = callback +        if callback_time <= 0: +            raise ValueError("Periodic callback must have a positive callback_time") +        self.callback_time = callback_time +        self.io_loop = io_loop or IOLoop.current() +        self._running = False +        self._timeout = None + +    def start(self): +        """Starts the timer.""" +        self._running = True +        self._next_timeout = self.io_loop.time() +        self._schedule_next() + +    def stop(self): +        """Stops the timer.""" +        self._running = False +        if self._timeout is not None: +            self.io_loop.remove_timeout(self._timeout) +            self._timeout = None + +    def _run(self): +        if not self._running: +            return +        try: +            self.callback() +        except Exception: +            app_log.error("Error in periodic callback", exc_info=True) +        self._schedule_next() + +    def _schedule_next(self): +        if self._running: +            current_time = self.io_loop.time() +            while self._next_timeout <= current_time: +                self._next_timeout += self.callback_time / 1000.0 +            self._timeout = self.io_loop.add_timeout(self._next_timeout, self._run) diff --git a/zmq/eventloop/minitornado/log.py b/zmq/eventloop/minitornado/log.py new file mode 100644 index 0000000..49051e8 --- /dev/null +++ b/zmq/eventloop/minitornado/log.py @@ -0,0 +1,6 @@ +"""minimal subset of tornado.log for zmq.eventloop.minitornado""" + +import logging + +app_log = logging.getLogger("tornado.application") +gen_log = logging.getLogger("tornado.general") diff --git a/zmq/eventloop/minitornado/platform/__init__.py b/zmq/eventloop/minitornado/platform/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/zmq/eventloop/minitornado/platform/__init__.py diff --git a/zmq/eventloop/minitornado/platform/auto.py b/zmq/eventloop/minitornado/platform/auto.py new file mode 100644 index 0000000..b40ccd9 --- /dev/null +++ b/zmq/eventloop/minitornado/platform/auto.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python +# +# Copyright 2011 Facebook +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +#     http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +"""Implementation of platform-specific functionality. + +For each function or class described in `tornado.platform.interface`, +the appropriate platform-specific implementation exists in this module. +Most code that needs access to this functionality should do e.g.:: + +    from tornado.platform.auto import set_close_exec +""" + +from __future__ import absolute_import, division, print_function, with_statement + +import os + +if os.name == 'nt': +    from .common import Waker +    from .windows import set_close_exec +else: +    from .posix import set_close_exec, Waker + +try: +    # monotime monkey-patches the time module to have a monotonic function +    # in versions of python before 3.3. +    import monotime +except ImportError: +    pass +try: +    from time import monotonic as monotonic_time +except ImportError: +    monotonic_time = None diff --git a/zmq/eventloop/minitornado/platform/common.py b/zmq/eventloop/minitornado/platform/common.py new file mode 100644 index 0000000..2d75dc1 --- /dev/null +++ b/zmq/eventloop/minitornado/platform/common.py @@ -0,0 +1,91 @@ +"""Lowest-common-denominator implementations of platform functionality.""" +from __future__ import absolute_import, division, print_function, with_statement + +import errno +import socket + +from . import interface + + +class Waker(interface.Waker): +    """Create an OS independent asynchronous pipe. + +    For use on platforms that don't have os.pipe() (or where pipes cannot +    be passed to select()), but do have sockets.  This includes Windows +    and Jython. +    """ +    def __init__(self): +        # Based on Zope async.py: http://svn.zope.org/zc.ngi/trunk/src/zc/ngi/async.py + +        self.writer = socket.socket() +        # Disable buffering -- pulling the trigger sends 1 byte, +        # and we want that sent immediately, to wake up ASAP. +        self.writer.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + +        count = 0 +        while 1: +            count += 1 +            # Bind to a local port; for efficiency, let the OS pick +            # a free port for us. +            # Unfortunately, stress tests showed that we may not +            # be able to connect to that port ("Address already in +            # use") despite that the OS picked it.  This appears +            # to be a race bug in the Windows socket implementation. +            # So we loop until a connect() succeeds (almost always +            # on the first try).  See the long thread at +            # http://mail.zope.org/pipermail/zope/2005-July/160433.html +            # for hideous details. +            a = socket.socket() +            a.bind(("127.0.0.1", 0)) +            a.listen(1) +            connect_address = a.getsockname()  # assigned (host, port) pair +            try: +                self.writer.connect(connect_address) +                break    # success +            except socket.error as detail: +                if (not hasattr(errno, 'WSAEADDRINUSE') or +                        detail[0] != errno.WSAEADDRINUSE): +                    # "Address already in use" is the only error +                    # I've seen on two WinXP Pro SP2 boxes, under +                    # Pythons 2.3.5 and 2.4.1. +                    raise +                # (10048, 'Address already in use') +                # assert count <= 2 # never triggered in Tim's tests +                if count >= 10:  # I've never seen it go above 2 +                    a.close() +                    self.writer.close() +                    raise socket.error("Cannot bind trigger!") +                # Close `a` and try again.  Note:  I originally put a short +                # sleep() here, but it didn't appear to help or hurt. +                a.close() + +        self.reader, addr = a.accept() +        self.reader.setblocking(0) +        self.writer.setblocking(0) +        a.close() +        self.reader_fd = self.reader.fileno() + +    def fileno(self): +        return self.reader.fileno() + +    def write_fileno(self): +        return self.writer.fileno() + +    def wake(self): +        try: +            self.writer.send(b"x") +        except (IOError, socket.error): +            pass + +    def consume(self): +        try: +            while True: +                result = self.reader.recv(1024) +                if not result: +                    break +        except (IOError, socket.error): +            pass + +    def close(self): +        self.reader.close() +        self.writer.close() diff --git a/zmq/eventloop/minitornado/platform/interface.py b/zmq/eventloop/minitornado/platform/interface.py new file mode 100644 index 0000000..07da6ba --- /dev/null +++ b/zmq/eventloop/minitornado/platform/interface.py @@ -0,0 +1,63 @@ +#!/usr/bin/env python +# +# Copyright 2011 Facebook +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +#     http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +"""Interfaces for platform-specific functionality. + +This module exists primarily for documentation purposes and as base classes +for other tornado.platform modules.  Most code should import the appropriate +implementation from `tornado.platform.auto`. +""" + +from __future__ import absolute_import, division, print_function, with_statement + + +def set_close_exec(fd): +    """Sets the close-on-exec bit (``FD_CLOEXEC``)for a file descriptor.""" +    raise NotImplementedError() + + +class Waker(object): +    """A socket-like object that can wake another thread from ``select()``. + +    The `~tornado.ioloop.IOLoop` will add the Waker's `fileno()` to +    its ``select`` (or ``epoll`` or ``kqueue``) calls.  When another +    thread wants to wake up the loop, it calls `wake`.  Once it has woken +    up, it will call `consume` to do any necessary per-wake cleanup.  When +    the ``IOLoop`` is closed, it closes its waker too. +    """ +    def fileno(self): +        """Returns the read file descriptor for this waker. + +        Must be suitable for use with ``select()`` or equivalent on the +        local platform. +        """ +        raise NotImplementedError() + +    def write_fileno(self): +        """Returns the write file descriptor for this waker.""" +        raise NotImplementedError() + +    def wake(self): +        """Triggers activity on the waker's file descriptor.""" +        raise NotImplementedError() + +    def consume(self): +        """Called after the listen has woken up to do any necessary cleanup.""" +        raise NotImplementedError() + +    def close(self): +        """Closes the waker's file descriptor(s).""" +        raise NotImplementedError() diff --git a/zmq/eventloop/minitornado/platform/posix.py b/zmq/eventloop/minitornado/platform/posix.py new file mode 100644 index 0000000..ccffbb6 --- /dev/null +++ b/zmq/eventloop/minitornado/platform/posix.py @@ -0,0 +1,70 @@ +#!/usr/bin/env python +# +# Copyright 2011 Facebook +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +#     http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +"""Posix implementations of platform-specific functionality.""" + +from __future__ import absolute_import, division, print_function, with_statement + +import fcntl +import os + +from . import interface + + +def set_close_exec(fd): +    flags = fcntl.fcntl(fd, fcntl.F_GETFD) +    fcntl.fcntl(fd, fcntl.F_SETFD, flags | fcntl.FD_CLOEXEC) + + +def _set_nonblocking(fd): +    flags = fcntl.fcntl(fd, fcntl.F_GETFL) +    fcntl.fcntl(fd, fcntl.F_SETFL, flags | os.O_NONBLOCK) + + +class Waker(interface.Waker): +    def __init__(self): +        r, w = os.pipe() +        _set_nonblocking(r) +        _set_nonblocking(w) +        set_close_exec(r) +        set_close_exec(w) +        self.reader = os.fdopen(r, "rb", 0) +        self.writer = os.fdopen(w, "wb", 0) + +    def fileno(self): +        return self.reader.fileno() + +    def write_fileno(self): +        return self.writer.fileno() + +    def wake(self): +        try: +            self.writer.write(b"x") +        except IOError: +            pass + +    def consume(self): +        try: +            while True: +                result = self.reader.read() +                if not result: +                    break +        except IOError: +            pass + +    def close(self): +        self.reader.close() +        self.writer.close() diff --git a/zmq/eventloop/minitornado/platform/windows.py b/zmq/eventloop/minitornado/platform/windows.py new file mode 100644 index 0000000..817bdca --- /dev/null +++ b/zmq/eventloop/minitornado/platform/windows.py @@ -0,0 +1,20 @@ +# NOTE: win32 support is currently experimental, and not recommended +# for production use. + + +from __future__ import absolute_import, division, print_function, with_statement +import ctypes +import ctypes.wintypes + +# See: http://msdn.microsoft.com/en-us/library/ms724935(VS.85).aspx +SetHandleInformation = ctypes.windll.kernel32.SetHandleInformation +SetHandleInformation.argtypes = (ctypes.wintypes.HANDLE, ctypes.wintypes.DWORD, ctypes.wintypes.DWORD) +SetHandleInformation.restype = ctypes.wintypes.BOOL + +HANDLE_FLAG_INHERIT = 0x00000001 + + +def set_close_exec(fd): +    success = SetHandleInformation(fd, HANDLE_FLAG_INHERIT, 0) +    if not success: +        raise ctypes.GetLastError() diff --git a/zmq/eventloop/minitornado/stack_context.py b/zmq/eventloop/minitornado/stack_context.py new file mode 100644 index 0000000..226d804 --- /dev/null +++ b/zmq/eventloop/minitornado/stack_context.py @@ -0,0 +1,376 @@ +#!/usr/bin/env python +# +# Copyright 2010 Facebook +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +#     http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +"""`StackContext` allows applications to maintain threadlocal-like state +that follows execution as it moves to other execution contexts. + +The motivating examples are to eliminate the need for explicit +``async_callback`` wrappers (as in `tornado.web.RequestHandler`), and to +allow some additional context to be kept for logging. + +This is slightly magic, but it's an extension of the idea that an +exception handler is a kind of stack-local state and when that stack +is suspended and resumed in a new context that state needs to be +preserved.  `StackContext` shifts the burden of restoring that state +from each call site (e.g.  wrapping each `.AsyncHTTPClient` callback +in ``async_callback``) to the mechanisms that transfer control from +one context to another (e.g. `.AsyncHTTPClient` itself, `.IOLoop`, +thread pools, etc). + +Example usage:: + +    @contextlib.contextmanager +    def die_on_error(): +        try: +            yield +        except Exception: +            logging.error("exception in asynchronous operation",exc_info=True) +            sys.exit(1) + +    with StackContext(die_on_error): +        # Any exception thrown here *or in callback and its desendents* +        # will cause the process to exit instead of spinning endlessly +        # in the ioloop. +        http_client.fetch(url, callback) +    ioloop.start() + +Most applications shouln't have to work with `StackContext` directly. +Here are a few rules of thumb for when it's necessary: + +* If you're writing an asynchronous library that doesn't rely on a +  stack_context-aware library like `tornado.ioloop` or `tornado.iostream` +  (for example, if you're writing a thread pool), use +  `.stack_context.wrap()` before any asynchronous operations to capture the +  stack context from where the operation was started. + +* If you're writing an asynchronous library that has some shared +  resources (such as a connection pool), create those shared resources +  within a ``with stack_context.NullContext():`` block.  This will prevent +  ``StackContexts`` from leaking from one request to another. + +* If you want to write something like an exception handler that will +  persist across asynchronous calls, create a new `StackContext` (or +  `ExceptionStackContext`), and make your asynchronous calls in a ``with`` +  block that references your `StackContext`. +""" + +from __future__ import absolute_import, division, print_function, with_statement + +import sys +import threading + +from .util import raise_exc_info + + +class StackContextInconsistentError(Exception): +    pass + + +class _State(threading.local): +    def __init__(self): +        self.contexts = (tuple(), None) +_state = _State() + + +class StackContext(object): +    """Establishes the given context as a StackContext that will be transferred. + +    Note that the parameter is a callable that returns a context +    manager, not the context itself.  That is, where for a +    non-transferable context manager you would say:: + +      with my_context(): + +    StackContext takes the function itself rather than its result:: + +      with StackContext(my_context): + +    The result of ``with StackContext() as cb:`` is a deactivation +    callback.  Run this callback when the StackContext is no longer +    needed to ensure that it is not propagated any further (note that +    deactivating a context does not affect any instances of that +    context that are currently pending).  This is an advanced feature +    and not necessary in most applications. +    """ +    def __init__(self, context_factory): +        self.context_factory = context_factory +        self.contexts = [] +        self.active = True + +    def _deactivate(self): +        self.active = False + +    # StackContext protocol +    def enter(self): +        context = self.context_factory() +        self.contexts.append(context) +        context.__enter__() + +    def exit(self, type, value, traceback): +        context = self.contexts.pop() +        context.__exit__(type, value, traceback) + +    # Note that some of this code is duplicated in ExceptionStackContext +    # below.  ExceptionStackContext is more common and doesn't need +    # the full generality of this class. +    def __enter__(self): +        self.old_contexts = _state.contexts +        self.new_contexts = (self.old_contexts[0] + (self,), self) +        _state.contexts = self.new_contexts + +        try: +            self.enter() +        except: +            _state.contexts = self.old_contexts +            raise + +        return self._deactivate + +    def __exit__(self, type, value, traceback): +        try: +            self.exit(type, value, traceback) +        finally: +            final_contexts = _state.contexts +            _state.contexts = self.old_contexts + +            # Generator coroutines and with-statements with non-local +            # effects interact badly.  Check here for signs of +            # the stack getting out of sync. +            # Note that this check comes after restoring _state.context +            # so that if it fails things are left in a (relatively) +            # consistent state. +            if final_contexts is not self.new_contexts: +                raise StackContextInconsistentError( +                    'stack_context inconsistency (may be caused by yield ' +                    'within a "with StackContext" block)') + +            # Break up a reference to itself to allow for faster GC on CPython. +            self.new_contexts = None + + +class ExceptionStackContext(object): +    """Specialization of StackContext for exception handling. + +    The supplied ``exception_handler`` function will be called in the +    event of an uncaught exception in this context.  The semantics are +    similar to a try/finally clause, and intended use cases are to log +    an error, close a socket, or similar cleanup actions.  The +    ``exc_info`` triple ``(type, value, traceback)`` will be passed to the +    exception_handler function. + +    If the exception handler returns true, the exception will be +    consumed and will not be propagated to other exception handlers. +    """ +    def __init__(self, exception_handler): +        self.exception_handler = exception_handler +        self.active = True + +    def _deactivate(self): +        self.active = False + +    def exit(self, type, value, traceback): +        if type is not None: +            return self.exception_handler(type, value, traceback) + +    def __enter__(self): +        self.old_contexts = _state.contexts +        self.new_contexts = (self.old_contexts[0], self) +        _state.contexts = self.new_contexts + +        return self._deactivate + +    def __exit__(self, type, value, traceback): +        try: +            if type is not None: +                return self.exception_handler(type, value, traceback) +        finally: +            final_contexts = _state.contexts +            _state.contexts = self.old_contexts + +            if final_contexts is not self.new_contexts: +                raise StackContextInconsistentError( +                    'stack_context inconsistency (may be caused by yield ' +                    'within a "with StackContext" block)') + +            # Break up a reference to itself to allow for faster GC on CPython. +            self.new_contexts = None + + +class NullContext(object): +    """Resets the `StackContext`. + +    Useful when creating a shared resource on demand (e.g. an +    `.AsyncHTTPClient`) where the stack that caused the creating is +    not relevant to future operations. +    """ +    def __enter__(self): +        self.old_contexts = _state.contexts +        _state.contexts = (tuple(), None) + +    def __exit__(self, type, value, traceback): +        _state.contexts = self.old_contexts + + +def _remove_deactivated(contexts): +    """Remove deactivated handlers from the chain""" +    # Clean ctx handlers +    stack_contexts = tuple([h for h in contexts[0] if h.active]) + +    # Find new head +    head = contexts[1] +    while head is not None and not head.active: +        head = head.old_contexts[1] + +    # Process chain +    ctx = head +    while ctx is not None: +        parent = ctx.old_contexts[1] + +        while parent is not None: +            if parent.active: +                break +            ctx.old_contexts = parent.old_contexts +            parent = parent.old_contexts[1] + +        ctx = parent + +    return (stack_contexts, head) + + +def wrap(fn): +    """Returns a callable object that will restore the current `StackContext` +    when executed. + +    Use this whenever saving a callback to be executed later in a +    different execution context (either in a different thread or +    asynchronously in the same thread). +    """ +    # Check if function is already wrapped +    if fn is None or hasattr(fn, '_wrapped'): +        return fn + +    # Capture current stack head +    # TODO: Any other better way to store contexts and update them in wrapped function? +    cap_contexts = [_state.contexts] + +    def wrapped(*args, **kwargs): +        ret = None +        try: +            # Capture old state +            current_state = _state.contexts + +            # Remove deactivated items +            cap_contexts[0] = contexts = _remove_deactivated(cap_contexts[0]) + +            # Force new state +            _state.contexts = contexts + +            # Current exception +            exc = (None, None, None) +            top = None + +            # Apply stack contexts +            last_ctx = 0 +            stack = contexts[0] + +            # Apply state +            for n in stack: +                try: +                    n.enter() +                    last_ctx += 1 +                except: +                    # Exception happened. Record exception info and store top-most handler +                    exc = sys.exc_info() +                    top = n.old_contexts[1] + +            # Execute callback if no exception happened while restoring state +            if top is None: +                try: +                    ret = fn(*args, **kwargs) +                except: +                    exc = sys.exc_info() +                    top = contexts[1] + +            # If there was exception, try to handle it by going through the exception chain +            if top is not None: +                exc = _handle_exception(top, exc) +            else: +                # Otherwise take shorter path and run stack contexts in reverse order +                while last_ctx > 0: +                    last_ctx -= 1 +                    c = stack[last_ctx] + +                    try: +                        c.exit(*exc) +                    except: +                        exc = sys.exc_info() +                        top = c.old_contexts[1] +                        break +                else: +                    top = None + +                # If if exception happened while unrolling, take longer exception handler path +                if top is not None: +                    exc = _handle_exception(top, exc) + +            # If exception was not handled, raise it +            if exc != (None, None, None): +                raise_exc_info(exc) +        finally: +            _state.contexts = current_state +        return ret + +    wrapped._wrapped = True +    return wrapped + + +def _handle_exception(tail, exc): +    while tail is not None: +        try: +            if tail.exit(*exc): +                exc = (None, None, None) +        except: +            exc = sys.exc_info() + +        tail = tail.old_contexts[1] + +    return exc + + +def run_with_stack_context(context, func): +    """Run a coroutine ``func`` in the given `StackContext`. + +    It is not safe to have a ``yield`` statement within a ``with StackContext`` +    block, so it is difficult to use stack context with `.gen.coroutine`. +    This helper function runs the function in the correct context while +    keeping the ``yield`` and ``with`` statements syntactically separate. + +    Example:: + +        @gen.coroutine +        def incorrect(): +            with StackContext(ctx): +                # ERROR: this will raise StackContextInconsistentError +                yield other_coroutine() + +        @gen.coroutine +        def correct(): +            yield run_with_stack_context(StackContext(ctx), other_coroutine) + +    .. versionadded:: 3.1 +    """ +    with context: +        return func() diff --git a/zmq/eventloop/minitornado/util.py b/zmq/eventloop/minitornado/util.py new file mode 100644 index 0000000..c1e2eb9 --- /dev/null +++ b/zmq/eventloop/minitornado/util.py @@ -0,0 +1,184 @@ +"""Miscellaneous utility functions and classes. + +This module is used internally by Tornado.  It is not necessarily expected +that the functions and classes defined here will be useful to other +applications, but they are documented here in case they are. + +The one public-facing part of this module is the `Configurable` class +and its `~Configurable.configure` method, which becomes a part of the +interface of its subclasses, including `.AsyncHTTPClient`, `.IOLoop`, +and `.Resolver`. +""" + +from __future__ import absolute_import, division, print_function, with_statement + +import sys + + +def import_object(name): +    """Imports an object by name. + +    import_object('x') is equivalent to 'import x'. +    import_object('x.y.z') is equivalent to 'from x.y import z'. + +    >>> import tornado.escape +    >>> import_object('tornado.escape') is tornado.escape +    True +    >>> import_object('tornado.escape.utf8') is tornado.escape.utf8 +    True +    >>> import_object('tornado') is tornado +    True +    >>> import_object('tornado.missing_module') +    Traceback (most recent call last): +        ... +    ImportError: No module named missing_module +    """ +    if name.count('.') == 0: +        return __import__(name, None, None) + +    parts = name.split('.') +    obj = __import__('.'.join(parts[:-1]), None, None, [parts[-1]], 0) +    try: +        return getattr(obj, parts[-1]) +    except AttributeError: +        raise ImportError("No module named %s" % parts[-1]) + + +# Fake unicode literal support:  Python 3.2 doesn't have the u'' marker for +# literal strings, and alternative solutions like "from __future__ import +# unicode_literals" have other problems (see PEP 414).  u() can be applied +# to ascii strings that include \u escapes (but they must not contain +# literal non-ascii characters). +if type('') is not type(b''): +    def u(s): +        return s +    bytes_type = bytes +    unicode_type = str +    basestring_type = str +else: +    def u(s): +        return s.decode('unicode_escape') +    bytes_type = str +    unicode_type = unicode +    basestring_type = basestring + + +if sys.version_info > (3,): +    exec(""" +def raise_exc_info(exc_info): +    raise exc_info[1].with_traceback(exc_info[2]) + +def exec_in(code, glob, loc=None): +    if isinstance(code, str): +        code = compile(code, '<string>', 'exec', dont_inherit=True) +    exec(code, glob, loc) +""") +else: +    exec(""" +def raise_exc_info(exc_info): +    raise exc_info[0], exc_info[1], exc_info[2] + +def exec_in(code, glob, loc=None): +    if isinstance(code, basestring): +        # exec(string) inherits the caller's future imports; compile +        # the string first to prevent that. +        code = compile(code, '<string>', 'exec', dont_inherit=True) +    exec code in glob, loc +""") + + +class Configurable(object): +    """Base class for configurable interfaces. + +    A configurable interface is an (abstract) class whose constructor +    acts as a factory function for one of its implementation subclasses. +    The implementation subclass as well as optional keyword arguments to +    its initializer can be set globally at runtime with `configure`. + +    By using the constructor as the factory method, the interface +    looks like a normal class, `isinstance` works as usual, etc.  This +    pattern is most useful when the choice of implementation is likely +    to be a global decision (e.g. when `~select.epoll` is available, +    always use it instead of `~select.select`), or when a +    previously-monolithic class has been split into specialized +    subclasses. + +    Configurable subclasses must define the class methods +    `configurable_base` and `configurable_default`, and use the instance +    method `initialize` instead of ``__init__``. +    """ +    __impl_class = None +    __impl_kwargs = None + +    def __new__(cls, **kwargs): +        base = cls.configurable_base() +        args = {} +        if cls is base: +            impl = cls.configured_class() +            if base.__impl_kwargs: +                args.update(base.__impl_kwargs) +        else: +            impl = cls +        args.update(kwargs) +        instance = super(Configurable, cls).__new__(impl) +        # initialize vs __init__ chosen for compatiblity with AsyncHTTPClient +        # singleton magic.  If we get rid of that we can switch to __init__ +        # here too. +        instance.initialize(**args) +        return instance + +    @classmethod +    def configurable_base(cls): +        """Returns the base class of a configurable hierarchy. + +        This will normally return the class in which it is defined. +        (which is *not* necessarily the same as the cls classmethod parameter). +        """ +        raise NotImplementedError() + +    @classmethod +    def configurable_default(cls): +        """Returns the implementation class to be used if none is configured.""" +        raise NotImplementedError() + +    def initialize(self): +        """Initialize a `Configurable` subclass instance. + +        Configurable classes should use `initialize` instead of ``__init__``. +        """ + +    @classmethod +    def configure(cls, impl, **kwargs): +        """Sets the class to use when the base class is instantiated. + +        Keyword arguments will be saved and added to the arguments passed +        to the constructor.  This can be used to set global defaults for +        some parameters. +        """ +        base = cls.configurable_base() +        if isinstance(impl, (unicode_type, bytes_type)): +            impl = import_object(impl) +        if impl is not None and not issubclass(impl, cls): +            raise ValueError("Invalid subclass of %s" % cls) +        base.__impl_class = impl +        base.__impl_kwargs = kwargs + +    @classmethod +    def configured_class(cls): +        """Returns the currently configured class.""" +        base = cls.configurable_base() +        if cls.__impl_class is None: +            base.__impl_class = cls.configurable_default() +        return base.__impl_class + +    @classmethod +    def _save_configuration(cls): +        base = cls.configurable_base() +        return (base.__impl_class, base.__impl_kwargs) + +    @classmethod +    def _restore_configuration(cls, saved): +        base = cls.configurable_base() +        base.__impl_class = saved[0] +        base.__impl_kwargs = saved[1] + diff --git a/zmq/eventloop/zmqstream.py b/zmq/eventloop/zmqstream.py new file mode 100644 index 0000000..86a97e4 --- /dev/null +++ b/zmq/eventloop/zmqstream.py @@ -0,0 +1,529 @@ +# +# Copyright 2009 Facebook +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +#     http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +"""A utility class to send to and recv from a non-blocking socket.""" + +from __future__ import with_statement + +import sys + +import zmq +from zmq.utils import jsonapi + +try: +    import cPickle as pickle +except ImportError: +    import pickle + +from .ioloop import IOLoop + +try: +    # gen_log will only import from >= 3.0 +    from tornado.log import gen_log +    from tornado import stack_context +except ImportError: +    from .minitornado.log import gen_log +    from .minitornado import stack_context + +try: +    from queue import Queue +except ImportError: +    from Queue import Queue + +from zmq.utils.strtypes import bytes, unicode, basestring + +try: +    callable +except NameError: +    callable = lambda obj: hasattr(obj, '__call__') + + +class ZMQStream(object): +    """A utility class to register callbacks when a zmq socket sends and receives +     +    For use with zmq.eventloop.ioloop + +    There are three main methods +     +    Methods: +     +    * **on_recv(callback, copy=True):** +        register a callback to be run every time the socket has something to receive +    * **on_send(callback):** +        register a callback to be run every time you call send +    * **send(self, msg, flags=0, copy=False, callback=None):** +        perform a send that will trigger the callback +        if callback is passed, on_send is also called. +         +        There are also send_multipart(), send_json(), send_pyobj() +     +    Three other methods for deactivating the callbacks: +     +    * **stop_on_recv():** +        turn off the recv callback +    * **stop_on_send():** +        turn off the send callback +     +    which simply call ``on_<evt>(None)``. +     +    The entire socket interface, excluding direct recv methods, is also +    provided, primarily through direct-linking the methods. +    e.g. +     +    >>> stream.bind is stream.socket.bind +    True +     +    """ +     +    socket = None +    io_loop = None +    poller = None +     +    def __init__(self, socket, io_loop=None): +        self.socket = socket +        self.io_loop = io_loop or IOLoop.instance() +        self.poller = zmq.Poller() +         +        self._send_queue = Queue() +        self._recv_callback = None +        self._send_callback = None +        self._close_callback = None +        self._recv_copy = False +        self._flushed = False +         +        self._state = self.io_loop.ERROR +        self._init_io_state() +         +        # shortcircuit some socket methods +        self.bind = self.socket.bind +        self.bind_to_random_port = self.socket.bind_to_random_port +        self.connect = self.socket.connect +        self.setsockopt = self.socket.setsockopt +        self.getsockopt = self.socket.getsockopt +        self.setsockopt_string = self.socket.setsockopt_string +        self.getsockopt_string = self.socket.getsockopt_string +        self.setsockopt_unicode = self.socket.setsockopt_unicode +        self.getsockopt_unicode = self.socket.getsockopt_unicode +     +     +    def stop_on_recv(self): +        """Disable callback and automatic receiving.""" +        return self.on_recv(None) +     +    def stop_on_send(self): +        """Disable callback on sending.""" +        return self.on_send(None) +     +    def stop_on_err(self): +        """DEPRECATED, does nothing""" +        gen_log.warn("on_err does nothing, and will be removed") +     +    def on_err(self, callback): +        """DEPRECATED, does nothing""" +        gen_log.warn("on_err does nothing, and will be removed") +     +    def on_recv(self, callback, copy=True): +        """Register a callback for when a message is ready to recv. +         +        There can be only one callback registered at a time, so each +        call to `on_recv` replaces previously registered callbacks. +         +        on_recv(None) disables recv event polling. +         +        Use on_recv_stream(callback) instead, to register a callback that will receive +        both this ZMQStream and the message, instead of just the message. +         +        Parameters +        ---------- +         +        callback : callable +            callback must take exactly one argument, which will be a +            list, as returned by socket.recv_multipart() +            if callback is None, recv callbacks are disabled. +        copy : bool +            copy is passed directly to recv, so if copy is False, +            callback will receive Message objects. If copy is True, +            then callback will receive bytes/str objects. +         +        Returns : None +        """ +         +        self._check_closed() +        assert callback is None or callable(callback) +        self._recv_callback = stack_context.wrap(callback) +        self._recv_copy = copy +        if callback is None: +            self._drop_io_state(self.io_loop.READ) +        else: +            self._add_io_state(self.io_loop.READ) +     +    def on_recv_stream(self, callback, copy=True): +        """Same as on_recv, but callback will get this stream as first argument +         +        callback must take exactly two arguments, as it will be called as:: +         +            callback(stream, msg) +         +        Useful when a single callback should be used with multiple streams. +        """ +        if callback is None: +            self.stop_on_recv() +        else: +            self.on_recv(lambda msg: callback(self, msg), copy=copy) +     +    def on_send(self, callback): +        """Register a callback to be called on each send +         +        There will be two arguments:: +         +            callback(msg, status) +         +        * `msg` will be the list of sendable objects that was just sent +        * `status` will be the return result of socket.send_multipart(msg) - +          MessageTracker or None. +         +        Non-copying sends return a MessageTracker object whose +        `done` attribute will be True when the send is complete. +        This allows users to track when an object is safe to write to +        again. +         +        The second argument will always be None if copy=True +        on the send. +         +        Use on_send_stream(callback) to register a callback that will be passed +        this ZMQStream as the first argument, in addition to the other two. +         +        on_send(None) disables recv event polling. +         +        Parameters +        ---------- +         +        callback : callable +            callback must take exactly two arguments, which will be +            the message being sent (always a list), +            and the return result of socket.send_multipart(msg) - +            MessageTracker or None. +             +            if callback is None, send callbacks are disabled. +        """ +         +        self._check_closed() +        assert callback is None or callable(callback) +        self._send_callback = stack_context.wrap(callback) +         +     +    def on_send_stream(self, callback): +        """Same as on_send, but callback will get this stream as first argument +         +        Callback will be passed three arguments:: +         +            callback(stream, msg, status) +         +        Useful when a single callback should be used with multiple streams. +        """ +        if callback is None: +            self.stop_on_send() +        else: +            self.on_send(lambda msg, status: callback(self, msg, status)) +         +         +    def send(self, msg, flags=0, copy=True, track=False, callback=None): +        """Send a message, optionally also register a new callback for sends. +        See zmq.socket.send for details. +        """ +        return self.send_multipart([msg], flags=flags, copy=copy, track=track, callback=callback) + +    def send_multipart(self, msg, flags=0, copy=True, track=False, callback=None): +        """Send a multipart message, optionally also register a new callback for sends. +        See zmq.socket.send_multipart for details. +        """ +        kwargs = dict(flags=flags, copy=copy, track=track) +        self._send_queue.put((msg, kwargs)) +        callback = callback or self._send_callback +        if callback is not None: +            self.on_send(callback) +        else: +            # noop callback +            self.on_send(lambda *args: None) +        self._add_io_state(self.io_loop.WRITE) +     +    def send_string(self, u, flags=0, encoding='utf-8', callback=None): +        """Send a unicode message with an encoding. +        See zmq.socket.send_unicode for details. +        """ +        if not isinstance(u, basestring): +            raise TypeError("unicode/str objects only") +        return self.send(u.encode(encoding), flags=flags, callback=callback) +     +    send_unicode = send_string +     +    def send_json(self, obj, flags=0, callback=None): +        """Send json-serialized version of an object. +        See zmq.socket.send_json for details. +        """ +        if jsonapi is None: +            raise ImportError('jsonlib{1,2}, json or simplejson library is required.') +        else: +            msg = jsonapi.dumps(obj) +            return self.send(msg, flags=flags, callback=callback) + +    def send_pyobj(self, obj, flags=0, protocol=-1, callback=None): +        """Send a Python object as a message using pickle to serialize. + +        See zmq.socket.send_json for details. +        """ +        msg = pickle.dumps(obj, protocol) +        return self.send(msg, flags, callback=callback) +     +    def _finish_flush(self): +        """callback for unsetting _flushed flag.""" +        self._flushed = False +     +    def flush(self, flag=zmq.POLLIN|zmq.POLLOUT, limit=None): +        """Flush pending messages. + +        This method safely handles all pending incoming and/or outgoing messages, +        bypassing the inner loop, passing them to the registered callbacks. + +        A limit can be specified, to prevent blocking under high load. + +        flush will return the first time ANY of these conditions are met: +            * No more events matching the flag are pending. +            * the total number of events handled reaches the limit. + +        Note that if ``flag|POLLIN != 0``, recv events will be flushed even if no callback +        is registered, unlike normal IOLoop operation. This allows flush to be +        used to remove *and ignore* incoming messages. + +        Parameters +        ---------- +        flag : int, default=POLLIN|POLLOUT +                0MQ poll flags. +                If flag|POLLIN,  recv events will be flushed. +                If flag|POLLOUT, send events will be flushed. +                Both flags can be set at once, which is the default. +        limit : None or int, optional +                The maximum number of messages to send or receive. +                Both send and recv count against this limit. + +        Returns +        ------- +        int : count of events handled (both send and recv) +        """ +        self._check_closed() +        # unset self._flushed, so callbacks will execute, in case flush has +        # already been called this iteration +        already_flushed = self._flushed +        self._flushed = False +        # initialize counters +        count = 0 +        def update_flag(): +            """Update the poll flag, to prevent registering POLLOUT events +            if we don't have pending sends.""" +            return flag & zmq.POLLIN | (self.sending() and flag & zmq.POLLOUT) +        flag = update_flag() +        if not flag: +            # nothing to do +            return 0 +        self.poller.register(self.socket, flag) +        events = self.poller.poll(0) +        while events and (not limit or count < limit): +            s,event = events[0] +            if event & zmq.POLLIN: # receiving +                self._handle_recv() +                count += 1 +                if self.socket is None: +                    # break if socket was closed during callback +                    break +            if event & zmq.POLLOUT and self.sending(): +                self._handle_send() +                count += 1 +                if self.socket is None: +                    # break if socket was closed during callback +                    break +             +            flag = update_flag() +            if flag: +                self.poller.register(self.socket, flag) +                events = self.poller.poll(0) +            else: +                events = [] +        if count: # only bypass loop if we actually flushed something +            # skip send/recv callbacks this iteration +            self._flushed = True +            # reregister them at the end of the loop +            if not already_flushed: # don't need to do it again +                self.io_loop.add_callback(self._finish_flush) +        elif already_flushed: +            self._flushed = True + +        # update ioloop poll state, which may have changed +        self._rebuild_io_state() +        return count +     +    def set_close_callback(self, callback): +        """Call the given callback when the stream is closed.""" +        self._close_callback = stack_context.wrap(callback) +     +    def close(self, linger=None): +        """Close this stream.""" +        if self.socket is not None: +            self.io_loop.remove_handler(self.socket) +            self.socket.close(linger) +            self.socket = None +            if self._close_callback: +                self._run_callback(self._close_callback) + +    def receiving(self): +        """Returns True if we are currently receiving from the stream.""" +        return self._recv_callback is not None + +    def sending(self): +        """Returns True if we are currently sending to the stream.""" +        return not self._send_queue.empty() + +    def closed(self): +        return self.socket is None + +    def _run_callback(self, callback, *args, **kwargs): +        """Wrap running callbacks in try/except to allow us to +        close our socket.""" +        try: +            # Use a NullContext to ensure that all StackContexts are run +            # inside our blanket exception handler rather than outside. +            with stack_context.NullContext(): +                callback(*args, **kwargs) +        except: +            gen_log.error("Uncaught exception, closing connection.", +                          exc_info=True) +            # Close the socket on an uncaught exception from a user callback +            # (It would eventually get closed when the socket object is +            # gc'd, but we don't want to rely on gc happening before we +            # run out of file descriptors) +            self.close() +            # Re-raise the exception so that IOLoop.handle_callback_exception +            # can see it and log the error +            raise + +    def _handle_events(self, fd, events): +        """This method is the actual handler for IOLoop, that gets called whenever +        an event on my socket is posted. It dispatches to _handle_recv, etc.""" +        # print "handling events" +        if not self.socket: +            gen_log.warning("Got events for closed stream %s", fd) +            return +        try: +            # dispatch events: +            if events & IOLoop.ERROR: +                gen_log.error("got POLLERR event on ZMQStream, which doesn't make sense") +                return +            if events & IOLoop.READ: +                self._handle_recv() +                if not self.socket: +                    return +            if events & IOLoop.WRITE: +                self._handle_send() +                if not self.socket: +                    return + +            # rebuild the poll state +            self._rebuild_io_state() +        except: +            gen_log.error("Uncaught exception, closing connection.", +                          exc_info=True) +            self.close() +            raise +             +    def _handle_recv(self): +        """Handle a recv event.""" +        if self._flushed: +            return +        try: +            msg = self.socket.recv_multipart(zmq.NOBLOCK, copy=self._recv_copy) +        except zmq.ZMQError as e: +            if e.errno == zmq.EAGAIN: +                # state changed since poll event +                pass +            else: +                gen_log.error("RECV Error: %s"%zmq.strerror(e.errno)) +        else: +            if self._recv_callback: +                callback = self._recv_callback +                # self._recv_callback = None +                self._run_callback(callback, msg) +                 +        # self.update_state() +         + +    def _handle_send(self): +        """Handle a send event.""" +        if self._flushed: +            return +        if not self.sending(): +            gen_log.error("Shouldn't have handled a send event") +            return +         +        msg, kwargs = self._send_queue.get() +        try: +            status = self.socket.send_multipart(msg, **kwargs) +        except zmq.ZMQError as e: +            gen_log.error("SEND Error: %s", e) +            status = e +        if self._send_callback: +            callback = self._send_callback +            self._run_callback(callback, msg, status) +         +        # self.update_state() +     +    def _check_closed(self): +        if not self.socket: +            raise IOError("Stream is closed") +     +    def _rebuild_io_state(self): +        """rebuild io state based on self.sending() and receiving()""" +        if self.socket is None: +            return +        state = self.io_loop.ERROR +        if self.receiving(): +            state |= self.io_loop.READ +        if self.sending(): +            state |= self.io_loop.WRITE +        if state != self._state: +            self._state = state +            self._update_handler(state) +     +    def _add_io_state(self, state): +        """Add io_state to poller.""" +        if not self._state & state: +            self._state = self._state | state +            self._update_handler(self._state) +     +    def _drop_io_state(self, state): +        """Stop poller from watching an io_state.""" +        if self._state & state: +            self._state = self._state & (~state) +            self._update_handler(self._state) +     +    def _update_handler(self, state): +        """Update IOLoop handler with state.""" +        if self.socket is None: +            return +        self.io_loop.update_handler(self.socket, state) +     +    def _init_io_state(self): +        """initialize the ioloop event handler""" +        with stack_context.NullContext(): +            self.io_loop.add_handler(self.socket, self._handle_events, self._state) + diff --git a/zmq/green/__init__.py b/zmq/green/__init__.py new file mode 100644 index 0000000..ff7e596 --- /dev/null +++ b/zmq/green/__init__.py @@ -0,0 +1,40 @@ +# -*- coding: utf-8 -*- +#----------------------------------------------------------------------------- +#  Copyright (C) 2011-2012 Travis Cline +# +#  This file is part of pyzmq +#  It is adapted from upstream project zeromq_gevent under the New BSD License +# +#  Distributed under the terms of the New BSD License.  The full license is in +#  the file COPYING.BSD, distributed as part of this software. +#----------------------------------------------------------------------------- + +"""zmq.green - gevent compatibility with zeromq. + +Usage +----- + +Instead of importing zmq directly, do so in the following manner: + +.. + +    import zmq.green as zmq + + +Any calls that would have blocked the current thread will now only block the +current green thread. + +This compatibility is accomplished by ensuring the nonblocking flag is set +before any blocking operation and the ØMQ file descriptor is polled internally +to trigger needed events. +""" + +from zmq import * +from zmq.green.core import _Context, _Socket +from zmq.green.poll import _Poller +Context = _Context +Socket = _Socket +Poller = _Poller + +from zmq.green.device import device + diff --git a/zmq/green/core.py b/zmq/green/core.py new file mode 100644 index 0000000..9fc73e3 --- /dev/null +++ b/zmq/green/core.py @@ -0,0 +1,287 @@ +#----------------------------------------------------------------------------- +#  Copyright (C) 2011-2012 Travis Cline +# +#  This file is part of pyzmq +#  It is adapted from upstream project zeromq_gevent under the New BSD License +# +#  Distributed under the terms of the New BSD License.  The full license is in +#  the file COPYING.BSD, distributed as part of this software. +#----------------------------------------------------------------------------- + +"""This module wraps the :class:`Socket` and :class:`Context` found in :mod:`pyzmq <zmq>` to be non blocking +""" + +from __future__ import print_function + +import sys +import time +import warnings + +import zmq + +from zmq import Context as _original_Context +from zmq import Socket as _original_Socket +from .poll import _Poller + +import gevent +from gevent.event import AsyncResult +from gevent.hub import get_hub + +if hasattr(zmq, 'RCVTIMEO'): +    TIMEOS = (zmq.RCVTIMEO, zmq.SNDTIMEO) +else: +    TIMEOS = () + +def _stop(evt): +    """simple wrapper for stopping an Event, allowing for method rename in gevent 1.0""" +    try: +        evt.stop() +    except AttributeError as e: +        # gevent<1.0 compat +        evt.cancel() + +class _Socket(_original_Socket): +    """Green version of :class:`zmq.Socket` + +    The following methods are overridden: + +        * send +        * recv + +    To ensure that the ``zmq.NOBLOCK`` flag is set and that sending or receiving +    is deferred to the hub if a ``zmq.EAGAIN`` (retry) error is raised. +     +    The `__state_changed` method is triggered when the zmq.FD for the socket is +    marked as readable and triggers the necessary read and write events (which +    are waited for in the recv and send methods). + +    Some double underscore prefixes are used to minimize pollution of +    :class:`zmq.Socket`'s namespace. +    """ +    __in_send_multipart = False +    __in_recv_multipart = False +    __writable = None +    __readable = None +    _state_event = None +    _gevent_bug_timeout = 11.6 # timeout for not trusting gevent +    _debug_gevent = False # turn on if you think gevent is missing events +    _poller_class = _Poller +     +    def __init__(self, context, socket_type): +        _original_Socket.__init__(self, context, socket_type) +        self.__in_send_multipart = False +        self.__in_recv_multipart = False +        self.__setup_events() +         + +    def __del__(self): +        self.close() + +    def close(self, linger=None): +        super(_Socket, self).close(linger) +        self.__cleanup_events() + +    def __cleanup_events(self): +        # close the _state_event event, keeps the number of active file descriptors down +        if getattr(self, '_state_event', None): +            _stop(self._state_event) +            self._state_event = None +        # if the socket has entered a close state resume any waiting greenlets +        self.__writable.set() +        self.__readable.set() + +    def __setup_events(self): +        self.__readable = AsyncResult() +        self.__writable = AsyncResult() +        self.__readable.set() +        self.__writable.set() +         +        try: +            self._state_event = get_hub().loop.io(self.getsockopt(zmq.FD), 1) # read state watcher +            self._state_event.start(self.__state_changed) +        except AttributeError: +            # for gevent<1.0 compatibility +            from gevent.core import read_event +            self._state_event = read_event(self.getsockopt(zmq.FD), self.__state_changed, persist=True) + +    def __state_changed(self, event=None, _evtype=None): +        if self.closed: +            self.__cleanup_events() +            return +        try: +            # avoid triggering __state_changed from inside __state_changed +            events = super(_Socket, self).getsockopt(zmq.EVENTS) +        except zmq.ZMQError as exc: +            self.__writable.set_exception(exc) +            self.__readable.set_exception(exc) +        else: +            if events & zmq.POLLOUT: +                self.__writable.set() +            if events & zmq.POLLIN: +                self.__readable.set() + +    def _wait_write(self): +        assert self.__writable.ready(), "Only one greenlet can be waiting on this event" +        self.__writable = AsyncResult() +        # timeout is because libzmq cannot be trusted to properly signal a new send event: +        # this is effectively a maximum poll interval of 1s +        tic = time.time() +        dt = self._gevent_bug_timeout +        if dt: +            timeout = gevent.Timeout(seconds=dt) +        else: +            timeout = None +        try: +            if timeout: +                timeout.start() +            self.__writable.get(block=True) +        except gevent.Timeout as t: +            if t is not timeout: +                raise +            toc = time.time() +            # gevent bug: get can raise timeout even on clean return +            # don't display zmq bug warning for gevent bug (this is getting ridiculous) +            if self._debug_gevent and timeout and toc-tic > dt and \ +                    self.getsockopt(zmq.EVENTS) & zmq.POLLOUT: +                print("BUG: gevent may have missed a libzmq send event on %i!" % self.FD, file=sys.stderr) +        finally: +            if timeout: +                timeout.cancel() +            self.__writable.set() + +    def _wait_read(self): +        assert self.__readable.ready(), "Only one greenlet can be waiting on this event" +        self.__readable = AsyncResult() +        # timeout is because libzmq cannot always be trusted to play nice with libevent. +        # I can only confirm that this actually happens for send, but lets be symmetrical +        # with our dirty hacks. +        # this is effectively a maximum poll interval of 1s +        tic = time.time() +        dt = self._gevent_bug_timeout +        if dt: +            timeout = gevent.Timeout(seconds=dt) +        else: +            timeout = None +        try: +            if timeout: +                timeout.start() +            self.__readable.get(block=True) +        except gevent.Timeout as t: +            if t is not timeout: +                raise +            toc = time.time() +            # gevent bug: get can raise timeout even on clean return +            # don't display zmq bug warning for gevent bug (this is getting ridiculous) +            if self._debug_gevent and timeout and toc-tic > dt and \ +                    self.getsockopt(zmq.EVENTS) & zmq.POLLIN: +                print("BUG: gevent may have missed a libzmq recv event on %i!" % self.FD, file=sys.stderr) +        finally: +            if timeout: +                timeout.cancel() +            self.__readable.set() + +    def send(self, data, flags=0, copy=True, track=False): +        """send, which will only block current greenlet +         +        state_changed always fires exactly once (success or fail) at the +        end of this method. +        """ +         +        # if we're given the NOBLOCK flag act as normal and let the EAGAIN get raised +        if flags & zmq.NOBLOCK: +            try: +                msg = super(_Socket, self).send(data, flags, copy, track) +            finally: +                if not self.__in_send_multipart: +                    self.__state_changed() +            return msg +        # ensure the zmq.NOBLOCK flag is part of flags +        flags |= zmq.NOBLOCK +        while True: # Attempt to complete this operation indefinitely, blocking the current greenlet +            try: +                # attempt the actual call +                msg = super(_Socket, self).send(data, flags, copy, track) +            except zmq.ZMQError as e: +                # if the raised ZMQError is not EAGAIN, reraise +                if e.errno != zmq.EAGAIN: +                    if not self.__in_send_multipart: +                        self.__state_changed() +                    raise +            else: +                if not self.__in_send_multipart: +                    self.__state_changed() +                return msg +            # defer to the event loop until we're notified the socket is writable +            self._wait_write() + +    def recv(self, flags=0, copy=True, track=False): +        """recv, which will only block current greenlet +         +        state_changed always fires exactly once (success or fail) at the +        end of this method. +        """ +        if flags & zmq.NOBLOCK: +            try: +                msg = super(_Socket, self).recv(flags, copy, track) +            finally: +                if not self.__in_recv_multipart: +                    self.__state_changed() +            return msg +         +        flags |= zmq.NOBLOCK +        while True: +            try: +                msg = super(_Socket, self).recv(flags, copy, track) +            except zmq.ZMQError as e: +                if e.errno != zmq.EAGAIN: +                    if not self.__in_recv_multipart: +                        self.__state_changed() +                    raise +            else: +                if not self.__in_recv_multipart: +                    self.__state_changed() +                return msg +            self._wait_read() +     +    def send_multipart(self, *args, **kwargs): +        """wrap send_multipart to prevent state_changed on each partial send""" +        self.__in_send_multipart = True +        try: +            msg = super(_Socket, self).send_multipart(*args, **kwargs) +        finally: +            self.__in_send_multipart = False +            self.__state_changed() +        return msg +     +    def recv_multipart(self, *args, **kwargs): +        """wrap recv_multipart to prevent state_changed on each partial recv""" +        self.__in_recv_multipart = True +        try: +            msg = super(_Socket, self).recv_multipart(*args, **kwargs) +        finally: +            self.__in_recv_multipart = False +            self.__state_changed() +        return msg +     +    def get(self, opt): +        """trigger state_changed on getsockopt(EVENTS)""" +        if opt in TIMEOS: +            warnings.warn("TIMEO socket options have no effect in zmq.green", UserWarning) +        optval = super(_Socket, self).get(opt) +        if opt == zmq.EVENTS: +            self.__state_changed() +        return optval +     +    def set(self, opt, val): +        """set socket option""" +        if opt in TIMEOS: +            warnings.warn("TIMEO socket options have no effect in zmq.green", UserWarning) +        return super(_Socket, self).set(opt, val) + + +class _Context(_original_Context): +    """Replacement for :class:`zmq.Context` + +    Ensures that the greened Socket above is used in calls to `socket`. +    """ +    _socket_class = _Socket diff --git a/zmq/green/device.py b/zmq/green/device.py new file mode 100644 index 0000000..4b07023 --- /dev/null +++ b/zmq/green/device.py @@ -0,0 +1,32 @@ +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +import zmq +from zmq.green import Poller + +def device(device_type, isocket, osocket): +    """Start a zeromq device (gevent-compatible). +     +    Unlike the true zmq.device, this does not release the GIL. + +    Parameters +    ---------- +    device_type : (QUEUE, FORWARDER, STREAMER) +        The type of device to start (ignored). +    isocket : Socket +        The Socket instance for the incoming traffic. +    osocket : Socket +        The Socket instance for the outbound traffic. +    """ +    p = Poller() +    if osocket == -1: +        osocket = isocket +    p.register(isocket, zmq.POLLIN) +    p.register(osocket, zmq.POLLIN) +     +    while True: +        events = dict(p.poll()) +        if isocket in events: +            osocket.send_multipart(isocket.recv_multipart()) +        if osocket in events: +            isocket.send_multipart(osocket.recv_multipart()) diff --git a/zmq/green/eventloop/__init__.py b/zmq/green/eventloop/__init__.py new file mode 100644 index 0000000..c5150ef --- /dev/null +++ b/zmq/green/eventloop/__init__.py @@ -0,0 +1,3 @@ +from zmq.green.eventloop.ioloop import IOLoop + +__all__ = ['IOLoop']
\ No newline at end of file diff --git a/zmq/green/eventloop/ioloop.py b/zmq/green/eventloop/ioloop.py new file mode 100644 index 0000000..e12fd5e --- /dev/null +++ b/zmq/green/eventloop/ioloop.py @@ -0,0 +1,33 @@ +from zmq.eventloop.ioloop import * +from zmq.green import Poller + +RealIOLoop = IOLoop +RealZMQPoller = ZMQPoller + +class IOLoop(RealIOLoop): +     +    def initialize(self, impl=None): +        impl = _poll() if impl is None else impl +        super(IOLoop, self).initialize(impl) + +    @staticmethod +    def instance(): +        """Returns a global `IOLoop` instance. +         +        Most applications have a single, global `IOLoop` running on the +        main thread.  Use this method to get this instance from +        another thread.  To get the current thread's `IOLoop`, use `current()`. +        """ +        # install this class as the active IOLoop implementation +        # when using tornado 3 +        if tornado_version >= (3,): +            PollIOLoop.configure(IOLoop) +        return PollIOLoop.instance() + + +class ZMQPoller(RealZMQPoller): +    """gevent-compatible version of ioloop.ZMQPoller""" +    def __init__(self): +        self._poller = Poller() + +_poll = ZMQPoller diff --git a/zmq/green/eventloop/zmqstream.py b/zmq/green/eventloop/zmqstream.py new file mode 100644 index 0000000..90fbd1f --- /dev/null +++ b/zmq/green/eventloop/zmqstream.py @@ -0,0 +1,11 @@ +from zmq.eventloop.zmqstream import * + +from zmq.green.eventloop.ioloop import IOLoop + +RealZMQStream = ZMQStream + +class ZMQStream(RealZMQStream): +     +    def __init__(self, socket, io_loop=None): +        io_loop = io_loop or IOLoop.instance() +        super(ZMQStream, self).__init__(socket, io_loop=io_loop) diff --git a/zmq/green/poll.py b/zmq/green/poll.py new file mode 100644 index 0000000..8f01612 --- /dev/null +++ b/zmq/green/poll.py @@ -0,0 +1,95 @@ +import zmq +import gevent +from gevent import select + +from zmq import Poller as _original_Poller + + +class _Poller(_original_Poller): +    """Replacement for :class:`zmq.Poller` + +    Ensures that the greened Poller below is used in calls to +    :meth:`zmq.Poller.poll`. +    """ +    _gevent_bug_timeout = 1.33 # minimum poll interval, for working around gevent bug + +    def _get_descriptors(self): +        """Returns three elements tuple with socket descriptors ready +        for gevent.select.select +        """ +        rlist = [] +        wlist = [] +        xlist = [] + +        for socket, flags in self.sockets: +            if isinstance(socket, zmq.Socket): +                rlist.append(socket.getsockopt(zmq.FD)) +                continue +            elif isinstance(socket, int): +                fd = socket +            elif hasattr(socket, 'fileno'): +                try: +                    fd = int(socket.fileno()) +                except: +                    raise ValueError('fileno() must return an valid integer fd') +            else: +                raise TypeError('Socket must be a 0MQ socket, an integer fd ' +                                'or have a fileno() method: %r' % socket) + +            if flags & zmq.POLLIN: +                rlist.append(fd) +            if flags & zmq.POLLOUT: +                wlist.append(fd) +            if flags & zmq.POLLERR: +                xlist.append(fd) + +        return (rlist, wlist, xlist) + +    def poll(self, timeout=-1): +        """Overridden method to ensure that the green version of +        Poller is used. + +        Behaves the same as :meth:`zmq.core.Poller.poll` +        """ + +        if timeout is None: +            timeout = -1 + +        if timeout < 0: +            timeout = -1 + +        rlist = None +        wlist = None +        xlist = None + +        if timeout > 0: +            tout = gevent.Timeout.start_new(timeout/1000.0) + +        try: +            # Loop until timeout or events available +            rlist, wlist, xlist = self._get_descriptors() +            while True: +                events = super(_Poller, self).poll(0) +                if events or timeout == 0: +                    return events + +                # wait for activity on sockets in a green way +                # set a minimum poll frequency, +                # because gevent < 1.0 cannot be trusted to catch edge-triggered FD events +                _bug_timeout = gevent.Timeout.start_new(self._gevent_bug_timeout) +                try: +                    select.select(rlist, wlist, xlist) +                except gevent.Timeout as t: +                    if t is not _bug_timeout: +                        raise +                finally: +                    _bug_timeout.cancel() + +        except gevent.Timeout as t: +            if t is not tout: +                raise +            return [] +        finally: +           if timeout > 0: +               tout.cancel() + diff --git a/zmq/log/__init__.py b/zmq/log/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/zmq/log/__init__.py diff --git a/zmq/log/handlers.py b/zmq/log/handlers.py new file mode 100644 index 0000000..5ff21bf --- /dev/null +++ b/zmq/log/handlers.py @@ -0,0 +1,146 @@ +"""pyzmq logging handlers. + +This mainly defines the PUBHandler object for publishing logging messages over +a zmq.PUB socket. + +The PUBHandler can be used with the regular logging module, as in:: + +    >>> import logging +    >>> handler = PUBHandler('tcp://127.0.0.1:12345') +    >>> handler.root_topic = 'foo' +    >>> logger = logging.getLogger('foobar') +    >>> logger.setLevel(logging.DEBUG) +    >>> logger.addHandler(handler) + +After this point, all messages logged by ``logger`` will be published on the +PUB socket. + +Code adapted from StarCluster: + +    http://github.com/jtriley/StarCluster/blob/master/starcluster/logger.py +""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + + +import logging +from logging import INFO, DEBUG, WARN, ERROR, FATAL + +import zmq +from zmq.utils.strtypes import bytes, unicode, cast_bytes + + +TOPIC_DELIM="::" # delimiter for splitting topics on the receiving end. + + +class PUBHandler(logging.Handler): +    """A basic logging handler that emits log messages through a PUB socket. + +    Takes a PUB socket already bound to interfaces or an interface to bind to. + +    Example:: + +        sock = context.socket(zmq.PUB) +        sock.bind('inproc://log') +        handler = PUBHandler(sock) + +    Or:: + +        handler = PUBHandler('inproc://loc') + +    These are equivalent. + +    Log messages handled by this handler are broadcast with ZMQ topics +    ``this.root_topic`` comes first, followed by the log level +    (DEBUG,INFO,etc.), followed by any additional subtopics specified in the +    message by: log.debug("subtopic.subsub::the real message") +    """ +    root_topic="" +    socket = None +     +    formatters = { +        logging.DEBUG: logging.Formatter( +        "%(levelname)s %(filename)s:%(lineno)d - %(message)s\n"), +        logging.INFO: logging.Formatter("%(message)s\n"), +        logging.WARN: logging.Formatter( +        "%(levelname)s %(filename)s:%(lineno)d - %(message)s\n"), +        logging.ERROR: logging.Formatter( +        "%(levelname)s %(filename)s:%(lineno)d - %(message)s - %(exc_info)s\n"), +        logging.CRITICAL: logging.Formatter( +        "%(levelname)s %(filename)s:%(lineno)d - %(message)s\n")} +     +    def __init__(self, interface_or_socket, context=None): +        logging.Handler.__init__(self) +        if isinstance(interface_or_socket, zmq.Socket): +            self.socket = interface_or_socket +            self.ctx = self.socket.context +        else: +            self.ctx = context or zmq.Context() +            self.socket = self.ctx.socket(zmq.PUB) +            self.socket.bind(interface_or_socket) + +    def format(self,record): +        """Format a record.""" +        return self.formatters[record.levelno].format(record) + +    def emit(self, record): +        """Emit a log message on my socket.""" +        try: +            topic, record.msg = record.msg.split(TOPIC_DELIM,1) +        except Exception: +            topic = "" +        try: +            bmsg = cast_bytes(self.format(record)) +        except Exception: +            self.handleError(record) +            return +         +        topic_list = [] + +        if self.root_topic: +            topic_list.append(self.root_topic) + +        topic_list.append(record.levelname) + +        if topic: +            topic_list.append(topic) + +        btopic = b'.'.join(cast_bytes(t) for t in topic_list) + +        self.socket.send_multipart([btopic, bmsg]) + + +class TopicLogger(logging.Logger): +    """A simple wrapper that takes an additional argument to log methods. + +    All the regular methods exist, but instead of one msg argument, two +    arguments: topic, msg are passed. + +    That is:: + +        logger.debug('msg') + +    Would become:: + +        logger.debug('topic.sub', 'msg') +    """ +    def log(self, level, topic, msg, *args, **kwargs): +        """Log 'msg % args' with level and topic. + +        To pass exception information, use the keyword argument exc_info +        with a True value:: + +            logger.log(level, "zmq.fun", "We have a %s",  +                    "mysterious problem", exc_info=1) +        """ +        logging.Logger.log(self, level, '%s::%s'%(topic,msg), *args, **kwargs) + +# Generate the methods of TopicLogger, since they are just adding a +# topic prefix to a message. +for name in "debug warn warning error critical fatal".split(): +    meth = getattr(logging.Logger,name) +    setattr(TopicLogger, name,  +            lambda self, level, topic, msg, *args, **kwargs:  +                meth(self, level, topic+TOPIC_DELIM+msg,*args, **kwargs)) +     diff --git a/zmq/ssh/__init__.py b/zmq/ssh/__init__.py new file mode 100644 index 0000000..57f0956 --- /dev/null +++ b/zmq/ssh/__init__.py @@ -0,0 +1 @@ +from zmq.ssh.tunnel import * diff --git a/zmq/ssh/forward.py b/zmq/ssh/forward.py new file mode 100644 index 0000000..2d61946 --- /dev/null +++ b/zmq/ssh/forward.py @@ -0,0 +1,91 @@ +# +# This file is adapted from a paramiko demo, and thus licensed under LGPL 2.1. +# Original Copyright (C) 2003-2007  Robey Pointer <robeypointer@gmail.com> +# Edits Copyright (C) 2010 The IPython Team +# +# Paramiko is free software; you can redistribute it and/or modify it under the +# terms of the GNU Lesser General Public License as published by the Free +# Software Foundation; either version 2.1 of the License, or (at your option) +# any later version. +# +# Paramiko is distrubuted in the hope that it will be useful, but WITHOUT ANY +# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR +# A PARTICULAR PURPOSE.  See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with Paramiko; if not, write to the Free Software Foundation, Inc., +# 51 Franklin Street, Fifth Floor, Boston, MA  02111-1301  USA. + +""" +Sample script showing how to do local port forwarding over paramiko. + +This script connects to the requested SSH server and sets up local port +forwarding (the openssh -L option) from a local port through a tunneled +connection to a destination reachable from the SSH server machine. +""" + +from __future__ import print_function + +import logging +import select +try:  # Python 3 +    import socketserver +except ImportError:  # Python 2 +    import SocketServer as socketserver + +logger = logging.getLogger('ssh') + +class ForwardServer (socketserver.ThreadingTCPServer): +    daemon_threads = True +    allow_reuse_address = True +     + +class Handler (socketserver.BaseRequestHandler): + +    def handle(self): +        try: +            chan = self.ssh_transport.open_channel('direct-tcpip', +                                                   (self.chain_host, self.chain_port), +                                                   self.request.getpeername()) +        except Exception as e: +            logger.debug('Incoming request to %s:%d failed: %s' % (self.chain_host, +                                                              self.chain_port, +                                                              repr(e))) +            return +        if chan is None: +            logger.debug('Incoming request to %s:%d was rejected by the SSH server.' % +                    (self.chain_host, self.chain_port)) +            return + +        logger.debug('Connected!  Tunnel open %r -> %r -> %r' % (self.request.getpeername(), +                                                            chan.getpeername(), (self.chain_host, self.chain_port))) +        while True: +            r, w, x = select.select([self.request, chan], [], []) +            if self.request in r: +                data = self.request.recv(1024) +                if len(data) == 0: +                    break +                chan.send(data) +            if chan in r: +                data = chan.recv(1024) +                if len(data) == 0: +                    break +                self.request.send(data) +        chan.close() +        self.request.close() +        logger.debug('Tunnel closed ') + + +def forward_tunnel(local_port, remote_host, remote_port, transport): +    # this is a little convoluted, but lets me configure things for the Handler +    # object.  (SocketServer doesn't give Handlers any way to access the outer +    # server normally.) +    class SubHander (Handler): +        chain_host = remote_host +        chain_port = remote_port +        ssh_transport = transport +    ForwardServer(('127.0.0.1', local_port), SubHander).serve_forever() + + +__all__ = ['forward_tunnel'] diff --git a/zmq/ssh/tunnel.py b/zmq/ssh/tunnel.py new file mode 100644 index 0000000..6400352 --- /dev/null +++ b/zmq/ssh/tunnel.py @@ -0,0 +1,368 @@ +"""Basic ssh tunnel utilities, and convenience functions for tunneling +zeromq connections. +""" + +# Copyright (C) 2010-2011  IPython Development Team +# Copyright (C) 2011- PyZMQ Developers +# +# Redistributed from IPython under the terms of the BSD License. + + +from __future__ import print_function + +import atexit +import os +import signal +import socket +import sys +import warnings +from getpass import getpass, getuser +from multiprocessing import Process + +try: +    with warnings.catch_warnings(): +        warnings.simplefilter('ignore', DeprecationWarning) +        import paramiko +except ImportError: +    paramiko = None +else: +    from .forward import forward_tunnel + + +try: +    import pexpect +except ImportError: +    pexpect = None + + +_random_ports = set() + +def select_random_ports(n): +    """Selects and return n random ports that are available.""" +    ports = [] +    for i in range(n): +        sock = socket.socket() +        sock.bind(('', 0)) +        while sock.getsockname()[1] in _random_ports: +            sock.close() +            sock = socket.socket() +            sock.bind(('', 0)) +        ports.append(sock) +    for i, sock in enumerate(ports): +        port = sock.getsockname()[1] +        sock.close() +        ports[i] = port +        _random_ports.add(port) +    return ports + + +#----------------------------------------------------------------------------- +# Check for passwordless login +#----------------------------------------------------------------------------- + +def try_passwordless_ssh(server, keyfile, paramiko=None): +    """Attempt to make an ssh connection without a password. +    This is mainly used for requiring password input only once +    when many tunnels may be connected to the same server. +     +    If paramiko is None, the default for the platform is chosen. +    """ +    if paramiko is None: +        paramiko = sys.platform == 'win32' +    if not paramiko: +        f = _try_passwordless_openssh +    else: +        f = _try_passwordless_paramiko +    return f(server, keyfile) + +def _try_passwordless_openssh(server, keyfile): +    """Try passwordless login with shell ssh command.""" +    if pexpect is None: +        raise ImportError("pexpect unavailable, use paramiko") +    cmd = 'ssh -f '+ server +    if keyfile: +        cmd += ' -i ' + keyfile +    cmd += ' exit' +     +    # pop SSH_ASKPASS from env +    env = os.environ.copy() +    env.pop('SSH_ASKPASS', None) +     +    p = pexpect.spawn(cmd, env=env) +    while True: +        try: +            p.expect('[Pp]assword:', timeout=.1) +        except pexpect.TIMEOUT: +            continue +        except pexpect.EOF: +            return True +        else: +            return False + +def _try_passwordless_paramiko(server, keyfile): +    """Try passwordless login with paramiko.""" +    if paramiko is None: +        msg = "Paramiko unavaliable, " +        if sys.platform == 'win32': +            msg += "Paramiko is required for ssh tunneled connections on Windows." +        else: +            msg += "use OpenSSH." +        raise ImportError(msg) +    username, server, port = _split_server(server) +    client = paramiko.SSHClient() +    client.load_system_host_keys() +    client.set_missing_host_key_policy(paramiko.WarningPolicy()) +    try: +        client.connect(server, port, username=username, key_filename=keyfile, +               look_for_keys=True) +    except paramiko.AuthenticationException: +        return False +    else: +        client.close() +        return True + + +def tunnel_connection(socket, addr, server, keyfile=None, password=None, paramiko=None, timeout=60): +    """Connect a socket to an address via an ssh tunnel. +     +    This is a wrapper for socket.connect(addr), when addr is not accessible +    from the local machine.  It simply creates an ssh tunnel using the remaining args, +    and calls socket.connect('tcp://localhost:lport') where lport is the randomly +    selected local port of the tunnel. +     +    """ +    new_url, tunnel = open_tunnel(addr, server, keyfile=keyfile, password=password, paramiko=paramiko, timeout=timeout) +    socket.connect(new_url) +    return tunnel + + +def open_tunnel(addr, server, keyfile=None, password=None, paramiko=None, timeout=60): +    """Open a tunneled connection from a 0MQ url. +     +    For use inside tunnel_connection. +     +    Returns +    ------- +     +    (url, tunnel) : (str, object) +        The 0MQ url that has been forwarded, and the tunnel object +    """ +     +    lport = select_random_ports(1)[0] +    transport, addr = addr.split('://') +    ip,rport = addr.split(':') +    rport = int(rport) +    if paramiko is None: +        paramiko = sys.platform == 'win32' +    if paramiko: +        tunnelf = paramiko_tunnel +    else: +        tunnelf = openssh_tunnel +     +    tunnel = tunnelf(lport, rport, server, remoteip=ip, keyfile=keyfile, password=password, timeout=timeout) +    return 'tcp://127.0.0.1:%i'%lport, tunnel + +def openssh_tunnel(lport, rport, server, remoteip='127.0.0.1', keyfile=None, password=None, timeout=60): +    """Create an ssh tunnel using command-line ssh that connects port lport +    on this machine to localhost:rport on server.  The tunnel +    will automatically close when not in use, remaining open +    for a minimum of timeout seconds for an initial connection. +     +    This creates a tunnel redirecting `localhost:lport` to `remoteip:rport`, +    as seen from `server`. +     +    keyfile and password may be specified, but ssh config is checked for defaults. +     +    Parameters +    ---------- +     +    lport : int +        local port for connecting to the tunnel from this machine. +    rport : int +        port on the remote machine to connect to. +    server : str +        The ssh server to connect to. The full ssh server string will be parsed. +        user@server:port +    remoteip : str [Default: 127.0.0.1] +        The remote ip, specifying the destination of the tunnel. +        Default is localhost, which means that the tunnel would redirect +        localhost:lport on this machine to localhost:rport on the *server*. +         +    keyfile : str; path to public key file +        This specifies a key to be used in ssh login, default None. +        Regular default ssh keys will be used without specifying this argument. +    password : str;  +        Your ssh password to the ssh server. Note that if this is left None, +        you will be prompted for it if passwordless key based login is unavailable. +    timeout : int [default: 60] +        The time (in seconds) after which no activity will result in the tunnel +        closing.  This prevents orphaned tunnels from running forever. +    """ +    if pexpect is None: +        raise ImportError("pexpect unavailable, use paramiko_tunnel") +    ssh="ssh " +    if keyfile: +        ssh += "-i " + keyfile +     +    if ':' in server: +        server, port = server.split(':') +        ssh += " -p %s" % port +     +    cmd = "%s -O check %s" % (ssh, server) +    (output, exitstatus) = pexpect.run(cmd, withexitstatus=True) +    if not exitstatus: +        pid = int(output[output.find("(pid=")+5:output.find(")")])  +        cmd = "%s -O forward -L 127.0.0.1:%i:%s:%i %s" % ( +            ssh, lport, remoteip, rport, server) +        (output, exitstatus) = pexpect.run(cmd, withexitstatus=True) +        if not exitstatus: +            atexit.register(_stop_tunnel, cmd.replace("-O forward", "-O cancel", 1)) +            return pid +    cmd = "%s -f -S none -L 127.0.0.1:%i:%s:%i %s sleep %i" % ( +        ssh, lport, remoteip, rport, server, timeout) +     +    # pop SSH_ASKPASS from env +    env = os.environ.copy() +    env.pop('SSH_ASKPASS', None) +     +    tunnel = pexpect.spawn(cmd, env=env) +    failed = False +    while True: +        try: +            tunnel.expect('[Pp]assword:', timeout=.1) +        except pexpect.TIMEOUT: +            continue +        except pexpect.EOF: +            if tunnel.exitstatus: +                print(tunnel.exitstatus) +                print(tunnel.before) +                print(tunnel.after) +                raise RuntimeError("tunnel '%s' failed to start"%(cmd)) +            else: +                return tunnel.pid +        else: +            if failed: +                print("Password rejected, try again") +                password=None +            if password is None: +                password = getpass("%s's password: "%(server)) +            tunnel.sendline(password) +            failed = True +     +def _stop_tunnel(cmd): +    pexpect.run(cmd) + +def _split_server(server): +    if '@' in server: +        username,server = server.split('@', 1) +    else: +        username = getuser() +    if ':' in server: +        server, port = server.split(':') +        port = int(port) +    else: +        port = 22 +    return username, server, port + +def paramiko_tunnel(lport, rport, server, remoteip='127.0.0.1', keyfile=None, password=None, timeout=60): +    """launch a tunner with paramiko in a subprocess. This should only be used +    when shell ssh is unavailable (e.g. Windows). +     +    This creates a tunnel redirecting `localhost:lport` to `remoteip:rport`, +    as seen from `server`. +     +    If you are familiar with ssh tunnels, this creates the tunnel: +     +    ssh server -L localhost:lport:remoteip:rport +     +    keyfile and password may be specified, but ssh config is checked for defaults. +     +     +    Parameters +    ---------- +     +    lport : int +        local port for connecting to the tunnel from this machine. +    rport : int +        port on the remote machine to connect to. +    server : str +        The ssh server to connect to. The full ssh server string will be parsed. +        user@server:port +    remoteip : str [Default: 127.0.0.1] +        The remote ip, specifying the destination of the tunnel. +        Default is localhost, which means that the tunnel would redirect +        localhost:lport on this machine to localhost:rport on the *server*. +         +    keyfile : str; path to public key file +        This specifies a key to be used in ssh login, default None. +        Regular default ssh keys will be used without specifying this argument. +    password : str;  +        Your ssh password to the ssh server. Note that if this is left None, +        you will be prompted for it if passwordless key based login is unavailable. +    timeout : int [default: 60] +        The time (in seconds) after which no activity will result in the tunnel +        closing.  This prevents orphaned tunnels from running forever. +     +    """ +    if paramiko is None: +        raise ImportError("Paramiko not available") +     +    if password is None: +        if not _try_passwordless_paramiko(server, keyfile): +            password = getpass("%s's password: "%(server)) + +    p = Process(target=_paramiko_tunnel,  +            args=(lport, rport, server, remoteip),  +            kwargs=dict(keyfile=keyfile, password=password)) +    p.daemon=False +    p.start() +    atexit.register(_shutdown_process, p) +    return p +     +def _shutdown_process(p): +    if p.is_alive(): +        p.terminate() + +def _paramiko_tunnel(lport, rport, server, remoteip, keyfile=None, password=None): +    """Function for actually starting a paramiko tunnel, to be passed +    to multiprocessing.Process(target=this), and not called directly. +    """ +    username, server, port = _split_server(server) +    client = paramiko.SSHClient() +    client.load_system_host_keys() +    client.set_missing_host_key_policy(paramiko.WarningPolicy()) + +    try: +        client.connect(server, port, username=username, key_filename=keyfile, +                       look_for_keys=True, password=password) +#    except paramiko.AuthenticationException: +#        if password is None: +#            password = getpass("%s@%s's password: "%(username, server)) +#            client.connect(server, port, username=username, password=password) +#        else: +#            raise +    except Exception as e: +        print('*** Failed to connect to %s:%d: %r' % (server, port, e)) +        sys.exit(1) + +    # Don't let SIGINT kill the tunnel subprocess +    signal.signal(signal.SIGINT, signal.SIG_IGN) + +    try: +        forward_tunnel(lport, remoteip, rport, client.get_transport()) +    except KeyboardInterrupt: +        print('SIGINT: Port forwarding stopped cleanly') +        sys.exit(0) +    except Exception as e: +        print("Port forwarding stopped uncleanly: %s"%e) +        sys.exit(255) + +if sys.platform == 'win32': +    ssh_tunnel = paramiko_tunnel +else: +    ssh_tunnel = openssh_tunnel + +     +__all__ = ['tunnel_connection', 'ssh_tunnel', 'openssh_tunnel', 'paramiko_tunnel', 'try_passwordless_ssh'] + + diff --git a/zmq/sugar/__init__.py b/zmq/sugar/__init__.py new file mode 100644 index 0000000..d0510a4 --- /dev/null +++ b/zmq/sugar/__init__.py @@ -0,0 +1,27 @@ +"""pure-Python sugar wrappers for core 0MQ objects.""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + + +from zmq.sugar import ( +    constants, context, frame, poll, socket, tracker, version +) +from zmq import error + +__all__ = ['constants'] +for submod in ( +    constants, context, error, frame, poll, socket, tracker, version +): +    __all__.extend(submod.__all__) + +from zmq.error import * +from zmq.sugar.context import * +from zmq.sugar.tracker import * +from zmq.sugar.socket import * +from zmq.sugar.constants import * +from zmq.sugar.frame import * +from zmq.sugar.poll import * +# from zmq.sugar.stopwatch import * +# from zmq.sugar._device import * +from zmq.sugar.version import * diff --git a/zmq/sugar/attrsettr.py b/zmq/sugar/attrsettr.py new file mode 100644 index 0000000..4bbd36d --- /dev/null +++ b/zmq/sugar/attrsettr.py @@ -0,0 +1,52 @@ +# coding: utf-8 +"""Mixin for mapping set/getattr to self.set/get""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + + +from . import constants + +class AttributeSetter(object): +     +    def __setattr__(self, key, value): +        """set zmq options by attribute""" +         +        # regular setattr only allowed for class-defined attributes +        for obj in [self] + self.__class__.mro(): +            if key in obj.__dict__: +                object.__setattr__(self, key, value) +                return +         +        upper_key = key.upper() +        try: +            opt = getattr(constants, upper_key) +        except AttributeError: +            raise AttributeError("%s has no such option: %s" % ( +                self.__class__.__name__, upper_key) +            ) +        else: +            self._set_attr_opt(upper_key, opt, value) +     +    def _set_attr_opt(self, name, opt, value): +        """override if setattr should do something other than call self.set""" +        self.set(opt, value) +     +    def __getattr__(self, key): +        """get zmq options by attribute""" +        upper_key = key.upper() +        try: +            opt = getattr(constants, upper_key) +        except AttributeError: +            raise AttributeError("%s has no such option: %s" % ( +                self.__class__.__name__, upper_key) +            ) +        else: +            return self._get_attr_opt(upper_key, opt) + +    def _get_attr_opt(self, name, opt): +        """override if getattr should do something other than call self.get""" +        return self.get(opt) +     + +__all__ = ['AttributeSetter'] diff --git a/zmq/sugar/constants.py b/zmq/sugar/constants.py new file mode 100644 index 0000000..8828117 --- /dev/null +++ b/zmq/sugar/constants.py @@ -0,0 +1,98 @@ +"""0MQ Constants.""" + +# Copyright (c) PyZMQ Developers. +# Distributed under the terms of the Modified BSD License. + +from zmq.backend import constants +from zmq.utils.constant_names import ( +    base_names, +    switched_sockopt_names, +    int_sockopt_names, +    int64_sockopt_names, +    bytes_sockopt_names, +    fd_sockopt_names, +    ctx_opt_names, +    msg_opt_names, +) + +#----------------------------------------------------------------------------- +# Python module level constants +#----------------------------------------------------------------------------- + +__all__ = [ +    'int_sockopts', +    'int64_sockopts', +    'bytes_sockopts', +    'ctx_opts', +    'ctx_opt_names', +    ] + +int_sockopts    = set() +int64_sockopts  = set() +bytes_sockopts  = set() +fd_sockopts     = set() +ctx_opts        = set() +msg_opts        = set() + + +if constants.VERSION < 30000: +    int64_sockopt_names.extend(switched_sockopt_names) +else: +    int_sockopt_names.extend(switched_sockopt_names) +     +_UNDEFINED = -9999 + +def _add_constant(name, container=None): +    """add a constant to be defined +     +    optionally add it to one of the sets for use in get/setopt checkers +    """ +    c = getattr(constants, name, _UNDEFINED) +    if c == _UNDEFINED: +        return +    globals()[name] = c +    __all__.append(name) +    if container is not None: +        container.add(c) +    return c +     +for name in base_names: +    _add_constant(name) + +for name in int_sockopt_names: +    _add_constant(name, int_sockopts) + +for name in int64_sockopt_names: +    _add_constant(name, int64_sockopts) + +for name in bytes_sockopt_names: +    _add_constant(name, bytes_sockopts) + +for name in fd_sockopt_names: +    _add_constant(name, fd_sockopts) + +for name in ctx_opt_names: +    _add_constant(name, ctx_opts) + +for name in msg_opt_names: +    _add_constant(name, msg_opts) + +# ensure some aliases are always defined +aliases = [ +    ('DONTWAIT', 'NOBLOCK'), +    ('XREQ', 'DEALER'), +    ('XREP', 'ROUTER'), +] +for group in aliases: +    undefined = set() +    found = None +    for name in group: +        value = getattr(constants, name, -1) +        if value != -1: +            found = value +        else: +            undefined.add(name) +    if found is not None: +        for name in undefined: +            globals()[name] = found +            __all__.append(name) diff --git a/zmq/sugar/context.py b/zmq/sugar/context.py new file mode 100644 index 0000000..86a9c5d --- /dev/null +++ b/zmq/sugar/context.py @@ -0,0 +1,192 @@ +# coding: utf-8 +"""Python bindings for 0MQ.""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +import atexit +import weakref + +from zmq.backend import Context as ContextBase +from . import constants +from .attrsettr import AttributeSetter +from .constants import ENOTSUP, ctx_opt_names +from .socket import Socket +from zmq.error import ZMQError + +from zmq.utils.interop import cast_int_addr + + +class Context(ContextBase, AttributeSetter): +    """Create a zmq Context +     +    A zmq Context creates sockets via its ``ctx.socket`` method. +    """ +    sockopts = None +    _instance = None +    _shadow = False +    _exiting = False +     +    def __init__(self, io_threads=1, **kwargs): +        super(Context, self).__init__(io_threads=io_threads, **kwargs) +        if kwargs.get('shadow', False): +            self._shadow = True +        else: +            self._shadow = False +        self.sockopts = {} +         +        self._exiting = False +        if not self._shadow: +            ctx_ref = weakref.ref(self) +            def _notify_atexit(): +                ctx = ctx_ref() +                if ctx is not None: +                    ctx._exiting = True +            atexit.register(_notify_atexit) +     +    def __del__(self): +        """deleting a Context should terminate it, without trying non-threadsafe destroy""" +        if not self._shadow and not self._exiting: +            self.term() +     +    def __enter__(self): +        return self +     +    def __exit__(self, *args, **kwargs): +        self.term() +     +    @classmethod +    def shadow(cls, address): +        """Shadow an existing libzmq context +         +        address is the integer address of the libzmq context +        or an FFI pointer to it. +         +        .. versionadded:: 14.1 +        """ +        address = cast_int_addr(address) +        return cls(shadow=address) +     +    @classmethod +    def shadow_pyczmq(cls, ctx): +        """Shadow an existing pyczmq context +         +        ctx is the FFI `zctx_t *` pointer +         +        .. versionadded:: 14.1 +        """ +        from pyczmq import zctx +         +        underlying = zctx.underlying(ctx) +        address = cast_int_addr(underlying) +        return cls(shadow=address) + +    # static method copied from tornado IOLoop.instance +    @classmethod +    def instance(cls, io_threads=1): +        """Returns a global Context instance. + +        Most single-threaded applications have a single, global Context. +        Use this method instead of passing around Context instances +        throughout your code. + +        A common pattern for classes that depend on Contexts is to use +        a default argument to enable programs with multiple Contexts +        but not require the argument for simpler applications: + +            class MyClass(object): +                def __init__(self, context=None): +                    self.context = context or Context.instance() +        """ +        if cls._instance is None or cls._instance.closed: +            cls._instance = cls(io_threads=io_threads) +        return cls._instance +     +    #------------------------------------------------------------------------- +    # Hooks for ctxopt completion +    #------------------------------------------------------------------------- +     +    def __dir__(self): +        keys = dir(self.__class__) + +        for collection in ( +            ctx_opt_names, +        ): +            keys.extend(collection) +        return keys + +    #------------------------------------------------------------------------- +    # Creating Sockets +    #------------------------------------------------------------------------- + +    @property +    def _socket_class(self): +        return Socket +     +    def socket(self, socket_type): +        """Create a Socket associated with this Context. + +        Parameters +        ---------- +        socket_type : int +            The socket type, which can be any of the 0MQ socket types: +            REQ, REP, PUB, SUB, PAIR, DEALER, ROUTER, PULL, PUSH, etc. +        """ +        if self.closed: +            raise ZMQError(ENOTSUP) +        s = self._socket_class(self, socket_type) +        for opt, value in self.sockopts.items(): +            try: +                s.setsockopt(opt, value) +            except ZMQError: +                # ignore ZMQErrors, which are likely for socket options +                # that do not apply to a particular socket type, e.g. +                # SUBSCRIBE for non-SUB sockets. +                pass +        return s +     +    def setsockopt(self, opt, value): +        """set default socket options for new sockets created by this Context +         +        .. versionadded:: 13.0 +        """ +        self.sockopts[opt] = value +     +    def getsockopt(self, opt): +        """get default socket options for new sockets created by this Context +         +        .. versionadded:: 13.0 +        """ +        return self.sockopts[opt] +     +    def _set_attr_opt(self, name, opt, value): +        """set default sockopts as attributes""" +        if name in constants.ctx_opt_names: +            return self.set(opt, value) +        else: +            self.sockopts[opt] = value +     +    def _get_attr_opt(self, name, opt): +        """get default sockopts as attributes""" +        if name in constants.ctx_opt_names: +            return self.get(opt) +        else: +            if opt not in self.sockopts: +                raise AttributeError(name) +            else: +                return self.sockopts[opt] +     +    def __delattr__(self, key): +        """delete default sockopts as attributes""" +        key = key.upper() +        try: +            opt = getattr(constants, key) +        except AttributeError: +            raise AttributeError("no such socket option: %s" % key) +        else: +            if opt not in self.sockopts: +                raise AttributeError(key) +            else: +                del self.sockopts[opt] + +__all__ = ['Context'] diff --git a/zmq/sugar/frame.py b/zmq/sugar/frame.py new file mode 100644 index 0000000..9f556c8 --- /dev/null +++ b/zmq/sugar/frame.py @@ -0,0 +1,19 @@ +# coding: utf-8 +"""0MQ Frame pure Python methods.""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + + +from .attrsettr import AttributeSetter +from zmq.backend import Frame as FrameBase + + +class Frame(FrameBase, AttributeSetter): +    def __getitem__(self, key): +        # map Frame['User-Id'] to Frame.get('User-Id') +        return self.get(key) + +# keep deprecated alias +Message = Frame +__all__ = ['Frame', 'Message']
\ No newline at end of file diff --git a/zmq/sugar/poll.py b/zmq/sugar/poll.py new file mode 100644 index 0000000..c7b1d1b --- /dev/null +++ b/zmq/sugar/poll.py @@ -0,0 +1,161 @@ +"""0MQ polling related functions and classes.""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + + +import zmq +from zmq.backend import zmq_poll +from .constants import POLLIN, POLLOUT, POLLERR + +#----------------------------------------------------------------------------- +# Polling related methods +#----------------------------------------------------------------------------- + + +class Poller(object): +    """A stateful poll interface that mirrors Python's built-in poll.""" +    sockets = None +    _map = {} + +    def __init__(self): +        self.sockets = [] +        self._map = {} +     +    def __contains__(self, socket): +        return socket in self._map + +    def register(self, socket, flags=POLLIN|POLLOUT): +        """p.register(socket, flags=POLLIN|POLLOUT) + +        Register a 0MQ socket or native fd for I/O monitoring. +         +        register(s,0) is equivalent to unregister(s). + +        Parameters +        ---------- +        socket : zmq.Socket or native socket +            A zmq.Socket or any Python object having a ``fileno()``  +            method that returns a valid file descriptor. +        flags : int +            The events to watch for.  Can be POLLIN, POLLOUT or POLLIN|POLLOUT. +            If `flags=0`, socket will be unregistered. +        """ +        if flags: +            if socket in self._map: +                idx = self._map[socket] +                self.sockets[idx] = (socket, flags) +            else: +                idx = len(self.sockets) +                self.sockets.append((socket, flags)) +                self._map[socket] = idx +        elif socket in self._map: +            # uregister sockets registered with no events +            self.unregister(socket) +        else: +            # ignore new sockets with no events +            pass + +    def modify(self, socket, flags=POLLIN|POLLOUT): +        """Modify the flags for an already registered 0MQ socket or native fd.""" +        self.register(socket, flags) + +    def unregister(self, socket): +        """Remove a 0MQ socket or native fd for I/O monitoring. + +        Parameters +        ---------- +        socket : Socket +            The socket instance to stop polling. +        """ +        idx = self._map.pop(socket) +        self.sockets.pop(idx) +        # shift indices after deletion +        for socket, flags in self.sockets[idx:]: +            self._map[socket] -= 1 + +    def poll(self, timeout=None): +        """Poll the registered 0MQ or native fds for I/O. + +        Parameters +        ---------- +        timeout : float, int +            The timeout in milliseconds. If None, no `timeout` (infinite). This +            is in milliseconds to be compatible with ``select.poll()``. The +            underlying zmq_poll uses microseconds and we convert to that in +            this function. +         +        Returns +        ------- +        events : list of tuples +            The list of events that are ready to be processed. +            This is a list of tuples of the form ``(socket, event)``, where the 0MQ Socket +            or integer fd is the first element, and the poll event mask (POLLIN, POLLOUT) is the second. +            It is common to call ``events = dict(poller.poll())``, +            which turns the list of tuples into a mapping of ``socket : event``. +        """ +        if timeout is None or timeout < 0: +            timeout = -1 +        elif isinstance(timeout, float): +            timeout = int(timeout) +        return zmq_poll(self.sockets, timeout=timeout) + + +def select(rlist, wlist, xlist, timeout=None): +    """select(rlist, wlist, xlist, timeout=None) -> (rlist, wlist, xlist) + +    Return the result of poll as a lists of sockets ready for r/w/exception. + +    This has the same interface as Python's built-in ``select.select()`` function. + +    Parameters +    ---------- +    timeout : float, int, optional +        The timeout in seconds. If None, no timeout (infinite). This is in seconds to be +        compatible with ``select.select()``. The underlying zmq_poll uses microseconds +        and we convert to that in this function. +    rlist : list of sockets/FDs +        sockets/FDs to be polled for read events +    wlist : list of sockets/FDs +        sockets/FDs to be polled for write events +    xlist : list of sockets/FDs +        sockets/FDs to be polled for error events +     +    Returns +    ------- +    (rlist, wlist, xlist) : tuple of lists of sockets (length 3) +        Lists correspond to sockets available for read/write/error events respectively. +    """ +    if timeout is None: +        timeout = -1 +    # Convert from sec -> us for zmq_poll. +    # zmq_poll accepts 3.x style timeout in ms +    timeout = int(timeout*1000.0) +    if timeout < 0: +        timeout = -1 +    sockets = [] +    for s in set(rlist + wlist + xlist): +        flags = 0 +        if s in rlist: +            flags |= POLLIN +        if s in wlist: +            flags |= POLLOUT +        if s in xlist: +            flags |= POLLERR +        sockets.append((s, flags)) +    return_sockets = zmq_poll(sockets, timeout) +    rlist, wlist, xlist = [], [], [] +    for s, flags in return_sockets: +        if flags & POLLIN: +            rlist.append(s) +        if flags & POLLOUT: +            wlist.append(s) +        if flags & POLLERR: +            xlist.append(s) +    return rlist, wlist, xlist + +#----------------------------------------------------------------------------- +# Symbols to export +#----------------------------------------------------------------------------- + +__all__ = [ 'Poller', 'select' ] diff --git a/zmq/sugar/socket.py b/zmq/sugar/socket.py new file mode 100644 index 0000000..ab56deb --- /dev/null +++ b/zmq/sugar/socket.py @@ -0,0 +1,491 @@ +# coding: utf-8 +"""0MQ Socket pure Python methods.""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + + +import codecs +import random +import warnings + +import zmq +from zmq.backend import Socket as SocketBase +from .poll import Poller +from . import constants +from .attrsettr import AttributeSetter +from zmq.error import ZMQError, ZMQBindError +from zmq.utils import jsonapi +from zmq.utils.strtypes import bytes,unicode,basestring +from zmq.utils.interop import cast_int_addr + +from .constants import ( +    SNDMORE, ENOTSUP, POLLIN, +    int64_sockopt_names, +    int_sockopt_names, +    bytes_sockopt_names, +    fd_sockopt_names, +) +try: +    import cPickle +    pickle = cPickle +except: +    cPickle = None +    import pickle + + +class Socket(SocketBase, AttributeSetter): +    """The ZMQ socket object +     +    To create a Socket, first create a Context:: +     +        ctx = zmq.Context.instance() +     +    then call ``ctx.socket(socket_type)``:: +     +        s = ctx.socket(zmq.ROUTER) +     +    """ +    _shadow = False +     +    def __del__(self): +        if not self._shadow: +            self.close() +     +    # socket as context manager: +    def __enter__(self): +        """Sockets are context managers +         +        .. versionadded:: 14.4 +        """ +        return self +     +    def __exit__(self, *args, **kwargs): +        self.close() +     +    #------------------------------------------------------------------------- +    # Socket creation +    #------------------------------------------------------------------------- +     +    @classmethod +    def shadow(cls, address): +        """Shadow an existing libzmq socket +         +        address is the integer address of the libzmq socket +        or an FFI pointer to it. +         +        .. versionadded:: 14.1 +        """ +        address = cast_int_addr(address) +        return cls(shadow=address) +     +    #------------------------------------------------------------------------- +    # Deprecated aliases +    #------------------------------------------------------------------------- +     +    @property +    def socket_type(self): +        warnings.warn("Socket.socket_type is deprecated, use Socket.type", +            DeprecationWarning +        ) +        return self.type +     +    #------------------------------------------------------------------------- +    # Hooks for sockopt completion +    #------------------------------------------------------------------------- +     +    def __dir__(self): +        keys = dir(self.__class__) +        for collection in ( +            bytes_sockopt_names, +            int_sockopt_names, +            int64_sockopt_names, +            fd_sockopt_names, +        ): +            keys.extend(collection) +        return keys +     +    #------------------------------------------------------------------------- +    # Getting/Setting options +    #------------------------------------------------------------------------- +    setsockopt = SocketBase.set +    getsockopt = SocketBase.get +     +    def set_string(self, option, optval, encoding='utf-8'): +        """set socket options with a unicode object +         +        This is simply a wrapper for setsockopt to protect from encoding ambiguity. + +        See the 0MQ documentation for details on specific options. +         +        Parameters +        ---------- +        option : int +            The name of the option to set. Can be any of: SUBSCRIBE,  +            UNSUBSCRIBE, IDENTITY +        optval : unicode string (unicode on py2, str on py3) +            The value of the option to set. +        encoding : str +            The encoding to be used, default is utf8 +        """ +        if not isinstance(optval, unicode): +            raise TypeError("unicode strings only") +        return self.set(option, optval.encode(encoding)) +     +    setsockopt_unicode = setsockopt_string = set_string +     +    def get_string(self, option, encoding='utf-8'): +        """get the value of a socket option + +        See the 0MQ documentation for details on specific options. + +        Parameters +        ---------- +        option : int +            The option to retrieve. + +        Returns +        ------- +        optval : unicode string (unicode on py2, str on py3) +            The value of the option as a unicode string. +        """ +     +        if option not in constants.bytes_sockopts: +            raise TypeError("option %i will not return a string to be decoded"%option) +        return self.getsockopt(option).decode(encoding) +     +    getsockopt_unicode = getsockopt_string = get_string +     +    def bind_to_random_port(self, addr, min_port=49152, max_port=65536, max_tries=100): +        """bind this socket to a random port in a range + +        Parameters +        ---------- +        addr : str +            The address string without the port to pass to ``Socket.bind()``. +        min_port : int, optional +            The minimum port in the range of ports to try (inclusive). +        max_port : int, optional +            The maximum port in the range of ports to try (exclusive). +        max_tries : int, optional +            The maximum number of bind attempts to make. + +        Returns +        ------- +        port : int +            The port the socket was bound to. +     +        Raises +        ------ +        ZMQBindError +            if `max_tries` reached before successful bind +        """ +        for i in range(max_tries): +            try: +                port = random.randrange(min_port, max_port) +                self.bind('%s:%s' % (addr, port)) +            except ZMQError as exception: +                if not exception.errno == zmq.EADDRINUSE: +                    raise +            else: +                return port +        raise ZMQBindError("Could not bind socket to random port.") +     +    def get_hwm(self): +        """get the High Water Mark +         +        On libzmq ≥ 3, this gets SNDHWM if available, otherwise RCVHWM +        """ +        major = zmq.zmq_version_info()[0] +        if major >= 3: +            # return sndhwm, fallback on rcvhwm +            try: +                return self.getsockopt(zmq.SNDHWM) +            except zmq.ZMQError as e: +                pass +             +            return self.getsockopt(zmq.RCVHWM) +        else: +            return self.getsockopt(zmq.HWM) +     +    def set_hwm(self, value): +        """set the High Water Mark +         +        On libzmq ≥ 3, this sets both SNDHWM and RCVHWM +        """ +        major = zmq.zmq_version_info()[0] +        if major >= 3: +            raised = None +            try: +                self.sndhwm = value +            except Exception as e: +                raised = e +            try: +                self.rcvhwm = value +            except Exception: +                raised = e +             +            if raised: +                raise raised +        else: +            return self.setsockopt(zmq.HWM, value) +     +    hwm = property(get_hwm, set_hwm, +        """property for High Water Mark +         +        Setting hwm sets both SNDHWM and RCVHWM as appropriate. +        It gets SNDHWM if available, otherwise RCVHWM. +        """ +    ) +     +    #------------------------------------------------------------------------- +    # Sending and receiving messages +    #------------------------------------------------------------------------- + +    def send_multipart(self, msg_parts, flags=0, copy=True, track=False): +        """send a sequence of buffers as a multipart message +         +        The zmq.SNDMORE flag is added to all msg parts before the last. + +        Parameters +        ---------- +        msg_parts : iterable +            A sequence of objects to send as a multipart message. Each element +            can be any sendable object (Frame, bytes, buffer-providers) +        flags : int, optional +            SNDMORE is handled automatically for frames before the last. +        copy : bool, optional +            Should the frame(s) be sent in a copying or non-copying manner. +        track : bool, optional +            Should the frame(s) be tracked for notification that ZMQ has +            finished with it (ignored if copy=True). +     +        Returns +        ------- +        None : if copy or not track +        MessageTracker : if track and not copy +            a MessageTracker object, whose `pending` property will +            be True until the last send is completed. +        """ +        for msg in msg_parts[:-1]: +            self.send(msg, SNDMORE|flags, copy=copy, track=track) +        # Send the last part without the extra SNDMORE flag. +        return self.send(msg_parts[-1], flags, copy=copy, track=track) + +    def recv_multipart(self, flags=0, copy=True, track=False): +        """receive a multipart message as a list of bytes or Frame objects + +        Parameters +        ---------- +        flags : int, optional +            Any supported flag: NOBLOCK. If NOBLOCK is set, this method +            will raise a ZMQError with EAGAIN if a message is not ready. +            If NOBLOCK is not set, then this method will block until a +            message arrives. +        copy : bool, optional +            Should the message frame(s) be received in a copying or non-copying manner? +            If False a Frame object is returned for each part, if True a copy of +            the bytes is made for each frame. +        track : bool, optional +            Should the message frame(s) be tracked for notification that ZMQ has +            finished with it? (ignored if copy=True) +         +        Returns +        ------- +        msg_parts : list +            A list of frames in the multipart message; either Frames or bytes, +            depending on `copy`. +     +        """ +        parts = [self.recv(flags, copy=copy, track=track)] +        # have first part already, only loop while more to receive +        while self.getsockopt(zmq.RCVMORE): +            part = self.recv(flags, copy=copy, track=track) +            parts.append(part) +     +        return parts + +    def send_string(self, u, flags=0, copy=True, encoding='utf-8'): +        """send a Python unicode string as a message with an encoding +     +        0MQ communicates with raw bytes, so you must encode/decode +        text (unicode on py2, str on py3) around 0MQ. +         +        Parameters +        ---------- +        u : Python unicode string (unicode on py2, str on py3) +            The unicode string to send. +        flags : int, optional +            Any valid send flag. +        encoding : str [default: 'utf-8'] +            The encoding to be used +        """ +        if not isinstance(u, basestring): +            raise TypeError("unicode/str objects only") +        return self.send(u.encode(encoding), flags=flags, copy=copy) +     +    send_unicode = send_string +     +    def recv_string(self, flags=0, encoding='utf-8'): +        """receive a unicode string, as sent by send_string +     +        Parameters +        ---------- +        flags : int +            Any valid recv flag. +        encoding : str [default: 'utf-8'] +            The encoding to be used + +        Returns +        ------- +        s : unicode string (unicode on py2, str on py3) +            The Python unicode string that arrives as encoded bytes. +        """ +        b = self.recv(flags=flags) +        return b.decode(encoding) +     +    recv_unicode = recv_string +     +    def send_pyobj(self, obj, flags=0, protocol=-1): +        """send a Python object as a message using pickle to serialize + +        Parameters +        ---------- +        obj : Python object +            The Python object to send. +        flags : int +            Any valid send flag. +        protocol : int +            The pickle protocol number to use. Default of -1 will select +            the highest supported number. Use 0 for multiple platform +            support. +        """ +        msg = pickle.dumps(obj, protocol) +        return self.send(msg, flags) + +    def recv_pyobj(self, flags=0): +        """receive a Python object as a message using pickle to serialize + +        Parameters +        ---------- +        flags : int +            Any valid recv flag. + +        Returns +        ------- +        obj : Python object +            The Python object that arrives as a message. +        """ +        s = self.recv(flags) +        return pickle.loads(s) + +    def send_json(self, obj, flags=0, **kwargs): +        """send a Python object as a message using json to serialize +         +        Keyword arguments are passed on to json.dumps +         +        Parameters +        ---------- +        obj : Python object +            The Python object to send +        flags : int +            Any valid send flag +        """ +        msg = jsonapi.dumps(obj, **kwargs) +        return self.send(msg, flags) + +    def recv_json(self, flags=0, **kwargs): +        """receive a Python object as a message using json to serialize + +        Keyword arguments are passed on to json.loads +         +        Parameters +        ---------- +        flags : int +            Any valid recv flag. + +        Returns +        ------- +        obj : Python object +            The Python object that arrives as a message. +        """ +        msg = self.recv(flags) +        return jsonapi.loads(msg, **kwargs) +     +    _poller_class = Poller + +    def poll(self, timeout=None, flags=POLLIN): +        """poll the socket for events +         +        The default is to poll forever for incoming +        events.  Timeout is in milliseconds, if specified. + +        Parameters +        ---------- +        timeout : int [default: None] +            The timeout (in milliseconds) to wait for an event. If unspecified +            (or specified None), will wait forever for an event. +        flags : bitfield (int) [default: POLLIN] +            The event flags to poll for (any combination of POLLIN|POLLOUT). +            The default is to check for incoming events (POLLIN). + +        Returns +        ------- +        events : bitfield (int) +            The events that are ready and waiting.  Will be 0 if no events were ready +            by the time timeout was reached. +        """ + +        if self.closed: +            raise ZMQError(ENOTSUP) + +        p = self._poller_class() +        p.register(self, flags) +        evts = dict(p.poll(timeout)) +        # return 0 if no events, otherwise return event bitfield +        return evts.get(self, 0) + +    def get_monitor_socket(self, events=None, addr=None): +        """Return a connected PAIR socket ready to receive the event notifications. +         +        .. versionadded:: libzmq-4.0 +        .. versionadded:: 14.0 +         +        Parameters +        ---------- +        events : bitfield (int) [default: ZMQ_EVENTS_ALL] +            The bitmask defining which events are wanted. +        addr :  string [default: None] +            The optional endpoint for the monitoring sockets. + +        Returns +        ------- +        socket :  (PAIR) +            The socket is already connected and ready to receive messages. +        """ +        # safe-guard, method only available on libzmq >= 4 +        if zmq.zmq_version_info() < (4,): +            raise NotImplementedError("get_monitor_socket requires libzmq >= 4, have %s" % zmq.zmq_version()) +        if addr is None: +            # create endpoint name from internal fd +            addr = "inproc://monitor.s-%d" % self.FD +        if events is None: +            # use all events +            events = zmq.EVENT_ALL +        # attach monitoring socket +        self.monitor(addr, events) +        # create new PAIR socket and connect it +        ret = self.context.socket(zmq.PAIR) +        ret.connect(addr) +        return ret + +    def disable_monitor(self): +        """Shutdown the PAIR socket (created using get_monitor_socket) +        that is serving socket events. +         +        .. versionadded:: 14.4 +        """ +        self.monitor(None, 0) + + +__all__ = ['Socket'] diff --git a/zmq/sugar/tracker.py b/zmq/sugar/tracker.py new file mode 100644 index 0000000..fb8c007 --- /dev/null +++ b/zmq/sugar/tracker.py @@ -0,0 +1,120 @@ +"""Tracker for zero-copy messages with 0MQ.""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +import time + +try: +    # below 3.3 +    from threading import _Event as Event +except (ImportError, AttributeError): +    # python throws ImportError, cython throws AttributeError +    from threading import Event + +from zmq.error import NotDone +from zmq.backend import Frame + +class MessageTracker(object): +    """MessageTracker(*towatch) + +    A class for tracking if 0MQ is done using one or more messages. + +    When you send a 0MQ message, it is not sent immediately. The 0MQ IO thread +    sends the message at some later time. Often you want to know when 0MQ has +    actually sent the message though. This is complicated by the fact that +    a single 0MQ message can be sent multiple times using different sockets. +    This class allows you to track all of the 0MQ usages of a message. + +    Parameters +    ---------- +    *towatch : tuple of Event, MessageTracker, Message instances. +        This list of objects to track. This class can track the low-level +        Events used by the Message class, other MessageTrackers or +        actual Messages. +    """ +    events = None +    peers = None + +    def __init__(self, *towatch): +        """MessageTracker(*towatch) + +        Create a message tracker to track a set of mesages. + +        Parameters +        ---------- +        *towatch : tuple of Event, MessageTracker, Message instances. +            This list of objects to track. This class can track the low-level +            Events used by the Message class, other MessageTrackers or  +            actual Messages. +        """ +        self.events = set() +        self.peers = set() +        for obj in towatch: +            if isinstance(obj, Event): +                self.events.add(obj) +            elif isinstance(obj, MessageTracker): +                self.peers.add(obj) +            elif isinstance(obj, Frame): +                if not obj.tracker: +                    raise ValueError("Not a tracked message") +                self.peers.add(obj.tracker) +            else: +                raise TypeError("Require Events or Message Frames, not %s"%type(obj)) +     +    @property +    def done(self): +        """Is 0MQ completely done with the message(s) being tracked?""" +        for evt in self.events: +            if not evt.is_set(): +                return False +        for pm in self.peers: +            if not pm.done: +                return False +        return True +     +    def wait(self, timeout=-1): +        """mt.wait(timeout=-1) + +        Wait for 0MQ to be done with the message or until `timeout`. + +        Parameters +        ---------- +        timeout : float [default: -1, wait forever] +            Maximum time in (s) to wait before raising NotDone. + +        Returns +        ------- +        None +            if done before `timeout` +         +        Raises +        ------ +        NotDone +            if `timeout` reached before I am done. +        """ +        tic = time.time() +        if timeout is False or timeout < 0: +            remaining = 3600*24*7 # a week +        else: +            remaining = timeout +        done = False +        for evt in self.events: +            if remaining < 0: +                raise NotDone +            evt.wait(timeout=remaining) +            if not evt.is_set(): +                raise NotDone +            toc = time.time() +            remaining -= (toc-tic) +            tic = toc +         +        for peer in self.peers: +            if remaining < 0: +                raise NotDone +            peer.wait(timeout=remaining) +            toc = time.time() +            remaining -= (toc-tic) +            tic = toc + +__all__ = ['MessageTracker']
\ No newline at end of file diff --git a/zmq/sugar/version.py b/zmq/sugar/version.py new file mode 100644 index 0000000..f4c4b57 --- /dev/null +++ b/zmq/sugar/version.py @@ -0,0 +1,48 @@ +"""PyZMQ and 0MQ version functions.""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + + +from zmq.backend import zmq_version_info + + +VERSION_MAJOR = 14 +VERSION_MINOR = 4 +VERSION_PATCH = 1 +VERSION_EXTRA = "" +__version__ = '%i.%i.%i' % (VERSION_MAJOR, VERSION_MINOR, VERSION_PATCH) + +if VERSION_EXTRA: +    __version__ = "%s-%s" % (__version__, VERSION_EXTRA) +    version_info = (VERSION_MAJOR, VERSION_MINOR, VERSION_PATCH, float('inf')) +else: +    version_info = (VERSION_MAJOR, VERSION_MINOR, VERSION_PATCH) + +__revision__ = '' + +def pyzmq_version(): +    """return the version of pyzmq as a string""" +    if __revision__: +        return '@'.join([__version__,__revision__[:6]]) +    else: +        return __version__ + +def pyzmq_version_info(): +    """return the pyzmq version as a tuple of at least three numbers +     +    If pyzmq is a development version, `inf` will be appended after the third integer. +    """ +    return version_info + + +def zmq_version(): +    """return the version of libzmq as a string""" +    return "%i.%i.%i" % zmq_version_info() + + +__all__ = ['zmq_version', 'zmq_version_info', +           'pyzmq_version','pyzmq_version_info', +           '__version__', '__revision__' +] + diff --git a/zmq/tests/__init__.py b/zmq/tests/__init__.py new file mode 100644 index 0000000..325a3f1 --- /dev/null +++ b/zmq/tests/__init__.py @@ -0,0 +1,211 @@ +# Copyright (c) PyZMQ Developers. +# Distributed under the terms of the Modified BSD License. + +import functools +import sys +import time +from threading import Thread + +from unittest import TestCase + +import zmq +from zmq.utils import jsonapi + +try: +    import gevent +    from zmq import green as gzmq +    have_gevent = True +except ImportError: +    have_gevent = False + +try: +    from unittest import SkipTest +except ImportError: +    try: +        from nose import SkipTest +    except ImportError: +        class SkipTest(Exception): +            pass + +PYPY = 'PyPy' in sys.version + +#----------------------------------------------------------------------------- +# skip decorators (directly from unittest) +#----------------------------------------------------------------------------- + +_id = lambda x: x + +def skip(reason): +    """ +    Unconditionally skip a test. +    """ +    def decorator(test_item): +        if not (isinstance(test_item, type) and issubclass(test_item, TestCase)): +            @functools.wraps(test_item) +            def skip_wrapper(*args, **kwargs): +                raise SkipTest(reason) +            test_item = skip_wrapper + +        test_item.__unittest_skip__ = True +        test_item.__unittest_skip_why__ = reason +        return test_item +    return decorator + +def skip_if(condition, reason="Skipped"): +    """ +    Skip a test if the condition is true. +    """ +    if condition: +        return skip(reason) +    return _id + +skip_pypy = skip_if(PYPY, "Doesn't work on PyPy") + +#----------------------------------------------------------------------------- +# Base test class +#----------------------------------------------------------------------------- + +class BaseZMQTestCase(TestCase): +    green = False +     +    @property +    def Context(self): +        if self.green: +            return gzmq.Context +        else: +            return zmq.Context +     +    def socket(self, socket_type): +        s = self.context.socket(socket_type) +        self.sockets.append(s) +        return s +     +    def setUp(self): +        if self.green and not have_gevent: +                raise SkipTest("requires gevent") +        self.context = self.Context.instance() +        self.sockets = [] +     +    def tearDown(self): +        contexts = set([self.context]) +        while self.sockets: +            sock = self.sockets.pop() +            contexts.add(sock.context) # in case additional contexts are created +            sock.close(0) +        for ctx in contexts: +            t = Thread(target=ctx.term) +            t.daemon = True +            t.start() +            t.join(timeout=2) +            if t.is_alive(): +                # reset Context.instance, so the failure to term doesn't corrupt subsequent tests +                zmq.sugar.context.Context._instance = None +                raise RuntimeError("context could not terminate, open sockets likely remain in test") + +    def create_bound_pair(self, type1=zmq.PAIR, type2=zmq.PAIR, interface='tcp://127.0.0.1'): +        """Create a bound socket pair using a random port.""" +        s1 = self.context.socket(type1) +        s1.setsockopt(zmq.LINGER, 0) +        port = s1.bind_to_random_port(interface) +        s2 = self.context.socket(type2) +        s2.setsockopt(zmq.LINGER, 0) +        s2.connect('%s:%s' % (interface, port)) +        self.sockets.extend([s1,s2]) +        return s1, s2 + +    def ping_pong(self, s1, s2, msg): +        s1.send(msg) +        msg2 = s2.recv() +        s2.send(msg2) +        msg3 = s1.recv() +        return msg3 + +    def ping_pong_json(self, s1, s2, o): +        if jsonapi.jsonmod is None: +            raise SkipTest("No json library") +        s1.send_json(o) +        o2 = s2.recv_json() +        s2.send_json(o2) +        o3 = s1.recv_json() +        return o3 + +    def ping_pong_pyobj(self, s1, s2, o): +        s1.send_pyobj(o) +        o2 = s2.recv_pyobj() +        s2.send_pyobj(o2) +        o3 = s1.recv_pyobj() +        return o3 + +    def assertRaisesErrno(self, errno, func, *args, **kwargs): +        try: +            func(*args, **kwargs) +        except zmq.ZMQError as e: +            self.assertEqual(e.errno, errno, "wrong error raised, expected '%s' \ +got '%s'" % (zmq.ZMQError(errno), zmq.ZMQError(e.errno))) +        else: +            self.fail("Function did not raise any error") +     +    def _select_recv(self, multipart, socket, **kwargs): +        """call recv[_multipart] in a way that raises if there is nothing to receive""" +        if zmq.zmq_version_info() >= (3,1,0): +            # zmq 3.1 has a bug, where poll can return false positives, +            # so we wait a little bit just in case +            # See LIBZMQ-280 on JIRA +            time.sleep(0.1) +         +        r,w,x = zmq.select([socket], [], [], timeout=5) +        assert len(r) > 0, "Should have received a message" +        kwargs['flags'] = zmq.DONTWAIT | kwargs.get('flags', 0) +         +        recv = socket.recv_multipart if multipart else socket.recv +        return recv(**kwargs) +         +    def recv(self, socket, **kwargs): +        """call recv in a way that raises if there is nothing to receive""" +        return self._select_recv(False, socket, **kwargs) + +    def recv_multipart(self, socket, **kwargs): +        """call recv_multipart in a way that raises if there is nothing to receive""" +        return self._select_recv(True, socket, **kwargs) +     + +class PollZMQTestCase(BaseZMQTestCase): +    pass + +class GreenTest: +    """Mixin for making green versions of test classes""" +    green = True +     +    def assertRaisesErrno(self, errno, func, *args, **kwargs): +        if errno == zmq.EAGAIN: +            raise SkipTest("Skipping because we're green.") +        try: +            func(*args, **kwargs) +        except zmq.ZMQError: +            e = sys.exc_info()[1] +            self.assertEqual(e.errno, errno, "wrong error raised, expected '%s' \ +got '%s'" % (zmq.ZMQError(errno), zmq.ZMQError(e.errno))) +        else: +            self.fail("Function did not raise any error") + +    def tearDown(self): +        contexts = set([self.context]) +        while self.sockets: +            sock = self.sockets.pop() +            contexts.add(sock.context) # in case additional contexts are created +            sock.close() +        try: +            gevent.joinall([gevent.spawn(ctx.term) for ctx in contexts], timeout=2, raise_error=True) +        except gevent.Timeout: +            raise RuntimeError("context could not terminate, open sockets likely remain in test") +     +    def skip_green(self): +        raise SkipTest("Skipping because we are green") + +def skip_green(f): +    def skipping_test(self, *args, **kwargs): +        if self.green: +            raise SkipTest("Skipping because we are green") +        else: +            return f(self, *args, **kwargs) +    return skipping_test diff --git a/zmq/tests/test_auth.py b/zmq/tests/test_auth.py new file mode 100644 index 0000000..d350f61 --- /dev/null +++ b/zmq/tests/test_auth.py @@ -0,0 +1,431 @@ +# -*- coding: utf8 -*- + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +import logging +import os +import shutil +import sys +import tempfile + +import zmq.auth +from zmq.auth.ioloop import IOLoopAuthenticator +from zmq.auth.thread import ThreadAuthenticator + +from zmq.eventloop import ioloop, zmqstream +from zmq.tests import (BaseZMQTestCase, SkipTest) + +class BaseAuthTestCase(BaseZMQTestCase): +    def setUp(self): +        if zmq.zmq_version_info() < (4,0): +            raise SkipTest("security is new in libzmq 4.0") +        try: +            zmq.curve_keypair() +        except zmq.ZMQError: +            raise SkipTest("security requires libzmq to be linked against libsodium") +        super(BaseAuthTestCase, self).setUp() +        # enable debug logging while we run tests +        logging.getLogger('zmq.auth').setLevel(logging.DEBUG) +        self.auth = self.make_auth() +        self.auth.start() +        self.base_dir, self.public_keys_dir, self.secret_keys_dir = self.create_certs() +     +    def make_auth(self): +        raise NotImplementedError() +     +    def tearDown(self): +        if self.auth: +            self.auth.stop() +            self.auth = None +        self.remove_certs(self.base_dir) +        super(BaseAuthTestCase, self).tearDown() +     +    def create_certs(self): +        """Create CURVE certificates for a test""" + +        # Create temporary CURVE keypairs for this test run. We create all keys in a +        # temp directory and then move them into the appropriate private or public +        # directory. + +        base_dir = tempfile.mkdtemp() +        keys_dir = os.path.join(base_dir, 'certificates') +        public_keys_dir = os.path.join(base_dir, 'public_keys') +        secret_keys_dir = os.path.join(base_dir, 'private_keys') + +        os.mkdir(keys_dir) +        os.mkdir(public_keys_dir) +        os.mkdir(secret_keys_dir) + +        server_public_file, server_secret_file = zmq.auth.create_certificates(keys_dir, "server") +        client_public_file, client_secret_file = zmq.auth.create_certificates(keys_dir, "client") + +        for key_file in os.listdir(keys_dir): +            if key_file.endswith(".key"): +                shutil.move(os.path.join(keys_dir, key_file), +                            os.path.join(public_keys_dir, '.')) + +        for key_file in os.listdir(keys_dir): +            if key_file.endswith(".key_secret"): +                shutil.move(os.path.join(keys_dir, key_file), +                            os.path.join(secret_keys_dir, '.')) + +        return (base_dir, public_keys_dir, secret_keys_dir) + +    def remove_certs(self, base_dir): +        """Remove certificates for a test""" +        shutil.rmtree(base_dir) + +    def load_certs(self, secret_keys_dir): +        """Return server and client certificate keys""" +        server_secret_file = os.path.join(secret_keys_dir, "server.key_secret") +        client_secret_file = os.path.join(secret_keys_dir, "client.key_secret") + +        server_public, server_secret = zmq.auth.load_certificate(server_secret_file) +        client_public, client_secret = zmq.auth.load_certificate(client_secret_file) + +        return server_public, server_secret, client_public, client_secret + + +class TestThreadAuthentication(BaseAuthTestCase): +    """Test authentication running in a thread""" + +    def make_auth(self): +        return ThreadAuthenticator(self.context) + +    def can_connect(self, server, client): +        """Check if client can connect to server using tcp transport""" +        result = False +        iface = 'tcp://127.0.0.1' +        port = server.bind_to_random_port(iface) +        client.connect("%s:%i" % (iface, port)) +        msg = [b"Hello World"] +        server.send_multipart(msg) +        if client.poll(1000): +            rcvd_msg = client.recv_multipart() +            self.assertEqual(rcvd_msg, msg) +            result = True +        return result + +    def test_null(self): +        """threaded auth - NULL""" +        # A default NULL connection should always succeed, and not +        # go through our authentication infrastructure at all. +        self.auth.stop() +        self.auth = None +         +        server = self.socket(zmq.PUSH) +        client = self.socket(zmq.PULL) +        self.assertTrue(self.can_connect(server, client)) + +        # By setting a domain we switch on authentication for NULL sockets, +        # though no policies are configured yet. The client connection +        # should still be allowed. +        server = self.socket(zmq.PUSH) +        server.zap_domain = b'global' +        client = self.socket(zmq.PULL) +        self.assertTrue(self.can_connect(server, client)) + +    def test_blacklist(self): +        """threaded auth - Blacklist""" +        # Blacklist 127.0.0.1, connection should fail +        self.auth.deny('127.0.0.1') +        server = self.socket(zmq.PUSH) +        # By setting a domain we switch on authentication for NULL sockets, +        # though no policies are configured yet. +        server.zap_domain = b'global' +        client = self.socket(zmq.PULL) +        self.assertFalse(self.can_connect(server, client)) + +    def test_whitelist(self): +        """threaded auth - Whitelist""" +        # Whitelist 127.0.0.1, connection should pass" +        self.auth.allow('127.0.0.1') +        server = self.socket(zmq.PUSH) +        # By setting a domain we switch on authentication for NULL sockets, +        # though no policies are configured yet. +        server.zap_domain = b'global' +        client = self.socket(zmq.PULL) +        self.assertTrue(self.can_connect(server, client)) + +    def test_plain(self): +        """threaded auth - PLAIN""" + +        # Try PLAIN authentication - without configuring server, connection should fail +        server = self.socket(zmq.PUSH) +        server.plain_server = True +        client = self.socket(zmq.PULL) +        client.plain_username = b'admin' +        client.plain_password = b'Password' +        self.assertFalse(self.can_connect(server, client)) + +        # Try PLAIN authentication - with server configured, connection should pass +        server = self.socket(zmq.PUSH) +        server.plain_server = True +        client = self.socket(zmq.PULL) +        client.plain_username = b'admin' +        client.plain_password = b'Password' +        self.auth.configure_plain(domain='*', passwords={'admin': 'Password'}) +        self.assertTrue(self.can_connect(server, client)) + +        # Try PLAIN authentication - with bogus credentials, connection should fail +        server = self.socket(zmq.PUSH) +        server.plain_server = True +        client = self.socket(zmq.PULL) +        client.plain_username = b'admin' +        client.plain_password = b'Bogus' +        self.assertFalse(self.can_connect(server, client)) + +        # Remove authenticator and check that a normal connection works +        self.auth.stop() +        self.auth = None + +        server = self.socket(zmq.PUSH) +        client = self.socket(zmq.PULL) +        self.assertTrue(self.can_connect(server, client)) +        client.close() +        server.close() + +    def test_curve(self): +        """threaded auth - CURVE""" +        self.auth.allow('127.0.0.1') +        certs = self.load_certs(self.secret_keys_dir) +        server_public, server_secret, client_public, client_secret = certs + +        #Try CURVE authentication - without configuring server, connection should fail +        server = self.socket(zmq.PUSH) +        server.curve_publickey = server_public +        server.curve_secretkey = server_secret +        server.curve_server = True +        client = self.socket(zmq.PULL) +        client.curve_publickey = client_public +        client.curve_secretkey = client_secret +        client.curve_serverkey = server_public +        self.assertFalse(self.can_connect(server, client)) + +        #Try CURVE authentication - with server configured to CURVE_ALLOW_ANY, connection should pass +        self.auth.configure_curve(domain='*', location=zmq.auth.CURVE_ALLOW_ANY) +        server = self.socket(zmq.PUSH) +        server.curve_publickey = server_public +        server.curve_secretkey = server_secret +        server.curve_server = True +        client = self.socket(zmq.PULL) +        client.curve_publickey = client_public +        client.curve_secretkey = client_secret +        client.curve_serverkey = server_public +        self.assertTrue(self.can_connect(server, client)) + +        # Try CURVE authentication - with server configured, connection should pass +        self.auth.configure_curve(domain='*', location=self.public_keys_dir) +        server = self.socket(zmq.PUSH) +        server.curve_publickey = server_public +        server.curve_secretkey = server_secret +        server.curve_server = True +        client = self.socket(zmq.PULL) +        client.curve_publickey = client_public +        client.curve_secretkey = client_secret +        client.curve_serverkey = server_public +        self.assertTrue(self.can_connect(server, client)) + +        # Remove authenticator and check that a normal connection works +        self.auth.stop() +        self.auth = None + +        # Try connecting using NULL and no authentication enabled, connection should pass +        server = self.socket(zmq.PUSH) +        client = self.socket(zmq.PULL) +        self.assertTrue(self.can_connect(server, client)) + + +def with_ioloop(method, expect_success=True): +    """decorator for running tests with an IOLoop""" +    def test_method(self): +        r = method(self) +         +        loop = self.io_loop +        if expect_success: +            self.pullstream.on_recv(self.on_message_succeed) +        else: +            self.pullstream.on_recv(self.on_message_fail) +         +        t = loop.time() +        loop.add_callback(self.attempt_connection) +        loop.add_callback(self.send_msg) +        if expect_success: +            loop.add_timeout(t + 1, self.on_test_timeout_fail) +        else: +            loop.add_timeout(t + 1, self.on_test_timeout_succeed) +         +        loop.start() +        if self.fail_msg: +            self.fail(self.fail_msg) +         +        return r +    return test_method + +def should_auth(method): +    return with_ioloop(method, True) + +def should_not_auth(method): +    return with_ioloop(method, False) + +class TestIOLoopAuthentication(BaseAuthTestCase): +    """Test authentication running in ioloop""" + +    def setUp(self): +        self.fail_msg = None +        self.io_loop = ioloop.IOLoop() +        super(TestIOLoopAuthentication, self).setUp() +        self.server = self.socket(zmq.PUSH) +        self.client = self.socket(zmq.PULL) +        self.pushstream = zmqstream.ZMQStream(self.server, self.io_loop) +        self.pullstream = zmqstream.ZMQStream(self.client, self.io_loop) +     +    def make_auth(self): +        return IOLoopAuthenticator(self.context, io_loop=self.io_loop) + +    def tearDown(self): +        if self.auth: +            self.auth.stop() +            self.auth = None +        self.io_loop.close(all_fds=True) +        super(TestIOLoopAuthentication, self).tearDown() + +    def attempt_connection(self): +        """Check if client can connect to server using tcp transport""" +        iface = 'tcp://127.0.0.1' +        port = self.server.bind_to_random_port(iface) +        self.client.connect("%s:%i" % (iface, port)) + +    def send_msg(self): +        """Send a message from server to a client""" +        msg = [b"Hello World"] +        self.pushstream.send_multipart(msg) +     +    def on_message_succeed(self, frames): +        """A message was received, as expected.""" +        if frames != [b"Hello World"]: +            self.fail_msg = "Unexpected message received" +        self.io_loop.stop() + +    def on_message_fail(self, frames): +        """A message was received, unexpectedly.""" +        self.fail_msg = 'Received messaged unexpectedly, security failed' +        self.io_loop.stop() + +    def on_test_timeout_succeed(self): +        """Test timer expired, indicates test success""" +        self.io_loop.stop() + +    def on_test_timeout_fail(self): +        """Test timer expired, indicates test failure""" +        self.fail_msg = 'Test timed out' +        self.io_loop.stop() + +    @should_auth +    def test_none(self): +        """ioloop auth - NONE""" +        # A default NULL connection should always succeed, and not +        # go through our authentication infrastructure at all. +        # no auth should be running +        self.auth.stop() +        self.auth = None + +    @should_auth +    def test_null(self): +        """ioloop auth - NULL""" +        # By setting a domain we switch on authentication for NULL sockets, +        # though no policies are configured yet. The client connection +        # should still be allowed. +        self.server.zap_domain = b'global' + +    @should_not_auth +    def test_blacklist(self): +        """ioloop auth - Blacklist""" +        # Blacklist 127.0.0.1, connection should fail +        self.auth.deny('127.0.0.1') +        self.server.zap_domain = b'global' + +    @should_auth +    def test_whitelist(self): +        """ioloop auth - Whitelist""" +        # Whitelist 127.0.0.1, which overrides the blacklist, connection should pass" +        self.auth.allow('127.0.0.1') + +        self.server.setsockopt(zmq.ZAP_DOMAIN, b'global') + +    @should_not_auth +    def test_plain_unconfigured_server(self): +        """ioloop auth - PLAIN, unconfigured server""" +        self.client.plain_username = b'admin' +        self.client.plain_password = b'Password' +        # Try PLAIN authentication - without configuring server, connection should fail +        self.server.plain_server = True + +    @should_auth +    def test_plain_configured_server(self): +        """ioloop auth - PLAIN, configured server""" +        self.client.plain_username = b'admin' +        self.client.plain_password = b'Password' +        # Try PLAIN authentication - with server configured, connection should pass +        self.server.plain_server = True +        self.auth.configure_plain(domain='*', passwords={'admin': 'Password'}) + +    @should_not_auth +    def test_plain_bogus_credentials(self): +        """ioloop auth - PLAIN, bogus credentials""" +        self.client.plain_username = b'admin' +        self.client.plain_password = b'Bogus' +        self.server.plain_server = True + +        self.auth.configure_plain(domain='*', passwords={'admin': 'Password'}) + +    @should_not_auth +    def test_curve_unconfigured_server(self): +        """ioloop auth - CURVE, unconfigured server""" +        certs = self.load_certs(self.secret_keys_dir) +        server_public, server_secret, client_public, client_secret = certs + +        self.auth.allow('127.0.0.1') + +        self.server.curve_publickey = server_public +        self.server.curve_secretkey = server_secret +        self.server.curve_server = True + +        self.client.curve_publickey = client_public +        self.client.curve_secretkey = client_secret +        self.client.curve_serverkey = server_public + +    @should_auth +    def test_curve_allow_any(self): +        """ioloop auth - CURVE, CURVE_ALLOW_ANY""" +        certs = self.load_certs(self.secret_keys_dir) +        server_public, server_secret, client_public, client_secret = certs + +        self.auth.allow('127.0.0.1') +        self.auth.configure_curve(domain='*', location=zmq.auth.CURVE_ALLOW_ANY) + +        self.server.curve_publickey = server_public +        self.server.curve_secretkey = server_secret +        self.server.curve_server = True + +        self.client.curve_publickey = client_public +        self.client.curve_secretkey = client_secret +        self.client.curve_serverkey = server_public + +    @should_auth +    def test_curve_configured_server(self): +        """ioloop auth - CURVE, configured server""" +        self.auth.allow('127.0.0.1') +        certs = self.load_certs(self.secret_keys_dir) +        server_public, server_secret, client_public, client_secret = certs + +        self.auth.configure_curve(domain='*', location=self.public_keys_dir) + +        self.server.curve_publickey = server_public +        self.server.curve_secretkey = server_secret +        self.server.curve_server = True + +        self.client.curve_publickey = client_public +        self.client.curve_secretkey = client_secret +        self.client.curve_serverkey = server_public diff --git a/zmq/tests/test_cffi_backend.py b/zmq/tests/test_cffi_backend.py new file mode 100644 index 0000000..1f85eeb --- /dev/null +++ b/zmq/tests/test_cffi_backend.py @@ -0,0 +1,310 @@ +# -*- coding: utf8 -*- + +import sys +import time + +from unittest import TestCase + +from zmq.tests import BaseZMQTestCase, SkipTest + +try: +    from zmq.backend.cffi import ( +        zmq_version_info, +        PUSH, PULL, IDENTITY, +        REQ, REP, POLLIN, POLLOUT, +    ) +    from zmq.backend.cffi._cffi import ffi, C +    have_ffi_backend = True +except ImportError: +    have_ffi_backend = False + + +class TestCFFIBackend(TestCase): +     +    def setUp(self): +        if not have_ffi_backend or not 'PyPy' in sys.version: +            raise SkipTest('PyPy Tests Only') + +    def test_zmq_version_info(self): +        version = zmq_version_info() + +        assert version[0] in range(2,11) + +    def test_zmq_ctx_new_destroy(self): +        ctx = C.zmq_ctx_new() + +        assert ctx != ffi.NULL +        assert 0 == C.zmq_ctx_destroy(ctx) + +    def test_zmq_socket_open_close(self): +        ctx = C.zmq_ctx_new() +        socket = C.zmq_socket(ctx, PUSH) + +        assert ctx != ffi.NULL +        assert ffi.NULL != socket +        assert 0 == C.zmq_close(socket) +        assert 0 == C.zmq_ctx_destroy(ctx) + +    def test_zmq_setsockopt(self): +        ctx = C.zmq_ctx_new() +        socket = C.zmq_socket(ctx, PUSH) + +        identity = ffi.new('char[3]', 'zmq') +        ret = C.zmq_setsockopt(socket, IDENTITY, ffi.cast('void*', identity), 3) + +        assert ret == 0 +        assert ctx != ffi.NULL +        assert ffi.NULL != socket +        assert 0 == C.zmq_close(socket) +        assert 0 == C.zmq_ctx_destroy(ctx) + +    def test_zmq_getsockopt(self): +        ctx = C.zmq_ctx_new() +        socket = C.zmq_socket(ctx, PUSH) + +        identity = ffi.new('char[]', 'zmq') +        ret = C.zmq_setsockopt(socket, IDENTITY, ffi.cast('void*', identity), 3) +        assert ret == 0 + +        option_len = ffi.new('size_t*', 3) +        option = ffi.new('char*') +        ret = C.zmq_getsockopt(socket, +                            IDENTITY, +                            ffi.cast('void*', option), +                            option_len) + +        assert ret == 0 +        assert ffi.string(ffi.cast('char*', option))[0] == "z" +        assert ffi.string(ffi.cast('char*', option))[1] == "m" +        assert ffi.string(ffi.cast('char*', option))[2] == "q" +        assert ctx != ffi.NULL +        assert ffi.NULL != socket +        assert 0 == C.zmq_close(socket) +        assert 0 == C.zmq_ctx_destroy(ctx) + +    def test_zmq_bind(self): +        ctx = C.zmq_ctx_new() +        socket = C.zmq_socket(ctx, 8) + +        assert 0 == C.zmq_bind(socket, 'tcp://*:4444') +        assert ctx != ffi.NULL +        assert ffi.NULL != socket +        assert 0 == C.zmq_close(socket) +        assert 0 == C.zmq_ctx_destroy(ctx) + +    def test_zmq_bind_connect(self): +        ctx = C.zmq_ctx_new() + +        socket1 = C.zmq_socket(ctx, PUSH) +        socket2 = C.zmq_socket(ctx, PULL) + +        assert 0 == C.zmq_bind(socket1, 'tcp://*:4444') +        assert 0 == C.zmq_connect(socket2, 'tcp://127.0.0.1:4444') +        assert ctx != ffi.NULL +        assert ffi.NULL != socket1 +        assert ffi.NULL != socket2 +        assert 0 == C.zmq_close(socket1) +        assert 0 == C.zmq_close(socket2) +        assert 0 == C.zmq_ctx_destroy(ctx) + +    def test_zmq_msg_init_close(self): +        zmq_msg = ffi.new('zmq_msg_t*') + +        assert ffi.NULL != zmq_msg +        assert 0 == C.zmq_msg_init(zmq_msg) +        assert 0 == C.zmq_msg_close(zmq_msg) + +    def test_zmq_msg_init_size(self): +        zmq_msg = ffi.new('zmq_msg_t*') + +        assert ffi.NULL != zmq_msg +        assert 0 == C.zmq_msg_init_size(zmq_msg, 10) +        assert 0 == C.zmq_msg_close(zmq_msg) + +    def test_zmq_msg_init_data(self): +        zmq_msg = ffi.new('zmq_msg_t*') +        message = ffi.new('char[5]', 'Hello') + +        assert 0 == C.zmq_msg_init_data(zmq_msg, +                                        ffi.cast('void*', message), +                                        5, +                                        ffi.NULL, +                                        ffi.NULL) + +        assert ffi.NULL != zmq_msg +        assert 0 == C.zmq_msg_close(zmq_msg) + +    def test_zmq_msg_data(self): +        zmq_msg = ffi.new('zmq_msg_t*') +        message = ffi.new('char[]', 'Hello') +        assert 0 == C.zmq_msg_init_data(zmq_msg, +                                        ffi.cast('void*', message), +                                        5, +                                        ffi.NULL, +                                        ffi.NULL) + +        data = C.zmq_msg_data(zmq_msg) + +        assert ffi.NULL != zmq_msg +        assert ffi.string(ffi.cast("char*", data)) == 'Hello' +        assert 0 == C.zmq_msg_close(zmq_msg) + + +    def test_zmq_send(self): +        ctx = C.zmq_ctx_new() + +        sender = C.zmq_socket(ctx, REQ) +        receiver = C.zmq_socket(ctx, REP) + +        assert 0 == C.zmq_bind(receiver, 'tcp://*:7777') +        assert 0 == C.zmq_connect(sender, 'tcp://127.0.0.1:7777') + +        time.sleep(0.1) + +        zmq_msg = ffi.new('zmq_msg_t*') +        message = ffi.new('char[5]', 'Hello') + +        C.zmq_msg_init_data(zmq_msg, +                            ffi.cast('void*', message), +                            ffi.cast('size_t', 5), +                            ffi.NULL, +                            ffi.NULL) + +        assert 5 == C.zmq_msg_send(zmq_msg, sender, 0) +        assert 0 == C.zmq_msg_close(zmq_msg) +        assert C.zmq_close(sender) == 0 +        assert C.zmq_close(receiver) == 0 +        assert C.zmq_ctx_destroy(ctx) == 0 + +    def test_zmq_recv(self): +        ctx = C.zmq_ctx_new() + +        sender = C.zmq_socket(ctx, REQ) +        receiver = C.zmq_socket(ctx, REP) + +        assert 0 == C.zmq_bind(receiver, 'tcp://*:2222') +        assert 0 == C.zmq_connect(sender, 'tcp://127.0.0.1:2222') + +        time.sleep(0.1) + +        zmq_msg = ffi.new('zmq_msg_t*') +        message = ffi.new('char[5]', 'Hello') + +        C.zmq_msg_init_data(zmq_msg, +                            ffi.cast('void*', message), +                            ffi.cast('size_t', 5), +                            ffi.NULL, +                            ffi.NULL) + +        zmq_msg2 = ffi.new('zmq_msg_t*') +        C.zmq_msg_init(zmq_msg2) + +        assert 5 == C.zmq_msg_send(zmq_msg, sender, 0) +        assert 5 == C.zmq_msg_recv(zmq_msg2, receiver, 0) +        assert 5 == C.zmq_msg_size(zmq_msg2) +        assert b"Hello" == ffi.buffer(C.zmq_msg_data(zmq_msg2), +                                      C.zmq_msg_size(zmq_msg2))[:] +        assert C.zmq_close(sender) == 0 +        assert C.zmq_close(receiver) == 0 +        assert C.zmq_ctx_destroy(ctx) == 0 + +    def test_zmq_poll(self): +        ctx = C.zmq_ctx_new() + +        sender = C.zmq_socket(ctx, REQ) +        receiver = C.zmq_socket(ctx, REP) + +        r1 = C.zmq_bind(receiver, 'tcp://*:3333') +        r2 = C.zmq_connect(sender, 'tcp://127.0.0.1:3333') + +        zmq_msg = ffi.new('zmq_msg_t*') +        message = ffi.new('char[5]', 'Hello') + +        C.zmq_msg_init_data(zmq_msg, +                            ffi.cast('void*', message), +                            ffi.cast('size_t', 5), +                            ffi.NULL, +                            ffi.NULL) + +        receiver_pollitem = ffi.new('zmq_pollitem_t*') +        receiver_pollitem.socket = receiver +        receiver_pollitem.fd = 0 +        receiver_pollitem.events = POLLIN | POLLOUT +        receiver_pollitem.revents = 0 + +        ret = C.zmq_poll(ffi.NULL, 0, 0) +        assert ret == 0 + +        ret = C.zmq_poll(receiver_pollitem, 1, 0) +        assert ret == 0 + +        ret = C.zmq_msg_send(zmq_msg, sender, 0) +        print(ffi.string(C.zmq_strerror(C.zmq_errno()))) +        assert ret == 5 + +        time.sleep(0.2) + +        ret = C.zmq_poll(receiver_pollitem, 1, 0) +        assert ret == 1 + +        assert int(receiver_pollitem.revents) & POLLIN +        assert not int(receiver_pollitem.revents) & POLLOUT + +        zmq_msg2 = ffi.new('zmq_msg_t*') +        C.zmq_msg_init(zmq_msg2) + +        ret_recv = C.zmq_msg_recv(zmq_msg2, receiver, 0) +        assert ret_recv == 5 + +        assert 5 == C.zmq_msg_size(zmq_msg2) +        assert b"Hello" == ffi.buffer(C.zmq_msg_data(zmq_msg2), +                                    C.zmq_msg_size(zmq_msg2))[:] + +        sender_pollitem = ffi.new('zmq_pollitem_t*') +        sender_pollitem.socket = sender +        sender_pollitem.fd = 0 +        sender_pollitem.events = POLLIN | POLLOUT +        sender_pollitem.revents = 0 + +        ret = C.zmq_poll(sender_pollitem, 1, 0) +        assert ret == 0 + +        zmq_msg_again = ffi.new('zmq_msg_t*') +        message_again = ffi.new('char[11]', 'Hello Again') + +        C.zmq_msg_init_data(zmq_msg_again, +                            ffi.cast('void*', message_again), +                            ffi.cast('size_t', 11), +                            ffi.NULL, +                            ffi.NULL) + +        assert 11 == C.zmq_msg_send(zmq_msg_again, receiver, 0) + +        time.sleep(0.2) + +        assert 0 <= C.zmq_poll(sender_pollitem, 1, 0) +        assert int(sender_pollitem.revents) & POLLIN +        assert 11 == C.zmq_msg_recv(zmq_msg2, sender, 0) +        assert 11 == C.zmq_msg_size(zmq_msg2) +        assert b"Hello Again" == ffi.buffer(C.zmq_msg_data(zmq_msg2), +                                            int(C.zmq_msg_size(zmq_msg2)))[:] +        assert 0 == C.zmq_close(sender) +        assert 0 == C.zmq_close(receiver) +        assert 0 == C.zmq_ctx_destroy(ctx) +        assert 0 == C.zmq_msg_close(zmq_msg) +        assert 0 == C.zmq_msg_close(zmq_msg2) +        assert 0 == C.zmq_msg_close(zmq_msg_again) + +    def test_zmq_stopwatch_functions(self): +        stopwatch = C.zmq_stopwatch_start() +        ret = C.zmq_stopwatch_stop(stopwatch) + +        assert ffi.NULL != stopwatch +        assert 0 < int(ret) + +    def test_zmq_sleep(self): +        try: +            C.zmq_sleep(1) +        except Exception as e: +            raise AssertionError("Error executing zmq_sleep(int)") + diff --git a/zmq/tests/test_constants.py b/zmq/tests/test_constants.py new file mode 100644 index 0000000..d32b2b4 --- /dev/null +++ b/zmq/tests/test_constants.py @@ -0,0 +1,104 @@ +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +import json +from unittest import TestCase + +import zmq + +from zmq.utils import constant_names +from zmq.sugar import constants as sugar_constants +from zmq.backend import constants as backend_constants + +all_set = set(constant_names.all_names) + +class TestConstants(TestCase): +     +    def _duplicate_test(self, namelist, listname): +        """test that a given list has no duplicates""" +        dupes = {} +        for name in set(namelist): +            cnt = namelist.count(name) +            if cnt > 1: +                dupes[name] = cnt +        if dupes: +            self.fail("The following names occur more than once in %s: %s" % (listname, json.dumps(dupes, indent=2))) +     +    def test_duplicate_all(self): +        return self._duplicate_test(constant_names.all_names, "all_names") +     +    def _change_key(self, change, version): +        """return changed-in key""" +        return "%s-in %d.%d.%d" % tuple([change] + list(version)) + +    def test_duplicate_changed(self): +        all_changed = [] +        for change in ("new", "removed"): +            d = getattr(constant_names, change + "_in") +            for version, namelist in d.items(): +                all_changed.extend(namelist) +                self._duplicate_test(namelist, self._change_key(change, version)) +         +        self._duplicate_test(all_changed, "all-changed") +     +    def test_changed_in_all(self): +        missing = {} +        for change in ("new", "removed"): +            d = getattr(constant_names, change + "_in") +            for version, namelist in d.items(): +                key = self._change_key(change, version) +                for name in namelist: +                    if name not in all_set: +                        if key not in missing: +                            missing[key] = [] +                        missing[key].append(name) +         +        if missing: +            self.fail( +                "The following names are missing in `all_names`: %s" % json.dumps(missing, indent=2) +            ) +     +    def test_no_negative_constants(self): +        for name in sugar_constants.__all__: +            self.assertNotEqual(getattr(zmq, name), sugar_constants._UNDEFINED) +     +    def test_undefined_constants(self): +        all_aliases = [] +        for alias_group in sugar_constants.aliases: +            all_aliases.extend(alias_group) +         +        for name in all_set.difference(all_aliases): +            raw = getattr(backend_constants, name) +            if raw == sugar_constants._UNDEFINED: +                self.assertRaises(AttributeError, getattr, zmq, name) +            else: +                self.assertEqual(getattr(zmq, name), raw) +     +    def test_new(self): +        zmq_version = zmq.zmq_version_info() +        for version, new_names in constant_names.new_in.items(): +            should_have = zmq_version >= version +            for name in new_names: +                try: +                    value = getattr(zmq, name) +                except AttributeError: +                    if should_have: +                        self.fail("AttributeError: zmq.%s" % name) +                else: +                    if not should_have: +                        self.fail("Shouldn't have: zmq.%s=%s" % (name, value)) + +    def test_removed(self): +        zmq_version = zmq.zmq_version_info() +        for version, new_names in constant_names.removed_in.items(): +            should_have = zmq_version < version +            for name in new_names: +                try: +                    value = getattr(zmq, name) +                except AttributeError: +                    if should_have: +                        self.fail("AttributeError: zmq.%s" % name) +                else: +                    if not should_have: +                        self.fail("Shouldn't have: zmq.%s=%s" % (name, value)) + diff --git a/zmq/tests/test_context.py b/zmq/tests/test_context.py new file mode 100644 index 0000000..e328077 --- /dev/null +++ b/zmq/tests/test_context.py @@ -0,0 +1,257 @@ +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +import gc +import sys +import time +from threading import Thread, Event + +import zmq +from zmq.tests import ( +    BaseZMQTestCase, have_gevent, GreenTest, skip_green, PYPY, SkipTest, +) + + +class TestContext(BaseZMQTestCase): + +    def test_init(self): +        c1 = self.Context() +        self.assert_(isinstance(c1, self.Context)) +        del c1 +        c2 = self.Context() +        self.assert_(isinstance(c2, self.Context)) +        del c2 +        c3 = self.Context() +        self.assert_(isinstance(c3, self.Context)) +        del c3 + +    def test_dir(self): +        ctx = self.Context() +        self.assertTrue('socket' in dir(ctx)) +        if zmq.zmq_version_info() > (3,): +            self.assertTrue('IO_THREADS' in dir(ctx)) +        ctx.term() + +    def test_term(self): +        c = self.Context() +        c.term() +        self.assert_(c.closed) +     +    def test_context_manager(self): +        with self.Context() as c: +            pass +        self.assert_(c.closed) +     +    def test_fail_init(self): +        self.assertRaisesErrno(zmq.EINVAL, self.Context, -1) +     +    def test_term_hang(self): +        rep,req = self.create_bound_pair(zmq.ROUTER, zmq.DEALER) +        req.setsockopt(zmq.LINGER, 0) +        req.send(b'hello', copy=False) +        req.close() +        rep.close() +        self.context.term() +     +    def test_instance(self): +        ctx = self.Context.instance() +        c2 = self.Context.instance(io_threads=2) +        self.assertTrue(c2 is ctx) +        c2.term() +        c3 = self.Context.instance() +        c4 = self.Context.instance() +        self.assertFalse(c3 is c2) +        self.assertFalse(c3.closed) +        self.assertTrue(c3 is c4) +     +    def test_many_sockets(self): +        """opening and closing many sockets shouldn't cause problems""" +        ctx = self.Context() +        for i in range(16): +            sockets = [ ctx.socket(zmq.REP) for i in range(65) ] +            [ s.close() for s in sockets ] +            # give the reaper a chance +            time.sleep(1e-2) +        ctx.term() +     +    def test_sockopts(self): +        """setting socket options with ctx attributes""" +        ctx = self.Context() +        ctx.linger = 5 +        self.assertEqual(ctx.linger, 5) +        s = ctx.socket(zmq.REQ) +        self.assertEqual(s.linger, 5) +        self.assertEqual(s.getsockopt(zmq.LINGER), 5) +        s.close() +        # check that subscribe doesn't get set on sockets that don't subscribe: +        ctx.subscribe = b'' +        s = ctx.socket(zmq.REQ) +        s.close() +         +        ctx.term() + +     +    def test_destroy(self): +        """Context.destroy should close sockets""" +        ctx = self.Context() +        sockets = [ ctx.socket(zmq.REP) for i in range(65) ] +         +        # close half of the sockets +        [ s.close() for s in sockets[::2] ] +         +        ctx.destroy() +        # reaper is not instantaneous +        time.sleep(1e-2) +        for s in sockets: +            self.assertTrue(s.closed) +         +    def test_destroy_linger(self): +        """Context.destroy should set linger on closing sockets""" +        req,rep = self.create_bound_pair(zmq.REQ, zmq.REP) +        req.send(b'hi') +        time.sleep(1e-2) +        self.context.destroy(linger=0) +        # reaper is not instantaneous +        time.sleep(1e-2) +        for s in (req,rep): +            self.assertTrue(s.closed) +         +    def test_term_noclose(self): +        """Context.term won't close sockets""" +        ctx = self.Context() +        s = ctx.socket(zmq.REQ) +        self.assertFalse(s.closed) +        t = Thread(target=ctx.term) +        t.start() +        t.join(timeout=0.1) +        self.assertTrue(t.is_alive(), "Context should be waiting") +        s.close() +        t.join(timeout=0.1) +        self.assertFalse(t.is_alive(), "Context should have closed") +     +    def test_gc(self): +        """test close&term by garbage collection alone""" +        if PYPY: +            raise SkipTest("GC doesn't work ") +             +        # test credit @dln (GH #137): +        def gcf(): +            def inner(): +                ctx = self.Context() +                s = ctx.socket(zmq.PUSH) +            inner() +            gc.collect() +        t = Thread(target=gcf) +        t.start() +        t.join(timeout=1) +        self.assertFalse(t.is_alive(), "Garbage collection should have cleaned up context") +     +    def test_cyclic_destroy(self): +        """ctx.destroy should succeed when cyclic ref prevents gc""" +        # test credit @dln (GH #137): +        class CyclicReference(object): +            def __init__(self, parent=None): +                self.parent = parent +             +            def crash(self, sock): +                self.sock = sock +                self.child = CyclicReference(self) +         +        def crash_zmq(): +            ctx = self.Context() +            sock = ctx.socket(zmq.PULL) +            c = CyclicReference() +            c.crash(sock) +            ctx.destroy() +         +        crash_zmq() +     +    def test_term_thread(self): +        """ctx.term should not crash active threads (#139)""" +        ctx = self.Context() +        evt = Event() +        evt.clear() + +        def block(): +            s = ctx.socket(zmq.REP) +            s.bind_to_random_port('tcp://127.0.0.1') +            evt.set() +            try: +                s.recv() +            except zmq.ZMQError as e: +                self.assertEqual(e.errno, zmq.ETERM) +                return +            finally: +                s.close() +            self.fail("recv should have been interrupted with ETERM") +        t = Thread(target=block) +        t.start() +         +        evt.wait(1) +        self.assertTrue(evt.is_set(), "sync event never fired") +        time.sleep(0.01) +        ctx.term() +        t.join(timeout=1) +        self.assertFalse(t.is_alive(), "term should have interrupted s.recv()") +     +    def test_destroy_no_sockets(self): +        ctx = self.Context() +        s = ctx.socket(zmq.PUB) +        s.bind_to_random_port('tcp://127.0.0.1') +        s.close() +        ctx.destroy() +        assert s.closed +        assert ctx.closed +     +    def test_ctx_opts(self): +        if zmq.zmq_version_info() < (3,): +            raise SkipTest("context options require libzmq 3") +        ctx = self.Context() +        ctx.set(zmq.MAX_SOCKETS, 2) +        self.assertEqual(ctx.get(zmq.MAX_SOCKETS), 2) +        ctx.max_sockets = 100 +        self.assertEqual(ctx.max_sockets, 100) +        self.assertEqual(ctx.get(zmq.MAX_SOCKETS), 100) +     +    def test_shadow(self): +        ctx = self.Context() +        ctx2 = self.Context.shadow(ctx.underlying) +        self.assertEqual(ctx.underlying, ctx2.underlying) +        s = ctx.socket(zmq.PUB) +        s.close() +        del ctx2 +        self.assertFalse(ctx.closed) +        s = ctx.socket(zmq.PUB) +        ctx2 = self.Context.shadow(ctx.underlying) +        s2 = ctx2.socket(zmq.PUB) +        s.close() +        s2.close() +        ctx.term() +        self.assertRaisesErrno(zmq.EFAULT, ctx2.socket, zmq.PUB) +        del ctx2 +     +    def test_shadow_pyczmq(self): +        try: +            from pyczmq import zctx, zsocket, zstr +        except Exception: +            raise SkipTest("Requires pyczmq") +         +        ctx = zctx.new() +        a = zsocket.new(ctx, zmq.PUSH) +        zsocket.bind(a, "inproc://a") +        ctx2 = self.Context.shadow_pyczmq(ctx) +        b = ctx2.socket(zmq.PULL) +        b.connect("inproc://a") +        zstr.send(a, b'hi') +        rcvd = self.recv(b) +        self.assertEqual(rcvd, b'hi') +        b.close() + + +if False: # disable green context tests +    class TestContextGreen(GreenTest, TestContext): +        """gevent subclass of context tests""" +        # skip tests that use real threads: +        test_gc = GreenTest.skip_green +        test_term_thread = GreenTest.skip_green +        test_destroy_linger = GreenTest.skip_green diff --git a/zmq/tests/test_device.py b/zmq/tests/test_device.py new file mode 100644 index 0000000..f830507 --- /dev/null +++ b/zmq/tests/test_device.py @@ -0,0 +1,146 @@ +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +import time + +import zmq +from zmq import devices +from zmq.tests import BaseZMQTestCase, SkipTest, have_gevent, GreenTest, PYPY +from zmq.utils.strtypes import (bytes,unicode,basestring) + +if PYPY: +    # cleanup of shared Context doesn't work on PyPy +    devices.Device.context_factory = zmq.Context + +class TestDevice(BaseZMQTestCase): +     +    def test_device_types(self): +        for devtype in (zmq.STREAMER, zmq.FORWARDER, zmq.QUEUE): +            dev = devices.Device(devtype, zmq.PAIR, zmq.PAIR) +            self.assertEqual(dev.device_type, devtype) +            del dev +     +    def test_device_attributes(self): +        dev = devices.Device(zmq.QUEUE, zmq.SUB, zmq.PUB) +        self.assertEqual(dev.in_type, zmq.SUB) +        self.assertEqual(dev.out_type, zmq.PUB) +        self.assertEqual(dev.device_type, zmq.QUEUE) +        self.assertEqual(dev.daemon, True) +        del dev +     +    def test_tsdevice_attributes(self): +        dev = devices.Device(zmq.QUEUE, zmq.SUB, zmq.PUB) +        self.assertEqual(dev.in_type, zmq.SUB) +        self.assertEqual(dev.out_type, zmq.PUB) +        self.assertEqual(dev.device_type, zmq.QUEUE) +        self.assertEqual(dev.daemon, True) +        del dev +         +     +    def test_single_socket_forwarder_connect(self): +        dev = devices.ThreadDevice(zmq.QUEUE, zmq.REP, -1) +        req = self.context.socket(zmq.REQ) +        port = req.bind_to_random_port('tcp://127.0.0.1') +        dev.connect_in('tcp://127.0.0.1:%i'%port) +        dev.start() +        time.sleep(.25) +        msg = b'hello' +        req.send(msg) +        self.assertEqual(msg, self.recv(req)) +        del dev +        req.close() +        dev = devices.ThreadDevice(zmq.QUEUE, zmq.REP, -1) +        req = self.context.socket(zmq.REQ) +        port = req.bind_to_random_port('tcp://127.0.0.1') +        dev.connect_out('tcp://127.0.0.1:%i'%port) +        dev.start() +        time.sleep(.25) +        msg = b'hello again' +        req.send(msg) +        self.assertEqual(msg, self.recv(req)) +        del dev +        req.close() +         +    def test_single_socket_forwarder_bind(self): +        dev = devices.ThreadDevice(zmq.QUEUE, zmq.REP, -1) +        # select random port: +        binder = self.context.socket(zmq.REQ) +        port = binder.bind_to_random_port('tcp://127.0.0.1') +        binder.close() +        time.sleep(0.1) +        req = self.context.socket(zmq.REQ) +        req.connect('tcp://127.0.0.1:%i'%port) +        dev.bind_in('tcp://127.0.0.1:%i'%port) +        dev.start() +        time.sleep(.25) +        msg = b'hello' +        req.send(msg) +        self.assertEqual(msg, self.recv(req)) +        del dev +        req.close() +        dev = devices.ThreadDevice(zmq.QUEUE, zmq.REP, -1) +        # select random port: +        binder = self.context.socket(zmq.REQ) +        port = binder.bind_to_random_port('tcp://127.0.0.1') +        binder.close() +        time.sleep(0.1) +        req = self.context.socket(zmq.REQ) +        req.connect('tcp://127.0.0.1:%i'%port) +        dev.bind_in('tcp://127.0.0.1:%i'%port) +        dev.start() +        time.sleep(.25) +        msg = b'hello again' +        req.send(msg) +        self.assertEqual(msg, self.recv(req)) +        del dev +        req.close() +     +    def test_proxy(self): +        if zmq.zmq_version_info() < (3,2): +            raise SkipTest("Proxies only in libzmq >= 3") +        dev = devices.ThreadProxy(zmq.PULL, zmq.PUSH, zmq.PUSH) +        binder = self.context.socket(zmq.REQ) +        iface = 'tcp://127.0.0.1' +        port = binder.bind_to_random_port(iface) +        port2 = binder.bind_to_random_port(iface) +        port3 = binder.bind_to_random_port(iface) +        binder.close() +        time.sleep(0.1) +        dev.bind_in("%s:%i" % (iface, port)) +        dev.bind_out("%s:%i" % (iface, port2)) +        dev.bind_mon("%s:%i" % (iface, port3)) +        dev.start() +        time.sleep(0.25) +        msg = b'hello' +        push = self.context.socket(zmq.PUSH) +        push.connect("%s:%i" % (iface, port)) +        pull = self.context.socket(zmq.PULL) +        pull.connect("%s:%i" % (iface, port2)) +        mon = self.context.socket(zmq.PULL) +        mon.connect("%s:%i" % (iface, port3)) +        push.send(msg) +        self.sockets.extend([push, pull, mon]) +        self.assertEqual(msg, self.recv(pull)) +        self.assertEqual(msg, self.recv(mon)) + +if have_gevent: +    import gevent +    import zmq.green +     +    class TestDeviceGreen(GreenTest, BaseZMQTestCase): +         +        def test_green_device(self): +            rep = self.context.socket(zmq.REP) +            req = self.context.socket(zmq.REQ) +            self.sockets.extend([req, rep]) +            port = rep.bind_to_random_port('tcp://127.0.0.1') +            g = gevent.spawn(zmq.green.device, zmq.QUEUE, rep, rep) +            req.connect('tcp://127.0.0.1:%i' % port) +            req.send(b'hi') +            timeout = gevent.Timeout(3) +            timeout.start() +            receiver = gevent.spawn(req.recv) +            self.assertEqual(receiver.get(2), b'hi') +            timeout.cancel() +            g.kill(block=True) +             diff --git a/zmq/tests/test_error.py b/zmq/tests/test_error.py new file mode 100644 index 0000000..a2eee14 --- /dev/null +++ b/zmq/tests/test_error.py @@ -0,0 +1,43 @@ +# -*- coding: utf8 -*- +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +import sys +import time + +import zmq +from zmq import ZMQError, strerror, Again, ContextTerminated +from zmq.tests import BaseZMQTestCase + +if sys.version_info[0] >= 3: +    long = int + +class TestZMQError(BaseZMQTestCase): +     +    def test_strerror(self): +        """test that strerror gets the right type.""" +        for i in range(10): +            e = strerror(i) +            self.assertTrue(isinstance(e, str)) +     +    def test_zmqerror(self): +        for errno in range(10): +            e = ZMQError(errno) +            self.assertEqual(e.errno, errno) +            self.assertEqual(str(e), strerror(errno)) +     +    def test_again(self): +        s = self.context.socket(zmq.REP) +        self.assertRaises(Again, s.recv, zmq.NOBLOCK) +        self.assertRaisesErrno(zmq.EAGAIN, s.recv, zmq.NOBLOCK) +        s.close() +     +    def atest_ctxterm(self): +        s = self.context.socket(zmq.REP) +        t = Thread(target=self.context.term) +        t.start() +        self.assertRaises(ContextTerminated, s.recv, zmq.NOBLOCK) +        self.assertRaisesErrno(zmq.TERM, s.recv, zmq.NOBLOCK) +        s.close() +        t.join() + diff --git a/zmq/tests/test_etc.py b/zmq/tests/test_etc.py new file mode 100644 index 0000000..ad22406 --- /dev/null +++ b/zmq/tests/test_etc.py @@ -0,0 +1,15 @@ +# Copyright (c) PyZMQ Developers. +# Distributed under the terms of the Modified BSD License. + +import sys + +import zmq + +from . import skip_if + +@skip_if(zmq.zmq_version_info() < (4,1), "libzmq < 4.1") +def test_has(): +    assert not zmq.has('something weird') +    has_ipc = zmq.has('ipc') +    not_windows = not sys.platform.startswith('win') +    assert has_ipc == not_windows diff --git a/zmq/tests/test_imports.py b/zmq/tests/test_imports.py new file mode 100644 index 0000000..c0ddfaa --- /dev/null +++ b/zmq/tests/test_imports.py @@ -0,0 +1,62 @@ +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +import sys +from unittest import TestCase + +class TestImports(TestCase): +    """Test Imports - the quickest test to ensure that we haven't +    introduced version-incompatible syntax errors.""" +     +    def test_toplevel(self): +        """test toplevel import""" +        import zmq +         +    def test_core(self): +        """test core imports""" +        from zmq import Context +        from zmq import Socket +        from zmq import Poller +        from zmq import Frame +        from zmq import constants +        from zmq import device, proxy +        from zmq import Stopwatch +        from zmq import (  +            zmq_version, +            zmq_version_info, +            pyzmq_version, +            pyzmq_version_info, +        ) +     +    def test_devices(self): +        """test device imports""" +        import zmq.devices +        from zmq.devices import basedevice +        from zmq.devices import monitoredqueue +        from zmq.devices import monitoredqueuedevice +     +    def test_log(self): +        """test log imports""" +        import zmq.log +        from zmq.log import handlers +     +    def test_eventloop(self): +        """test eventloop imports""" +        import zmq.eventloop +        from zmq.eventloop import ioloop +        from zmq.eventloop import zmqstream +        from zmq.eventloop.minitornado.platform import auto +        from zmq.eventloop.minitornado import ioloop +     +    def test_utils(self): +        """test util imports""" +        import zmq.utils +        from zmq.utils import strtypes +        from zmq.utils import jsonapi +     +    def test_ssh(self): +        """test ssh imports""" +        from zmq.ssh import tunnel +     + + diff --git a/zmq/tests/test_ioloop.py b/zmq/tests/test_ioloop.py new file mode 100644 index 0000000..2a8b115 --- /dev/null +++ b/zmq/tests/test_ioloop.py @@ -0,0 +1,113 @@ +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +import time +import os +import threading + +import zmq +from zmq.tests import BaseZMQTestCase +from zmq.eventloop import ioloop +from zmq.eventloop.minitornado.ioloop import _Timeout +try: +    from tornado.ioloop import PollIOLoop, IOLoop as BaseIOLoop +except ImportError: +    from zmq.eventloop.minitornado.ioloop import IOLoop as BaseIOLoop + + +def printer(): +    os.system("say hello") +    raise Exception +    print (time.time()) + + +class Delay(threading.Thread): +    def __init__(self, f, delay=1): +        self.f=f +        self.delay=delay +        self.aborted=False +        self.cond=threading.Condition() +        super(Delay, self).__init__() +     +    def run(self): +        self.cond.acquire() +        self.cond.wait(self.delay) +        self.cond.release() +        if not self.aborted: +            self.f() +     +    def abort(self): +        self.aborted=True +        self.cond.acquire() +        self.cond.notify() +        self.cond.release() + + +class TestIOLoop(BaseZMQTestCase): + +    def test_simple(self): +        """simple IOLoop creation test""" +        loop = ioloop.IOLoop() +        dc = ioloop.PeriodicCallback(loop.stop, 200, loop) +        pc = ioloop.PeriodicCallback(lambda : None, 10, loop) +        pc.start() +        dc.start() +        t = Delay(loop.stop,1) +        t.start() +        loop.start() +        if t.isAlive(): +            t.abort() +        else: +            self.fail("IOLoop failed to exit") +     +    def test_timeout_compare(self): +        """test timeout comparisons""" +        loop = ioloop.IOLoop() +        t = _Timeout(1, 2, loop) +        t2 = _Timeout(1, 3, loop) +        self.assertEqual(t < t2, id(t) < id(t2)) +        t2 = _Timeout(2,1, loop) +        self.assertTrue(t < t2) + +    def test_poller_events(self): +        """Tornado poller implementation maps events correctly""" +        req,rep = self.create_bound_pair(zmq.REQ, zmq.REP) +        poller = ioloop.ZMQPoller() +        poller.register(req, ioloop.IOLoop.READ) +        poller.register(rep, ioloop.IOLoop.READ) +        events = dict(poller.poll(0)) +        self.assertEqual(events.get(rep), None) +        self.assertEqual(events.get(req), None) +         +        poller.register(req, ioloop.IOLoop.WRITE) +        poller.register(rep, ioloop.IOLoop.WRITE) +        events = dict(poller.poll(1)) +        self.assertEqual(events.get(req), ioloop.IOLoop.WRITE) +        self.assertEqual(events.get(rep), None) +         +        poller.register(rep, ioloop.IOLoop.READ) +        req.send(b'hi') +        events = dict(poller.poll(1)) +        self.assertEqual(events.get(rep), ioloop.IOLoop.READ) +        self.assertEqual(events.get(req), None) +     +    def test_instance(self): +        """Test IOLoop.instance returns the right object""" +        loop = ioloop.IOLoop.instance() +        self.assertEqual(loop.__class__, ioloop.IOLoop) +        loop = BaseIOLoop.instance() +        self.assertEqual(loop.__class__, ioloop.IOLoop) +     +    def test_close_all(self): +        """Test close(all_fds=True)""" +        loop = ioloop.IOLoop.instance() +        req,rep = self.create_bound_pair(zmq.REQ, zmq.REP) +        loop.add_handler(req, lambda msg: msg, ioloop.IOLoop.READ) +        loop.add_handler(rep, lambda msg: msg, ioloop.IOLoop.READ) +        self.assertEqual(req.closed, False) +        self.assertEqual(rep.closed, False) +        loop.close(all_fds=True) +        self.assertEqual(req.closed, True) +        self.assertEqual(rep.closed, True) +         + diff --git a/zmq/tests/test_log.py b/zmq/tests/test_log.py new file mode 100644 index 0000000..9206f09 --- /dev/null +++ b/zmq/tests/test_log.py @@ -0,0 +1,116 @@ +# encoding: utf-8 + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + + +import logging +import time +from unittest import TestCase + +import zmq +from zmq.log import handlers +from zmq.utils.strtypes import b, u +from zmq.tests import BaseZMQTestCase + + +class TestPubLog(BaseZMQTestCase): +     +    iface = 'inproc://zmqlog' +    topic= 'zmq' +     +    @property +    def logger(self): +        # print dir(self) +        logger = logging.getLogger('zmqtest') +        logger.setLevel(logging.DEBUG) +        return logger +     +    def connect_handler(self, topic=None): +        topic = self.topic if topic is None else topic +        logger = self.logger +        pub,sub = self.create_bound_pair(zmq.PUB, zmq.SUB) +        handler = handlers.PUBHandler(pub) +        handler.setLevel(logging.DEBUG) +        handler.root_topic = topic +        logger.addHandler(handler) +        sub.setsockopt(zmq.SUBSCRIBE, b(topic)) +        time.sleep(0.1) +        return logger, handler, sub +     +    def test_init_iface(self): +        logger = self.logger +        ctx = self.context +        handler = handlers.PUBHandler(self.iface) +        self.assertFalse(handler.ctx is ctx) +        self.sockets.append(handler.socket) +        # handler.ctx.term() +        handler = handlers.PUBHandler(self.iface, self.context) +        self.sockets.append(handler.socket) +        self.assertTrue(handler.ctx is ctx) +        handler.setLevel(logging.DEBUG) +        handler.root_topic = self.topic +        logger.addHandler(handler) +        sub = ctx.socket(zmq.SUB) +        self.sockets.append(sub) +        sub.setsockopt(zmq.SUBSCRIBE, b(self.topic)) +        sub.connect(self.iface) +        import time; time.sleep(0.25) +        msg1 = 'message' +        logger.info(msg1) +         +        (topic, msg2) = sub.recv_multipart() +        self.assertEqual(topic, b'zmq.INFO') +        self.assertEqual(msg2, b(msg1)+b'\n') +        logger.removeHandler(handler) +     +    def test_init_socket(self): +        pub,sub = self.create_bound_pair(zmq.PUB, zmq.SUB) +        logger = self.logger +        handler = handlers.PUBHandler(pub) +        handler.setLevel(logging.DEBUG) +        handler.root_topic = self.topic +        logger.addHandler(handler) +         +        self.assertTrue(handler.socket is pub) +        self.assertTrue(handler.ctx is pub.context) +        self.assertTrue(handler.ctx is self.context) +        sub.setsockopt(zmq.SUBSCRIBE, b(self.topic)) +        import time; time.sleep(0.1) +        msg1 = 'message' +        logger.info(msg1) +         +        (topic, msg2) = sub.recv_multipart() +        self.assertEqual(topic, b'zmq.INFO') +        self.assertEqual(msg2, b(msg1)+b'\n') +        logger.removeHandler(handler) +     +    def test_root_topic(self): +        logger, handler, sub = self.connect_handler() +        handler.socket.bind(self.iface) +        sub2 = sub.context.socket(zmq.SUB) +        self.sockets.append(sub2) +        sub2.connect(self.iface) +        sub2.setsockopt(zmq.SUBSCRIBE, b'') +        handler.root_topic = b'twoonly' +        msg1 = 'ignored' +        logger.info(msg1) +        self.assertRaisesErrno(zmq.EAGAIN, sub.recv, zmq.NOBLOCK) +        topic,msg2 = sub2.recv_multipart() +        self.assertEqual(topic, b'twoonly.INFO') +        self.assertEqual(msg2, b(msg1)+b'\n') +         +        logger.removeHandler(handler) +     +    def test_unicode_message(self): +        logger, handler, sub = self.connect_handler() +        base_topic = b(self.topic + '.INFO') +        for msg, expected in [ +            (u('hello'), [base_topic, b('hello\n')]), +            (u('héllo'), [base_topic, b('héllo\n')]), +            (u('tøpic::héllo'), [base_topic + b('.tøpic'), b('héllo\n')]), +        ]: +            logger.info(msg) +            received = sub.recv_multipart() +            self.assertEqual(received, expected) + diff --git a/zmq/tests/test_message.py b/zmq/tests/test_message.py new file mode 100644 index 0000000..d8770bd --- /dev/null +++ b/zmq/tests/test_message.py @@ -0,0 +1,362 @@ +# -*- coding: utf8 -*- +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + + +import copy +import sys +try: +    from sys import getrefcount as grc +except ImportError: +    grc = None + +import time +from pprint import pprint +from unittest import TestCase + +import zmq +from zmq.tests import BaseZMQTestCase, SkipTest, skip_pypy, PYPY +from zmq.utils.strtypes import unicode, bytes, b, u + + +# some useful constants: + +x = b'x' + +try: +    view = memoryview +except NameError: +    view = buffer + +if grc: +    rc0 = grc(x) +    v = view(x) +    view_rc = grc(x) - rc0 + +def await_gc(obj, rc): +    """wait for refcount on an object to drop to an expected value +     +    Necessary because of the zero-copy gc thread, +    which can take some time to receive its DECREF message. +    """ +    for i in range(50): +        # rc + 2 because of the refs in this function +        if grc(obj) <= rc + 2: +            return +        time.sleep(0.05) +     +class TestFrame(BaseZMQTestCase): + +    @skip_pypy +    def test_above_30(self): +        """Message above 30 bytes are never copied by 0MQ.""" +        for i in range(5, 16):  # 32, 64,..., 65536 +            s = (2**i)*x +            self.assertEqual(grc(s), 2) +            m = zmq.Frame(s) +            self.assertEqual(grc(s), 4) +            del m +            await_gc(s, 2) +            self.assertEqual(grc(s), 2) +            del s + +    def test_str(self): +        """Test the str representations of the Frames.""" +        for i in range(16): +            s = (2**i)*x +            m = zmq.Frame(s) +            m_str = str(m) +            m_str_b = b(m_str) # py3compat +            self.assertEqual(s, m_str_b) + +    def test_bytes(self): +        """Test the Frame.bytes property.""" +        for i in range(1,16): +            s = (2**i)*x +            m = zmq.Frame(s) +            b = m.bytes +            self.assertEqual(s, m.bytes) +            if not PYPY: +                # check that it copies +                self.assert_(b is not s) +            # check that it copies only once +            self.assert_(b is m.bytes) + +    def test_unicode(self): +        """Test the unicode representations of the Frames.""" +        s = u('asdf') +        self.assertRaises(TypeError, zmq.Frame, s) +        for i in range(16): +            s = (2**i)*u('§') +            m = zmq.Frame(s.encode('utf8')) +            self.assertEqual(s, unicode(m.bytes,'utf8')) + +    def test_len(self): +        """Test the len of the Frames.""" +        for i in range(16): +            s = (2**i)*x +            m = zmq.Frame(s) +            self.assertEqual(len(s), len(m)) + +    @skip_pypy +    def test_lifecycle1(self): +        """Run through a ref counting cycle with a copy.""" +        for i in range(5, 16):  # 32, 64,..., 65536 +            s = (2**i)*x +            rc = 2 +            self.assertEqual(grc(s), rc) +            m = zmq.Frame(s) +            rc += 2 +            self.assertEqual(grc(s), rc) +            m2 = copy.copy(m) +            rc += 1 +            self.assertEqual(grc(s), rc) +            buf = m2.buffer + +            rc += view_rc +            self.assertEqual(grc(s), rc) + +            self.assertEqual(s, b(str(m))) +            self.assertEqual(s, bytes(m2)) +            self.assertEqual(s, m.bytes) +            # self.assert_(s is str(m)) +            # self.assert_(s is str(m2)) +            del m2 +            rc -= 1 +            self.assertEqual(grc(s), rc) +            rc -= view_rc +            del buf +            self.assertEqual(grc(s), rc) +            del m +            rc -= 2 +            await_gc(s, rc) +            self.assertEqual(grc(s), rc) +            self.assertEqual(rc, 2) +            del s + +    @skip_pypy +    def test_lifecycle2(self): +        """Run through a different ref counting cycle with a copy.""" +        for i in range(5, 16):  # 32, 64,..., 65536 +            s = (2**i)*x +            rc = 2 +            self.assertEqual(grc(s), rc) +            m = zmq.Frame(s) +            rc += 2 +            self.assertEqual(grc(s), rc) +            m2 = copy.copy(m) +            rc += 1 +            self.assertEqual(grc(s), rc) +            buf = m.buffer +            rc += view_rc +            self.assertEqual(grc(s), rc) +            self.assertEqual(s, b(str(m))) +            self.assertEqual(s, bytes(m2)) +            self.assertEqual(s, m2.bytes) +            self.assertEqual(s, m.bytes) +            # self.assert_(s is str(m)) +            # self.assert_(s is str(m2)) +            del buf +            self.assertEqual(grc(s), rc) +            del m +            # m.buffer is kept until m is del'd +            rc -= view_rc +            rc -= 1 +            self.assertEqual(grc(s), rc) +            del m2 +            rc -= 2 +            await_gc(s, rc) +            self.assertEqual(grc(s), rc) +            self.assertEqual(rc, 2) +            del s +     +    @skip_pypy +    def test_tracker(self): +        m = zmq.Frame(b'asdf', track=True) +        self.assertFalse(m.tracker.done) +        pm = zmq.MessageTracker(m) +        self.assertFalse(pm.done) +        del m +        for i in range(10): +            if pm.done: +                break +            time.sleep(0.1) +        self.assertTrue(pm.done) +     +    def test_no_tracker(self): +        m = zmq.Frame(b'asdf', track=False) +        self.assertEqual(m.tracker, None) +        m2 = copy.copy(m) +        self.assertEqual(m2.tracker, None) +        self.assertRaises(ValueError, zmq.MessageTracker, m) +     +    @skip_pypy +    def test_multi_tracker(self): +        m = zmq.Frame(b'asdf', track=True) +        m2 = zmq.Frame(b'whoda', track=True) +        mt = zmq.MessageTracker(m,m2) +        self.assertFalse(m.tracker.done) +        self.assertFalse(mt.done) +        self.assertRaises(zmq.NotDone, mt.wait, 0.1) +        del m +        time.sleep(0.1) +        self.assertRaises(zmq.NotDone, mt.wait, 0.1) +        self.assertFalse(mt.done) +        del m2 +        self.assertTrue(mt.wait() is None) +        self.assertTrue(mt.done) +         +     +    def test_buffer_in(self): +        """test using a buffer as input""" +        ins = b("§§¶•ªº˜µ¬˚…∆˙åß∂©œ∑´†≈ç√") +        m = zmq.Frame(view(ins)) +     +    def test_bad_buffer_in(self): +        """test using a bad object""" +        self.assertRaises(TypeError, zmq.Frame, 5) +        self.assertRaises(TypeError, zmq.Frame, object()) +         +    def test_buffer_out(self): +        """receiving buffered output""" +        ins = b("§§¶•ªº˜µ¬˚…∆˙åß∂©œ∑´†≈ç√") +        m = zmq.Frame(ins) +        outb = m.buffer +        self.assertTrue(isinstance(outb, view)) +        self.assert_(outb is m.buffer) +        self.assert_(m.buffer is m.buffer) +     +    def test_multisend(self): +        """ensure that a message remains intact after multiple sends""" +        a,b = self.create_bound_pair(zmq.PAIR, zmq.PAIR) +        s = b"message" +        m = zmq.Frame(s) +        self.assertEqual(s, m.bytes) +         +        a.send(m, copy=False) +        time.sleep(0.1) +        self.assertEqual(s, m.bytes) +        a.send(m, copy=False) +        time.sleep(0.1) +        self.assertEqual(s, m.bytes) +        a.send(m, copy=True) +        time.sleep(0.1) +        self.assertEqual(s, m.bytes) +        a.send(m, copy=True) +        time.sleep(0.1) +        self.assertEqual(s, m.bytes) +        for i in range(4): +            r = b.recv() +            self.assertEqual(s,r) +        self.assertEqual(s, m.bytes) +     +    def test_buffer_numpy(self): +        """test non-copying numpy array messages""" +        try: +            import numpy +        except ImportError: +            raise SkipTest("numpy required") +        rand = numpy.random.randint +        shapes = [ rand(2,16) for i in range(5) ] +        for i in range(1,len(shapes)+1): +            shape = shapes[:i] +            A = numpy.random.random(shape) +            m = zmq.Frame(A) +            if view.__name__ == 'buffer': +                self.assertEqual(A.data, m.buffer) +                B = numpy.frombuffer(m.buffer,dtype=A.dtype).reshape(A.shape) +            else: +                self.assertEqual(memoryview(A), m.buffer) +                B = numpy.array(m.buffer,dtype=A.dtype).reshape(A.shape) +            self.assertEqual((A==B).all(), True) +     +    def test_memoryview(self): +        """test messages from memoryview""" +        major,minor = sys.version_info[:2] +        if not (major >= 3 or (major == 2 and minor >= 7)): +            raise SkipTest("memoryviews only in python >= 2.7") + +        s = b'carrotjuice' +        v = memoryview(s) +        m = zmq.Frame(s) +        buf = m.buffer +        s2 = buf.tobytes() +        self.assertEqual(s2,s) +        self.assertEqual(m.bytes,s) +     +    def test_noncopying_recv(self): +        """check for clobbering message buffers""" +        null = b'\0'*64 +        sa,sb = self.create_bound_pair(zmq.PAIR, zmq.PAIR) +        for i in range(32): +            # try a few times +            sb.send(null, copy=False) +            m = sa.recv(copy=False) +            mb = m.bytes +            # buf = view(m) +            buf = m.buffer +            del m +            for i in range(5): +                ff=b'\xff'*(40 + i*10) +                sb.send(ff, copy=False) +                m2 = sa.recv(copy=False) +                if view.__name__ == 'buffer': +                    b = bytes(buf) +                else: +                    b = buf.tobytes() +                self.assertEqual(b, null) +                self.assertEqual(mb, null) +                self.assertEqual(m2.bytes, ff) + +    @skip_pypy +    def test_buffer_numpy(self): +        """test non-copying numpy array messages""" +        try: +            import numpy +        except ImportError: +            raise SkipTest("requires numpy") +        if sys.version_info < (2,7): +            raise SkipTest("requires new-style buffer interface (py >= 2.7)") +        rand = numpy.random.randint +        shapes = [ rand(2,5) for i in range(5) ] +        a,b = self.create_bound_pair(zmq.PAIR, zmq.PAIR) +        dtypes = [int, float, '>i4', 'B'] +        for i in range(1,len(shapes)+1): +            shape = shapes[:i] +            for dt in dtypes: +                A = numpy.empty(shape, dtype=dt) +                while numpy.isnan(A).any(): +                    # don't let nan sneak in +                    A = numpy.ndarray(shape, dtype=dt) +                a.send(A, copy=False) +                msg = b.recv(copy=False) +                 +                B = numpy.frombuffer(msg, A.dtype).reshape(A.shape) +                self.assertEqual(A.shape, B.shape) +                self.assertTrue((A==B).all()) +            A = numpy.empty(shape, dtype=[('a', int), ('b', float), ('c', 'a32')]) +            A['a'] = 1024 +            A['b'] = 1e9 +            A['c'] = 'hello there' +            a.send(A, copy=False) +            msg = b.recv(copy=False) +             +            B = numpy.frombuffer(msg, A.dtype).reshape(A.shape) +            self.assertEqual(A.shape, B.shape) +            self.assertTrue((A==B).all()) +     +    def test_frame_more(self): +        """test Frame.more attribute""" +        frame = zmq.Frame(b"hello") +        self.assertFalse(frame.more) +        sa,sb = self.create_bound_pair(zmq.PAIR, zmq.PAIR) +        sa.send_multipart([b'hi', b'there']) +        frame = self.recv(sb, copy=False) +        self.assertTrue(frame.more) +        if zmq.zmq_version_info()[0] >= 3 and not PYPY: +            self.assertTrue(frame.get(zmq.MORE)) +        frame = self.recv(sb, copy=False) +        self.assertFalse(frame.more) +        if zmq.zmq_version_info()[0] >= 3 and not PYPY: +            self.assertFalse(frame.get(zmq.MORE)) + diff --git a/zmq/tests/test_monitor.py b/zmq/tests/test_monitor.py new file mode 100644 index 0000000..4f03538 --- /dev/null +++ b/zmq/tests/test_monitor.py @@ -0,0 +1,71 @@ +# -*- coding: utf-8 -*- +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + + +import sys +import time +import struct + +from unittest import TestCase + +import zmq +from zmq.tests import BaseZMQTestCase, skip_if, skip_pypy +from zmq.utils.monitor import recv_monitor_message + +skip_lt_4 = skip_if(zmq.zmq_version_info() < (4,), "requires zmq >= 4") + +class TestSocketMonitor(BaseZMQTestCase): + +    @skip_lt_4 +    def test_monitor(self): +        """Test monitoring interface for sockets.""" +        s_rep = self.context.socket(zmq.REP) +        s_req = self.context.socket(zmq.REQ) +        self.sockets.extend([s_rep, s_req]) +        s_req.bind("tcp://127.0.0.1:6666") +        # try monitoring the REP socket +         +        s_rep.monitor("inproc://monitor.rep", zmq.EVENT_ALL) +        # create listening socket for monitor +        s_event = self.context.socket(zmq.PAIR) +        self.sockets.append(s_event) +        s_event.connect("inproc://monitor.rep") +        s_event.linger = 0 +        # test receive event for connect event +        s_rep.connect("tcp://127.0.0.1:6666") +        m = recv_monitor_message(s_event) +        if m['event'] == zmq.EVENT_CONNECT_DELAYED: +            self.assertEqual(m['endpoint'], b"tcp://127.0.0.1:6666") +            # test receive event for connected event +            m = recv_monitor_message(s_event) +        self.assertEqual(m['event'], zmq.EVENT_CONNECTED) +        self.assertEqual(m['endpoint'], b"tcp://127.0.0.1:6666") + +        # test monitor can be disabled. +        s_rep.disable_monitor() +        m = recv_monitor_message(s_event) +        self.assertEqual(m['event'], zmq.EVENT_MONITOR_STOPPED) + + +    @skip_lt_4 +    def test_monitor_connected(self): +        """Test connected monitoring socket.""" +        s_rep = self.context.socket(zmq.REP) +        s_req = self.context.socket(zmq.REQ) +        self.sockets.extend([s_rep, s_req]) +        s_req.bind("tcp://127.0.0.1:6667") +        # try monitoring the REP socket +        # create listening socket for monitor +        s_event = s_rep.get_monitor_socket() +        s_event.linger = 0 +        self.sockets.append(s_event) +        # test receive event for connect event +        s_rep.connect("tcp://127.0.0.1:6667") +        m = recv_monitor_message(s_event) +        if m['event'] == zmq.EVENT_CONNECT_DELAYED: +            self.assertEqual(m['endpoint'], b"tcp://127.0.0.1:6667") +            # test receive event for connected event +            m = recv_monitor_message(s_event) +        self.assertEqual(m['event'], zmq.EVENT_CONNECTED) +        self.assertEqual(m['endpoint'], b"tcp://127.0.0.1:6667") diff --git a/zmq/tests/test_monqueue.py b/zmq/tests/test_monqueue.py new file mode 100644 index 0000000..e855602 --- /dev/null +++ b/zmq/tests/test_monqueue.py @@ -0,0 +1,227 @@ +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +import time +from unittest import TestCase + +import zmq +from zmq import devices + +from zmq.tests import BaseZMQTestCase, SkipTest, PYPY +from zmq.utils.strtypes import unicode + + +if PYPY or zmq.zmq_version_info() >= (4,1): +    # cleanup of shared Context doesn't work on PyPy +    # there also seems to be a bug in cleanup in libzmq-4.1 (zeromq/libzmq#1052) +    devices.Device.context_factory = zmq.Context + + +class TestMonitoredQueue(BaseZMQTestCase): +     +    sockets = [] +     +    def build_device(self, mon_sub=b"", in_prefix=b'in', out_prefix=b'out'): +        self.device = devices.ThreadMonitoredQueue(zmq.PAIR, zmq.PAIR, zmq.PUB, +                                            in_prefix, out_prefix) +        alice = self.context.socket(zmq.PAIR) +        bob = self.context.socket(zmq.PAIR) +        mon = self.context.socket(zmq.SUB) +         +        aport = alice.bind_to_random_port('tcp://127.0.0.1') +        bport = bob.bind_to_random_port('tcp://127.0.0.1') +        mport = mon.bind_to_random_port('tcp://127.0.0.1') +        mon.setsockopt(zmq.SUBSCRIBE, mon_sub) +         +        self.device.connect_in("tcp://127.0.0.1:%i"%aport) +        self.device.connect_out("tcp://127.0.0.1:%i"%bport) +        self.device.connect_mon("tcp://127.0.0.1:%i"%mport) +        self.device.start() +        time.sleep(.2) +        try: +            # this is currenlty necessary to ensure no dropped monitor messages +            # see LIBZMQ-248 for more info +            mon.recv_multipart(zmq.NOBLOCK) +        except zmq.ZMQError: +            pass +        self.sockets.extend([alice, bob, mon]) +        return alice, bob, mon +         +     +    def teardown_device(self): +        for socket in self.sockets: +            socket.close() +            del socket +        del self.device +         +    def test_reply(self): +        alice, bob, mon = self.build_device() +        alices = b"hello bob".split() +        alice.send_multipart(alices) +        bobs = self.recv_multipart(bob) +        self.assertEqual(alices, bobs) +        bobs = b"hello alice".split() +        bob.send_multipart(bobs) +        alices = self.recv_multipart(alice) +        self.assertEqual(alices, bobs) +        self.teardown_device() +     +    def test_queue(self): +        alice, bob, mon = self.build_device() +        alices = b"hello bob".split() +        alice.send_multipart(alices) +        alices2 = b"hello again".split() +        alice.send_multipart(alices2) +        alices3 = b"hello again and again".split() +        alice.send_multipart(alices3) +        bobs = self.recv_multipart(bob) +        self.assertEqual(alices, bobs) +        bobs = self.recv_multipart(bob) +        self.assertEqual(alices2, bobs) +        bobs = self.recv_multipart(bob) +        self.assertEqual(alices3, bobs) +        bobs = b"hello alice".split() +        bob.send_multipart(bobs) +        alices = self.recv_multipart(alice) +        self.assertEqual(alices, bobs) +        self.teardown_device() +     +    def test_monitor(self): +        alice, bob, mon = self.build_device() +        alices = b"hello bob".split() +        alice.send_multipart(alices) +        alices2 = b"hello again".split() +        alice.send_multipart(alices2) +        alices3 = b"hello again and again".split() +        alice.send_multipart(alices3) +        bobs = self.recv_multipart(bob) +        self.assertEqual(alices, bobs) +        mons = self.recv_multipart(mon) +        self.assertEqual([b'in']+bobs, mons) +        bobs = self.recv_multipart(bob) +        self.assertEqual(alices2, bobs) +        bobs = self.recv_multipart(bob) +        self.assertEqual(alices3, bobs) +        mons = self.recv_multipart(mon) +        self.assertEqual([b'in']+alices2, mons) +        bobs = b"hello alice".split() +        bob.send_multipart(bobs) +        alices = self.recv_multipart(alice) +        self.assertEqual(alices, bobs) +        mons = self.recv_multipart(mon) +        self.assertEqual([b'in']+alices3, mons) +        mons = self.recv_multipart(mon) +        self.assertEqual([b'out']+bobs, mons) +        self.teardown_device() +     +    def test_prefix(self): +        alice, bob, mon = self.build_device(b"", b'foo', b'bar') +        alices = b"hello bob".split() +        alice.send_multipart(alices) +        alices2 = b"hello again".split() +        alice.send_multipart(alices2) +        alices3 = b"hello again and again".split() +        alice.send_multipart(alices3) +        bobs = self.recv_multipart(bob) +        self.assertEqual(alices, bobs) +        mons = self.recv_multipart(mon) +        self.assertEqual([b'foo']+bobs, mons) +        bobs = self.recv_multipart(bob) +        self.assertEqual(alices2, bobs) +        bobs = self.recv_multipart(bob) +        self.assertEqual(alices3, bobs) +        mons = self.recv_multipart(mon) +        self.assertEqual([b'foo']+alices2, mons) +        bobs = b"hello alice".split() +        bob.send_multipart(bobs) +        alices = self.recv_multipart(alice) +        self.assertEqual(alices, bobs) +        mons = self.recv_multipart(mon) +        self.assertEqual([b'foo']+alices3, mons) +        mons = self.recv_multipart(mon) +        self.assertEqual([b'bar']+bobs, mons) +        self.teardown_device() +     +    def test_monitor_subscribe(self): +        alice, bob, mon = self.build_device(b"out") +        alices = b"hello bob".split() +        alice.send_multipart(alices) +        alices2 = b"hello again".split() +        alice.send_multipart(alices2) +        alices3 = b"hello again and again".split() +        alice.send_multipart(alices3) +        bobs = self.recv_multipart(bob) +        self.assertEqual(alices, bobs) +        bobs = self.recv_multipart(bob) +        self.assertEqual(alices2, bobs) +        bobs = self.recv_multipart(bob) +        self.assertEqual(alices3, bobs) +        bobs = b"hello alice".split() +        bob.send_multipart(bobs) +        alices = self.recv_multipart(alice) +        self.assertEqual(alices, bobs) +        mons = self.recv_multipart(mon) +        self.assertEqual([b'out']+bobs, mons) +        self.teardown_device() +     +    def test_router_router(self): +        """test router-router MQ devices""" +        dev = devices.ThreadMonitoredQueue(zmq.ROUTER, zmq.ROUTER, zmq.PUB, b'in', b'out') +        self.device = dev +        dev.setsockopt_in(zmq.LINGER, 0) +        dev.setsockopt_out(zmq.LINGER, 0) +        dev.setsockopt_mon(zmq.LINGER, 0) +         +        binder = self.context.socket(zmq.DEALER) +        porta = binder.bind_to_random_port('tcp://127.0.0.1') +        portb = binder.bind_to_random_port('tcp://127.0.0.1') +        binder.close() +        time.sleep(0.1) +        a = self.context.socket(zmq.DEALER) +        a.identity = b'a' +        b = self.context.socket(zmq.DEALER) +        b.identity = b'b' +        self.sockets.extend([a, b]) +         +        a.connect('tcp://127.0.0.1:%i'%porta) +        dev.bind_in('tcp://127.0.0.1:%i'%porta) +        b.connect('tcp://127.0.0.1:%i'%portb) +        dev.bind_out('tcp://127.0.0.1:%i'%portb) +        dev.start() +        time.sleep(0.2) +        if zmq.zmq_version_info() >= (3,1,0): +            # flush erroneous poll state, due to LIBZMQ-280 +            ping_msg = [ b'ping', b'pong' ] +            for s in (a,b): +                s.send_multipart(ping_msg) +                try: +                    s.recv(zmq.NOBLOCK) +                except zmq.ZMQError: +                    pass +        msg = [ b'hello', b'there' ] +        a.send_multipart([b'b']+msg) +        bmsg = self.recv_multipart(b) +        self.assertEqual(bmsg, [b'a']+msg) +        b.send_multipart(bmsg) +        amsg = self.recv_multipart(a) +        self.assertEqual(amsg, [b'b']+msg) +        self.teardown_device() +     +    def test_default_mq_args(self): +        self.device = dev = devices.ThreadMonitoredQueue(zmq.ROUTER, zmq.DEALER, zmq.PUB) +        dev.setsockopt_in(zmq.LINGER, 0) +        dev.setsockopt_out(zmq.LINGER, 0) +        dev.setsockopt_mon(zmq.LINGER, 0) +        # this will raise if default args are wrong +        dev.start() +        self.teardown_device() +     +    def test_mq_check_prefix(self): +        ins = self.context.socket(zmq.ROUTER) +        outs = self.context.socket(zmq.DEALER) +        mons = self.context.socket(zmq.PUB) +        self.sockets.extend([ins, outs, mons]) +         +        ins = unicode('in') +        outs = unicode('out') +        self.assertRaises(TypeError, devices.monitoredqueue, ins, outs, mons) diff --git a/zmq/tests/test_multipart.py b/zmq/tests/test_multipart.py new file mode 100644 index 0000000..24d41be --- /dev/null +++ b/zmq/tests/test_multipart.py @@ -0,0 +1,35 @@ +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + + +import zmq + + +from zmq.tests import BaseZMQTestCase, SkipTest, have_gevent, GreenTest + + +class TestMultipart(BaseZMQTestCase): + +    def test_router_dealer(self): +        router, dealer = self.create_bound_pair(zmq.ROUTER, zmq.DEALER) + +        msg1 = b'message1' +        dealer.send(msg1) +        ident = self.recv(router) +        more = router.rcvmore +        self.assertEqual(more, True) +        msg2 = self.recv(router) +        self.assertEqual(msg1, msg2) +        more = router.rcvmore +        self.assertEqual(more, False) +     +    def test_basic_multipart(self): +        a,b = self.create_bound_pair(zmq.PAIR, zmq.PAIR) +        msg = [ b'hi', b'there', b'b'] +        a.send_multipart(msg) +        recvd = b.recv_multipart() +        self.assertEqual(msg, recvd) + +if have_gevent: +    class TestMultipartGreen(GreenTest, TestMultipart): +        pass diff --git a/zmq/tests/test_pair.py b/zmq/tests/test_pair.py new file mode 100644 index 0000000..e88c1e8 --- /dev/null +++ b/zmq/tests/test_pair.py @@ -0,0 +1,53 @@ +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + + +import zmq + + +from zmq.tests import BaseZMQTestCase, have_gevent, GreenTest + + +x = b' ' +class TestPair(BaseZMQTestCase): + +    def test_basic(self): +        s1, s2 = self.create_bound_pair(zmq.PAIR, zmq.PAIR) + +        msg1 = b'message1' +        msg2 = self.ping_pong(s1, s2, msg1) +        self.assertEqual(msg1, msg2) + +    def test_multiple(self): +        s1, s2 = self.create_bound_pair(zmq.PAIR, zmq.PAIR) + +        for i in range(10): +            msg = i*x +            s1.send(msg) + +        for i in range(10): +            msg = i*x +            s2.send(msg) + +        for i in range(10): +            msg = s1.recv() +            self.assertEqual(msg, i*x) + +        for i in range(10): +            msg = s2.recv() +            self.assertEqual(msg, i*x) + +    def test_json(self): +        s1, s2 = self.create_bound_pair(zmq.PAIR, zmq.PAIR) +        o = dict(a=10,b=list(range(10))) +        o2 = self.ping_pong_json(s1, s2, o) + +    def test_pyobj(self): +        s1, s2 = self.create_bound_pair(zmq.PAIR, zmq.PAIR) +        o = dict(a=10,b=range(10)) +        o2 = self.ping_pong_pyobj(s1, s2, o) + +if have_gevent: +    class TestReqRepGreen(GreenTest, TestPair): +        pass + diff --git a/zmq/tests/test_poll.py b/zmq/tests/test_poll.py new file mode 100644 index 0000000..57346c8 --- /dev/null +++ b/zmq/tests/test_poll.py @@ -0,0 +1,229 @@ +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + + +import time +from unittest import TestCase + +import zmq + +from zmq.tests import PollZMQTestCase, have_gevent, GreenTest + +def wait(): +    time.sleep(.25) + + +class TestPoll(PollZMQTestCase): + +    Poller = zmq.Poller + +    # This test is failing due to this issue: +    # http://github.com/sustrik/zeromq2/issues#issue/26 +    def test_pair(self): +        s1, s2 = self.create_bound_pair(zmq.PAIR, zmq.PAIR) + +        # Sleep to allow sockets to connect. +        wait() + +        poller = self.Poller() +        poller.register(s1, zmq.POLLIN|zmq.POLLOUT) +        poller.register(s2, zmq.POLLIN|zmq.POLLOUT) +        # Poll result should contain both sockets +        socks = dict(poller.poll()) +        # Now make sure that both are send ready. +        self.assertEqual(socks[s1], zmq.POLLOUT) +        self.assertEqual(socks[s2], zmq.POLLOUT) +        # Now do a send on both, wait and test for zmq.POLLOUT|zmq.POLLIN +        s1.send(b'msg1') +        s2.send(b'msg2') +        wait() +        socks = dict(poller.poll()) +        self.assertEqual(socks[s1], zmq.POLLOUT|zmq.POLLIN) +        self.assertEqual(socks[s2], zmq.POLLOUT|zmq.POLLIN) +        # Make sure that both are in POLLOUT after recv. +        s1.recv() +        s2.recv() +        socks = dict(poller.poll()) +        self.assertEqual(socks[s1], zmq.POLLOUT) +        self.assertEqual(socks[s2], zmq.POLLOUT) + +        poller.unregister(s1) +        poller.unregister(s2) + +        # Wait for everything to finish. +        wait() + +    def test_reqrep(self): +        s1, s2 = self.create_bound_pair(zmq.REP, zmq.REQ) + +        # Sleep to allow sockets to connect. +        wait() + +        poller = self.Poller() +        poller.register(s1, zmq.POLLIN|zmq.POLLOUT) +        poller.register(s2, zmq.POLLIN|zmq.POLLOUT) + +        # Make sure that s1 is in state 0 and s2 is in POLLOUT +        socks = dict(poller.poll()) +        self.assertEqual(s1 in socks, 0) +        self.assertEqual(socks[s2], zmq.POLLOUT) + +        # Make sure that s2 goes immediately into state 0 after send. +        s2.send(b'msg1') +        socks = dict(poller.poll()) +        self.assertEqual(s2 in socks, 0) + +        # Make sure that s1 goes into POLLIN state after a time.sleep(). +        time.sleep(0.5) +        socks = dict(poller.poll()) +        self.assertEqual(socks[s1], zmq.POLLIN) + +        # Make sure that s1 goes into POLLOUT after recv. +        s1.recv() +        socks = dict(poller.poll()) +        self.assertEqual(socks[s1], zmq.POLLOUT) + +        # Make sure s1 goes into state 0 after send. +        s1.send(b'msg2') +        socks = dict(poller.poll()) +        self.assertEqual(s1 in socks, 0) + +        # Wait and then see that s2 is in POLLIN. +        time.sleep(0.5) +        socks = dict(poller.poll()) +        self.assertEqual(socks[s2], zmq.POLLIN) + +        # Make sure that s2 is in POLLOUT after recv. +        s2.recv() +        socks = dict(poller.poll()) +        self.assertEqual(socks[s2], zmq.POLLOUT) + +        poller.unregister(s1) +        poller.unregister(s2) + +        # Wait for everything to finish. +        wait() +     +    def test_no_events(self): +        s1, s2 = self.create_bound_pair(zmq.PAIR, zmq.PAIR) +        poller = self.Poller() +        poller.register(s1, zmq.POLLIN|zmq.POLLOUT) +        poller.register(s2, 0) +        self.assertTrue(s1 in poller) +        self.assertFalse(s2 in poller) +        poller.register(s1, 0) +        self.assertFalse(s1 in poller) + +    def test_pubsub(self): +        s1, s2 = self.create_bound_pair(zmq.PUB, zmq.SUB) +        s2.setsockopt(zmq.SUBSCRIBE, b'') + +        # Sleep to allow sockets to connect. +        wait() + +        poller = self.Poller() +        poller.register(s1, zmq.POLLIN|zmq.POLLOUT) +        poller.register(s2, zmq.POLLIN) + +        # Now make sure that both are send ready. +        socks = dict(poller.poll()) +        self.assertEqual(socks[s1], zmq.POLLOUT) +        self.assertEqual(s2 in socks, 0) +        # Make sure that s1 stays in POLLOUT after a send. +        s1.send(b'msg1') +        socks = dict(poller.poll()) +        self.assertEqual(socks[s1], zmq.POLLOUT) + +        # Make sure that s2 is POLLIN after waiting. +        wait() +        socks = dict(poller.poll()) +        self.assertEqual(socks[s2], zmq.POLLIN) + +        # Make sure that s2 goes into 0 after recv. +        s2.recv() +        socks = dict(poller.poll()) +        self.assertEqual(s2 in socks, 0) + +        poller.unregister(s1) +        poller.unregister(s2) + +        # Wait for everything to finish. +        wait() +    def test_timeout(self): +        """make sure Poller.poll timeout has the right units (milliseconds).""" +        s1, s2 = self.create_bound_pair(zmq.PAIR, zmq.PAIR) +        poller = self.Poller() +        poller.register(s1, zmq.POLLIN) +        tic = time.time() +        evt = poller.poll(.005) +        toc = time.time() +        self.assertTrue(toc-tic < 0.1) +        tic = time.time() +        evt = poller.poll(5) +        toc = time.time() +        self.assertTrue(toc-tic < 0.1) +        self.assertTrue(toc-tic > .001) +        tic = time.time() +        evt = poller.poll(500) +        toc = time.time() +        self.assertTrue(toc-tic < 1) +        self.assertTrue(toc-tic > 0.1) + +class TestSelect(PollZMQTestCase): + +    def test_pair(self): +        s1, s2 = self.create_bound_pair(zmq.PAIR, zmq.PAIR) + +        # Sleep to allow sockets to connect. +        wait() + +        rlist, wlist, xlist = zmq.select([s1, s2], [s1, s2], [s1, s2]) +        self.assert_(s1 in wlist) +        self.assert_(s2 in wlist) +        self.assert_(s1 not in rlist) +        self.assert_(s2 not in rlist) + +    def test_timeout(self): +        """make sure select timeout has the right units (seconds).""" +        s1, s2 = self.create_bound_pair(zmq.PAIR, zmq.PAIR) +        tic = time.time() +        r,w,x = zmq.select([s1,s2],[],[],.005) +        toc = time.time() +        self.assertTrue(toc-tic < 1) +        self.assertTrue(toc-tic > 0.001) +        tic = time.time() +        r,w,x = zmq.select([s1,s2],[],[],.25) +        toc = time.time() +        self.assertTrue(toc-tic < 1) +        self.assertTrue(toc-tic > 0.1) + + +if have_gevent: +    import gevent +    from zmq import green as gzmq + +    class TestPollGreen(GreenTest, TestPoll): +        Poller = gzmq.Poller + +        def test_wakeup(self): +            s1, s2 = self.create_bound_pair(zmq.PAIR, zmq.PAIR) +            poller = self.Poller() +            poller.register(s2, zmq.POLLIN) + +            tic = time.time() +            r = gevent.spawn(lambda: poller.poll(10000)) +            s = gevent.spawn(lambda: s1.send(b'msg1')) +            r.join() +            toc = time.time() +            self.assertTrue(toc-tic < 1) +         +        def test_socket_poll(self): +            s1, s2 = self.create_bound_pair(zmq.PAIR, zmq.PAIR) + +            tic = time.time() +            r = gevent.spawn(lambda: s2.poll(10000)) +            s = gevent.spawn(lambda: s1.send(b'msg1')) +            r.join() +            toc = time.time() +            self.assertTrue(toc-tic < 1) + diff --git a/zmq/tests/test_pubsub.py b/zmq/tests/test_pubsub.py new file mode 100644 index 0000000..a3ee22a --- /dev/null +++ b/zmq/tests/test_pubsub.py @@ -0,0 +1,41 @@ +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + + +import time +from unittest import TestCase + +import zmq + +from zmq.tests import BaseZMQTestCase, have_gevent, GreenTest + + +class TestPubSub(BaseZMQTestCase): + +    pass + +    # We are disabling this test while an issue is being resolved. +    def test_basic(self): +        s1, s2 = self.create_bound_pair(zmq.PUB, zmq.SUB) +        s2.setsockopt(zmq.SUBSCRIBE,b'') +        time.sleep(0.1) +        msg1 = b'message' +        s1.send(msg1) +        msg2 = s2.recv()  # This is blocking! +        self.assertEqual(msg1, msg2) + +    def test_topic(self): +        s1, s2 = self.create_bound_pair(zmq.PUB, zmq.SUB) +        s2.setsockopt(zmq.SUBSCRIBE, b'x') +        time.sleep(0.1) +        msg1 = b'message' +        s1.send(msg1) +        self.assertRaisesErrno(zmq.EAGAIN, s2.recv, zmq.NOBLOCK) +        msg1 = b'xmessage' +        s1.send(msg1) +        msg2 = s2.recv() +        self.assertEqual(msg1, msg2) + +if have_gevent: +    class TestPubSubGreen(GreenTest, TestPubSub): +        pass diff --git a/zmq/tests/test_reqrep.py b/zmq/tests/test_reqrep.py new file mode 100644 index 0000000..de17f2b --- /dev/null +++ b/zmq/tests/test_reqrep.py @@ -0,0 +1,62 @@ +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + + +from unittest import TestCase + +import zmq +from zmq.tests import BaseZMQTestCase, have_gevent, GreenTest + + +class TestReqRep(BaseZMQTestCase): + +    def test_basic(self): +        s1, s2 = self.create_bound_pair(zmq.REQ, zmq.REP) + +        msg1 = b'message 1' +        msg2 = self.ping_pong(s1, s2, msg1) +        self.assertEqual(msg1, msg2) + +    def test_multiple(self): +        s1, s2 = self.create_bound_pair(zmq.REQ, zmq.REP) + +        for i in range(10): +            msg1 = i*b' ' +            msg2 = self.ping_pong(s1, s2, msg1) +            self.assertEqual(msg1, msg2) + +    def test_bad_send_recv(self): +        s1, s2 = self.create_bound_pair(zmq.REQ, zmq.REP) +         +        if zmq.zmq_version() != '2.1.8': +            # this doesn't work on 2.1.8 +            for copy in (True,False): +                self.assertRaisesErrno(zmq.EFSM, s1.recv, copy=copy) +                self.assertRaisesErrno(zmq.EFSM, s2.send, b'asdf', copy=copy) + +        # I have to have this or we die on an Abort trap. +        msg1 = b'asdf' +        msg2 = self.ping_pong(s1, s2, msg1) +        self.assertEqual(msg1, msg2) + +    def test_json(self): +        s1, s2 = self.create_bound_pair(zmq.REQ, zmq.REP) +        o = dict(a=10,b=list(range(10))) +        o2 = self.ping_pong_json(s1, s2, o) + +    def test_pyobj(self): +        s1, s2 = self.create_bound_pair(zmq.REQ, zmq.REP) +        o = dict(a=10,b=range(10)) +        o2 = self.ping_pong_pyobj(s1, s2, o) + +    def test_large_msg(self): +        s1, s2 = self.create_bound_pair(zmq.REQ, zmq.REP) +        msg1 = 10000*b'X' + +        for i in range(10): +            msg2 = self.ping_pong(s1, s2, msg1) +            self.assertEqual(msg1, msg2) + +if have_gevent: +    class TestReqRepGreen(GreenTest, TestReqRep): +        pass diff --git a/zmq/tests/test_security.py b/zmq/tests/test_security.py new file mode 100644 index 0000000..687b7e0 --- /dev/null +++ b/zmq/tests/test_security.py @@ -0,0 +1,212 @@ +"""Test libzmq security (libzmq >= 3.3.0)""" +# -*- coding: utf8 -*- + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +import os +from threading import Thread + +import zmq +from zmq.tests import ( +    BaseZMQTestCase, SkipTest, PYPY +) +from zmq.utils import z85 + + +USER = b"admin" +PASS = b"password" + +class TestSecurity(BaseZMQTestCase): +     +    def setUp(self): +        if zmq.zmq_version_info() < (4,0): +            raise SkipTest("security is new in libzmq 4.0") +        try: +            zmq.curve_keypair() +        except zmq.ZMQError: +            raise SkipTest("security requires libzmq to be linked against libsodium") +        super(TestSecurity, self).setUp() +     +     +    def zap_handler(self): +        socket = self.context.socket(zmq.REP) +        socket.bind("inproc://zeromq.zap.01") +        try: +            msg = self.recv_multipart(socket) + +            version, sequence, domain, address, identity, mechanism = msg[:6] +            if mechanism == b'PLAIN': +                username, password = msg[6:] +            elif mechanism == b'CURVE': +                key = msg[6] + +            self.assertEqual(version, b"1.0") +            self.assertEqual(identity, b"IDENT") +            reply = [version, sequence] +            if mechanism == b'CURVE' or \ +                (mechanism == b'PLAIN' and username == USER and password == PASS) or \ +                (mechanism == b'NULL'): +                reply.extend([ +                    b"200", +                    b"OK", +                    b"anonymous", +                    b"\5Hello\0\0\0\5World", +                ]) +            else: +                reply.extend([ +                    b"400", +                    b"Invalid username or password", +                    b"", +                    b"", +                ]) +            socket.send_multipart(reply) +        finally: +            socket.close() +     +    def start_zap(self): +        self.zap_thread = Thread(target=self.zap_handler) +        self.zap_thread.start() +     +    def stop_zap(self): +        self.zap_thread.join() + +    def bounce(self, server, client, test_metadata=True): +        msg = [os.urandom(64), os.urandom(64)] +        client.send_multipart(msg) +        frames = self.recv_multipart(server, copy=False) +        recvd = list(map(lambda x: x.bytes, frames)) + +        try: +            if test_metadata and not PYPY: +                for frame in frames: +                    self.assertEqual(frame.get('User-Id'), 'anonymous') +                    self.assertEqual(frame.get('Hello'), 'World') +                    self.assertEqual(frame['Socket-Type'], 'DEALER') +        except zmq.ZMQVersionError: +            pass + +        self.assertEqual(recvd, msg) +        server.send_multipart(recvd) +        msg2 = self.recv_multipart(client) +        self.assertEqual(msg2, msg) +     +    def test_null(self): +        """test NULL (default) security""" +        server = self.socket(zmq.DEALER) +        client = self.socket(zmq.DEALER) +        self.assertEqual(client.MECHANISM, zmq.NULL) +        self.assertEqual(server.mechanism, zmq.NULL) +        self.assertEqual(client.plain_server, 0) +        self.assertEqual(server.plain_server, 0) +        iface = 'tcp://127.0.0.1' +        port = server.bind_to_random_port(iface) +        client.connect("%s:%i" % (iface, port)) +        self.bounce(server, client, False) + +    def test_plain(self): +        """test PLAIN authentication""" +        server = self.socket(zmq.DEALER) +        server.identity = b'IDENT' +        client = self.socket(zmq.DEALER) +        self.assertEqual(client.plain_username, b'') +        self.assertEqual(client.plain_password, b'') +        client.plain_username = USER +        client.plain_password = PASS +        self.assertEqual(client.getsockopt(zmq.PLAIN_USERNAME), USER) +        self.assertEqual(client.getsockopt(zmq.PLAIN_PASSWORD), PASS) +        self.assertEqual(client.plain_server, 0) +        self.assertEqual(server.plain_server, 0) +        server.plain_server = True +        self.assertEqual(server.mechanism, zmq.PLAIN) +        self.assertEqual(client.mechanism, zmq.PLAIN) +         +        assert not client.plain_server +        assert server.plain_server +         +        self.start_zap() +         +        iface = 'tcp://127.0.0.1' +        port = server.bind_to_random_port(iface) +        client.connect("%s:%i" % (iface, port)) +        self.bounce(server, client) +        self.stop_zap() + +    def skip_plain_inauth(self): +        """test PLAIN failed authentication""" +        server = self.socket(zmq.DEALER) +        server.identity = b'IDENT' +        client = self.socket(zmq.DEALER) +        self.sockets.extend([server, client]) +        client.plain_username = USER +        client.plain_password = b'incorrect' +        server.plain_server = True +        self.assertEqual(server.mechanism, zmq.PLAIN) +        self.assertEqual(client.mechanism, zmq.PLAIN) +         +        self.start_zap() +         +        iface = 'tcp://127.0.0.1' +        port = server.bind_to_random_port(iface) +        client.connect("%s:%i" % (iface, port)) +        client.send(b'ping') +        server.rcvtimeo = 250 +        self.assertRaisesErrno(zmq.EAGAIN, server.recv) +        self.stop_zap() +     +    def test_keypair(self): +        """test curve_keypair""" +        try: +            public, secret = zmq.curve_keypair() +        except zmq.ZMQError: +            raise SkipTest("CURVE unsupported") +         +        self.assertEqual(type(secret), bytes) +        self.assertEqual(type(public), bytes) +        self.assertEqual(len(secret), 40) +        self.assertEqual(len(public), 40) +         +        # verify that it is indeed Z85 +        bsecret, bpublic = [ z85.decode(key) for key in (public, secret) ] +        self.assertEqual(type(bsecret), bytes) +        self.assertEqual(type(bpublic), bytes) +        self.assertEqual(len(bsecret), 32) +        self.assertEqual(len(bpublic), 32) +         +     +    def test_curve(self): +        """test CURVE encryption""" +        server = self.socket(zmq.DEALER) +        server.identity = b'IDENT' +        client = self.socket(zmq.DEALER) +        self.sockets.extend([server, client]) +        try: +            server.curve_server = True +        except zmq.ZMQError as e: +            # will raise EINVAL if not linked against libsodium +            if e.errno == zmq.EINVAL: +                raise SkipTest("CURVE unsupported") +         +        server_public, server_secret = zmq.curve_keypair() +        client_public, client_secret = zmq.curve_keypair() +         +        server.curve_secretkey = server_secret +        server.curve_publickey = server_public +        client.curve_serverkey = server_public +        client.curve_publickey = client_public +        client.curve_secretkey = client_secret +         +        self.assertEqual(server.mechanism, zmq.CURVE) +        self.assertEqual(client.mechanism, zmq.CURVE) +         +        self.assertEqual(server.get(zmq.CURVE_SERVER), True) +        self.assertEqual(client.get(zmq.CURVE_SERVER), False) +         +        self.start_zap() +         +        iface = 'tcp://127.0.0.1' +        port = server.bind_to_random_port(iface) +        client.connect("%s:%i" % (iface, port)) +        self.bounce(server, client) +        self.stop_zap() +         diff --git a/zmq/tests/test_socket.py b/zmq/tests/test_socket.py new file mode 100644 index 0000000..13bfed7 --- /dev/null +++ b/zmq/tests/test_socket.py @@ -0,0 +1,451 @@ +# -*- coding: utf8 -*- +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +import time +import warnings + +import zmq +from zmq.tests import ( +    BaseZMQTestCase, SkipTest, have_gevent, GreenTest, skip_pypy, skip_if +) +from zmq.utils.strtypes import bytes, unicode + + +class TestSocket(BaseZMQTestCase): + +    def test_create(self): +        ctx = self.Context() +        s = ctx.socket(zmq.PUB) +        # Superluminal protocol not yet implemented +        self.assertRaisesErrno(zmq.EPROTONOSUPPORT, s.bind, 'ftl://a') +        self.assertRaisesErrno(zmq.EPROTONOSUPPORT, s.connect, 'ftl://a') +        self.assertRaisesErrno(zmq.EINVAL, s.bind, 'tcp://') +        s.close() +        del ctx +     +    def test_context_manager(self): +        url = 'inproc://a' +        with self.Context() as ctx: +            with ctx.socket(zmq.PUSH) as a: +                a.bind(url) +                with ctx.socket(zmq.PULL) as b: +                    b.connect(url) +                    msg = b'hi' +                    a.send(msg) +                    rcvd = self.recv(b) +                    self.assertEqual(rcvd, msg) +                self.assertEqual(b.closed, True) +            self.assertEqual(a.closed, True) +        self.assertEqual(ctx.closed, True) +     +    def test_dir(self): +        ctx = self.Context() +        s = ctx.socket(zmq.PUB) +        self.assertTrue('send' in dir(s)) +        self.assertTrue('IDENTITY' in dir(s)) +        self.assertTrue('AFFINITY' in dir(s)) +        self.assertTrue('FD' in dir(s)) +        s.close() +        ctx.term() + +    def test_bind_unicode(self): +        s = self.socket(zmq.PUB) +        p = s.bind_to_random_port(unicode("tcp://*")) + +    def test_connect_unicode(self): +        s = self.socket(zmq.PUB) +        s.connect(unicode("tcp://127.0.0.1:5555")) + +    def test_bind_to_random_port(self): +        # Check that bind_to_random_port do not hide usefull exception +        ctx = self.Context() +        c = ctx.socket(zmq.PUB) +        # Invalid format +        try: +            c.bind_to_random_port('tcp:*') +        except zmq.ZMQError as e: +            self.assertEqual(e.errno, zmq.EINVAL) +        # Invalid protocol +        try: +            c.bind_to_random_port('rand://*') +        except zmq.ZMQError as e: +            self.assertEqual(e.errno, zmq.EPROTONOSUPPORT) + +    def test_identity(self): +        s = self.context.socket(zmq.PULL) +        self.sockets.append(s) +        ident = b'identity\0\0' +        s.identity = ident +        self.assertEqual(s.get(zmq.IDENTITY), ident) + +    def test_unicode_sockopts(self): +        """test setting/getting sockopts with unicode strings""" +        topic = "tést" +        if str is not unicode: +            topic = topic.decode('utf8') +        p,s = self.create_bound_pair(zmq.PUB, zmq.SUB) +        self.assertEqual(s.send_unicode, s.send_unicode) +        self.assertEqual(p.recv_unicode, p.recv_unicode) +        self.assertRaises(TypeError, s.setsockopt, zmq.SUBSCRIBE, topic) +        self.assertRaises(TypeError, s.setsockopt, zmq.IDENTITY, topic) +        s.setsockopt_unicode(zmq.IDENTITY, topic, 'utf16') +        self.assertRaises(TypeError, s.setsockopt, zmq.AFFINITY, topic) +        s.setsockopt_unicode(zmq.SUBSCRIBE, topic) +        self.assertRaises(TypeError, s.getsockopt_unicode, zmq.AFFINITY) +        self.assertRaisesErrno(zmq.EINVAL, s.getsockopt_unicode, zmq.SUBSCRIBE) +         +        identb = s.getsockopt(zmq.IDENTITY) +        identu = identb.decode('utf16') +        identu2 = s.getsockopt_unicode(zmq.IDENTITY, 'utf16') +        self.assertEqual(identu, identu2) +        time.sleep(0.1) # wait for connection/subscription +        p.send_unicode(topic,zmq.SNDMORE) +        p.send_unicode(topic*2, encoding='latin-1') +        self.assertEqual(topic, s.recv_unicode()) +        self.assertEqual(topic*2, s.recv_unicode(encoding='latin-1')) +     +    def test_int_sockopts(self): +        "test integer sockopts" +        v = zmq.zmq_version_info() +        if v < (3,0): +            default_hwm = 0 +        else: +            default_hwm = 1000 +        p,s = self.create_bound_pair(zmq.PUB, zmq.SUB) +        p.setsockopt(zmq.LINGER, 0) +        self.assertEqual(p.getsockopt(zmq.LINGER), 0) +        p.setsockopt(zmq.LINGER, -1) +        self.assertEqual(p.getsockopt(zmq.LINGER), -1) +        self.assertEqual(p.hwm, default_hwm) +        p.hwm = 11 +        self.assertEqual(p.hwm, 11) +        # p.setsockopt(zmq.EVENTS, zmq.POLLIN) +        self.assertEqual(p.getsockopt(zmq.EVENTS), zmq.POLLOUT) +        self.assertRaisesErrno(zmq.EINVAL, p.setsockopt,zmq.EVENTS, 2**7-1) +        self.assertEqual(p.getsockopt(zmq.TYPE), p.socket_type) +        self.assertEqual(p.getsockopt(zmq.TYPE), zmq.PUB) +        self.assertEqual(s.getsockopt(zmq.TYPE), s.socket_type) +        self.assertEqual(s.getsockopt(zmq.TYPE), zmq.SUB) +         +        # check for overflow / wrong type: +        errors = [] +        backref = {} +        constants = zmq.constants +        for name in constants.__all__: +            value = getattr(constants, name) +            if isinstance(value, int): +                backref[value] = name +        for opt in zmq.constants.int_sockopts.union(zmq.constants.int64_sockopts): +            sopt = backref[opt] +            if sopt.startswith(( +                'ROUTER', 'XPUB', 'TCP', 'FAIL', +                'REQ_', 'CURVE_', 'PROBE_ROUTER', +                'IPC_FILTER', 'GSSAPI', +                )): +                # some sockopts are write-only +                continue +            try: +                n = p.getsockopt(opt) +            except zmq.ZMQError as e: +                errors.append("getsockopt(zmq.%s) raised '%s'."%(sopt, e)) +            else: +                if n > 2**31: +                    errors.append("getsockopt(zmq.%s) returned a ridiculous value." +                                    " It is probably the wrong type."%sopt) +        if errors: +            self.fail('\n'.join([''] + errors)) +     +    def test_bad_sockopts(self): +        """Test that appropriate errors are raised on bad socket options""" +        s = self.context.socket(zmq.PUB) +        self.sockets.append(s) +        s.setsockopt(zmq.LINGER, 0) +        # unrecognized int sockopts pass through to libzmq, and should raise EINVAL +        self.assertRaisesErrno(zmq.EINVAL, s.setsockopt, 9999, 5) +        self.assertRaisesErrno(zmq.EINVAL, s.getsockopt, 9999) +        # but only int sockopts are allowed through this way, otherwise raise a TypeError +        self.assertRaises(TypeError, s.setsockopt, 9999, b"5") +        # some sockopts are valid in general, but not on every socket: +        self.assertRaisesErrno(zmq.EINVAL, s.setsockopt, zmq.SUBSCRIBE, b'hi') +     +    def test_sockopt_roundtrip(self): +        "test set/getsockopt roundtrip." +        p = self.context.socket(zmq.PUB) +        self.sockets.append(p) +        self.assertEqual(p.getsockopt(zmq.LINGER), -1) +        p.setsockopt(zmq.LINGER, 11) +        self.assertEqual(p.getsockopt(zmq.LINGER), 11) +     +    def test_send_unicode(self): +        "test sending unicode objects" +        a,b = self.create_bound_pair(zmq.PAIR, zmq.PAIR) +        self.sockets.extend([a,b]) +        u = "çπ§" +        if str is not unicode: +            u = u.decode('utf8') +        self.assertRaises(TypeError, a.send, u,copy=False) +        self.assertRaises(TypeError, a.send, u,copy=True) +        a.send_unicode(u) +        s = b.recv() +        self.assertEqual(s,u.encode('utf8')) +        self.assertEqual(s.decode('utf8'),u) +        a.send_unicode(u,encoding='utf16') +        s = b.recv_unicode(encoding='utf16') +        self.assertEqual(s,u) +     +    @skip_pypy +    def test_tracker(self): +        "test the MessageTracker object for tracking when zmq is done with a buffer" +        addr = 'tcp://127.0.0.1' +        a = self.context.socket(zmq.PUB) +        port = a.bind_to_random_port(addr) +        a.close() +        iface = "%s:%i"%(addr,port) +        a = self.context.socket(zmq.PAIR) +        # a.setsockopt(zmq.IDENTITY, b"a") +        b = self.context.socket(zmq.PAIR) +        self.sockets.extend([a,b]) +        a.connect(iface) +        time.sleep(0.1) +        p1 = a.send(b'something', copy=False, track=True) +        self.assertTrue(isinstance(p1, zmq.MessageTracker)) +        self.assertFalse(p1.done) +        p2 = a.send_multipart([b'something', b'else'], copy=False, track=True) +        self.assert_(isinstance(p2, zmq.MessageTracker)) +        self.assertEqual(p2.done, False) +        self.assertEqual(p1.done, False) + +        b.bind(iface) +        msg = b.recv_multipart() +        for i in range(10): +            if p1.done: +                break +            time.sleep(0.1) +        self.assertEqual(p1.done, True) +        self.assertEqual(msg, [b'something']) +        msg = b.recv_multipart() +        for i in range(10): +            if p2.done: +                break +            time.sleep(0.1) +        self.assertEqual(p2.done, True) +        self.assertEqual(msg, [b'something', b'else']) +        m = zmq.Frame(b"again", track=True) +        self.assertEqual(m.tracker.done, False) +        p1 = a.send(m, copy=False) +        p2 = a.send(m, copy=False) +        self.assertEqual(m.tracker.done, False) +        self.assertEqual(p1.done, False) +        self.assertEqual(p2.done, False) +        msg = b.recv_multipart() +        self.assertEqual(m.tracker.done, False) +        self.assertEqual(msg, [b'again']) +        msg = b.recv_multipart() +        self.assertEqual(m.tracker.done, False) +        self.assertEqual(msg, [b'again']) +        self.assertEqual(p1.done, False) +        self.assertEqual(p2.done, False) +        pm = m.tracker +        del m +        for i in range(10): +            if p1.done: +                break +            time.sleep(0.1) +        self.assertEqual(p1.done, True) +        self.assertEqual(p2.done, True) +        m = zmq.Frame(b'something', track=False) +        self.assertRaises(ValueError, a.send, m, copy=False, track=True) +         + +    def test_close(self): +        ctx = self.Context() +        s = ctx.socket(zmq.PUB) +        s.close() +        self.assertRaisesErrno(zmq.ENOTSOCK, s.bind, b'') +        self.assertRaisesErrno(zmq.ENOTSOCK, s.connect, b'') +        self.assertRaisesErrno(zmq.ENOTSOCK, s.setsockopt, zmq.SUBSCRIBE, b'') +        self.assertRaisesErrno(zmq.ENOTSOCK, s.send, b'asdf') +        self.assertRaisesErrno(zmq.ENOTSOCK, s.recv) +        del ctx +     +    def test_attr(self): +        """set setting/getting sockopts as attributes""" +        s = self.context.socket(zmq.DEALER) +        self.sockets.append(s) +        linger = 10 +        s.linger = linger +        self.assertEqual(linger, s.linger) +        self.assertEqual(linger, s.getsockopt(zmq.LINGER)) +        self.assertEqual(s.fd, s.getsockopt(zmq.FD)) +     +    def test_bad_attr(self): +        s = self.context.socket(zmq.DEALER) +        self.sockets.append(s) +        try: +            s.apple='foo' +        except AttributeError: +            pass +        else: +            self.fail("bad setattr should have raised AttributeError") +        try: +            s.apple +        except AttributeError: +            pass +        else: +            self.fail("bad getattr should have raised AttributeError") + +    def test_subclass(self): +        """subclasses can assign attributes""" +        class S(zmq.Socket): +            a = None +            def __init__(self, *a, **kw): +                self.a=-1 +                super(S, self).__init__(*a, **kw) +         +        s = S(self.context, zmq.REP) +        self.sockets.append(s) +        self.assertEqual(s.a, -1) +        s.a=1 +        self.assertEqual(s.a, 1) +        a=s.a +        self.assertEqual(a, 1) +     +    def test_recv_multipart(self): +        a,b = self.create_bound_pair() +        msg = b'hi' +        for i in range(3): +            a.send(msg) +        time.sleep(0.1) +        for i in range(3): +            self.assertEqual(b.recv_multipart(), [msg]) +     +    def test_close_after_destroy(self): +        """s.close() after ctx.destroy() should be fine""" +        ctx = self.Context() +        s = ctx.socket(zmq.REP) +        ctx.destroy() +        # reaper is not instantaneous +        time.sleep(1e-2) +        s.close() +        self.assertTrue(s.closed) +     +    def test_poll(self): +        a,b = self.create_bound_pair() +        tic = time.time() +        evt = a.poll(50) +        self.assertEqual(evt, 0) +        evt = a.poll(50, zmq.POLLOUT) +        self.assertEqual(evt, zmq.POLLOUT) +        msg = b'hi' +        a.send(msg) +        evt = b.poll(50) +        self.assertEqual(evt, zmq.POLLIN) +        msg2 = self.recv(b) +        evt = b.poll(50) +        self.assertEqual(evt, 0) +        self.assertEqual(msg2, msg) +     +    def test_ipc_path_max_length(self): +        """IPC_PATH_MAX_LEN is a sensible value""" +        if zmq.IPC_PATH_MAX_LEN == 0: +            raise SkipTest("IPC_PATH_MAX_LEN undefined") +         +        msg = "Surprising value for IPC_PATH_MAX_LEN: %s" % zmq.IPC_PATH_MAX_LEN +        self.assertTrue(zmq.IPC_PATH_MAX_LEN > 30, msg) +        self.assertTrue(zmq.IPC_PATH_MAX_LEN < 1025, msg) + +    def test_ipc_path_max_length_msg(self): +        if zmq.IPC_PATH_MAX_LEN == 0: +            raise SkipTest("IPC_PATH_MAX_LEN undefined") +         +        s = self.context.socket(zmq.PUB) +        self.sockets.append(s) +        try: +            s.bind('ipc://{0}'.format('a' * (zmq.IPC_PATH_MAX_LEN + 1))) +        except zmq.ZMQError as e: +            self.assertTrue(str(zmq.IPC_PATH_MAX_LEN) in e.strerror) +     +    def test_hwm(self): +        zmq3 = zmq.zmq_version_info()[0] >= 3 +        for stype in (zmq.PUB, zmq.ROUTER, zmq.SUB, zmq.REQ, zmq.DEALER): +            s = self.context.socket(stype) +            s.hwm = 100 +            self.assertEqual(s.hwm, 100) +            if zmq3: +                try: +                    self.assertEqual(s.sndhwm, 100) +                except AttributeError: +                    pass +                try: +                    self.assertEqual(s.rcvhwm, 100) +                except AttributeError: +                    pass +            s.close() +     +    def test_shadow(self): +        p = self.socket(zmq.PUSH) +        p.bind("tcp://127.0.0.1:5555") +        p2 = zmq.Socket.shadow(p.underlying) +        self.assertEqual(p.underlying, p2.underlying) +        s = self.socket(zmq.PULL) +        s2 = zmq.Socket.shadow(s.underlying) +        self.assertNotEqual(s.underlying, p.underlying) +        self.assertEqual(s.underlying, s2.underlying) +        s2.connect("tcp://127.0.0.1:5555") +        sent = b'hi' +        p2.send(sent) +        rcvd = self.recv(s2) +        self.assertEqual(rcvd, sent) +     +    def test_shadow_pyczmq(self): +        try: +            from pyczmq import zctx, zsocket +        except Exception: +            raise SkipTest("Requires pyczmq") +         +        ctx = zctx.new() +        ca = zsocket.new(ctx, zmq.PUSH) +        cb = zsocket.new(ctx, zmq.PULL) +        a = zmq.Socket.shadow(ca) +        b = zmq.Socket.shadow(cb) +        a.bind("inproc://a") +        b.connect("inproc://a") +        a.send(b'hi') +        rcvd = self.recv(b) +        self.assertEqual(rcvd, b'hi') + + +if have_gevent: +    import gevent +     +    class TestSocketGreen(GreenTest, TestSocket): +        test_bad_attr = GreenTest.skip_green +        test_close_after_destroy = GreenTest.skip_green +         +        def test_timeout(self): +            a,b = self.create_bound_pair() +            g = gevent.spawn_later(0.5, lambda: a.send(b'hi')) +            timeout = gevent.Timeout(0.1) +            timeout.start() +            self.assertRaises(gevent.Timeout, b.recv) +            g.kill() +         +        @skip_if(not hasattr(zmq, 'RCVTIMEO')) +        def test_warn_set_timeo(self): +            s = self.context.socket(zmq.REQ) +            with warnings.catch_warnings(record=True) as w: +                s.rcvtimeo = 5 +            s.close() +            self.assertEqual(len(w), 1) +            self.assertEqual(w[0].category, UserWarning) +             + +        @skip_if(not hasattr(zmq, 'SNDTIMEO')) +        def test_warn_get_timeo(self): +            s = self.context.socket(zmq.REQ) +            with warnings.catch_warnings(record=True) as w: +                s.sndtimeo +            s.close() +            self.assertEqual(len(w), 1) +            self.assertEqual(w[0].category, UserWarning) diff --git a/zmq/tests/test_stopwatch.py b/zmq/tests/test_stopwatch.py new file mode 100644 index 0000000..49fb79f --- /dev/null +++ b/zmq/tests/test_stopwatch.py @@ -0,0 +1,42 @@ +# -*- coding: utf8 -*- +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + + +import sys +import time + +from unittest import TestCase + +from zmq import Stopwatch, ZMQError + +if sys.version_info[0] >= 3: +    long = int + +class TestStopWatch(TestCase): +     +    def test_stop_long(self): +        """Ensure stop returns a long int.""" +        watch = Stopwatch() +        watch.start() +        us = watch.stop() +        self.assertTrue(isinstance(us, long)) +         +    def test_stop_microseconds(self): +        """Test that stop/sleep have right units.""" +        watch = Stopwatch() +        watch.start() +        tic = time.time() +        watch.sleep(1) +        us = watch.stop() +        toc = time.time() +        self.assertAlmostEqual(us/1e6,(toc-tic),places=0) +     +    def test_double_stop(self): +        """Test error raised on multiple calls to stop.""" +        watch = Stopwatch() +        watch.start() +        watch.stop() +        self.assertRaises(ZMQError, watch.stop) +        self.assertRaises(ZMQError, watch.stop) +     diff --git a/zmq/tests/test_version.py b/zmq/tests/test_version.py new file mode 100644 index 0000000..6ebebf3 --- /dev/null +++ b/zmq/tests/test_version.py @@ -0,0 +1,44 @@ +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + + +from unittest import TestCase +import zmq +from zmq.sugar import version + + +class TestVersion(TestCase): + +    def test_pyzmq_version(self): +        vs = zmq.pyzmq_version() +        vs2 = zmq.__version__ +        self.assertTrue(isinstance(vs, str)) +        if zmq.__revision__: +            self.assertEqual(vs, '@'.join(vs2, zmq.__revision__)) +        else: +            self.assertEqual(vs, vs2) +        if version.VERSION_EXTRA: +            self.assertTrue(version.VERSION_EXTRA in vs) +            self.assertTrue(version.VERSION_EXTRA in vs2) + +    def test_pyzmq_version_info(self): +        info = zmq.pyzmq_version_info() +        self.assertTrue(isinstance(info, tuple)) +        for n in info[:3]: +            self.assertTrue(isinstance(n, int)) +        if version.VERSION_EXTRA: +            self.assertEqual(len(info), 4) +            self.assertEqual(info[-1], float('inf')) +        else: +            self.assertEqual(len(info), 3) + +    def test_zmq_version_info(self): +        info = zmq.zmq_version_info() +        self.assertTrue(isinstance(info, tuple)) +        for n in info[:3]: +            self.assertTrue(isinstance(n, int)) + +    def test_zmq_version(self): +        v = zmq.zmq_version() +        self.assertTrue(isinstance(v, str)) + diff --git a/zmq/tests/test_win32_shim.py b/zmq/tests/test_win32_shim.py new file mode 100644 index 0000000..55657bd --- /dev/null +++ b/zmq/tests/test_win32_shim.py @@ -0,0 +1,56 @@ +from __future__ import print_function + +import os + +from functools import wraps +from zmq.tests import BaseZMQTestCase +from zmq.utils.win32 import allow_interrupt + + +def count_calls(f): +    @wraps(f) +    def _(*args, **kwds): +        try: +            return f(*args, **kwds) +        finally: +            _.__calls__ += 1 +    _.__calls__ = 0 +    return _ + + +class TestWindowsConsoleControlHandler(BaseZMQTestCase): + +    def test_handler(self): +        @count_calls +        def interrupt_polling(): +            print('Caught CTRL-C!') + +        if os.name == 'nt': +            from ctypes import windll +            from ctypes.wintypes import BOOL, DWORD + +            kernel32 = windll.LoadLibrary('kernel32') + +            # <http://msdn.microsoft.com/en-us/library/ms683155.aspx> +            GenerateConsoleCtrlEvent = kernel32.GenerateConsoleCtrlEvent +            GenerateConsoleCtrlEvent.argtypes = (DWORD, DWORD) +            GenerateConsoleCtrlEvent.restype = BOOL + +            try: +                # Simulate CTRL-C event while handler is active. +                with allow_interrupt(interrupt_polling): +                    result = GenerateConsoleCtrlEvent(0, 0) +                    if result == 0: +                        raise WindowsError +            except KeyboardInterrupt: +                pass +            else: +                self.fail('Expecting `KeyboardInterrupt` exception!') + +            # Make sure our handler was called. +            self.assertEqual(interrupt_polling.__calls__, 1) +        else: +            # On non-Windows systems, this utility is just a no-op! +            with allow_interrupt(interrupt_polling): +                pass +            self.assertEqual(interrupt_polling.__calls__, 0) diff --git a/zmq/tests/test_z85.py b/zmq/tests/test_z85.py new file mode 100644 index 0000000..8a73cb4 --- /dev/null +++ b/zmq/tests/test_z85.py @@ -0,0 +1,63 @@ +# -*- coding: utf8 -*- +"""Test Z85 encoding + +confirm values and roundtrip with test values from the reference implementation. +""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +from unittest import TestCase +from zmq.utils import z85 + + +class TestZ85(TestCase): +     +    def test_client_public(self): +        client_public = \ +            b"\xBB\x88\x47\x1D\x65\xE2\x65\x9B" \ +            b"\x30\xC5\x5A\x53\x21\xCE\xBB\x5A" \ +            b"\xAB\x2B\x70\xA3\x98\x64\x5C\x26" \ +            b"\xDC\xA2\xB2\xFC\xB4\x3F\xC5\x18" +        encoded = z85.encode(client_public) +         +        self.assertEqual(encoded, b"Yne@$w-vo<fVvi]a<NY6T1ed:M$fCG*[IaLV{hID") +        decoded = z85.decode(encoded) +        self.assertEqual(decoded, client_public) +     +    def test_client_secret(self): +        client_secret = \ +            b"\x7B\xB8\x64\xB4\x89\xAF\xA3\x67" \ +            b"\x1F\xBE\x69\x10\x1F\x94\xB3\x89" \ +            b"\x72\xF2\x48\x16\xDF\xB0\x1B\x51" \ +            b"\x65\x6B\x3F\xEC\x8D\xFD\x08\x88" +        encoded = z85.encode(client_secret) +         +        self.assertEqual(encoded, b"D:)Q[IlAW!ahhC2ac:9*A}h:p?([4%wOTJ%JR%cs") +        decoded = z85.decode(encoded) +        self.assertEqual(decoded, client_secret) + +    def test_server_public(self): +        server_public = \ +            b"\x54\xFC\xBA\x24\xE9\x32\x49\x96" \ +            b"\x93\x16\xFB\x61\x7C\x87\x2B\xB0" \ +            b"\xC1\xD1\xFF\x14\x80\x04\x27\xC5" \ +            b"\x94\xCB\xFA\xCF\x1B\xC2\xD6\x52" +        encoded = z85.encode(server_public) +         +        self.assertEqual(encoded, b"rq:rM>}U?@Lns47E1%kR.o@n%FcmmsL/@{H8]yf7") +        decoded = z85.decode(encoded) +        self.assertEqual(decoded, server_public) +     +    def test_server_secret(self): +        server_secret = \ +            b"\x8E\x0B\xDD\x69\x76\x28\xB9\x1D" \ +            b"\x8F\x24\x55\x87\xEE\x95\xC5\xB0" \ +            b"\x4D\x48\x96\x3F\x79\x25\x98\x77" \ +            b"\xB4\x9C\xD9\x06\x3A\xEA\xD3\xB7" +        encoded = z85.encode(server_secret) +         +        self.assertEqual(encoded, b"JTKVSB%%)wK0E.X)V>+}o?pNmC{O&4W4b!Ni{Lh6") +        decoded = z85.decode(encoded) +        self.assertEqual(decoded, server_secret) + diff --git a/zmq/tests/test_zmqstream.py b/zmq/tests/test_zmqstream.py new file mode 100644 index 0000000..cdb3a17 --- /dev/null +++ b/zmq/tests/test_zmqstream.py @@ -0,0 +1,34 @@ +# -*- coding: utf8 -*- +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + + +import sys +import time + +from unittest import TestCase + +import zmq +from zmq.eventloop import ioloop, zmqstream + +class TestZMQStream(TestCase): +     +    def setUp(self): +        self.context = zmq.Context() +        self.socket = self.context.socket(zmq.REP) +        self.loop = ioloop.IOLoop.instance() +        self.stream = zmqstream.ZMQStream(self.socket) +     +    def tearDown(self): +        self.socket.close() +        self.context.term() +     +    def test_callable_check(self): +        """Ensure callable check works (py3k).""" +         +        self.stream.on_send(lambda *args: None) +        self.stream.on_recv(lambda *args: None) +        self.assertRaises(AssertionError, self.stream.on_recv, 1) +        self.assertRaises(AssertionError, self.stream.on_send, 1) +        self.assertRaises(AssertionError, self.stream.on_recv, zmq) +         diff --git a/zmq/utils/__init__.py b/zmq/utils/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/zmq/utils/__init__.py diff --git a/zmq/utils/buffers.pxd b/zmq/utils/buffers.pxd new file mode 100644 index 0000000..998aa55 --- /dev/null +++ b/zmq/utils/buffers.pxd @@ -0,0 +1,313 @@ +"""Python version-independent methods for C/Python buffers. + +This file was copied and adapted from mpi4py. + +Authors +------- +* MinRK +""" + +#----------------------------------------------------------------------------- +#  Copyright (c) 2010 Lisandro Dalcin +#  All rights reserved. +#  Used under BSD License: http://www.opensource.org/licenses/bsd-license.php +# +#  Retrieval: +#  Jul 23, 2010 18:00 PST (r539) +#  http://code.google.com/p/mpi4py/source/browse/trunk/src/MPI/asbuffer.pxi +# +#  Modifications from original: +#  Copyright (c) 2010-2012 Brian Granger, Min Ragan-Kelley +# +#  Distributed under the terms of the New BSD License.  The full license is in +#  the file COPYING.BSD, distributed as part of this software. +#----------------------------------------------------------------------------- + + +#----------------------------------------------------------------------------- +# Python includes. +#----------------------------------------------------------------------------- + +# get version-independent aliases: +cdef extern from "pyversion_compat.h": +    pass + +# Python 3 buffer interface (PEP 3118) +cdef extern from "Python.h": +    int PY_MAJOR_VERSION +    int PY_MINOR_VERSION +    ctypedef int Py_ssize_t +    ctypedef struct PyMemoryViewObject: +        pass +    ctypedef struct Py_buffer: +        void *buf +        Py_ssize_t len +        int readonly +        char *format +        int ndim +        Py_ssize_t *shape +        Py_ssize_t *strides +        Py_ssize_t *suboffsets +        Py_ssize_t itemsize +        void *internal +    cdef enum: +        PyBUF_SIMPLE +        PyBUF_WRITABLE +        PyBUF_FORMAT +        PyBUF_ANY_CONTIGUOUS +    int  PyObject_CheckBuffer(object) +    int  PyObject_GetBuffer(object, Py_buffer *, int) except -1 +    void PyBuffer_Release(Py_buffer *) +     +    int PyBuffer_FillInfo(Py_buffer *view, object obj, void *buf, +                Py_ssize_t len, int readonly, int infoflags) except -1 +    object PyMemoryView_FromBuffer(Py_buffer *info) +     +    object PyMemoryView_FromObject(object) + +# Python 2 buffer interface (legacy) +cdef extern from "Python.h": +    ctypedef void const_void "const void" +    Py_ssize_t Py_END_OF_BUFFER +    int PyObject_CheckReadBuffer(object) +    int PyObject_AsReadBuffer (object, const_void **, Py_ssize_t *) except -1 +    int PyObject_AsWriteBuffer(object, void **, Py_ssize_t *) except -1 +     +    object PyBuffer_FromMemory(void *ptr, Py_ssize_t s) +    object PyBuffer_FromReadWriteMemory(void *ptr, Py_ssize_t s) + +    object PyBuffer_FromObject(object, Py_ssize_t offset, Py_ssize_t size) +    object PyBuffer_FromReadWriteObject(object, Py_ssize_t offset, Py_ssize_t size) + + +#----------------------------------------------------------------------------- +# asbuffer: C buffer from python object +#----------------------------------------------------------------------------- + + +cdef inline int memoryview_available(): +    return PY_MAJOR_VERSION >= 3 or (PY_MAJOR_VERSION >=2 and PY_MINOR_VERSION >= 7) + +cdef inline int oldstyle_available(): +    return PY_MAJOR_VERSION < 3 + + +cdef inline int check_buffer(object ob): +    """Version independent check for whether an object is a buffer. +     +    Parameters +    ---------- +    object : object +        Any Python object + +    Returns +    ------- +    int : 0 if no buffer interface, 3 if newstyle buffer interface, 2 if oldstyle. +    """ +    if PyObject_CheckBuffer(ob): +        return 3 +    if oldstyle_available(): +        return PyObject_CheckReadBuffer(ob) and 2 +    return 0 + + +cdef inline object asbuffer(object ob, int writable, int format, +                            void **base, Py_ssize_t *size, +                            Py_ssize_t *itemsize): +    """Turn an object into a C buffer in a Python version-independent way. +     +    Parameters +    ---------- +    ob : object +        The object to be turned into a buffer. +        Must provide a Python Buffer interface +    writable : int +        Whether the resulting buffer should be allowed to write +        to the object. +    format : int +        The format of the buffer.  See Python buffer docs. +    base : void ** +        The pointer that will be used to store the resulting C buffer. +    size : Py_ssize_t * +        The size of the buffer(s). +    itemsize : Py_ssize_t * +        The size of an item, if the buffer is non-contiguous. +     +    Returns +    ------- +    An object describing the buffer format. Generally a str, such as 'B'. +    """ + +    cdef void *bptr = NULL +    cdef Py_ssize_t blen = 0, bitemlen = 0 +    cdef Py_buffer view +    cdef int flags = PyBUF_SIMPLE +    cdef int mode = 0 +     +    bfmt = None + +    mode = check_buffer(ob) +    if mode == 0: +        raise TypeError("%r does not provide a buffer interface."%ob) + +    if mode == 3: +        flags = PyBUF_ANY_CONTIGUOUS +        if writable: +            flags |= PyBUF_WRITABLE +        if format: +            flags |= PyBUF_FORMAT +        PyObject_GetBuffer(ob, &view, flags) +        bptr = view.buf +        blen = view.len +        if format: +            if view.format != NULL: +                bfmt = view.format +                bitemlen = view.itemsize +        PyBuffer_Release(&view) +    else: # oldstyle +        if writable: +            PyObject_AsWriteBuffer(ob, &bptr, &blen) +        else: +            PyObject_AsReadBuffer(ob, <const_void **>&bptr, &blen) +        if format: +            try: # numpy.ndarray +                dtype = ob.dtype +                bfmt = dtype.char +                bitemlen = dtype.itemsize +            except AttributeError: +                try: # array.array +                    bfmt = ob.typecode +                    bitemlen = ob.itemsize +                except AttributeError: +                    if isinstance(ob, bytes): +                        bfmt = b"B" +                        bitemlen = 1 +                    else: +                        # nothing found +                        bfmt = None +                        bitemlen = 0 +    if base: base[0] = <void *>bptr +    if size: size[0] = <Py_ssize_t>blen +    if itemsize: itemsize[0] = <Py_ssize_t>bitemlen +     +    if PY_MAJOR_VERSION >= 3 and bfmt is not None: +        return bfmt.decode('ascii') +    return bfmt + + +cdef inline object asbuffer_r(object ob, void **base, Py_ssize_t *size): +    """Wrapper for standard calls to asbuffer with a readonly buffer.""" +    asbuffer(ob, 0, 0, base, size, NULL) +    return ob + + +cdef inline object asbuffer_w(object ob, void **base, Py_ssize_t *size): +    """Wrapper for standard calls to asbuffer with a writable buffer.""" +    asbuffer(ob, 1, 0, base, size, NULL) +    return ob + +#------------------------------------------------------------------------------ +# frombuffer: python buffer/view from C buffer +#------------------------------------------------------------------------------ + + +cdef inline object frombuffer_3(void *ptr, Py_ssize_t s, int readonly): +    """Python 3 version of frombuffer. + +    This is the Python 3 model, but will work on Python >= 2.6. Currently, +    we use it only on >= 3.0. +    """ +    cdef Py_buffer pybuf +    cdef Py_ssize_t *shape = [s] +    cdef str astr="" +    PyBuffer_FillInfo(&pybuf, astr, ptr, s, readonly, PyBUF_SIMPLE) +    pybuf.format = "B" +    pybuf.shape = shape +    return PyMemoryView_FromBuffer(&pybuf) + + +cdef inline object frombuffer_2(void *ptr, Py_ssize_t s, int readonly): +    """Python 2 version of frombuffer.  + +    This must be used for Python <= 2.6, but we use it for all Python < 3. +    """ +     +    if oldstyle_available(): +        if readonly: +            return PyBuffer_FromMemory(ptr, s) +        else: +            return PyBuffer_FromReadWriteMemory(ptr, s) +    else: +        raise NotImplementedError("Old style buffers not available.") + + +cdef inline object frombuffer(void *ptr, Py_ssize_t s, int readonly): +    """Create a Python Buffer/View of a C array.  +     +    Parameters +    ---------- +    ptr : void * +        Pointer to the array to be copied. +    s : size_t +        Length of the buffer. +    readonly : int +        whether the resulting object should be allowed to write to the buffer. +     +    Returns +    ------- +    Python Buffer/View of the C buffer. +    """ +    # oldstyle first priority for now +    if oldstyle_available(): +        return frombuffer_2(ptr, s, readonly) +    else: +        return frombuffer_3(ptr, s, readonly) + + +cdef inline object frombuffer_r(void *ptr, Py_ssize_t s): +    """Wrapper for readonly view frombuffer.""" +    return frombuffer(ptr, s, 1) + + +cdef inline object frombuffer_w(void *ptr, Py_ssize_t s): +    """Wrapper for writable view frombuffer.""" +    return frombuffer(ptr, s, 0) + +#------------------------------------------------------------------------------ +# viewfromobject: python buffer/view from python object, refcounts intact +# frombuffer(asbuffer(obj)) would lose track of refs +#------------------------------------------------------------------------------ + +cdef inline object viewfromobject(object obj, int readonly): +    """Construct a Python Buffer/View object from another Python object. + +    This work in a Python version independent manner. +     +    Parameters +    ---------- +    obj : object +        The input object to be cast as a buffer +    readonly : int +        Whether the result should be prevented from overwriting the original. +     +    Returns +    ------- +    Buffer/View of the original object. +    """ +    if not memoryview_available(): +        if readonly: +            return PyBuffer_FromObject(obj, 0, Py_END_OF_BUFFER) +        else: +            return PyBuffer_FromReadWriteObject(obj, 0, Py_END_OF_BUFFER) +    else: +        return PyMemoryView_FromObject(obj) + + +cdef inline object viewfromobject_r(object obj): +    """Wrapper for readonly viewfromobject.""" +    return viewfromobject(obj, 1) + + +cdef inline object viewfromobject_w(object obj): +    """Wrapper for writable viewfromobject.""" +    return viewfromobject(obj, 0) diff --git a/zmq/utils/constant_names.py b/zmq/utils/constant_names.py new file mode 100644 index 0000000..47da9dc --- /dev/null +++ b/zmq/utils/constant_names.py @@ -0,0 +1,365 @@ +"""0MQ Constant names""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +# dictionaries of constants new or removed in particular versions + +new_in = { +    (2,2,0) : [ +        'RCVTIMEO', +        'SNDTIMEO', +    ], +    (3,2,2) : [ +        # errnos +        'EMSGSIZE', +        'EAFNOSUPPORT', +        'ENETUNREACH', +        'ECONNABORTED', +        'ECONNRESET', +        'ENOTCONN', +        'ETIMEDOUT', +        'EHOSTUNREACH', +        'ENETRESET', +         +        # ctx opts +        'IO_THREADS', +        'MAX_SOCKETS', +        'IO_THREADS_DFLT', +        'MAX_SOCKETS_DFLT', +         +        # socket opts +        'ROUTER_BEHAVIOR', +        'ROUTER_MANDATORY', +        'FAIL_UNROUTABLE', +        'TCP_KEEPALIVE', +        'TCP_KEEPALIVE_CNT', +        'TCP_KEEPALIVE_IDLE', +        'TCP_KEEPALIVE_INTVL', +        'DELAY_ATTACH_ON_CONNECT', +        'XPUB_VERBOSE', +         +        # msg opts +        'MORE', +         +        'EVENT_CONNECTED', +        'EVENT_CONNECT_DELAYED', +        'EVENT_CONNECT_RETRIED', +        'EVENT_LISTENING', +        'EVENT_BIND_FAILED', +        'EVENT_ACCEPTED', +        'EVENT_ACCEPT_FAILED', +        'EVENT_CLOSED', +        'EVENT_CLOSE_FAILED', +        'EVENT_DISCONNECTED', +        'EVENT_ALL', +    ], +    (4,0,0) : [ +        # socket types +        'STREAM', +         +        # socket opts +        'IMMEDIATE', +        'ROUTER_RAW', +        'IPV6', +        'MECHANISM', +        'PLAIN_SERVER', +        'PLAIN_USERNAME', +        'PLAIN_PASSWORD', +        'CURVE_SERVER', +        'CURVE_PUBLICKEY', +        'CURVE_SECRETKEY', +        'CURVE_SERVERKEY', +        'PROBE_ROUTER', +        'REQ_RELAXED', +        'REQ_CORRELATE', +        'CONFLATE', +        'ZAP_DOMAIN', +         +        # security +        'NULL', +        'PLAIN', +        'CURVE', +         +        # events +        'EVENT_MONITOR_STOPPED', +    ], +    (4,1,0) : [ +        # ctx opts +        'SOCKET_LIMIT', +        'THREAD_PRIORITY', +        'THREAD_PRIORITY_DFLT', +        'THREAD_SCHED_POLICY', +        'THREAD_SCHED_POLICY_DFLT', +         +        # socket opts +        'ROUTER_HANDOVER', +        'TOS', +        'IPC_FILTER_PID', +        'IPC_FILTER_UID', +        'IPC_FILTER_GID', +        'CONNECT_RID', +        'GSSAPI_SERVER', +        'GSSAPI_PRINCIPAL', +        'GSSAPI_SERVICE_PRINCIPAL', +        'GSSAPI_PLAINTEXT', +        'HANDSHAKE_IVL', +        'IDENTITY_FD', +        'XPUB_NODROP', +        'SOCKS_PROXY', +         +        # msg opts +        'SRCFD', +        'SHARED', +         +        # security +        'GSSAPI', +         +    ], +} + + +removed_in = { +    (3,2,2) : [ +        'UPSTREAM', +        'DOWNSTREAM', +         +        'HWM', +        'SWAP', +        'MCAST_LOOP', +        'RECOVERY_IVL_MSEC', +    ] +} + +# collections of zmq constant names based on their role +# base names have no specific use +# opt names are validated in get/set methods of various objects + +base_names = [ +    # base +    'VERSION', +    'VERSION_MAJOR', +    'VERSION_MINOR', +    'VERSION_PATCH', +    'NOBLOCK', +    'DONTWAIT', + +    'POLLIN', +    'POLLOUT', +    'POLLERR', +     +    'SNDMORE', + +    'STREAMER', +    'FORWARDER', +    'QUEUE', + +    'IO_THREADS_DFLT', +    'MAX_SOCKETS_DFLT', +    'POLLITEMS_DFLT', +    'THREAD_PRIORITY_DFLT', +    'THREAD_SCHED_POLICY_DFLT', + +    # socktypes +    'PAIR', +    'PUB', +    'SUB', +    'REQ', +    'REP', +    'DEALER', +    'ROUTER', +    'XREQ', +    'XREP', +    'PULL', +    'PUSH', +    'XPUB', +    'XSUB', +    'UPSTREAM', +    'DOWNSTREAM', +    'STREAM', + +    # events +    'EVENT_CONNECTED', +    'EVENT_CONNECT_DELAYED', +    'EVENT_CONNECT_RETRIED', +    'EVENT_LISTENING', +    'EVENT_BIND_FAILED', +    'EVENT_ACCEPTED', +    'EVENT_ACCEPT_FAILED', +    'EVENT_CLOSED', +    'EVENT_CLOSE_FAILED', +    'EVENT_DISCONNECTED', +    'EVENT_ALL', +    'EVENT_MONITOR_STOPPED', + +    # security +    'NULL', +    'PLAIN', +    'CURVE', +    'GSSAPI', + +    ## ERRNO +    # Often used (these are alse in errno.) +    'EAGAIN', +    'EINVAL', +    'EFAULT', +    'ENOMEM', +    'ENODEV', +    'EMSGSIZE', +    'EAFNOSUPPORT', +    'ENETUNREACH', +    'ECONNABORTED', +    'ECONNRESET', +    'ENOTCONN', +    'ETIMEDOUT', +    'EHOSTUNREACH', +    'ENETRESET', + +    # For Windows compatability +    'HAUSNUMERO', +    'ENOTSUP', +    'EPROTONOSUPPORT', +    'ENOBUFS', +    'ENETDOWN', +    'EADDRINUSE', +    'EADDRNOTAVAIL', +    'ECONNREFUSED', +    'EINPROGRESS', +    'ENOTSOCK', + +    # 0MQ Native +    'EFSM', +    'ENOCOMPATPROTO', +    'ETERM', +    'EMTHREAD', +] + +int64_sockopt_names = [ +    'AFFINITY', +    'MAXMSGSIZE', + +    # sockopts removed in 3.0.0 +    'HWM', +    'SWAP', +    'MCAST_LOOP', +    'RECOVERY_IVL_MSEC', +] + +bytes_sockopt_names = [ +    'IDENTITY', +    'SUBSCRIBE', +    'UNSUBSCRIBE', +    'LAST_ENDPOINT', +    'TCP_ACCEPT_FILTER', + +    'PLAIN_USERNAME', +    'PLAIN_PASSWORD', + +    'CURVE_PUBLICKEY', +    'CURVE_SECRETKEY', +    'CURVE_SERVERKEY', +    'ZAP_DOMAIN', +    'CONNECT_RID', +    'GSSAPI_PRINCIPAL', +    'GSSAPI_SERVICE_PRINCIPAL', +    'SOCKS_PROXY', +] + +fd_sockopt_names = [ +    'FD', +    'IDENTITY_FD', +] + +int_sockopt_names = [ +    # sockopts +    'RECONNECT_IVL_MAX', + +    # sockopts new in 2.2.0 +    'SNDTIMEO', +    'RCVTIMEO', + +    # new in 3.x +    'SNDHWM', +    'RCVHWM', +    'MULTICAST_HOPS', +    'IPV4ONLY', + +    'ROUTER_BEHAVIOR', +    'TCP_KEEPALIVE', +    'TCP_KEEPALIVE_CNT', +    'TCP_KEEPALIVE_IDLE', +    'TCP_KEEPALIVE_INTVL', +    'DELAY_ATTACH_ON_CONNECT', +    'XPUB_VERBOSE', + +    'EVENTS', +    'TYPE', +    'LINGER', +    'RECONNECT_IVL', +    'BACKLOG', +     +    'ROUTER_MANDATORY', +    'FAIL_UNROUTABLE', + +    'ROUTER_RAW', +    'IMMEDIATE', +    'IPV6', +    'MECHANISM', +    'PLAIN_SERVER', +    'CURVE_SERVER', +    'PROBE_ROUTER', +    'REQ_RELAXED', +    'REQ_CORRELATE', +    'CONFLATE', +    'ROUTER_HANDOVER', +    'TOS', +    'IPC_FILTER_PID', +    'IPC_FILTER_UID', +    'IPC_FILTER_GID', +    'GSSAPI_SERVER', +    'GSSAPI_PLAINTEXT', +    'HANDSHAKE_IVL', +    'XPUB_NODROP', +] + +switched_sockopt_names = [ +    'RATE', +    'RECOVERY_IVL', +    'SNDBUF', +    'RCVBUF', +    'RCVMORE', +] + +ctx_opt_names = [ +    'IO_THREADS', +    'MAX_SOCKETS', +    'SOCKET_LIMIT', +    'THREAD_PRIORITY', +    'THREAD_SCHED_POLICY', +] + +msg_opt_names = [ +    'MORE', +    'SRCFD', +    'SHARED', +] + +from itertools import chain + +all_names = list(chain( +    base_names, +    ctx_opt_names, +    bytes_sockopt_names, +    fd_sockopt_names, +    int_sockopt_names, +    int64_sockopt_names, +    switched_sockopt_names, +    msg_opt_names, +)) + +del chain + +def no_prefix(name): +    """does the given constant have a ZMQ_ prefix?""" +    return name.startswith('E') and not name.startswith('EVENT') + diff --git a/zmq/utils/garbage.py b/zmq/utils/garbage.py new file mode 100644 index 0000000..22a8977 --- /dev/null +++ b/zmq/utils/garbage.py @@ -0,0 +1,170 @@ +"""Garbage collection thread for representing zmq refcount of Python objects +used in zero-copy sends. +""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + + +import atexit +import struct + +from os import getpid +from collections import namedtuple +from threading import Thread, Event, Lock +import warnings + +import zmq + + +gcref = namedtuple('gcref', ['obj', 'event']) + +class GarbageCollectorThread(Thread): +    """Thread in which garbage collection actually happens.""" +    def __init__(self, gc): +        super(GarbageCollectorThread, self).__init__() +        self.gc = gc +        self.daemon = True +        self.pid = getpid() +        self.ready = Event() +     +    def run(self): +        # detect fork at begining of the thread +        if getpid is None or getpid() != self.pid: +            return +        s = self.gc.context.socket(zmq.PULL) +        s.linger = 0 +        s.bind(self.gc.url) + +        self.ready.set() +         +        while True: +            # detect fork +            if getpid is None or getpid() != self.pid: +                return +            msg = s.recv() +            if msg == b'DIE': +                break +            fmt = 'L' if len(msg) == 4 else 'Q' +            key = struct.unpack(fmt, msg)[0] +            tup = self.gc.refs.pop(key, None) +            if tup and tup.event: +                tup.event.set() +            del tup +        s.close() + + +class GarbageCollector(object): +    """PyZMQ Garbage Collector +     +    Used for representing the reference held by libzmq during zero-copy sends. +    This object holds a dictionary, keyed by Python id, +    of the Python objects whose memory are currently in use by zeromq. +     +    When zeromq is done with the memory, it sends a message on an inproc PUSH socket +    containing the packed size_t (32 or 64-bit unsigned int), +    which is the key in the dict. +    When the PULL socket in the gc thread receives that message, +    the reference is popped from the dict, +    and any tracker events that should be signaled fire. +    """ +     +    refs = None +    _context = None +    _lock = None +    url = "inproc://pyzmq.gc.01" +     +    def __init__(self, context=None): +        super(GarbageCollector, self).__init__() +        self.refs = {} +        self.pid = None +        self.thread = None +        self._context = context +        self._lock = Lock() +        self._stay_down = False +        atexit.register(self._atexit) +     +    @property +    def context(self): +        if self._context is None: +            self._context = zmq.Context() +        return self._context +     +    @context.setter +    def context(self, ctx): +        if self.is_alive(): +            if self.refs: +                warnings.warn("Replacing gc context while gc is running", RuntimeWarning) +            self.stop() +        self._context = ctx +     +    def _atexit(self): +        """atexit callback +         +        sets _stay_down flag so that gc doesn't try to start up again in other atexit handlers +        """ +        self._stay_down = True +        self.stop() +     +    def stop(self): +        """stop the garbage-collection thread""" +        if not self.is_alive(): +            return +        push = self.context.socket(zmq.PUSH) +        push.connect(self.url) +        push.send(b'DIE') +        push.close() +        self.thread.join() +        self.context.term() +        self.refs.clear() +     +    def start(self): +        """Start a new garbage collection thread. +         +        Creates a new zmq Context used for garbage collection. +        Under most circumstances, this will only be called once per process. +        """ +        self.pid = getpid() +        self.refs = {} +        self.thread = GarbageCollectorThread(self) +        self.thread.start() +        self.thread.ready.wait() +     +    def is_alive(self): +        """Is the garbage collection thread currently running? +         +        Includes checks for process shutdown or fork. +        """ +        if (getpid is None or +            getpid() != self.pid or +            self.thread is None or +            not self.thread.is_alive() +            ): +            return False +        return True +     +    def store(self, obj, event=None): +        """store an object and (optionally) event for zero-copy""" +        if not self.is_alive(): +            if self._stay_down: +                return 0 +            # safely start the gc thread +            # use lock and double check, +            # so we don't start multiple threads +            with self._lock: +                if not self.is_alive(): +                    self.start() +        tup = gcref(obj, event) +        theid = id(tup) +        self.refs[theid] = tup +        return theid +     +    def __del__(self): +        if not self.is_alive(): +            return +        try: +            self.stop() +        except Exception as e: +            raise (e) + +gc = GarbageCollector() diff --git a/zmq/utils/getpid_compat.h b/zmq/utils/getpid_compat.h new file mode 100644 index 0000000..47ce90f --- /dev/null +++ b/zmq/utils/getpid_compat.h @@ -0,0 +1,6 @@ +#ifdef _WIN32 +    #include <process.h> +    #define getpid _getpid +#else +    #include <unistd.h> +#endif diff --git a/zmq/utils/interop.py b/zmq/utils/interop.py new file mode 100644 index 0000000..26c0196 --- /dev/null +++ b/zmq/utils/interop.py @@ -0,0 +1,33 @@ +"""Utils for interoperability with other libraries. + +Just CFFI pointer casting for now. +""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + + +try: +    long +except NameError: +    long = int # Python 3 + + +def cast_int_addr(n): +    """Cast an address to a Python int +     +    This could be a Python integer or a CFFI pointer +    """ +    if isinstance(n, (int, long)): +        return n +    try: +        import cffi +    except ImportError: +        pass +    else: +        # from pyzmq, this is an FFI void * +        ffi = cffi.FFI() +        if isinstance(n, ffi.CData): +            return int(ffi.cast("size_t", n)) +     +    raise ValueError("Cannot cast %r to int" % n) diff --git a/zmq/utils/ipcmaxlen.h b/zmq/utils/ipcmaxlen.h new file mode 100644 index 0000000..7218db7 --- /dev/null +++ b/zmq/utils/ipcmaxlen.h @@ -0,0 +1,21 @@ +/* + +Platform-independant detection of IPC path max length + +Copyright (c) 2012 Godefroid Chapelle + +Distributed under the terms of the New BSD License.  The full license is in +the file COPYING.BSD, distributed as part of this software. + */ + +#if defined(HAVE_SYS_UN_H) +#include "sys/un.h" +int get_ipc_path_max_len(void) { +    struct sockaddr_un *dummy; +    return sizeof(dummy->sun_path) - 1; +} +#else +int get_ipc_path_max_len(void) { +    return 0; +} +#endif diff --git a/zmq/utils/jsonapi.py b/zmq/utils/jsonapi.py new file mode 100644 index 0000000..865ca6d --- /dev/null +++ b/zmq/utils/jsonapi.py @@ -0,0 +1,59 @@ +"""Priority based json library imports. + +Always serializes to bytes instead of unicode for zeromq compatibility +on Python 2 and 3. + +Use ``jsonapi.loads()`` and ``jsonapi.dumps()`` for guaranteed symmetry. + +Priority: ``simplejson`` > ``jsonlib2`` > stdlib ``json`` + +``jsonapi.loads/dumps`` provide kwarg-compatibility with stdlib json. + +``jsonapi.jsonmod`` will be the module of the actual underlying implementation. +""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +from zmq.utils.strtypes import bytes, unicode + +jsonmod = None + +priority = ['simplejson', 'jsonlib2', 'json'] +for mod in priority: +    try: +        jsonmod = __import__(mod) +    except ImportError: +        pass +    else: +        break + +def dumps(o, **kwargs): +    """Serialize object to JSON bytes (utf-8). +     +    See jsonapi.jsonmod.dumps for details on kwargs. +    """ +     +    if 'separators' not in kwargs: +        kwargs['separators'] = (',', ':') +     +    s = jsonmod.dumps(o, **kwargs) +     +    if isinstance(s, unicode): +        s = s.encode('utf8') +     +    return s + +def loads(s, **kwargs): +    """Load object from JSON bytes (utf-8). +     +    See jsonapi.jsonmod.loads for details on kwargs. +    """ +     +    if str is unicode and isinstance(s, bytes): +        s = s.decode('utf8') +     +    return jsonmod.loads(s, **kwargs) + +__all__ = ['jsonmod', 'dumps', 'loads'] + diff --git a/zmq/utils/monitor.py b/zmq/utils/monitor.py new file mode 100644 index 0000000..734d54b --- /dev/null +++ b/zmq/utils/monitor.py @@ -0,0 +1,68 @@ +# -*- coding: utf-8 -*- +"""Module holding utility and convenience functions for zmq event monitoring.""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +import struct +import zmq +from zmq.error import _check_version + +def parse_monitor_message(msg): +    """decode zmq_monitor event messages. +     +    Parameters +    ---------- +    msg : list(bytes) +        zmq multipart message that has arrived on a monitor PAIR socket. +         +        First frame is:: +         +            16 bit event id +            32 bit event value +            no padding + +        Second frame is the endpoint as a bytestring + +    Returns +    ------- +    event : dict +        event description as dict with the keys `event`, `value`, and `endpoint`. +    """ +     +    if len(msg) != 2 or len(msg[0]) != 6: +        raise RuntimeError("Invalid event message format: %s" % msg) +    event = {} +    event['event'], event['value'] = struct.unpack("=hi", msg[0]) +    event['endpoint'] = msg[1] +    return event + +def recv_monitor_message(socket, flags=0): +    """Receive and decode the given raw message from the monitoring socket and return a dict. + +    Requires libzmq ≥ 4.0 + +    The returned dict will have the following entries: +      event     : int, the event id as described in libzmq.zmq_socket_monitor +      value     : int, the event value associated with the event, see libzmq.zmq_socket_monitor +      endpoint  : string, the affected endpoint +     +    Parameters +    ---------- +    socket : zmq PAIR socket +        The PAIR socket (created by other.get_monitor_socket()) on which to recv the message +    flags : bitfield (int) +        standard zmq recv flags + +    Returns +    ------- +    event : dict +        event description as dict with the keys `event`, `value`, and `endpoint`. +    """ +    _check_version((4,0), 'libzmq event API') +    # will always return a list +    msg = socket.recv_multipart(flags) +    # 4.0-style event API +    return parse_monitor_message(msg) + +__all__ = ['parse_monitor_message', 'recv_monitor_message'] diff --git a/zmq/utils/pyversion_compat.h b/zmq/utils/pyversion_compat.h new file mode 100644 index 0000000..fac0904 --- /dev/null +++ b/zmq/utils/pyversion_compat.h @@ -0,0 +1,25 @@ +#include "Python.h" + +#if PY_VERSION_HEX < 0x02070000 +    #define PyMemoryView_FromBuffer(info) (PyErr_SetString(PyExc_NotImplementedError, \ +                    "new buffer interface is not available"), (PyObject *)NULL) +    #define PyMemoryView_FromObject(object)     (PyErr_SetString(PyExc_NotImplementedError, \ +                                        "new buffer interface is not available"), (PyObject *)NULL) +#endif + +#if PY_VERSION_HEX >= 0x03000000 +    // for buffers +    #define Py_END_OF_BUFFER ((Py_ssize_t) 0) + +    #define PyObject_CheckReadBuffer(object) (0) + +    #define PyBuffer_FromMemory(ptr, s) (PyErr_SetString(PyExc_NotImplementedError, \ +                            "old buffer interface is not available"), (PyObject *)NULL) +    #define PyBuffer_FromReadWriteMemory(ptr, s) (PyErr_SetString(PyExc_NotImplementedError, \ +                            "old buffer interface is not available"), (PyObject *)NULL) +    #define PyBuffer_FromObject(object, offset, size)  (PyErr_SetString(PyExc_NotImplementedError, \ +                            "old buffer interface is not available"), (PyObject *)NULL) +    #define PyBuffer_FromReadWriteObject(object, offset, size)  (PyErr_SetString(PyExc_NotImplementedError, \ +                            "old buffer interface is not available"), (PyObject *)NULL) + +#endif diff --git a/zmq/utils/sixcerpt.py b/zmq/utils/sixcerpt.py new file mode 100644 index 0000000..5492fd5 --- /dev/null +++ b/zmq/utils/sixcerpt.py @@ -0,0 +1,52 @@ +"""Excerpts of six.py""" + +# Copyright (C) 2010-2014 Benjamin Peterson +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import sys + +# Useful for very coarse version differentiation. +PY2 = sys.version_info[0] == 2 +PY3 = sys.version_info[0] == 3 + +if PY3: + +    def reraise(tp, value, tb=None): +        if value.__traceback__ is not tb: +            raise value.with_traceback(tb) +        raise value + +else: +    def exec_(_code_, _globs_=None, _locs_=None): +        """Execute code in a namespace.""" +        if _globs_ is None: +            frame = sys._getframe(1) +            _globs_ = frame.f_globals +            if _locs_ is None: +                _locs_ = frame.f_locals +            del frame +        elif _locs_ is None: +            _locs_ = _globs_ +        exec("""exec _code_ in _globs_, _locs_""") + + +    exec_("""def reraise(tp, value, tb=None): +    raise tp, value, tb +""") diff --git a/zmq/utils/strtypes.py b/zmq/utils/strtypes.py new file mode 100644 index 0000000..548410d --- /dev/null +++ b/zmq/utils/strtypes.py @@ -0,0 +1,45 @@ +"""Declare basic string types unambiguously for various Python versions. + +Authors +------- +* MinRK +""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +import sys + +if sys.version_info[0] >= 3: +    bytes = bytes +    unicode = str +    basestring = (bytes, unicode) +else: +    unicode = unicode +    bytes = str +    basestring = basestring + +def cast_bytes(s, encoding='utf8', errors='strict'): +    """cast unicode or bytes to bytes""" +    if isinstance(s, bytes): +        return s +    elif isinstance(s, unicode): +        return s.encode(encoding, errors) +    else: +        raise TypeError("Expected unicode or bytes, got %r" % s) + +def cast_unicode(s, encoding='utf8', errors='strict'): +    """cast bytes or unicode to unicode""" +    if isinstance(s, bytes): +        return s.decode(encoding, errors) +    elif isinstance(s, unicode): +        return s +    else: +        raise TypeError("Expected unicode or bytes, got %r" % s) + +# give short 'b' alias for cast_bytes, so that we can use fake b('stuff') +# to simulate b'stuff' +b = asbytes = cast_bytes +u = cast_unicode + +__all__ = ['asbytes', 'bytes', 'unicode', 'basestring', 'b', 'u', 'cast_bytes', 'cast_unicode'] diff --git a/zmq/utils/win32.py b/zmq/utils/win32.py new file mode 100644 index 0000000..e423536 --- /dev/null +++ b/zmq/utils/win32.py @@ -0,0 +1,132 @@ +"""Win32 compatibility utilities.""" + +#----------------------------------------------------------------------------- +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. +#----------------------------------------------------------------------------- + +import os + +# No-op implementation for other platforms. +class _allow_interrupt(object): +    """Utility for fixing CTRL-C events on Windows. + +    On Windows, the Python interpreter intercepts CTRL-C events in order to +    translate them into ``KeyboardInterrupt`` exceptions.  It (presumably) +    does this by setting a flag in its "control control handler" and +    checking it later at a convenient location in the interpreter. + +    However, when the Python interpreter is blocked waiting for the ZMQ +    poll operation to complete, it must wait for ZMQ's ``select()`` +    operation to complete before translating the CTRL-C event into the +    ``KeyboardInterrupt`` exception. + +    The only way to fix this seems to be to add our own "console control +    handler" and perform some application-defined operation that will +    unblock the ZMQ polling operation in order to force ZMQ to pass control +    back to the Python interpreter. + +    This context manager performs all that Windows-y stuff, providing you +    with a hook that is called when a CTRL-C event is intercepted.  This +    hook allows you to unblock your ZMQ poll operation immediately, which +    will then result in the expected ``KeyboardInterrupt`` exception. + +    Without this context manager, your ZMQ-based application will not +    respond normally to CTRL-C events on Windows.  If a CTRL-C event occurs +    while blocked on ZMQ socket polling, the translation to a +    ``KeyboardInterrupt`` exception will be delayed until the I/O completes +    and control returns to the Python interpreter (this may never happen if +    you use an infinite timeout). + +    A no-op implementation is provided on non-Win32 systems to avoid the +    application from having to conditionally use it. + +    Example usage: + +    .. sourcecode:: python + +       def stop_my_application(): +           # ... + +       with allow_interrupt(stop_my_application): +           # main polling loop. + +    In a typical ZMQ application, you would use the "self pipe trick" to +    send message to a ``PAIR`` socket in order to interrupt your blocking +    socket polling operation. + +    In a Tornado event loop, you can use the ``IOLoop.stop`` method to +    unblock your I/O loop. +    """ + +    def __init__(self, action=None): +        """Translate ``action`` into a CTRL-C handler. + +        ``action`` is a callable that takes no arguments and returns no +        value (returned value is ignored).  It must *NEVER* raise an +        exception. +         +        If unspecified, a no-op will be used. +        """ +        self._init_action(action) +     +    def _init_action(self, action): +        pass + +    def __enter__(self): +        return self + +    def __exit__(self, *args): +        return + +if os.name == 'nt': +    from ctypes import WINFUNCTYPE, windll +    from ctypes.wintypes import BOOL, DWORD + +    kernel32 = windll.LoadLibrary('kernel32') + +    # <http://msdn.microsoft.com/en-us/library/ms686016.aspx> +    PHANDLER_ROUTINE = WINFUNCTYPE(BOOL, DWORD) +    SetConsoleCtrlHandler = kernel32.SetConsoleCtrlHandler +    SetConsoleCtrlHandler.argtypes = (PHANDLER_ROUTINE, BOOL) +    SetConsoleCtrlHandler.restype = BOOL + +    class allow_interrupt(_allow_interrupt): +        __doc__ = _allow_interrupt.__doc__ + +        def init_action(self, action): +            if action is None: +                action = lambda: None +            self.action = action +            @PHANDLER_ROUTINE +            def handle(event): +                if event == 0:  # CTRL_C_EVENT +                    action() +                    # Typical C implementations would return 1 to indicate that +                    # the event was processed and other control handlers in the +                    # stack should not be executed.  However, that would +                    # prevent the Python interpreter's handler from translating +                    # CTRL-C to a `KeyboardInterrupt` exception, so we pretend +                    # that we didn't handle it. +                return 0 +            self.handle = handle + +        def __enter__(self): +            """Install the custom CTRL-C handler.""" +            result = SetConsoleCtrlHandler(self.handle, 1) +            if result == 0: +                # Have standard library automatically call `GetLastError()` and +                # `FormatMessage()` into a nice exception object :-) +                raise WindowsError() + +        def __exit__(self, *args): +            """Remove the custom CTRL-C handler.""" +            result = SetConsoleCtrlHandler(self.handle, 0) +            if result == 0: +                # Have standard library automatically call `GetLastError()` and +                # `FormatMessage()` into a nice exception object :-) +                raise WindowsError() +else: +    class allow_interrupt(_allow_interrupt): +        __doc__ = _allow_interrupt.__doc__ +        pass diff --git a/zmq/utils/z85.py b/zmq/utils/z85.py new file mode 100644 index 0000000..1bb1784 --- /dev/null +++ b/zmq/utils/z85.py @@ -0,0 +1,56 @@ +"""Python implementation of Z85 85-bit encoding + +Z85 encoding is a plaintext encoding for a bytestring interpreted as 32bit integers. +Since the chunks are 32bit, a bytestring must be a multiple of 4 bytes. +See ZMQ RFC 32 for details. + + +""" + +# Copyright (C) PyZMQ Developers +# Distributed under the terms of the Modified BSD License. + +import sys +import struct + +PY3 = sys.version_info[0] >= 3 +# Z85CHARS is the base 85 symbol table +Z85CHARS = b"0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ.-:+=^!/*?&<>()[]{}@%$#" +# Z85MAP maps integers in [0,84] to the appropriate character in Z85CHARS +Z85MAP = dict([(c, idx) for idx, c in enumerate(Z85CHARS)]) + +_85s = [ 85**i for i in range(5) ][::-1] + +def encode(rawbytes): +    """encode raw bytes into Z85""" +    # Accepts only byte arrays bounded to 4 bytes +    if len(rawbytes) % 4: +        raise ValueError("length must be multiple of 4, not %i" % len(rawbytes)) +     +    nvalues = len(rawbytes) / 4 +     +    values = struct.unpack('>%dI' % nvalues, rawbytes) +    encoded = [] +    for v in values: +        for offset in _85s: +            encoded.append(Z85CHARS[(v // offset) % 85]) +     +    # In Python 3, encoded is a list of integers (obviously?!) +    if PY3: +        return bytes(encoded) +    else: +        return b''.join(encoded) + +def decode(z85bytes): +    """decode Z85 bytes to raw bytes""" +    if len(z85bytes) % 5: +        raise ValueError("Z85 length must be multiple of 5, not %i" % len(z85bytes)) +     +    nvalues = len(z85bytes) / 5 +    values = [] +    for i in range(0, len(z85bytes), 5): +        value = 0 +        for j, offset in enumerate(_85s): +            value += Z85MAP[z85bytes[i+j]] * offset +        values.append(value) +    return struct.pack('>%dI' % nvalues, *values) diff --git a/zmq/utils/zmq_compat.h b/zmq/utils/zmq_compat.h new file mode 100644 index 0000000..81c57b6 --- /dev/null +++ b/zmq/utils/zmq_compat.h @@ -0,0 +1,80 @@ +//----------------------------------------------------------------------------- +//  Copyright (c) 2010 Brian Granger, Min Ragan-Kelley +// +//  Distributed under the terms of the New BSD License.  The full license is in +//  the file COPYING.BSD, distributed as part of this software. +//----------------------------------------------------------------------------- + +#if defined(_MSC_VER) +#define pyzmq_int64_t __int64 +#else +#include <stdint.h> +#define pyzmq_int64_t int64_t +#endif + + +#include "zmq.h" +// version compatibility for constants: +#include "zmq_constants.h" + +#define _missing (-1) + + +// define fd type (from libzmq's fd.hpp) +#ifdef _WIN32 +  #ifdef _MSC_VER && _MSC_VER <= 1400 +    #define ZMQ_FD_T UINT_PTR +  #else +    #define ZMQ_FD_T SOCKET +  #endif +#else +    #define ZMQ_FD_T int +#endif + +// use unambiguous aliases for zmq_send/recv functions + +#if ZMQ_VERSION_MAJOR >= 4 +// nothing to remove +#else +    #define zmq_curve_keypair(z85_public_key, z85_secret_key) _missing +#endif + +#if ZMQ_VERSION_MAJOR >= 4 && ZMQ_VERSION_MINOR >= 1 +// nothing to remove +#else +    #define zmq_msg_gets(msg, prop) _missing +    #define zmq_has(capability) _missing +#endif + +#if ZMQ_VERSION_MAJOR >= 3 +    #define zmq_sendbuf zmq_send +    #define zmq_recvbuf zmq_recv + +    // 3.x deprecations - these symbols haven't been removed, +    // but let's protect against their planned removal +    #define zmq_device(device_type, isocket, osocket) _missing +    #define zmq_init(io_threads) ((void*)NULL) +    #define zmq_term zmq_ctx_destroy +#else +    #define zmq_ctx_set(ctx, opt, val) _missing +    #define zmq_ctx_get(ctx, opt) _missing +    #define zmq_ctx_destroy zmq_term +    #define zmq_ctx_new() ((void*)NULL) + +    #define zmq_proxy(a,b,c) _missing + +    #define zmq_disconnect(s, addr) _missing +    #define zmq_unbind(s, addr) _missing +     +    #define zmq_msg_more(msg) _missing +    #define zmq_msg_get(msg, opt) _missing +    #define zmq_msg_set(msg, opt, val) _missing +    #define zmq_msg_send(msg, s, flags) zmq_send(s, msg, flags) +    #define zmq_msg_recv(msg, s, flags) zmq_recv(s, msg, flags) +     +    #define zmq_sendbuf(s, buf, len, flags) _missing +    #define zmq_recvbuf(s, buf, len, flags) _missing + +    #define zmq_socket_monitor(s, addr, flags) _missing + +#endif diff --git a/zmq/utils/zmq_constants.h b/zmq/utils/zmq_constants.h new file mode 100644 index 0000000..9768302 --- /dev/null +++ b/zmq/utils/zmq_constants.h @@ -0,0 +1,622 @@ +#ifndef _PYZMQ_CONSTANT_DEFS +#define _PYZMQ_CONSTANT_DEFS + +#define _PYZMQ_UNDEFINED (-9999) +#ifndef ZMQ_VERSION +    #define ZMQ_VERSION (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_VERSION_MAJOR +    #define ZMQ_VERSION_MAJOR (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_VERSION_MINOR +    #define ZMQ_VERSION_MINOR (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_VERSION_PATCH +    #define ZMQ_VERSION_PATCH (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_NOBLOCK +    #define ZMQ_NOBLOCK (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_DONTWAIT +    #define ZMQ_DONTWAIT (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_POLLIN +    #define ZMQ_POLLIN (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_POLLOUT +    #define ZMQ_POLLOUT (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_POLLERR +    #define ZMQ_POLLERR (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_SNDMORE +    #define ZMQ_SNDMORE (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_STREAMER +    #define ZMQ_STREAMER (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_FORWARDER +    #define ZMQ_FORWARDER (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_QUEUE +    #define ZMQ_QUEUE (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_IO_THREADS_DFLT +    #define ZMQ_IO_THREADS_DFLT (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_MAX_SOCKETS_DFLT +    #define ZMQ_MAX_SOCKETS_DFLT (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_POLLITEMS_DFLT +    #define ZMQ_POLLITEMS_DFLT (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_THREAD_PRIORITY_DFLT +    #define ZMQ_THREAD_PRIORITY_DFLT (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_THREAD_SCHED_POLICY_DFLT +    #define ZMQ_THREAD_SCHED_POLICY_DFLT (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_PAIR +    #define ZMQ_PAIR (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_PUB +    #define ZMQ_PUB (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_SUB +    #define ZMQ_SUB (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_REQ +    #define ZMQ_REQ (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_REP +    #define ZMQ_REP (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_DEALER +    #define ZMQ_DEALER (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_ROUTER +    #define ZMQ_ROUTER (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_XREQ +    #define ZMQ_XREQ (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_XREP +    #define ZMQ_XREP (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_PULL +    #define ZMQ_PULL (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_PUSH +    #define ZMQ_PUSH (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_XPUB +    #define ZMQ_XPUB (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_XSUB +    #define ZMQ_XSUB (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_UPSTREAM +    #define ZMQ_UPSTREAM (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_DOWNSTREAM +    #define ZMQ_DOWNSTREAM (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_STREAM +    #define ZMQ_STREAM (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_EVENT_CONNECTED +    #define ZMQ_EVENT_CONNECTED (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_EVENT_CONNECT_DELAYED +    #define ZMQ_EVENT_CONNECT_DELAYED (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_EVENT_CONNECT_RETRIED +    #define ZMQ_EVENT_CONNECT_RETRIED (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_EVENT_LISTENING +    #define ZMQ_EVENT_LISTENING (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_EVENT_BIND_FAILED +    #define ZMQ_EVENT_BIND_FAILED (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_EVENT_ACCEPTED +    #define ZMQ_EVENT_ACCEPTED (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_EVENT_ACCEPT_FAILED +    #define ZMQ_EVENT_ACCEPT_FAILED (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_EVENT_CLOSED +    #define ZMQ_EVENT_CLOSED (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_EVENT_CLOSE_FAILED +    #define ZMQ_EVENT_CLOSE_FAILED (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_EVENT_DISCONNECTED +    #define ZMQ_EVENT_DISCONNECTED (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_EVENT_ALL +    #define ZMQ_EVENT_ALL (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_EVENT_MONITOR_STOPPED +    #define ZMQ_EVENT_MONITOR_STOPPED (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_NULL +    #define ZMQ_NULL (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_PLAIN +    #define ZMQ_PLAIN (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_CURVE +    #define ZMQ_CURVE (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_GSSAPI +    #define ZMQ_GSSAPI (_PYZMQ_UNDEFINED) +#endif + +#ifndef EAGAIN +    #define EAGAIN (_PYZMQ_UNDEFINED) +#endif + +#ifndef EINVAL +    #define EINVAL (_PYZMQ_UNDEFINED) +#endif + +#ifndef EFAULT +    #define EFAULT (_PYZMQ_UNDEFINED) +#endif + +#ifndef ENOMEM +    #define ENOMEM (_PYZMQ_UNDEFINED) +#endif + +#ifndef ENODEV +    #define ENODEV (_PYZMQ_UNDEFINED) +#endif + +#ifndef EMSGSIZE +    #define EMSGSIZE (_PYZMQ_UNDEFINED) +#endif + +#ifndef EAFNOSUPPORT +    #define EAFNOSUPPORT (_PYZMQ_UNDEFINED) +#endif + +#ifndef ENETUNREACH +    #define ENETUNREACH (_PYZMQ_UNDEFINED) +#endif + +#ifndef ECONNABORTED +    #define ECONNABORTED (_PYZMQ_UNDEFINED) +#endif + +#ifndef ECONNRESET +    #define ECONNRESET (_PYZMQ_UNDEFINED) +#endif + +#ifndef ENOTCONN +    #define ENOTCONN (_PYZMQ_UNDEFINED) +#endif + +#ifndef ETIMEDOUT +    #define ETIMEDOUT (_PYZMQ_UNDEFINED) +#endif + +#ifndef EHOSTUNREACH +    #define EHOSTUNREACH (_PYZMQ_UNDEFINED) +#endif + +#ifndef ENETRESET +    #define ENETRESET (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_HAUSNUMERO +    #define ZMQ_HAUSNUMERO (_PYZMQ_UNDEFINED) +#endif + +#ifndef ENOTSUP +    #define ENOTSUP (_PYZMQ_UNDEFINED) +#endif + +#ifndef EPROTONOSUPPORT +    #define EPROTONOSUPPORT (_PYZMQ_UNDEFINED) +#endif + +#ifndef ENOBUFS +    #define ENOBUFS (_PYZMQ_UNDEFINED) +#endif + +#ifndef ENETDOWN +    #define ENETDOWN (_PYZMQ_UNDEFINED) +#endif + +#ifndef EADDRINUSE +    #define EADDRINUSE (_PYZMQ_UNDEFINED) +#endif + +#ifndef EADDRNOTAVAIL +    #define EADDRNOTAVAIL (_PYZMQ_UNDEFINED) +#endif + +#ifndef ECONNREFUSED +    #define ECONNREFUSED (_PYZMQ_UNDEFINED) +#endif + +#ifndef EINPROGRESS +    #define EINPROGRESS (_PYZMQ_UNDEFINED) +#endif + +#ifndef ENOTSOCK +    #define ENOTSOCK (_PYZMQ_UNDEFINED) +#endif + +#ifndef EFSM +    #define EFSM (_PYZMQ_UNDEFINED) +#endif + +#ifndef ENOCOMPATPROTO +    #define ENOCOMPATPROTO (_PYZMQ_UNDEFINED) +#endif + +#ifndef ETERM +    #define ETERM (_PYZMQ_UNDEFINED) +#endif + +#ifndef EMTHREAD +    #define EMTHREAD (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_IO_THREADS +    #define ZMQ_IO_THREADS (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_MAX_SOCKETS +    #define ZMQ_MAX_SOCKETS (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_SOCKET_LIMIT +    #define ZMQ_SOCKET_LIMIT (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_THREAD_PRIORITY +    #define ZMQ_THREAD_PRIORITY (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_THREAD_SCHED_POLICY +    #define ZMQ_THREAD_SCHED_POLICY (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_IDENTITY +    #define ZMQ_IDENTITY (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_SUBSCRIBE +    #define ZMQ_SUBSCRIBE (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_UNSUBSCRIBE +    #define ZMQ_UNSUBSCRIBE (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_LAST_ENDPOINT +    #define ZMQ_LAST_ENDPOINT (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_TCP_ACCEPT_FILTER +    #define ZMQ_TCP_ACCEPT_FILTER (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_PLAIN_USERNAME +    #define ZMQ_PLAIN_USERNAME (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_PLAIN_PASSWORD +    #define ZMQ_PLAIN_PASSWORD (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_CURVE_PUBLICKEY +    #define ZMQ_CURVE_PUBLICKEY (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_CURVE_SECRETKEY +    #define ZMQ_CURVE_SECRETKEY (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_CURVE_SERVERKEY +    #define ZMQ_CURVE_SERVERKEY (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_ZAP_DOMAIN +    #define ZMQ_ZAP_DOMAIN (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_CONNECT_RID +    #define ZMQ_CONNECT_RID (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_GSSAPI_PRINCIPAL +    #define ZMQ_GSSAPI_PRINCIPAL (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_GSSAPI_SERVICE_PRINCIPAL +    #define ZMQ_GSSAPI_SERVICE_PRINCIPAL (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_SOCKS_PROXY +    #define ZMQ_SOCKS_PROXY (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_FD +    #define ZMQ_FD (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_IDENTITY_FD +    #define ZMQ_IDENTITY_FD (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_RECONNECT_IVL_MAX +    #define ZMQ_RECONNECT_IVL_MAX (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_SNDTIMEO +    #define ZMQ_SNDTIMEO (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_RCVTIMEO +    #define ZMQ_RCVTIMEO (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_SNDHWM +    #define ZMQ_SNDHWM (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_RCVHWM +    #define ZMQ_RCVHWM (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_MULTICAST_HOPS +    #define ZMQ_MULTICAST_HOPS (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_IPV4ONLY +    #define ZMQ_IPV4ONLY (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_ROUTER_BEHAVIOR +    #define ZMQ_ROUTER_BEHAVIOR (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_TCP_KEEPALIVE +    #define ZMQ_TCP_KEEPALIVE (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_TCP_KEEPALIVE_CNT +    #define ZMQ_TCP_KEEPALIVE_CNT (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_TCP_KEEPALIVE_IDLE +    #define ZMQ_TCP_KEEPALIVE_IDLE (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_TCP_KEEPALIVE_INTVL +    #define ZMQ_TCP_KEEPALIVE_INTVL (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_DELAY_ATTACH_ON_CONNECT +    #define ZMQ_DELAY_ATTACH_ON_CONNECT (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_XPUB_VERBOSE +    #define ZMQ_XPUB_VERBOSE (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_EVENTS +    #define ZMQ_EVENTS (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_TYPE +    #define ZMQ_TYPE (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_LINGER +    #define ZMQ_LINGER (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_RECONNECT_IVL +    #define ZMQ_RECONNECT_IVL (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_BACKLOG +    #define ZMQ_BACKLOG (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_ROUTER_MANDATORY +    #define ZMQ_ROUTER_MANDATORY (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_FAIL_UNROUTABLE +    #define ZMQ_FAIL_UNROUTABLE (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_ROUTER_RAW +    #define ZMQ_ROUTER_RAW (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_IMMEDIATE +    #define ZMQ_IMMEDIATE (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_IPV6 +    #define ZMQ_IPV6 (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_MECHANISM +    #define ZMQ_MECHANISM (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_PLAIN_SERVER +    #define ZMQ_PLAIN_SERVER (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_CURVE_SERVER +    #define ZMQ_CURVE_SERVER (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_PROBE_ROUTER +    #define ZMQ_PROBE_ROUTER (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_REQ_RELAXED +    #define ZMQ_REQ_RELAXED (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_REQ_CORRELATE +    #define ZMQ_REQ_CORRELATE (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_CONFLATE +    #define ZMQ_CONFLATE (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_ROUTER_HANDOVER +    #define ZMQ_ROUTER_HANDOVER (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_TOS +    #define ZMQ_TOS (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_IPC_FILTER_PID +    #define ZMQ_IPC_FILTER_PID (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_IPC_FILTER_UID +    #define ZMQ_IPC_FILTER_UID (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_IPC_FILTER_GID +    #define ZMQ_IPC_FILTER_GID (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_GSSAPI_SERVER +    #define ZMQ_GSSAPI_SERVER (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_GSSAPI_PLAINTEXT +    #define ZMQ_GSSAPI_PLAINTEXT (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_HANDSHAKE_IVL +    #define ZMQ_HANDSHAKE_IVL (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_XPUB_NODROP +    #define ZMQ_XPUB_NODROP (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_AFFINITY +    #define ZMQ_AFFINITY (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_MAXMSGSIZE +    #define ZMQ_MAXMSGSIZE (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_HWM +    #define ZMQ_HWM (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_SWAP +    #define ZMQ_SWAP (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_MCAST_LOOP +    #define ZMQ_MCAST_LOOP (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_RECOVERY_IVL_MSEC +    #define ZMQ_RECOVERY_IVL_MSEC (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_RATE +    #define ZMQ_RATE (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_RECOVERY_IVL +    #define ZMQ_RECOVERY_IVL (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_SNDBUF +    #define ZMQ_SNDBUF (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_RCVBUF +    #define ZMQ_RCVBUF (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_RCVMORE +    #define ZMQ_RCVMORE (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_MORE +    #define ZMQ_MORE (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_SRCFD +    #define ZMQ_SRCFD (_PYZMQ_UNDEFINED) +#endif + +#ifndef ZMQ_SHARED +    #define ZMQ_SHARED (_PYZMQ_UNDEFINED) +#endif + + +#endif // ifndef _PYZMQ_CONSTANT_DEFS  | 
