diff options
| author | Kali Kaneko <kali@leap.se> | 2016-04-18 10:56:20 -0400 | 
|---|---|---|
| committer | Kali Kaneko <kali@leap.se> | 2016-04-18 10:56:20 -0400 | 
| commit | e30e06d9062578e1932b5a6a4c4124a1663e18c2 (patch) | |
| tree | 7c81bef2afd8d32d0179cc0192239271252bc311 /src | |
| parent | e5796bf55e3db177ee567118519136fd96ada3c4 (diff) | |
| parent | cef15c04610ee188052af78ead8cfe7ea29d81c6 (diff) | |
Merge tag '0.5.1'
Tag leap.bitmask version 0.5.1
# gpg: Signature made Mon 18 Apr 2016 10:52:44 AM BOT
# gpg:                using RSA key 1CAF6C5B9F720808
# gpg: Good signature from "Kaliyuga <kaliyuga@riseup.net>" [ultimate]
# gpg:                 aka "Kali Kaneko (leap communications) <kali@leap.se>" [ultimate]
Diffstat (limited to 'src')
| -rw-r--r-- | src/leap/common/__init__.py | 2 | ||||
| -rw-r--r-- | src/leap/common/_version.py | 541 | ||||
| -rw-r--r-- | src/leap/common/certs.py | 4 | ||||
| -rw-r--r-- | src/leap/common/events/auth.py | 100 | ||||
| -rw-r--r-- | src/leap/common/events/catalog.py | 85 | ||||
| -rw-r--r-- | src/leap/common/events/client.py | 23 | ||||
| -rw-r--r-- | src/leap/common/events/examples/README.txt | 49 | ||||
| -rw-r--r-- | src/leap/common/events/examples/client.py | 2 | ||||
| -rw-r--r-- | src/leap/common/events/examples/server.py | 4 | ||||
| -rw-r--r-- | src/leap/common/events/server.py | 24 | ||||
| -rw-r--r-- | src/leap/common/events/tests/test_auth.py | 64 | ||||
| -rw-r--r-- | src/leap/common/events/tests/test_events.py (renamed from src/leap/common/tests/test_events.py) | 23 | ||||
| -rw-r--r-- | src/leap/common/events/txclient.py | 10 | ||||
| -rw-r--r-- | src/leap/common/events/zmq_components.py | 147 | ||||
| -rw-r--r-- | src/leap/common/service_hooks.py | 75 | ||||
| -rw-r--r-- | src/leap/common/testing/basetest.py | 9 | ||||
| -rw-r--r-- | src/leap/common/zmq_utils.py | 5 | 
17 files changed, 881 insertions, 286 deletions
| diff --git a/src/leap/common/__init__.py b/src/leap/common/__init__.py index 383e198..3b07cf8 100644 --- a/src/leap/common/__init__.py +++ b/src/leap/common/__init__.py @@ -4,7 +4,6 @@ from leap.common import certs  from leap.common import check  from leap.common import files  from leap.common import events -from ._version import get_versions  logger = logging.getLogger(__name__) @@ -17,5 +16,6 @@ except ImportError:  __all__ = ["certs", "check", "files", "events"] +from ._version import get_versions  __version__ = get_versions()['version']  del get_versions diff --git a/src/leap/common/_version.py b/src/leap/common/_version.py index de94ba8..e29d969 100644 --- a/src/leap/common/_version.py +++ b/src/leap/common/_version.py @@ -1,73 +1,157 @@ +  # This file helps to compute a version number in source trees obtained from  # git-archive tarball (such as those provided by githubs download-from-tag -# feature). Distribution tarballs (build by setup.py sdist) and build +# feature). Distribution tarballs (built by setup.py sdist) and build  # directories (produced by setup.py build) will contain a much shorter file  # that just contains the computed version number.  # This file is released into the public domain. Generated by -# versioneer-0.7+ (https://github.com/warner/python-versioneer) +# versioneer-0.16 (https://github.com/warner/python-versioneer) -# these strings will be replaced by git during git-archive +"""Git implementation of _version.py.""" +import errno +import os +import re  import subprocess  import sys -import re -import os.path -IN_LONG_VERSION_PY = True -git_refnames = "$Format:%d$" -git_full = "$Format:%H$" +def get_keywords(): +    """Get the keywords needed to look up the version information.""" +    # these strings will be replaced by git during git-archive. +    # setup.py/versioneer.py will grep for the variable names, so they must +    # each be defined on a line of their own. _version.py will just call +    # get_keywords(). +    git_refnames = "$Format:%d$" +    git_full = "$Format:%H$" +    keywords = {"refnames": git_refnames, "full": git_full} +    return keywords -def run_command(args, cwd=None, verbose=False): -    try: -        # remember shell=False, so use git.cmd on windows, not just git -        p = subprocess.Popen(args, stdout=subprocess.PIPE, cwd=cwd) -    except EnvironmentError: -        e = sys.exc_info()[1] + +class VersioneerConfig: +    """Container for Versioneer configuration parameters.""" + + +def get_config(): +    """Create, populate and return the VersioneerConfig() object.""" +    # these strings are filled in when 'setup.py versioneer' creates +    # _version.py +    cfg = VersioneerConfig() +    cfg.VCS = "git" +    cfg.style = "pep440" +    cfg.tag_prefix = "" +    cfg.parentdir_prefix = "None" +    cfg.versionfile_source = "src/leap/common/_version.py" +    cfg.verbose = False +    return cfg + + +class NotThisMethod(Exception): +    """Exception raised if a method is not valid for the current scenario.""" + + +LONG_VERSION_PY = {} +HANDLERS = {} + + +def register_vcs_handler(vcs, method):  # decorator +    """Decorator to mark a method as the handler for a particular VCS.""" +    def decorate(f): +        """Store f in HANDLERS[vcs][method].""" +        if vcs not in HANDLERS: +            HANDLERS[vcs] = {} +        HANDLERS[vcs][method] = f +        return f +    return decorate + + +def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False): +    """Call the given command(s).""" +    assert isinstance(commands, list) +    p = None +    for c in commands: +        try: +            dispcmd = str([c] + args) +            # remember shell=False, so use git.cmd on windows, not just git +            p = subprocess.Popen([c] + args, cwd=cwd, stdout=subprocess.PIPE, +                                 stderr=(subprocess.PIPE if hide_stderr +                                         else None)) +            break +        except EnvironmentError: +            e = sys.exc_info()[1] +            if e.errno == errno.ENOENT: +                continue +            if verbose: +                print("unable to run %s" % dispcmd) +                print(e) +            return None +    else:          if verbose: -            print("unable to run %s" % args[0]) -            print(e) +            print("unable to find command, tried %s" % (commands,))          return None      stdout = p.communicate()[0].strip() -    if sys.version >= '3': +    if sys.version_info[0] >= 3:          stdout = stdout.decode()      if p.returncode != 0:          if verbose: -            print("unable to run %s (error)" % args[0]) +            print("unable to run %s (error)" % dispcmd)          return None      return stdout -def get_expanded_variables(versionfile_source): +def versions_from_parentdir(parentdir_prefix, root, verbose): +    """Try to determine the version from the parent directory name. + +    Source tarballs conventionally unpack into a directory that includes +    both the project name and a version string. +    """ +    dirname = os.path.basename(root) +    if not dirname.startswith(parentdir_prefix): +        if verbose: +            print("guessing rootdir is '%s', but '%s' doesn't start with " +                  "prefix '%s'" % (root, dirname, parentdir_prefix)) +        raise NotThisMethod("rootdir doesn't start with parentdir_prefix") +    return {"version": dirname[len(parentdir_prefix):], +            "full-revisionid": None, +            "dirty": False, "error": None} + + +@register_vcs_handler("git", "get_keywords") +def git_get_keywords(versionfile_abs): +    """Extract version information from the given file."""      # the code embedded in _version.py can just fetch the value of these -    # variables. When used from setup.py, we don't want to import -    # _version.py, so we do it with a regexp instead. This function is not -    # used from _version.py. -    variables = {} +    # keywords. When used from setup.py, we don't want to import _version.py, +    # so we do it with a regexp instead. This function is not used from +    # _version.py. +    keywords = {}      try: -        f = open(versionfile_source, "r") +        f = open(versionfile_abs, "r")          for line in f.readlines():              if line.strip().startswith("git_refnames ="):                  mo = re.search(r'=\s*"(.*)"', line)                  if mo: -                    variables["refnames"] = mo.group(1) +                    keywords["refnames"] = mo.group(1)              if line.strip().startswith("git_full ="):                  mo = re.search(r'=\s*"(.*)"', line)                  if mo: -                    variables["full"] = mo.group(1) +                    keywords["full"] = mo.group(1)          f.close()      except EnvironmentError:          pass -    return variables +    return keywords -def versions_from_expanded_variables(variables, tag_prefix, verbose=False): -    refnames = variables["refnames"].strip() +@register_vcs_handler("git", "keywords") +def git_versions_from_keywords(keywords, tag_prefix, verbose): +    """Get version information from git keywords.""" +    if not keywords: +        raise NotThisMethod("no keywords at all, weird") +    refnames = keywords["refnames"].strip()      if refnames.startswith("$Format"):          if verbose: -            print("variables are unexpanded, not using") -        return {}  # unexpanded, so not in an unpacked git-archive tarball +            print("keywords are unexpanded, not using") +        raise NotThisMethod("unexpanded keywords, not a git-archive tarball")      refs = set([r.strip() for r in refnames.strip("()").split(",")])      # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of      # just "foo-1.0". If we see a "tag: " prefix, prefer those. @@ -83,7 +167,7 @@ def versions_from_expanded_variables(variables, tag_prefix, verbose=False):          # "stabilization", as well as "HEAD" and "master".          tags = set([r for r in refs if re.search(r'\d', r)])          if verbose: -            print("discarding '%s', no digits" % ",".join(refs - tags)) +            print("discarding '%s', no digits" % ",".join(refs-tags))      if verbose:          print("likely tags: %s" % ",".join(sorted(tags)))      for ref in sorted(tags): @@ -93,111 +177,308 @@ def versions_from_expanded_variables(variables, tag_prefix, verbose=False):              if verbose:                  print("picking %s" % r)              return {"version": r, -                    "full": variables["full"].strip()} -    # no suitable tags, so we use the full revision id +                    "full-revisionid": keywords["full"].strip(), +                    "dirty": False, "error": None +                    } +    # no suitable tags, so version is "0+unknown", but full hex is still there      if verbose: -        print("no suitable tags, using full revision id") -    return {"version": variables["full"].strip(), -            "full": variables["full"].strip()} - - -def versions_from_vcs(tag_prefix, versionfile_source, verbose=False): -    # this runs 'git' from the root of the source tree. That either means -    # someone ran a setup.py command (and this code is in versioneer.py, so -    # IN_LONG_VERSION_PY=False, thus the containing directory is the root of -    # the source tree), or someone ran a project-specific entry point (and -    # this code is in _version.py, so IN_LONG_VERSION_PY=True, thus the -    # containing directory is somewhere deeper in the source tree). This only -    # gets called if the git-archive 'subst' variables were *not* expanded, -    # and _version.py hasn't already been rewritten with a short version -    # string, meaning we're inside a checked out source tree. +        print("no suitable tags, using unknown + full revision id") +    return {"version": "0+unknown", +            "full-revisionid": keywords["full"].strip(), +            "dirty": False, "error": "no suitable tags"} -    try: -        here = os.path.abspath(__file__) -    except NameError: -        # some py2exe/bbfreeze/non-CPython implementations don't do __file__ -        return {}  # not always correct - -    # versionfile_source is the relative path from the top of the source tree -    # (where the .git directory might live) to this file. Invert this to find -    # the root from __file__. -    root = here -    if IN_LONG_VERSION_PY: -        for i in range(len(versionfile_source.split("/"))): -            root = os.path.dirname(root) -    else: -        root = os.path.dirname(here) + +@register_vcs_handler("git", "pieces_from_vcs") +def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): +    """Get version from 'git describe' in the root of the source tree. + +    This only gets called if the git-archive 'subst' keywords were *not* +    expanded, and _version.py hasn't already been rewritten with a short +    version string, meaning we're inside a checked out source tree. +    """      if not os.path.exists(os.path.join(root, ".git")):          if verbose:              print("no .git in %s" % root) -        return {} +        raise NotThisMethod("no .git directory") -    GIT = "git" +    GITS = ["git"]      if sys.platform == "win32": -        GIT = "git.cmd" -    stdout = run_command([GIT, "describe", "--tags", "--dirty", "--always"], -                         cwd=root) -    if stdout is None: -        return {} -    if not stdout.startswith(tag_prefix): -        if verbose: -            print("tag '%s' doesn't start with prefix '%s'" % (stdout, tag_prefix)) -        return {} -    tag = stdout[len(tag_prefix):] -    stdout = run_command([GIT, "rev-parse", "HEAD"], cwd=root) -    if stdout is None: -        return {} -    full = stdout.strip() -    if tag.endswith("-dirty"): -        full += "-dirty" -    return {"version": tag, "full": full} - - -def versions_from_parentdir(parentdir_prefix, versionfile_source, verbose=False): -    if IN_LONG_VERSION_PY: -        # We're running from _version.py. If it's from a source tree -        # (execute-in-place), we can work upwards to find the root of the -        # tree, and then check the parent directory for a version string. If -        # it's in an installed application, there's no hope. -        try: -            here = os.path.abspath(__file__) -        except NameError: -            # py2exe/bbfreeze/non-CPython don't have __file__ -            return {}  # without __file__, we have no hope +        GITS = ["git.cmd", "git.exe"] +    # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] +    # if there isn't one, this yields HEX[-dirty] (no NUM) +    describe_out = run_command(GITS, ["describe", "--tags", "--dirty", +                                      "--always", "--long", +                                      "--match", "%s*" % tag_prefix], +                               cwd=root) +    # --long was added in git-1.5.5 +    if describe_out is None: +        raise NotThisMethod("'git describe' failed") +    describe_out = describe_out.strip() +    full_out = run_command(GITS, ["rev-parse", "HEAD"], cwd=root) +    if full_out is None: +        raise NotThisMethod("'git rev-parse' failed") +    full_out = full_out.strip() + +    pieces = {} +    pieces["long"] = full_out +    pieces["short"] = full_out[:7]  # maybe improved later +    pieces["error"] = None + +    # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty] +    # TAG might have hyphens. +    git_describe = describe_out + +    # look for -dirty suffix +    dirty = git_describe.endswith("-dirty") +    pieces["dirty"] = dirty +    if dirty: +        git_describe = git_describe[:git_describe.rindex("-dirty")] + +    # now we have TAG-NUM-gHEX or HEX + +    if "-" in git_describe: +        # TAG-NUM-gHEX +        mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe) +        if not mo: +            # unparseable. Maybe git-describe is misbehaving? +            pieces["error"] = ("unable to parse git-describe output: '%s'" +                               % describe_out) +            return pieces + +        # tag +        full_tag = mo.group(1) +        if not full_tag.startswith(tag_prefix): +            if verbose: +                fmt = "tag '%s' doesn't start with prefix '%s'" +                print(fmt % (full_tag, tag_prefix)) +            pieces["error"] = ("tag '%s' doesn't start with prefix '%s'" +                               % (full_tag, tag_prefix)) +            return pieces +        pieces["closest-tag"] = full_tag[len(tag_prefix):] + +        # distance: number of commits since tag +        pieces["distance"] = int(mo.group(2)) + +        # commit: short hex revision ID +        pieces["short"] = mo.group(3) + +    else: +        # HEX: no tags +        pieces["closest-tag"] = None +        count_out = run_command(GITS, ["rev-list", "HEAD", "--count"], +                                cwd=root) +        pieces["distance"] = int(count_out)  # total number of commits + +    return pieces + + +def plus_or_dot(pieces): +    """Return a + if we don't already have one, else return a .""" +    if "+" in pieces.get("closest-tag", ""): +        return "." +    return "+" + + +def render_pep440(pieces): +    """Build up version string, with post-release "local version identifier". + +    Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you +    get a tagged build and then dirty it, you'll get TAG+0.gHEX.dirty + +    Exceptions: +    1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty] +    """ +    if pieces["closest-tag"]: +        rendered = pieces["closest-tag"] +        if pieces["distance"] or pieces["dirty"]: +            rendered += plus_or_dot(pieces) +            rendered += "%d.g%s" % (pieces["distance"], pieces["short"]) +            if pieces["dirty"]: +                rendered += ".dirty" +    else: +        # exception #1 +        rendered = "0+untagged.%d.g%s" % (pieces["distance"], +                                          pieces["short"]) +        if pieces["dirty"]: +            rendered += ".dirty" +    return rendered + + +def render_pep440_pre(pieces): +    """TAG[.post.devDISTANCE] -- No -dirty. + +    Exceptions: +    1: no tags. 0.post.devDISTANCE +    """ +    if pieces["closest-tag"]: +        rendered = pieces["closest-tag"] +        if pieces["distance"]: +            rendered += ".post.dev%d" % pieces["distance"] +    else: +        # exception #1 +        rendered = "0.post.dev%d" % pieces["distance"] +    return rendered + + +def render_pep440_post(pieces): +    """TAG[.postDISTANCE[.dev0]+gHEX] . + +    The ".dev0" means dirty. Note that .dev0 sorts backwards +    (a dirty tree will appear "older" than the corresponding clean one), +    but you shouldn't be releasing software with -dirty anyways. + +    Exceptions: +    1: no tags. 0.postDISTANCE[.dev0] +    """ +    if pieces["closest-tag"]: +        rendered = pieces["closest-tag"] +        if pieces["distance"] or pieces["dirty"]: +            rendered += ".post%d" % pieces["distance"] +            if pieces["dirty"]: +                rendered += ".dev0" +            rendered += plus_or_dot(pieces) +            rendered += "g%s" % pieces["short"] +    else: +        # exception #1 +        rendered = "0.post%d" % pieces["distance"] +        if pieces["dirty"]: +            rendered += ".dev0" +        rendered += "+g%s" % pieces["short"] +    return rendered + + +def render_pep440_old(pieces): +    """TAG[.postDISTANCE[.dev0]] . + +    The ".dev0" means dirty. + +    Eexceptions: +    1: no tags. 0.postDISTANCE[.dev0] +    """ +    if pieces["closest-tag"]: +        rendered = pieces["closest-tag"] +        if pieces["distance"] or pieces["dirty"]: +            rendered += ".post%d" % pieces["distance"] +            if pieces["dirty"]: +                rendered += ".dev0" +    else: +        # exception #1 +        rendered = "0.post%d" % pieces["distance"] +        if pieces["dirty"]: +            rendered += ".dev0" +    return rendered + + +def render_git_describe(pieces): +    """TAG[-DISTANCE-gHEX][-dirty]. + +    Like 'git describe --tags --dirty --always'. + +    Exceptions: +    1: no tags. HEX[-dirty]  (note: no 'g' prefix) +    """ +    if pieces["closest-tag"]: +        rendered = pieces["closest-tag"] +        if pieces["distance"]: +            rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) +    else: +        # exception #1 +        rendered = pieces["short"] +    if pieces["dirty"]: +        rendered += "-dirty" +    return rendered + + +def render_git_describe_long(pieces): +    """TAG-DISTANCE-gHEX[-dirty]. + +    Like 'git describe --tags --dirty --always -long'. +    The distance/hash is unconditional. + +    Exceptions: +    1: no tags. HEX[-dirty]  (note: no 'g' prefix) +    """ +    if pieces["closest-tag"]: +        rendered = pieces["closest-tag"] +        rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) +    else: +        # exception #1 +        rendered = pieces["short"] +    if pieces["dirty"]: +        rendered += "-dirty" +    return rendered + + +def render(pieces, style): +    """Render the given version pieces into the requested style.""" +    if pieces["error"]: +        return {"version": "unknown", +                "full-revisionid": pieces.get("long"), +                "dirty": None, +                "error": pieces["error"]} + +    if not style or style == "default": +        style = "pep440"  # the default + +    if style == "pep440": +        rendered = render_pep440(pieces) +    elif style == "pep440-pre": +        rendered = render_pep440_pre(pieces) +    elif style == "pep440-post": +        rendered = render_pep440_post(pieces) +    elif style == "pep440-old": +        rendered = render_pep440_old(pieces) +    elif style == "git-describe": +        rendered = render_git_describe(pieces) +    elif style == "git-describe-long": +        rendered = render_git_describe_long(pieces) +    else: +        raise ValueError("unknown style '%s'" % style) + +    return {"version": rendered, "full-revisionid": pieces["long"], +            "dirty": pieces["dirty"], "error": None} + + +def get_versions(): +    """Get version information or return default if unable to do so.""" +    # I am in _version.py, which lives at ROOT/VERSIONFILE_SOURCE. If we have +    # __file__, we can work backwards from there to the root. Some +    # py2exe/bbfreeze/non-CPython implementations don't do __file__, in which +    # case we can only use expanded keywords. + +    cfg = get_config() +    verbose = cfg.verbose + +    try: +        return git_versions_from_keywords(get_keywords(), cfg.tag_prefix, +                                          verbose) +    except NotThisMethod: +        pass + +    try: +        root = os.path.realpath(__file__)          # versionfile_source is the relative path from the top of the source -        # tree to _version.py. Invert this to find the root from __file__. -        root = here -        for i in range(len(versionfile_source.split("/"))): +        # tree (where the .git directory might live) to this file. Invert +        # this to find the root from __file__. +        for i in cfg.versionfile_source.split('/'):              root = os.path.dirname(root) -    else: -        # we're running from versioneer.py, which means we're running from -        # the setup.py in a source tree. sys.argv[0] is setup.py in the root. -        here = os.path.abspath(sys.argv[0]) -        root = os.path.dirname(here) +    except NameError: +        return {"version": "0+unknown", "full-revisionid": None, +                "dirty": None, +                "error": "unable to find root of source tree"} -    # Source tarballs conventionally unpack into a directory that includes -    # both the project name and a version string. -    dirname = os.path.basename(root) -    if not dirname.startswith(parentdir_prefix): -        if verbose: -            print("guessing rootdir is '%s', but '%s' doesn't start with prefix '%s'" % -                  (root, dirname, parentdir_prefix)) -        return None -    return {"version": dirname[len(parentdir_prefix):], "full": ""} - -tag_prefix = "" -parentdir_prefix = "leap.common-" -versionfile_source = "src/leap/common/_version.py" - - -def get_versions(default={"version": "unknown", "full": ""}, verbose=False): -    variables = {"refnames": git_refnames, "full": git_full} -    ver = versions_from_expanded_variables(variables, tag_prefix, verbose) -    if not ver: -        ver = versions_from_vcs(tag_prefix, versionfile_source, verbose) -    if not ver: -        ver = versions_from_parentdir(parentdir_prefix, versionfile_source, -                                      verbose) -    if not ver: -        ver = default -    return ver +    try: +        pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose) +        return render(pieces, cfg.style) +    except NotThisMethod: +        pass + +    try: +        if cfg.parentdir_prefix: +            return versions_from_parentdir(cfg.parentdir_prefix, root, verbose) +    except NotThisMethod: +        pass + +    return {"version": "0+unknown", "full-revisionid": None, +            "dirty": None, +            "error": "unable to compute version"} diff --git a/src/leap/common/certs.py b/src/leap/common/certs.py index c49015a..95704a6 100644 --- a/src/leap/common/certs.py +++ b/src/leap/common/certs.py @@ -192,8 +192,8 @@ def get_compatible_ssl_context_factory(cert_path=None):          class WebClientContextFactory(ssl.ClientContextFactory):              """ -            A web context factory which ignores the hostname and port and does no -            certificate verification. +            A web context factory which ignores the hostname and port and does +            no certificate verification.              """              def getContext(self, hostname, port):                  return ssl.ClientContextFactory.getContext(self) diff --git a/src/leap/common/events/auth.py b/src/leap/common/events/auth.py new file mode 100644 index 0000000..db217ca --- /dev/null +++ b/src/leap/common/events/auth.py @@ -0,0 +1,100 @@ +# -*- coding: utf-8 -*- +# auth.py +# Copyright (C) 2016 LEAP +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program.  If not, see <http://www.gnu.org/licenses/>. +""" +ZAP authentication, twisted style. +""" +from zmq import PAIR +from zmq.auth.base import Authenticator, VERSION +from txzmq.connection import ZmqConnection +from zmq.utils.strtypes import b, u + +from twisted.python import log + +from txzmq.connection import ZmqEndpoint, ZmqEndpointType + + +class TxAuthenticator(ZmqConnection): + +    """ +    This does not implement the whole ZAP protocol, but the bare minimum that +    we need. +    """ + +    socketType = PAIR +    address = 'inproc://zeromq.zap.01' +    encoding = 'utf-8' + +    def __init__(self, factory, *args, **kw): +        super(TxAuthenticator, self).__init__(factory, *args, **kw) +        self.authenticator = Authenticator(factory.context) +        self.authenticator._send_zap_reply = self._send_zap_reply + +    def start(self): +        endpoint = ZmqEndpoint(ZmqEndpointType.bind, self.address) +        self.addEndpoints([endpoint]) + +    def messageReceived(self, msg): + +        command = msg[0] + +        if command == b'ALLOW': +            addresses = [u(m, self.encoding) for m in msg[1:]] +            try: +                self.authenticator.allow(*addresses) +            except Exception as e: +                log.err("Failed to allow %s", addresses) + +        elif command == b'CURVE': +            domain = u(msg[1], self.encoding) +            location = u(msg[2], self.encoding) +            self.authenticator.configure_curve(domain, location) + +    def _send_zap_reply(self, request_id, status_code, status_text, +                        user_id='user'): +        """ +        Send a ZAP reply to finish the authentication. +        """ +        user_id = user_id if status_code == b'200' else b'' +        if isinstance(user_id, unicode): +            user_id = user_id.encode(self.encoding, 'replace') +        metadata = b''  # not currently used +        reply = [VERSION, request_id, status_code, status_text, +                 user_id, metadata] +        self.send(reply) + +    def shutdown(self): +        if self.factory: +            super(TxAuthenticator, self).shutdown() + + +class TxAuthenticationRequest(ZmqConnection): + +    socketType = PAIR +    address = 'inproc://zeromq.zap.01' +    encoding = 'utf-8' + +    def start(self): +        endpoint = ZmqEndpoint(ZmqEndpointType.connect, self.address) +        self.addEndpoints([endpoint]) + +    def allow(self, *addresses): +        self.send([b'ALLOW'] + [b(a, self.encoding) for a in addresses]) + +    def configure_curve(self, domain='*', location=''): +        domain = b(domain, self.encoding) +        location = b(location, self.encoding) +        self.send([b'CURVE', domain, location]) diff --git a/src/leap/common/events/catalog.py b/src/leap/common/events/catalog.py index 8bddd2c..9a834b2 100644 --- a/src/leap/common/events/catalog.py +++ b/src/leap/common/events/catalog.py @@ -24,49 +24,54 @@ Events catalog.  EVENTS = [      "CLIENT_SESSION_ID",      "CLIENT_UID", -    "IMAP_CLIENT_LOGIN", -    "IMAP_SERVICE_FAILED_TO_START", -    "IMAP_SERVICE_STARTED", -    "IMAP_UNHANDLED_ERROR", -    "KEYMANAGER_DONE_UPLOADING_KEYS", -    "KEYMANAGER_FINISHED_KEY_GENERATION", -    "KEYMANAGER_KEY_FOUND", -    "KEYMANAGER_KEY_NOT_FOUND", -    "KEYMANAGER_LOOKING_FOR_KEY", -    "KEYMANAGER_STARTED_KEY_GENERATION", -    "MAIL_FETCHED_INCOMING", -    "MAIL_MSG_DECRYPTED", -    "MAIL_MSG_DELETED_INCOMING", -    "MAIL_MSG_PROCESSING", -    "MAIL_MSG_SAVED_LOCALLY", -    "MAIL_UNREAD_MESSAGES",      "RAISE_WINDOW", -    "SMTP_CONNECTION_LOST", -    "SMTP_END_ENCRYPT_AND_SIGN", -    "SMTP_END_SIGN", -    "SMTP_RECIPIENT_ACCEPTED_ENCRYPTED", -    "SMTP_RECIPIENT_ACCEPTED_UNENCRYPTED", -    "SMTP_RECIPIENT_REJECTED", -    "SMTP_SEND_MESSAGE_ERROR", -    "SMTP_SEND_MESSAGE_START", -    "SMTP_SEND_MESSAGE_SUCCESS", -    "SMTP_SERVICE_FAILED_TO_START", -    "SMTP_SERVICE_STARTED", -    "SMTP_START_ENCRYPT_AND_SIGN", -    "SMTP_START_SIGN", -    "SOLEDAD_CREATING_KEYS", -    "SOLEDAD_DONE_CREATING_KEYS", -    "SOLEDAD_DONE_DATA_SYNC", -    "SOLEDAD_DONE_DOWNLOADING_KEYS", -    "SOLEDAD_DONE_UPLOADING_KEYS", -    "SOLEDAD_DOWNLOADING_KEYS", -    "SOLEDAD_INVALID_AUTH_TOKEN", -    "SOLEDAD_NEW_DATA_TO_SYNC", -    "SOLEDAD_SYNC_RECEIVE_STATUS", -    "SOLEDAD_SYNC_SEND_STATUS", -    "SOLEDAD_UPLOADING_KEYS",      "UPDATER_DONE_UPDATING",      "UPDATER_NEW_UPDATES", + +    "KEYMANAGER_DONE_UPLOADING_KEYS",  # (address) +    "KEYMANAGER_FINISHED_KEY_GENERATION",  # (address) +    "KEYMANAGER_KEY_FOUND",  # (address) +    "KEYMANAGER_KEY_NOT_FOUND",  # (address) +    "KEYMANAGER_LOOKING_FOR_KEY",  # (address) +    "KEYMANAGER_STARTED_KEY_GENERATION",  # (address) + +    "SOLEDAD_CREATING_KEYS",  # {uuid, userid} +    "SOLEDAD_DONE_CREATING_KEYS",  # {uuid, userid} +    "SOLEDAD_DONE_DATA_SYNC",  # {uuid, userid} +    "SOLEDAD_DONE_DOWNLOADING_KEYS",  # {uuid, userid} +    "SOLEDAD_DONE_UPLOADING_KEYS",  # {uuid, userid} +    "SOLEDAD_DOWNLOADING_KEYS",  # {uuid, userid} +    "SOLEDAD_INVALID_AUTH_TOKEN",  # {uuid, userid} +    "SOLEDAD_SYNC_RECEIVE_STATUS",  # {uuid, userid} +    "SOLEDAD_SYNC_SEND_STATUS",  # {uuid, userid} +    "SOLEDAD_UPLOADING_KEYS",  # {uuid, userid} +    "SOLEDAD_NEW_DATA_TO_SYNC", + +    "MAIL_FETCHED_INCOMING",  # (userid) +    "MAIL_MSG_DECRYPTED",  # (userid) +    "MAIL_MSG_DELETED_INCOMING",  # (userid) +    "MAIL_MSG_PROCESSING",  # (userid) +    "MAIL_MSG_SAVED_LOCALLY",  # (userid) +    "MAIL_UNREAD_MESSAGES",  # (userid, number) + +    "IMAP_SERVICE_STARTED", +    "IMAP_SERVICE_FAILED_TO_START", +    "IMAP_UNHANDLED_ERROR", +    "IMAP_CLIENT_LOGIN",  # (username) + +    "SMTP_SERVICE_STARTED", +    "SMTP_SERVICE_FAILED_TO_START", +    "SMTP_START_ENCRYPT_AND_SIGN",  # (from_addr) +    "SMTP_END_ENCRYPT_AND_SIGN",  # (from_addr) +    "SMTP_START_SIGN",  # (from_addr) +    "SMTP_END_SIGN",  # (from_addr) +    "SMTP_SEND_MESSAGE_START",  # (from_addr) +    "SMTP_SEND_MESSAGE_SUCCESS",  # (from_addr) +    "SMTP_RECIPIENT_ACCEPTED_ENCRYPTED",  # (userid, dest) +    "SMTP_RECIPIENT_ACCEPTED_UNENCRYPTED",  # (userid, dest) +    "SMTP_CONNECTION_LOST",  # (userid, dest) +    "SMTP_RECIPIENT_REJECTED",  # (userid, dest) +    "SMTP_SEND_MESSAGE_ERROR",  # (userid, dest)  ] diff --git a/src/leap/common/events/client.py b/src/leap/common/events/client.py index 60d24bc..78617de 100644 --- a/src/leap/common/events/client.py +++ b/src/leap/common/events/client.py @@ -63,14 +63,18 @@ logger = logging.getLogger(__name__)  _emit_addr = EMIT_ADDR  _reg_addr = REG_ADDR +_factory = None +_enable_curve = True -def configure_client(emit_addr, reg_addr): -    global _emit_addr, _reg_addr +def configure_client(emit_addr, reg_addr, factory=None, enable_curve=True): +    global _emit_addr, _reg_addr, _factory, _enable_curve      logger.debug("Configuring client with addresses: (%s, %s)" %                   (emit_addr, reg_addr))      _emit_addr = emit_addr      _reg_addr = reg_addr +    _factory = factory +    _enable_curve = enable_curve  class EventsClient(object): @@ -103,7 +107,9 @@ class EventsClient(object):          """          with cls._instance_lock:              if cls._instance is None: -                cls._instance = cls(_emit_addr, _reg_addr) +                cls._instance = cls( +                    _emit_addr, _reg_addr, factory=_factory, +                    enable_curve=_enable_curve)          return cls._instance      def register(self, event, callback, uid=None, replace=False): @@ -270,7 +276,7 @@ class EventsClientThread(threading.Thread, EventsClient):      A threaded version of the events client.      """ -    def __init__(self, emit_addr, reg_addr): +    def __init__(self, emit_addr, reg_addr, factory=None, enable_curve=True):          """          Initialize the events client.          """ @@ -281,15 +287,22 @@ class EventsClientThread(threading.Thread, EventsClient):          self._config_prefix = os.path.join(              get_path_prefix(flags.STANDALONE), "leap", "events")          self._loop = None +        self._factory = factory          self._context = None          self._push = None          self._sub = None +        if enable_curve: +            self.use_curve = zmq_has_curve() +        else: +            self.use_curve = False +      def _init_zmq(self):          """          Initialize ZMQ connections.          """          self._loop = EventsIOLoop() +        # we need a new context for each thread          self._context = zmq.Context()          # connect SUB first, otherwise we might miss some event sent from this          # same client @@ -311,7 +324,7 @@ class EventsClientThread(threading.Thread, EventsClient):          logger.debug("Connecting %s to %s." % (socktype, address))          socket = self._context.socket(socktype)          # configure curve authentication -        if zmq_has_curve(): +        if self.use_curve:              public, private = maybe_create_and_get_certificates(                  self._config_prefix, "client")              server_public_file = os.path.join( diff --git a/src/leap/common/events/examples/README.txt b/src/leap/common/events/examples/README.txt new file mode 100644 index 0000000..0bb0df6 --- /dev/null +++ b/src/leap/common/events/examples/README.txt @@ -0,0 +1,49 @@ +How to debug +----------------------------------------- +monitor the events socket: +  sudo ngrep -W byline -d any port 9000 + +launch the server: +  python server.py + +launch the client: +  python client.py + +if zmq is available and enabled, you should see encrypted messages passing by +the socket. + +You should see something like the following: + +#### +T 127.0.0.1:9000 -> 127.0.0.1:33122 [AP] +.......... +## +T 127.0.0.1:33122 -> 127.0.0.1:9000 [AP] +........... +## +T 127.0.0.1:9000 -> 127.0.0.1:33122 [AP] +..CURVE............................................... +# +T 127.0.0.1:33122 -> 127.0.0.1:9000 [AP] +.CURVE............................................... +# +T 127.0.0.1:33122 -> 127.0.0.1:9000 [AP] +...HELLO.............................................................................:....^...".....'.S...n......Y...................O.7.+.D.q".*..R...j.....8..qu..~......Ck.G\....:...m....Tg.s..M..x<.. +## +T 127.0.0.1:9000 -> 127.0.0.1:33122 [AP] +...WELCOME..%.'.,Td... I..}...........`..Nm......./_.Je...4.....-.....f<v.|.".jJ...^.D...$lJ..U......g..../w.......\..W.....!........i.v....0...........3..a.5}.@F..v./..$ +# +T 127.0.0.1:33122 -> 127.0.0.1:9000 [AP] +..........INITIATE......!.*.=0.-......D..]{...A\.tz...!2.....A./ +6.......Y.h.N....cb.U.|..f..)....W..3..X.2U.3PGl.........m..95.(......NJ....5.'..W.GQ..B/.....\%.,Q..r.'L5.......{.W<=._.$.(6j.G... +...37.H..Th...'.........0 ........,..q....U..G..M.`!_..w....f.".......... +.d.K.Y.>f.n.kV. +# +T 127.0.0.1:9000 -> 127.0.0.1:33122 [AP] +.2.READY............A...e.)......*.8y....k.<.N1Z.4.. +# +T 127.0.0.1:33122 -> 127.0.0.1:9000 [AP] +.+.MESSAGE........o...*M..,.... +.r..w..[.GwcU +### + diff --git a/src/leap/common/events/examples/client.py b/src/leap/common/events/examples/client.py new file mode 100644 index 0000000..d6d8985 --- /dev/null +++ b/src/leap/common/events/examples/client.py @@ -0,0 +1,2 @@ +from leap.common.events.txclient import emit +emit('stuff!') diff --git a/src/leap/common/events/examples/server.py b/src/leap/common/events/examples/server.py new file mode 100644 index 0000000..f40f8dc --- /dev/null +++ b/src/leap/common/events/examples/server.py @@ -0,0 +1,4 @@ +from twisted.internet import reactor +from leap.common.events.server import ensure_server +reactor.callWhenRunning(ensure_server) +reactor.run() diff --git a/src/leap/common/events/server.py b/src/leap/common/events/server.py index a69202e..05fc23e 100644 --- a/src/leap/common/events/server.py +++ b/src/leap/common/events/server.py @@ -14,33 +14,31 @@  #  # You should have received a copy of the GNU General Public License  # along with this program. If not, see <http://www.gnu.org/licenses/>. - -  """  The server for the events mechanism.  """ - -  import logging +import platform +  import txzmq  from leap.common.zmq_utils import zmq_has_curve -  from leap.common.events.zmq_components import TxZmqServerComponent -if zmq_has_curve(): +if zmq_has_curve() or platform.system() == "Windows": +    # Windows doesn't have ipc sockets, we need to use always tcp      EMIT_ADDR = "tcp://127.0.0.1:9000"      REG_ADDR = "tcp://127.0.0.1:9001"  else:      EMIT_ADDR = "ipc:///tmp/leap.common.events.socket.0"      REG_ADDR = "ipc:///tmp/leap.common.events.socket.1" -  logger = logging.getLogger(__name__) -def ensure_server(emit_addr=EMIT_ADDR, reg_addr=REG_ADDR): +def ensure_server(emit_addr=EMIT_ADDR, reg_addr=REG_ADDR, path_prefix=None, +                  factory=None, enable_curve=True):      """      Make sure the server is running in the given addresses. @@ -52,7 +50,8 @@ def ensure_server(emit_addr=EMIT_ADDR, reg_addr=REG_ADDR):      :return: an events server instance      :rtype: EventsServer      """ -    _server = EventsServer(emit_addr, reg_addr) +    _server = EventsServer(emit_addr, reg_addr, path_prefix, factory=factory, +                           enable_curve=enable_curve)      return _server @@ -62,7 +61,8 @@ class EventsServer(TxZmqServerComponent):      events in another address.      """ -    def __init__(self, emit_addr, reg_addr): +    def __init__(self, emit_addr, reg_addr, path_prefix=None, factory=None, +                 enable_curve=True):          """          Initialize the events server. @@ -71,7 +71,9 @@ class EventsServer(TxZmqServerComponent):          :param reg_addr: The address to which publish events to clients.          :type reg_addr: str          """ -        TxZmqServerComponent.__init__(self) +        TxZmqServerComponent.__init__(self, path_prefix=path_prefix, +                                      factory=factory, +                                      enable_curve=enable_curve)          # bind PULL and PUB sockets          self._pull, self.pull_port = self._zmq_bind(              txzmq.ZmqPullConnection, emit_addr) diff --git a/src/leap/common/events/tests/test_auth.py b/src/leap/common/events/tests/test_auth.py new file mode 100644 index 0000000..78ffd9f --- /dev/null +++ b/src/leap/common/events/tests/test_auth.py @@ -0,0 +1,64 @@ +# -*- coding: utf-8 -*- +# test_zmq_components.py +# Copyright (C) 2014 LEAP +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program.  If not, see <http://www.gnu.org/licenses/>. +""" +Tests for the auth module. +""" +import os + +from twisted.trial import unittest +from txzmq import ZmqFactory + +from leap.common.events import auth +from leap.common.testing.basetest import BaseLeapTest +from leap.common.zmq_utils import PUBLIC_KEYS_PREFIX +from leap.common.zmq_utils import maybe_create_and_get_certificates + +from txzmq.test import _wait + + +class ZmqAuthTestCase(unittest.TestCase, BaseLeapTest): + +    def setUp(self): +        self.setUpEnv(launch_events_server=False) + +        self.factory = ZmqFactory() +        self._config_prefix = os.path.join(self.tempdir, "leap", "events") + +        self.public, self.secret = maybe_create_and_get_certificates( +            self._config_prefix, 'server') + +        self.authenticator = auth.TxAuthenticator(self.factory) +        self.authenticator.start() +        self.auth_req = auth.TxAuthenticationRequest(self.factory) + +    def tearDown(self): +        self.factory.shutdown() +        self.tearDownEnv() + +    def test_curve_auth(self): +        self.auth_req.start() +        self.auth_req.allow('127.0.0.1') +        public_keys_dir = os.path.join(self._config_prefix, PUBLIC_KEYS_PREFIX) +        self.auth_req.configure_curve(domain="*", location=public_keys_dir) + +        def check(ignored): +            authenticator = self.authenticator.authenticator +            certs = authenticator.certs['*'] +            self.failUnlessEqual(authenticator.whitelist, set([u'127.0.0.1'])) +            self.failUnlessEqual(certs[certs.keys()[0]], True) + +        return _wait(0.1).addCallback(check) diff --git a/src/leap/common/tests/test_events.py b/src/leap/common/events/tests/test_events.py index 2ad097e..d8435c6 100644 --- a/src/leap/common/tests/test_events.py +++ b/src/leap/common/events/tests/test_events.py @@ -14,16 +14,18 @@  #  # You should have received a copy of the GNU General Public License  # along with this program. If not, see <http://www.gnu.org/licenses/>. - - +""" +Tests for the events framework +"""  import os  import logging -import time  from twisted.internet.reactor import callFromThread  from twisted.trial import unittest  from twisted.internet import defer +from txzmq import ZmqFactory +  from leap.common.events import server  from leap.common.events import client  from leap.common.events import flags @@ -40,19 +42,22 @@ class EventsGenericClientTestCase(object):      def setUp(self):          flags.set_events_enabled(True) +        self.factory = ZmqFactory()          self._server = server.ensure_server(              emit_addr="tcp://127.0.0.1:0", -            reg_addr="tcp://127.0.0.1:0") +            reg_addr="tcp://127.0.0.1:0", +            factory=self.factory, +            enable_curve=False) +          self._client.configure_client(              emit_addr="tcp://127.0.0.1:%d" % self._server.pull_port, -            reg_addr="tcp://127.0.0.1:%d" % self._server.pub_port) +            reg_addr="tcp://127.0.0.1:%d" % self._server.pub_port, +            factory=self.factory, enable_curve=False)      def tearDown(self): -        self._client.shutdown() -        self._server.shutdown()          flags.set_events_enabled(False) -        # wait a bit for sockets to close properly -        time.sleep(0.1) +        self.factory.shutdown() +        self._client.instance().reset()      def test_client_register(self):          """ diff --git a/src/leap/common/events/txclient.py b/src/leap/common/events/txclient.py index dfd0533..63f12d7 100644 --- a/src/leap/common/events/txclient.py +++ b/src/leap/common/events/txclient.py @@ -58,16 +58,19 @@ class EventsTxClient(TxZmqClientComponent, EventsClient):      """      def __init__(self, emit_addr=EMIT_ADDR, reg_addr=REG_ADDR, -                 path_prefix=None): +                 path_prefix=None, factory=None, enable_curve=True):          """ -        Initialize the events server. +        Initialize the events client.          """ -        TxZmqClientComponent.__init__(self, path_prefix=path_prefix) +        TxZmqClientComponent.__init__( +            self, path_prefix=path_prefix, factory=factory, +            enable_curve=enable_curve)          EventsClient.__init__(self, emit_addr, reg_addr)          # connect SUB first, otherwise we might miss some event sent from this          # same client          self._sub = self._zmq_connect(txzmq.ZmqSubConnection, reg_addr)          self._sub.gotMessage = self._gotMessage +          self._push = self._zmq_connect(txzmq.ZmqPushConnection, emit_addr)      def _gotMessage(self, msg, tag): @@ -122,7 +125,6 @@ class EventsTxClient(TxZmqClientComponent, EventsClient):          callback(event, *content)      def shutdown(self): -        TxZmqClientComponent.shutdown(self)          EventsClient.shutdown(self) diff --git a/src/leap/common/events/zmq_components.py b/src/leap/common/events/zmq_components.py index 51de02c..c533a74 100644 --- a/src/leap/common/events/zmq_components.py +++ b/src/leap/common/events/zmq_components.py @@ -1,6 +1,6 @@  # -*- coding: utf-8 -*-  # zmq.py -# Copyright (C) 2015 LEAP +# Copyright (C) 2015, 2016 LEAP  #  # This program is free software: you can redistribute it and/or modify  # it under the terms of the GNU General Public License as published by @@ -14,60 +14,63 @@  #  # You should have received a copy of the GNU General Public License  # along with this program. If not, see <http://www.gnu.org/licenses/>. - -  """  The server for the events mechanism.  """ - -  import os  import logging  import txzmq  import re -import time  from abc import ABCMeta -# XXX some distros don't package libsodium, so we have to be prepared for -#     absence of zmq.auth  try:      import zmq.auth -    from zmq.auth.thread import ThreadAuthenticator +    from leap.common.events.auth import TxAuthenticator +    from leap.common.events.auth import TxAuthenticationRequest  except ImportError:      pass +from txzmq.connection import ZmqEndpoint, ZmqEndpointType +  from leap.common.config import flags, get_path_prefix  from leap.common.zmq_utils import zmq_has_curve  from leap.common.zmq_utils import maybe_create_and_get_certificates  from leap.common.zmq_utils import PUBLIC_KEYS_PREFIX -  logger = logging.getLogger(__name__) -  ADDRESS_RE = re.compile("^([a-z]+)://([^:]+):?(\d+)?$") +LOCALHOST_ALLOWED = '127.0.0.1' +  class TxZmqComponent(object):      """      A twisted-powered zmq events component.      """ +    _factory = txzmq.ZmqFactory() +    _factory.registerForShutdown() +    _auth = None      __metaclass__ = ABCMeta      _component_type = None -    def __init__(self, path_prefix=None): +    def __init__(self, path_prefix=None, enable_curve=True, factory=None):          """          Initialize the txzmq component.          """ -        self._factory = txzmq.ZmqFactory() -        self._factory.registerForShutdown()          if path_prefix is None:              path_prefix = get_path_prefix(flags.STANDALONE) +        if factory is not None: +            self._factory = factory          self._config_prefix = os.path.join(path_prefix, "leap", "events")          self._connections = [] +        if enable_curve: +            self.use_curve = zmq_has_curve() +        else: +            self.use_curve = False      @property      def component_type(self): @@ -77,105 +80,89 @@ class TxZmqComponent(object):                  "define a self._component_type!")          return self._component_type -    def _zmq_connect(self, connClass, address): +    def _zmq_bind(self, connClass, address):          """ -        Connect to an address. +        Bind to an address.          :param connClass: The connection class to be used.          :type connClass: txzmq.ZmqConnection -        :param address: The address to connect to. +        :param address: The address to bind to.          :type address: str -        :return: The binded connection. -        :rtype: txzmq.ZmqConnection +        :return: The binded connection and port. +        :rtype: (txzmq.ZmqConnection, int)          """ +        proto, addr, port = ADDRESS_RE.search(address).groups() + +        endpoint = ZmqEndpoint(ZmqEndpointType.bind, address)          connection = connClass(self._factory) -        # create and configure socket -        socket = connection.socket -        if zmq_has_curve(): + +        if self.use_curve: +            socket = connection.socket +              public, secret = maybe_create_and_get_certificates(                  self._config_prefix, self.component_type) -            server_public_file = os.path.join( -                self._config_prefix, PUBLIC_KEYS_PREFIX, "server.key") -            server_public, _ = zmq.auth.load_certificate(server_public_file)              socket.curve_publickey = public              socket.curve_secretkey = secret -            socket.curve_serverkey = server_public -        socket.connect(address) -        logger.debug("Connected %s to %s." % (connClass, address)) -        self._connections.append(connection) -        return connection +            self._start_authentication(connection.socket) -    def _zmq_bind(self, connClass, address): +        if proto == 'tcp' and int(port) == 0: +            connection.endpoints.extend([endpoint]) +            port = connection.socket.bind_to_random_port('tcp://%s' % addr) +        else: +            connection.addEndpoints([endpoint]) + +        return connection, int(port) + +    def _zmq_connect(self, connClass, address):          """ -        Bind to an address. +        Connect to an address.          :param connClass: The connection class to be used.          :type connClass: txzmq.ZmqConnection -        :param address: The address to bind to. +        :param address: The address to connect to.          :type address: str -        :return: The binded connection and port. -        :rtype: (txzmq.ZmqConnection, int) +        :return: The binded connection. +        :rtype: txzmq.ZmqConnection          """ +        endpoint = ZmqEndpoint(ZmqEndpointType.connect, address)          connection = connClass(self._factory) -        socket = connection.socket -        if zmq_has_curve(): + +        if self.use_curve: +            socket = connection.socket              public, secret = maybe_create_and_get_certificates(                  self._config_prefix, self.component_type) +            server_public_file = os.path.join( +                self._config_prefix, PUBLIC_KEYS_PREFIX, "server.key") + +            server_public, _ = zmq.auth.load_certificate(server_public_file)              socket.curve_publickey = public              socket.curve_secretkey = secret -            self._start_thread_auth(connection.socket) +            socket.curve_serverkey = server_public -        proto, addr, port = ADDRESS_RE.search(address).groups() +        connection.addEndpoints([endpoint]) +        return connection -        if proto == "tcp": -            if port is None or port is '0': -                params = proto, addr -                port = socket.bind_to_random_port("%s://%s" % params) -                logger.debug("Binded %s to %s://%s." % ((connClass,) + params)) -            else: -                params = proto, addr, int(port) -                socket.bind("%s://%s:%d" % params) -                logger.debug( -                    "Binded %s to %s://%s:%d." % ((connClass,) + params)) -        else: -            params = proto, addr -            socket.bind("%s://%s" % params) -            logger.debug( -                "Binded %s to %s://%s" % ((connClass,) + params)) -        self._connections.append(connection) -        return connection, port - -    def _start_thread_auth(self, socket): -        """ -        Start the zmq curve thread authenticator. +    def _start_authentication(self, socket): -        :param socket: The socket in which to configure the authenticator. -        :type socket: zmq.Socket -        """ -        authenticator = ThreadAuthenticator(self._factory.context) +        if not TxZmqComponent._auth: +            TxZmqComponent._auth = TxAuthenticator(self._factory) +            TxZmqComponent._auth.start() -        # Temporary fix until we understand what the problem is -        # See https://leap.se/code/issues/7536 -        time.sleep(0.5) +        auth_req = TxAuthenticationRequest(self._factory) +        auth_req.start() +        auth_req.allow(LOCALHOST_ALLOWED) -        authenticator.start() -        # XXX do not hardcode this here. -        authenticator.allow('127.0.0.1')          # tell authenticator to use the certificate in a directory          public_keys_dir = os.path.join(self._config_prefix, PUBLIC_KEYS_PREFIX) -        authenticator.configure_curve(domain="*", location=public_keys_dir) -        socket.curve_server = True  # must come before bind +        auth_req.configure_curve(domain="*", location=public_keys_dir) +        auth_req.shutdown() +        TxZmqComponent._auth.shutdown() -    def shutdown(self): -        """ -        Shutdown the component. -        """ -        logger.debug("Shutting down component %s." % str(self)) -        for conn in self._connections: -            conn.shutdown() -        self._factory.shutdown() +        # This has to be set before binding the socket, that's why this method +        # has to be called before addEndpoints() +        socket.curve_server = True  class TxZmqServerComponent(TxZmqComponent): diff --git a/src/leap/common/service_hooks.py b/src/leap/common/service_hooks.py new file mode 100644 index 0000000..96e95cc --- /dev/null +++ b/src/leap/common/service_hooks.py @@ -0,0 +1,75 @@ +# -*- coding: utf-8 -*- +# service_hooks.py +# Copyright (C) 2016 LEAP +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program.  If not, see <http://www.gnu.org/licenses/>. +""" +Hooks for service composition. +""" +from collections import defaultdict + +from twisted.application.service import IService, Service +from twisted.python import log + +from zope.interface import implementer + + +@implementer(IService) +class HookableService(Service): + +    """ +    This service allows for other services in a Twisted Service tree to be +    notified whenever a certain kind of hook is triggered. + +    During the service composition, one is expected to register +    a hook name with the name of the service that wants to react to the +    triggering of the hook. All the services, both hooked and listeners, should +    be registered against the same parent service. + +    Upon the hook being triggered, the method "hook_<name>" will be called with +    the passed data in the listener service. +    """ + +    def register_hook(self, name, listener): +        if not hasattr(self, 'event_listeners'): +            self.event_listeners = defaultdict(list) +        log.msg("Registering hook %s->%s" % (name, listener)) +        self.event_listeners[name].append(listener) + +    def trigger_hook(self, name, **data): + +        def react_to_hook(listener, name, **kw): +            try: +                getattr(listener, 'hook_' + name)(**kw) +            except AttributeError: +                raise RuntimeError( +                    "Tried to notify a hook, but the listener service class %s" +                    "has not defined the proper method" % listener.__class__) + +        if not hasattr(self, 'event_listeners'): +            self.event_listeners = defaultdict(list) +        listeners = self._get_listener_services(name) + +        for listener in listeners: +            react_to_hook(listener, name, **data) + +    def _get_sibling_service(self, name): +        return self.parent.getServiceNamed(name) + +    def _get_listener_services(self, hook): +        if hook in self.event_listeners: +            service_names = self.event_listeners[hook] +            services = [ +                self._get_sibling_service(name) for name in service_names] +            return services diff --git a/src/leap/common/testing/basetest.py b/src/leap/common/testing/basetest.py index 3d3cee0..2e84a25 100644 --- a/src/leap/common/testing/basetest.py +++ b/src/leap/common/testing/basetest.py @@ -52,7 +52,7 @@ class BaseLeapTest(unittest.TestCase):          cls.tearDownEnv()      @classmethod -    def setUpEnv(cls): +    def setUpEnv(cls, launch_events_server=True):          """          Sets up common facilities for testing this TestCase:          - custom PATH and HOME environmental variables @@ -72,14 +72,15 @@ class BaseLeapTest(unittest.TestCase):          os.environ["PATH"] = bin_tdir          os.environ["HOME"] = cls.tempdir          os.environ["XDG_CONFIG_HOME"] = os.path.join(cls.tempdir, ".config") -        cls._init_events() +        if launch_events_server: +            cls._init_events()      @classmethod      def _init_events(cls):          if flags.EVENTS_ENABLED:              cls._server = events_server.ensure_server( -                emit_addr="tcp://127.0.0.1:0", -                reg_addr="tcp://127.0.0.1:0") +                emit_addr="tcp://127.0.0.1", +                reg_addr="tcp://127.0.0.1")              events_client.configure_client(                  emit_addr="tcp://127.0.0.1:%d" % cls._server.pull_port,                  reg_addr="tcp://127.0.0.1:%d" % cls._server.pub_port) diff --git a/src/leap/common/zmq_utils.py b/src/leap/common/zmq_utils.py index 0a781de..39a49c7 100644 --- a/src/leap/common/zmq_utils.py +++ b/src/leap/common/zmq_utils.py @@ -19,6 +19,7 @@ Utilities to handle ZMQ certificates.  """  import os  import logging +import platform  import stat  import shutil @@ -52,6 +53,10 @@ def zmq_has_curve():         `zmq.auth` module is new in version 14.1         `zmq.has()` is new in version 14.1, new in version libzmq-4.1.      """ +    if platform.system() == "Windows": +        # TODO: curve is not working on windows #7919 +        return False +      zmq_version = zmq.zmq_version_info()      pyzmq_version = zmq.pyzmq_version_info() | 
