summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/leap/common/__init__.py2
-rw-r--r--src/leap/common/_version.py541
-rw-r--r--src/leap/common/certs.py4
-rw-r--r--src/leap/common/events/auth.py100
-rw-r--r--src/leap/common/events/catalog.py85
-rw-r--r--src/leap/common/events/client.py23
-rw-r--r--src/leap/common/events/examples/README.txt49
-rw-r--r--src/leap/common/events/examples/client.py2
-rw-r--r--src/leap/common/events/examples/server.py4
-rw-r--r--src/leap/common/events/server.py24
-rw-r--r--src/leap/common/events/tests/test_auth.py64
-rw-r--r--src/leap/common/events/tests/test_events.py (renamed from src/leap/common/tests/test_events.py)23
-rw-r--r--src/leap/common/events/txclient.py10
-rw-r--r--src/leap/common/events/zmq_components.py147
-rw-r--r--src/leap/common/service_hooks.py75
-rw-r--r--src/leap/common/testing/basetest.py9
-rw-r--r--src/leap/common/zmq_utils.py5
17 files changed, 881 insertions, 286 deletions
diff --git a/src/leap/common/__init__.py b/src/leap/common/__init__.py
index 383e198..3b07cf8 100644
--- a/src/leap/common/__init__.py
+++ b/src/leap/common/__init__.py
@@ -4,7 +4,6 @@ from leap.common import certs
from leap.common import check
from leap.common import files
from leap.common import events
-from ._version import get_versions
logger = logging.getLogger(__name__)
@@ -17,5 +16,6 @@ except ImportError:
__all__ = ["certs", "check", "files", "events"]
+from ._version import get_versions
__version__ = get_versions()['version']
del get_versions
diff --git a/src/leap/common/_version.py b/src/leap/common/_version.py
index de94ba8..e29d969 100644
--- a/src/leap/common/_version.py
+++ b/src/leap/common/_version.py
@@ -1,73 +1,157 @@
+
# This file helps to compute a version number in source trees obtained from
# git-archive tarball (such as those provided by githubs download-from-tag
-# feature). Distribution tarballs (build by setup.py sdist) and build
+# feature). Distribution tarballs (built by setup.py sdist) and build
# directories (produced by setup.py build) will contain a much shorter file
# that just contains the computed version number.
# This file is released into the public domain. Generated by
-# versioneer-0.7+ (https://github.com/warner/python-versioneer)
+# versioneer-0.16 (https://github.com/warner/python-versioneer)
-# these strings will be replaced by git during git-archive
+"""Git implementation of _version.py."""
+import errno
+import os
+import re
import subprocess
import sys
-import re
-import os.path
-IN_LONG_VERSION_PY = True
-git_refnames = "$Format:%d$"
-git_full = "$Format:%H$"
+def get_keywords():
+ """Get the keywords needed to look up the version information."""
+ # these strings will be replaced by git during git-archive.
+ # setup.py/versioneer.py will grep for the variable names, so they must
+ # each be defined on a line of their own. _version.py will just call
+ # get_keywords().
+ git_refnames = "$Format:%d$"
+ git_full = "$Format:%H$"
+ keywords = {"refnames": git_refnames, "full": git_full}
+ return keywords
-def run_command(args, cwd=None, verbose=False):
- try:
- # remember shell=False, so use git.cmd on windows, not just git
- p = subprocess.Popen(args, stdout=subprocess.PIPE, cwd=cwd)
- except EnvironmentError:
- e = sys.exc_info()[1]
+
+class VersioneerConfig:
+ """Container for Versioneer configuration parameters."""
+
+
+def get_config():
+ """Create, populate and return the VersioneerConfig() object."""
+ # these strings are filled in when 'setup.py versioneer' creates
+ # _version.py
+ cfg = VersioneerConfig()
+ cfg.VCS = "git"
+ cfg.style = "pep440"
+ cfg.tag_prefix = ""
+ cfg.parentdir_prefix = "None"
+ cfg.versionfile_source = "src/leap/common/_version.py"
+ cfg.verbose = False
+ return cfg
+
+
+class NotThisMethod(Exception):
+ """Exception raised if a method is not valid for the current scenario."""
+
+
+LONG_VERSION_PY = {}
+HANDLERS = {}
+
+
+def register_vcs_handler(vcs, method): # decorator
+ """Decorator to mark a method as the handler for a particular VCS."""
+ def decorate(f):
+ """Store f in HANDLERS[vcs][method]."""
+ if vcs not in HANDLERS:
+ HANDLERS[vcs] = {}
+ HANDLERS[vcs][method] = f
+ return f
+ return decorate
+
+
+def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False):
+ """Call the given command(s)."""
+ assert isinstance(commands, list)
+ p = None
+ for c in commands:
+ try:
+ dispcmd = str([c] + args)
+ # remember shell=False, so use git.cmd on windows, not just git
+ p = subprocess.Popen([c] + args, cwd=cwd, stdout=subprocess.PIPE,
+ stderr=(subprocess.PIPE if hide_stderr
+ else None))
+ break
+ except EnvironmentError:
+ e = sys.exc_info()[1]
+ if e.errno == errno.ENOENT:
+ continue
+ if verbose:
+ print("unable to run %s" % dispcmd)
+ print(e)
+ return None
+ else:
if verbose:
- print("unable to run %s" % args[0])
- print(e)
+ print("unable to find command, tried %s" % (commands,))
return None
stdout = p.communicate()[0].strip()
- if sys.version >= '3':
+ if sys.version_info[0] >= 3:
stdout = stdout.decode()
if p.returncode != 0:
if verbose:
- print("unable to run %s (error)" % args[0])
+ print("unable to run %s (error)" % dispcmd)
return None
return stdout
-def get_expanded_variables(versionfile_source):
+def versions_from_parentdir(parentdir_prefix, root, verbose):
+ """Try to determine the version from the parent directory name.
+
+ Source tarballs conventionally unpack into a directory that includes
+ both the project name and a version string.
+ """
+ dirname = os.path.basename(root)
+ if not dirname.startswith(parentdir_prefix):
+ if verbose:
+ print("guessing rootdir is '%s', but '%s' doesn't start with "
+ "prefix '%s'" % (root, dirname, parentdir_prefix))
+ raise NotThisMethod("rootdir doesn't start with parentdir_prefix")
+ return {"version": dirname[len(parentdir_prefix):],
+ "full-revisionid": None,
+ "dirty": False, "error": None}
+
+
+@register_vcs_handler("git", "get_keywords")
+def git_get_keywords(versionfile_abs):
+ """Extract version information from the given file."""
# the code embedded in _version.py can just fetch the value of these
- # variables. When used from setup.py, we don't want to import
- # _version.py, so we do it with a regexp instead. This function is not
- # used from _version.py.
- variables = {}
+ # keywords. When used from setup.py, we don't want to import _version.py,
+ # so we do it with a regexp instead. This function is not used from
+ # _version.py.
+ keywords = {}
try:
- f = open(versionfile_source, "r")
+ f = open(versionfile_abs, "r")
for line in f.readlines():
if line.strip().startswith("git_refnames ="):
mo = re.search(r'=\s*"(.*)"', line)
if mo:
- variables["refnames"] = mo.group(1)
+ keywords["refnames"] = mo.group(1)
if line.strip().startswith("git_full ="):
mo = re.search(r'=\s*"(.*)"', line)
if mo:
- variables["full"] = mo.group(1)
+ keywords["full"] = mo.group(1)
f.close()
except EnvironmentError:
pass
- return variables
+ return keywords
-def versions_from_expanded_variables(variables, tag_prefix, verbose=False):
- refnames = variables["refnames"].strip()
+@register_vcs_handler("git", "keywords")
+def git_versions_from_keywords(keywords, tag_prefix, verbose):
+ """Get version information from git keywords."""
+ if not keywords:
+ raise NotThisMethod("no keywords at all, weird")
+ refnames = keywords["refnames"].strip()
if refnames.startswith("$Format"):
if verbose:
- print("variables are unexpanded, not using")
- return {} # unexpanded, so not in an unpacked git-archive tarball
+ print("keywords are unexpanded, not using")
+ raise NotThisMethod("unexpanded keywords, not a git-archive tarball")
refs = set([r.strip() for r in refnames.strip("()").split(",")])
# starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of
# just "foo-1.0". If we see a "tag: " prefix, prefer those.
@@ -83,7 +167,7 @@ def versions_from_expanded_variables(variables, tag_prefix, verbose=False):
# "stabilization", as well as "HEAD" and "master".
tags = set([r for r in refs if re.search(r'\d', r)])
if verbose:
- print("discarding '%s', no digits" % ",".join(refs - tags))
+ print("discarding '%s', no digits" % ",".join(refs-tags))
if verbose:
print("likely tags: %s" % ",".join(sorted(tags)))
for ref in sorted(tags):
@@ -93,111 +177,308 @@ def versions_from_expanded_variables(variables, tag_prefix, verbose=False):
if verbose:
print("picking %s" % r)
return {"version": r,
- "full": variables["full"].strip()}
- # no suitable tags, so we use the full revision id
+ "full-revisionid": keywords["full"].strip(),
+ "dirty": False, "error": None
+ }
+ # no suitable tags, so version is "0+unknown", but full hex is still there
if verbose:
- print("no suitable tags, using full revision id")
- return {"version": variables["full"].strip(),
- "full": variables["full"].strip()}
-
-
-def versions_from_vcs(tag_prefix, versionfile_source, verbose=False):
- # this runs 'git' from the root of the source tree. That either means
- # someone ran a setup.py command (and this code is in versioneer.py, so
- # IN_LONG_VERSION_PY=False, thus the containing directory is the root of
- # the source tree), or someone ran a project-specific entry point (and
- # this code is in _version.py, so IN_LONG_VERSION_PY=True, thus the
- # containing directory is somewhere deeper in the source tree). This only
- # gets called if the git-archive 'subst' variables were *not* expanded,
- # and _version.py hasn't already been rewritten with a short version
- # string, meaning we're inside a checked out source tree.
+ print("no suitable tags, using unknown + full revision id")
+ return {"version": "0+unknown",
+ "full-revisionid": keywords["full"].strip(),
+ "dirty": False, "error": "no suitable tags"}
- try:
- here = os.path.abspath(__file__)
- except NameError:
- # some py2exe/bbfreeze/non-CPython implementations don't do __file__
- return {} # not always correct
-
- # versionfile_source is the relative path from the top of the source tree
- # (where the .git directory might live) to this file. Invert this to find
- # the root from __file__.
- root = here
- if IN_LONG_VERSION_PY:
- for i in range(len(versionfile_source.split("/"))):
- root = os.path.dirname(root)
- else:
- root = os.path.dirname(here)
+
+@register_vcs_handler("git", "pieces_from_vcs")
+def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command):
+ """Get version from 'git describe' in the root of the source tree.
+
+ This only gets called if the git-archive 'subst' keywords were *not*
+ expanded, and _version.py hasn't already been rewritten with a short
+ version string, meaning we're inside a checked out source tree.
+ """
if not os.path.exists(os.path.join(root, ".git")):
if verbose:
print("no .git in %s" % root)
- return {}
+ raise NotThisMethod("no .git directory")
- GIT = "git"
+ GITS = ["git"]
if sys.platform == "win32":
- GIT = "git.cmd"
- stdout = run_command([GIT, "describe", "--tags", "--dirty", "--always"],
- cwd=root)
- if stdout is None:
- return {}
- if not stdout.startswith(tag_prefix):
- if verbose:
- print("tag '%s' doesn't start with prefix '%s'" % (stdout, tag_prefix))
- return {}
- tag = stdout[len(tag_prefix):]
- stdout = run_command([GIT, "rev-parse", "HEAD"], cwd=root)
- if stdout is None:
- return {}
- full = stdout.strip()
- if tag.endswith("-dirty"):
- full += "-dirty"
- return {"version": tag, "full": full}
-
-
-def versions_from_parentdir(parentdir_prefix, versionfile_source, verbose=False):
- if IN_LONG_VERSION_PY:
- # We're running from _version.py. If it's from a source tree
- # (execute-in-place), we can work upwards to find the root of the
- # tree, and then check the parent directory for a version string. If
- # it's in an installed application, there's no hope.
- try:
- here = os.path.abspath(__file__)
- except NameError:
- # py2exe/bbfreeze/non-CPython don't have __file__
- return {} # without __file__, we have no hope
+ GITS = ["git.cmd", "git.exe"]
+ # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty]
+ # if there isn't one, this yields HEX[-dirty] (no NUM)
+ describe_out = run_command(GITS, ["describe", "--tags", "--dirty",
+ "--always", "--long",
+ "--match", "%s*" % tag_prefix],
+ cwd=root)
+ # --long was added in git-1.5.5
+ if describe_out is None:
+ raise NotThisMethod("'git describe' failed")
+ describe_out = describe_out.strip()
+ full_out = run_command(GITS, ["rev-parse", "HEAD"], cwd=root)
+ if full_out is None:
+ raise NotThisMethod("'git rev-parse' failed")
+ full_out = full_out.strip()
+
+ pieces = {}
+ pieces["long"] = full_out
+ pieces["short"] = full_out[:7] # maybe improved later
+ pieces["error"] = None
+
+ # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty]
+ # TAG might have hyphens.
+ git_describe = describe_out
+
+ # look for -dirty suffix
+ dirty = git_describe.endswith("-dirty")
+ pieces["dirty"] = dirty
+ if dirty:
+ git_describe = git_describe[:git_describe.rindex("-dirty")]
+
+ # now we have TAG-NUM-gHEX or HEX
+
+ if "-" in git_describe:
+ # TAG-NUM-gHEX
+ mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe)
+ if not mo:
+ # unparseable. Maybe git-describe is misbehaving?
+ pieces["error"] = ("unable to parse git-describe output: '%s'"
+ % describe_out)
+ return pieces
+
+ # tag
+ full_tag = mo.group(1)
+ if not full_tag.startswith(tag_prefix):
+ if verbose:
+ fmt = "tag '%s' doesn't start with prefix '%s'"
+ print(fmt % (full_tag, tag_prefix))
+ pieces["error"] = ("tag '%s' doesn't start with prefix '%s'"
+ % (full_tag, tag_prefix))
+ return pieces
+ pieces["closest-tag"] = full_tag[len(tag_prefix):]
+
+ # distance: number of commits since tag
+ pieces["distance"] = int(mo.group(2))
+
+ # commit: short hex revision ID
+ pieces["short"] = mo.group(3)
+
+ else:
+ # HEX: no tags
+ pieces["closest-tag"] = None
+ count_out = run_command(GITS, ["rev-list", "HEAD", "--count"],
+ cwd=root)
+ pieces["distance"] = int(count_out) # total number of commits
+
+ return pieces
+
+
+def plus_or_dot(pieces):
+ """Return a + if we don't already have one, else return a ."""
+ if "+" in pieces.get("closest-tag", ""):
+ return "."
+ return "+"
+
+
+def render_pep440(pieces):
+ """Build up version string, with post-release "local version identifier".
+
+ Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you
+ get a tagged build and then dirty it, you'll get TAG+0.gHEX.dirty
+
+ Exceptions:
+ 1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty]
+ """
+ if pieces["closest-tag"]:
+ rendered = pieces["closest-tag"]
+ if pieces["distance"] or pieces["dirty"]:
+ rendered += plus_or_dot(pieces)
+ rendered += "%d.g%s" % (pieces["distance"], pieces["short"])
+ if pieces["dirty"]:
+ rendered += ".dirty"
+ else:
+ # exception #1
+ rendered = "0+untagged.%d.g%s" % (pieces["distance"],
+ pieces["short"])
+ if pieces["dirty"]:
+ rendered += ".dirty"
+ return rendered
+
+
+def render_pep440_pre(pieces):
+ """TAG[.post.devDISTANCE] -- No -dirty.
+
+ Exceptions:
+ 1: no tags. 0.post.devDISTANCE
+ """
+ if pieces["closest-tag"]:
+ rendered = pieces["closest-tag"]
+ if pieces["distance"]:
+ rendered += ".post.dev%d" % pieces["distance"]
+ else:
+ # exception #1
+ rendered = "0.post.dev%d" % pieces["distance"]
+ return rendered
+
+
+def render_pep440_post(pieces):
+ """TAG[.postDISTANCE[.dev0]+gHEX] .
+
+ The ".dev0" means dirty. Note that .dev0 sorts backwards
+ (a dirty tree will appear "older" than the corresponding clean one),
+ but you shouldn't be releasing software with -dirty anyways.
+
+ Exceptions:
+ 1: no tags. 0.postDISTANCE[.dev0]
+ """
+ if pieces["closest-tag"]:
+ rendered = pieces["closest-tag"]
+ if pieces["distance"] or pieces["dirty"]:
+ rendered += ".post%d" % pieces["distance"]
+ if pieces["dirty"]:
+ rendered += ".dev0"
+ rendered += plus_or_dot(pieces)
+ rendered += "g%s" % pieces["short"]
+ else:
+ # exception #1
+ rendered = "0.post%d" % pieces["distance"]
+ if pieces["dirty"]:
+ rendered += ".dev0"
+ rendered += "+g%s" % pieces["short"]
+ return rendered
+
+
+def render_pep440_old(pieces):
+ """TAG[.postDISTANCE[.dev0]] .
+
+ The ".dev0" means dirty.
+
+ Eexceptions:
+ 1: no tags. 0.postDISTANCE[.dev0]
+ """
+ if pieces["closest-tag"]:
+ rendered = pieces["closest-tag"]
+ if pieces["distance"] or pieces["dirty"]:
+ rendered += ".post%d" % pieces["distance"]
+ if pieces["dirty"]:
+ rendered += ".dev0"
+ else:
+ # exception #1
+ rendered = "0.post%d" % pieces["distance"]
+ if pieces["dirty"]:
+ rendered += ".dev0"
+ return rendered
+
+
+def render_git_describe(pieces):
+ """TAG[-DISTANCE-gHEX][-dirty].
+
+ Like 'git describe --tags --dirty --always'.
+
+ Exceptions:
+ 1: no tags. HEX[-dirty] (note: no 'g' prefix)
+ """
+ if pieces["closest-tag"]:
+ rendered = pieces["closest-tag"]
+ if pieces["distance"]:
+ rendered += "-%d-g%s" % (pieces["distance"], pieces["short"])
+ else:
+ # exception #1
+ rendered = pieces["short"]
+ if pieces["dirty"]:
+ rendered += "-dirty"
+ return rendered
+
+
+def render_git_describe_long(pieces):
+ """TAG-DISTANCE-gHEX[-dirty].
+
+ Like 'git describe --tags --dirty --always -long'.
+ The distance/hash is unconditional.
+
+ Exceptions:
+ 1: no tags. HEX[-dirty] (note: no 'g' prefix)
+ """
+ if pieces["closest-tag"]:
+ rendered = pieces["closest-tag"]
+ rendered += "-%d-g%s" % (pieces["distance"], pieces["short"])
+ else:
+ # exception #1
+ rendered = pieces["short"]
+ if pieces["dirty"]:
+ rendered += "-dirty"
+ return rendered
+
+
+def render(pieces, style):
+ """Render the given version pieces into the requested style."""
+ if pieces["error"]:
+ return {"version": "unknown",
+ "full-revisionid": pieces.get("long"),
+ "dirty": None,
+ "error": pieces["error"]}
+
+ if not style or style == "default":
+ style = "pep440" # the default
+
+ if style == "pep440":
+ rendered = render_pep440(pieces)
+ elif style == "pep440-pre":
+ rendered = render_pep440_pre(pieces)
+ elif style == "pep440-post":
+ rendered = render_pep440_post(pieces)
+ elif style == "pep440-old":
+ rendered = render_pep440_old(pieces)
+ elif style == "git-describe":
+ rendered = render_git_describe(pieces)
+ elif style == "git-describe-long":
+ rendered = render_git_describe_long(pieces)
+ else:
+ raise ValueError("unknown style '%s'" % style)
+
+ return {"version": rendered, "full-revisionid": pieces["long"],
+ "dirty": pieces["dirty"], "error": None}
+
+
+def get_versions():
+ """Get version information or return default if unable to do so."""
+ # I am in _version.py, which lives at ROOT/VERSIONFILE_SOURCE. If we have
+ # __file__, we can work backwards from there to the root. Some
+ # py2exe/bbfreeze/non-CPython implementations don't do __file__, in which
+ # case we can only use expanded keywords.
+
+ cfg = get_config()
+ verbose = cfg.verbose
+
+ try:
+ return git_versions_from_keywords(get_keywords(), cfg.tag_prefix,
+ verbose)
+ except NotThisMethod:
+ pass
+
+ try:
+ root = os.path.realpath(__file__)
# versionfile_source is the relative path from the top of the source
- # tree to _version.py. Invert this to find the root from __file__.
- root = here
- for i in range(len(versionfile_source.split("/"))):
+ # tree (where the .git directory might live) to this file. Invert
+ # this to find the root from __file__.
+ for i in cfg.versionfile_source.split('/'):
root = os.path.dirname(root)
- else:
- # we're running from versioneer.py, which means we're running from
- # the setup.py in a source tree. sys.argv[0] is setup.py in the root.
- here = os.path.abspath(sys.argv[0])
- root = os.path.dirname(here)
+ except NameError:
+ return {"version": "0+unknown", "full-revisionid": None,
+ "dirty": None,
+ "error": "unable to find root of source tree"}
- # Source tarballs conventionally unpack into a directory that includes
- # both the project name and a version string.
- dirname = os.path.basename(root)
- if not dirname.startswith(parentdir_prefix):
- if verbose:
- print("guessing rootdir is '%s', but '%s' doesn't start with prefix '%s'" %
- (root, dirname, parentdir_prefix))
- return None
- return {"version": dirname[len(parentdir_prefix):], "full": ""}
-
-tag_prefix = ""
-parentdir_prefix = "leap.common-"
-versionfile_source = "src/leap/common/_version.py"
-
-
-def get_versions(default={"version": "unknown", "full": ""}, verbose=False):
- variables = {"refnames": git_refnames, "full": git_full}
- ver = versions_from_expanded_variables(variables, tag_prefix, verbose)
- if not ver:
- ver = versions_from_vcs(tag_prefix, versionfile_source, verbose)
- if not ver:
- ver = versions_from_parentdir(parentdir_prefix, versionfile_source,
- verbose)
- if not ver:
- ver = default
- return ver
+ try:
+ pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose)
+ return render(pieces, cfg.style)
+ except NotThisMethod:
+ pass
+
+ try:
+ if cfg.parentdir_prefix:
+ return versions_from_parentdir(cfg.parentdir_prefix, root, verbose)
+ except NotThisMethod:
+ pass
+
+ return {"version": "0+unknown", "full-revisionid": None,
+ "dirty": None,
+ "error": "unable to compute version"}
diff --git a/src/leap/common/certs.py b/src/leap/common/certs.py
index c49015a..95704a6 100644
--- a/src/leap/common/certs.py
+++ b/src/leap/common/certs.py
@@ -192,8 +192,8 @@ def get_compatible_ssl_context_factory(cert_path=None):
class WebClientContextFactory(ssl.ClientContextFactory):
"""
- A web context factory which ignores the hostname and port and does no
- certificate verification.
+ A web context factory which ignores the hostname and port and does
+ no certificate verification.
"""
def getContext(self, hostname, port):
return ssl.ClientContextFactory.getContext(self)
diff --git a/src/leap/common/events/auth.py b/src/leap/common/events/auth.py
new file mode 100644
index 0000000..db217ca
--- /dev/null
+++ b/src/leap/common/events/auth.py
@@ -0,0 +1,100 @@
+# -*- coding: utf-8 -*-
+# auth.py
+# Copyright (C) 2016 LEAP
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see <http://www.gnu.org/licenses/>.
+"""
+ZAP authentication, twisted style.
+"""
+from zmq import PAIR
+from zmq.auth.base import Authenticator, VERSION
+from txzmq.connection import ZmqConnection
+from zmq.utils.strtypes import b, u
+
+from twisted.python import log
+
+from txzmq.connection import ZmqEndpoint, ZmqEndpointType
+
+
+class TxAuthenticator(ZmqConnection):
+
+ """
+ This does not implement the whole ZAP protocol, but the bare minimum that
+ we need.
+ """
+
+ socketType = PAIR
+ address = 'inproc://zeromq.zap.01'
+ encoding = 'utf-8'
+
+ def __init__(self, factory, *args, **kw):
+ super(TxAuthenticator, self).__init__(factory, *args, **kw)
+ self.authenticator = Authenticator(factory.context)
+ self.authenticator._send_zap_reply = self._send_zap_reply
+
+ def start(self):
+ endpoint = ZmqEndpoint(ZmqEndpointType.bind, self.address)
+ self.addEndpoints([endpoint])
+
+ def messageReceived(self, msg):
+
+ command = msg[0]
+
+ if command == b'ALLOW':
+ addresses = [u(m, self.encoding) for m in msg[1:]]
+ try:
+ self.authenticator.allow(*addresses)
+ except Exception as e:
+ log.err("Failed to allow %s", addresses)
+
+ elif command == b'CURVE':
+ domain = u(msg[1], self.encoding)
+ location = u(msg[2], self.encoding)
+ self.authenticator.configure_curve(domain, location)
+
+ def _send_zap_reply(self, request_id, status_code, status_text,
+ user_id='user'):
+ """
+ Send a ZAP reply to finish the authentication.
+ """
+ user_id = user_id if status_code == b'200' else b''
+ if isinstance(user_id, unicode):
+ user_id = user_id.encode(self.encoding, 'replace')
+ metadata = b'' # not currently used
+ reply = [VERSION, request_id, status_code, status_text,
+ user_id, metadata]
+ self.send(reply)
+
+ def shutdown(self):
+ if self.factory:
+ super(TxAuthenticator, self).shutdown()
+
+
+class TxAuthenticationRequest(ZmqConnection):
+
+ socketType = PAIR
+ address = 'inproc://zeromq.zap.01'
+ encoding = 'utf-8'
+
+ def start(self):
+ endpoint = ZmqEndpoint(ZmqEndpointType.connect, self.address)
+ self.addEndpoints([endpoint])
+
+ def allow(self, *addresses):
+ self.send([b'ALLOW'] + [b(a, self.encoding) for a in addresses])
+
+ def configure_curve(self, domain='*', location=''):
+ domain = b(domain, self.encoding)
+ location = b(location, self.encoding)
+ self.send([b'CURVE', domain, location])
diff --git a/src/leap/common/events/catalog.py b/src/leap/common/events/catalog.py
index 8bddd2c..9a834b2 100644
--- a/src/leap/common/events/catalog.py
+++ b/src/leap/common/events/catalog.py
@@ -24,49 +24,54 @@ Events catalog.
EVENTS = [
"CLIENT_SESSION_ID",
"CLIENT_UID",
- "IMAP_CLIENT_LOGIN",
- "IMAP_SERVICE_FAILED_TO_START",
- "IMAP_SERVICE_STARTED",
- "IMAP_UNHANDLED_ERROR",
- "KEYMANAGER_DONE_UPLOADING_KEYS",
- "KEYMANAGER_FINISHED_KEY_GENERATION",
- "KEYMANAGER_KEY_FOUND",
- "KEYMANAGER_KEY_NOT_FOUND",
- "KEYMANAGER_LOOKING_FOR_KEY",
- "KEYMANAGER_STARTED_KEY_GENERATION",
- "MAIL_FETCHED_INCOMING",
- "MAIL_MSG_DECRYPTED",
- "MAIL_MSG_DELETED_INCOMING",
- "MAIL_MSG_PROCESSING",
- "MAIL_MSG_SAVED_LOCALLY",
- "MAIL_UNREAD_MESSAGES",
"RAISE_WINDOW",
- "SMTP_CONNECTION_LOST",
- "SMTP_END_ENCRYPT_AND_SIGN",
- "SMTP_END_SIGN",
- "SMTP_RECIPIENT_ACCEPTED_ENCRYPTED",
- "SMTP_RECIPIENT_ACCEPTED_UNENCRYPTED",
- "SMTP_RECIPIENT_REJECTED",
- "SMTP_SEND_MESSAGE_ERROR",
- "SMTP_SEND_MESSAGE_START",
- "SMTP_SEND_MESSAGE_SUCCESS",
- "SMTP_SERVICE_FAILED_TO_START",
- "SMTP_SERVICE_STARTED",
- "SMTP_START_ENCRYPT_AND_SIGN",
- "SMTP_START_SIGN",
- "SOLEDAD_CREATING_KEYS",
- "SOLEDAD_DONE_CREATING_KEYS",
- "SOLEDAD_DONE_DATA_SYNC",
- "SOLEDAD_DONE_DOWNLOADING_KEYS",
- "SOLEDAD_DONE_UPLOADING_KEYS",
- "SOLEDAD_DOWNLOADING_KEYS",
- "SOLEDAD_INVALID_AUTH_TOKEN",
- "SOLEDAD_NEW_DATA_TO_SYNC",
- "SOLEDAD_SYNC_RECEIVE_STATUS",
- "SOLEDAD_SYNC_SEND_STATUS",
- "SOLEDAD_UPLOADING_KEYS",
"UPDATER_DONE_UPDATING",
"UPDATER_NEW_UPDATES",
+
+ "KEYMANAGER_DONE_UPLOADING_KEYS", # (address)
+ "KEYMANAGER_FINISHED_KEY_GENERATION", # (address)
+ "KEYMANAGER_KEY_FOUND", # (address)
+ "KEYMANAGER_KEY_NOT_FOUND", # (address)
+ "KEYMANAGER_LOOKING_FOR_KEY", # (address)
+ "KEYMANAGER_STARTED_KEY_GENERATION", # (address)
+
+ "SOLEDAD_CREATING_KEYS", # {uuid, userid}
+ "SOLEDAD_DONE_CREATING_KEYS", # {uuid, userid}
+ "SOLEDAD_DONE_DATA_SYNC", # {uuid, userid}
+ "SOLEDAD_DONE_DOWNLOADING_KEYS", # {uuid, userid}
+ "SOLEDAD_DONE_UPLOADING_KEYS", # {uuid, userid}
+ "SOLEDAD_DOWNLOADING_KEYS", # {uuid, userid}
+ "SOLEDAD_INVALID_AUTH_TOKEN", # {uuid, userid}
+ "SOLEDAD_SYNC_RECEIVE_STATUS", # {uuid, userid}
+ "SOLEDAD_SYNC_SEND_STATUS", # {uuid, userid}
+ "SOLEDAD_UPLOADING_KEYS", # {uuid, userid}
+ "SOLEDAD_NEW_DATA_TO_SYNC",
+
+ "MAIL_FETCHED_INCOMING", # (userid)
+ "MAIL_MSG_DECRYPTED", # (userid)
+ "MAIL_MSG_DELETED_INCOMING", # (userid)
+ "MAIL_MSG_PROCESSING", # (userid)
+ "MAIL_MSG_SAVED_LOCALLY", # (userid)
+ "MAIL_UNREAD_MESSAGES", # (userid, number)
+
+ "IMAP_SERVICE_STARTED",
+ "IMAP_SERVICE_FAILED_TO_START",
+ "IMAP_UNHANDLED_ERROR",
+ "IMAP_CLIENT_LOGIN", # (username)
+
+ "SMTP_SERVICE_STARTED",
+ "SMTP_SERVICE_FAILED_TO_START",
+ "SMTP_START_ENCRYPT_AND_SIGN", # (from_addr)
+ "SMTP_END_ENCRYPT_AND_SIGN", # (from_addr)
+ "SMTP_START_SIGN", # (from_addr)
+ "SMTP_END_SIGN", # (from_addr)
+ "SMTP_SEND_MESSAGE_START", # (from_addr)
+ "SMTP_SEND_MESSAGE_SUCCESS", # (from_addr)
+ "SMTP_RECIPIENT_ACCEPTED_ENCRYPTED", # (userid, dest)
+ "SMTP_RECIPIENT_ACCEPTED_UNENCRYPTED", # (userid, dest)
+ "SMTP_CONNECTION_LOST", # (userid, dest)
+ "SMTP_RECIPIENT_REJECTED", # (userid, dest)
+ "SMTP_SEND_MESSAGE_ERROR", # (userid, dest)
]
diff --git a/src/leap/common/events/client.py b/src/leap/common/events/client.py
index 60d24bc..78617de 100644
--- a/src/leap/common/events/client.py
+++ b/src/leap/common/events/client.py
@@ -63,14 +63,18 @@ logger = logging.getLogger(__name__)
_emit_addr = EMIT_ADDR
_reg_addr = REG_ADDR
+_factory = None
+_enable_curve = True
-def configure_client(emit_addr, reg_addr):
- global _emit_addr, _reg_addr
+def configure_client(emit_addr, reg_addr, factory=None, enable_curve=True):
+ global _emit_addr, _reg_addr, _factory, _enable_curve
logger.debug("Configuring client with addresses: (%s, %s)" %
(emit_addr, reg_addr))
_emit_addr = emit_addr
_reg_addr = reg_addr
+ _factory = factory
+ _enable_curve = enable_curve
class EventsClient(object):
@@ -103,7 +107,9 @@ class EventsClient(object):
"""
with cls._instance_lock:
if cls._instance is None:
- cls._instance = cls(_emit_addr, _reg_addr)
+ cls._instance = cls(
+ _emit_addr, _reg_addr, factory=_factory,
+ enable_curve=_enable_curve)
return cls._instance
def register(self, event, callback, uid=None, replace=False):
@@ -270,7 +276,7 @@ class EventsClientThread(threading.Thread, EventsClient):
A threaded version of the events client.
"""
- def __init__(self, emit_addr, reg_addr):
+ def __init__(self, emit_addr, reg_addr, factory=None, enable_curve=True):
"""
Initialize the events client.
"""
@@ -281,15 +287,22 @@ class EventsClientThread(threading.Thread, EventsClient):
self._config_prefix = os.path.join(
get_path_prefix(flags.STANDALONE), "leap", "events")
self._loop = None
+ self._factory = factory
self._context = None
self._push = None
self._sub = None
+ if enable_curve:
+ self.use_curve = zmq_has_curve()
+ else:
+ self.use_curve = False
+
def _init_zmq(self):
"""
Initialize ZMQ connections.
"""
self._loop = EventsIOLoop()
+ # we need a new context for each thread
self._context = zmq.Context()
# connect SUB first, otherwise we might miss some event sent from this
# same client
@@ -311,7 +324,7 @@ class EventsClientThread(threading.Thread, EventsClient):
logger.debug("Connecting %s to %s." % (socktype, address))
socket = self._context.socket(socktype)
# configure curve authentication
- if zmq_has_curve():
+ if self.use_curve:
public, private = maybe_create_and_get_certificates(
self._config_prefix, "client")
server_public_file = os.path.join(
diff --git a/src/leap/common/events/examples/README.txt b/src/leap/common/events/examples/README.txt
new file mode 100644
index 0000000..0bb0df6
--- /dev/null
+++ b/src/leap/common/events/examples/README.txt
@@ -0,0 +1,49 @@
+How to debug
+-----------------------------------------
+monitor the events socket:
+ sudo ngrep -W byline -d any port 9000
+
+launch the server:
+ python server.py
+
+launch the client:
+ python client.py
+
+if zmq is available and enabled, you should see encrypted messages passing by
+the socket.
+
+You should see something like the following:
+
+####
+T 127.0.0.1:9000 -> 127.0.0.1:33122 [AP]
+..........
+##
+T 127.0.0.1:33122 -> 127.0.0.1:9000 [AP]
+...........
+##
+T 127.0.0.1:9000 -> 127.0.0.1:33122 [AP]
+..CURVE...............................................
+#
+T 127.0.0.1:33122 -> 127.0.0.1:9000 [AP]
+.CURVE...............................................
+#
+T 127.0.0.1:33122 -> 127.0.0.1:9000 [AP]
+...HELLO.............................................................................:....^...".....'.S...n......Y...................O.7.+.D.q".*..R...j.....8..qu..~......Ck.G\....:...m....Tg.s..M..x<..
+##
+T 127.0.0.1:9000 -> 127.0.0.1:33122 [AP]
+...WELCOME..%.'.,Td... I..}...........`..Nm......./_.Je...4.....-.....f<v.|.".jJ...^.D...$lJ..U......g..../w.......\..W.....!........i.v....0...........3..a.5}.@F..v./..$
+#
+T 127.0.0.1:33122 -> 127.0.0.1:9000 [AP]
+..........INITIATE......!.*.=0.-......D..]{...A\.tz...!2.....A./
+6.......Y.h.N....cb.U.|..f..)....W..3..X.2U.3PGl.........m..95.(......NJ....5.'..W.GQ..B/.....\%.,Q..r.'L5.......{.W<=._.$.(6j.G...
+...37.H..Th...'.........0 ........,..q....U..G..M.`!_..w....f."..........
+.d.K.Y.>f.n.kV.
+#
+T 127.0.0.1:9000 -> 127.0.0.1:33122 [AP]
+.2.READY............A...e.)......*.8y....k.<.N1Z.4..
+#
+T 127.0.0.1:33122 -> 127.0.0.1:9000 [AP]
+.+.MESSAGE........o...*M..,....
+.r..w..[.GwcU
+###
+
diff --git a/src/leap/common/events/examples/client.py b/src/leap/common/events/examples/client.py
new file mode 100644
index 0000000..d6d8985
--- /dev/null
+++ b/src/leap/common/events/examples/client.py
@@ -0,0 +1,2 @@
+from leap.common.events.txclient import emit
+emit('stuff!')
diff --git a/src/leap/common/events/examples/server.py b/src/leap/common/events/examples/server.py
new file mode 100644
index 0000000..f40f8dc
--- /dev/null
+++ b/src/leap/common/events/examples/server.py
@@ -0,0 +1,4 @@
+from twisted.internet import reactor
+from leap.common.events.server import ensure_server
+reactor.callWhenRunning(ensure_server)
+reactor.run()
diff --git a/src/leap/common/events/server.py b/src/leap/common/events/server.py
index a69202e..05fc23e 100644
--- a/src/leap/common/events/server.py
+++ b/src/leap/common/events/server.py
@@ -14,33 +14,31 @@
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
-
-
"""
The server for the events mechanism.
"""
-
-
import logging
+import platform
+
import txzmq
from leap.common.zmq_utils import zmq_has_curve
-
from leap.common.events.zmq_components import TxZmqServerComponent
-if zmq_has_curve():
+if zmq_has_curve() or platform.system() == "Windows":
+ # Windows doesn't have ipc sockets, we need to use always tcp
EMIT_ADDR = "tcp://127.0.0.1:9000"
REG_ADDR = "tcp://127.0.0.1:9001"
else:
EMIT_ADDR = "ipc:///tmp/leap.common.events.socket.0"
REG_ADDR = "ipc:///tmp/leap.common.events.socket.1"
-
logger = logging.getLogger(__name__)
-def ensure_server(emit_addr=EMIT_ADDR, reg_addr=REG_ADDR):
+def ensure_server(emit_addr=EMIT_ADDR, reg_addr=REG_ADDR, path_prefix=None,
+ factory=None, enable_curve=True):
"""
Make sure the server is running in the given addresses.
@@ -52,7 +50,8 @@ def ensure_server(emit_addr=EMIT_ADDR, reg_addr=REG_ADDR):
:return: an events server instance
:rtype: EventsServer
"""
- _server = EventsServer(emit_addr, reg_addr)
+ _server = EventsServer(emit_addr, reg_addr, path_prefix, factory=factory,
+ enable_curve=enable_curve)
return _server
@@ -62,7 +61,8 @@ class EventsServer(TxZmqServerComponent):
events in another address.
"""
- def __init__(self, emit_addr, reg_addr):
+ def __init__(self, emit_addr, reg_addr, path_prefix=None, factory=None,
+ enable_curve=True):
"""
Initialize the events server.
@@ -71,7 +71,9 @@ class EventsServer(TxZmqServerComponent):
:param reg_addr: The address to which publish events to clients.
:type reg_addr: str
"""
- TxZmqServerComponent.__init__(self)
+ TxZmqServerComponent.__init__(self, path_prefix=path_prefix,
+ factory=factory,
+ enable_curve=enable_curve)
# bind PULL and PUB sockets
self._pull, self.pull_port = self._zmq_bind(
txzmq.ZmqPullConnection, emit_addr)
diff --git a/src/leap/common/events/tests/test_auth.py b/src/leap/common/events/tests/test_auth.py
new file mode 100644
index 0000000..78ffd9f
--- /dev/null
+++ b/src/leap/common/events/tests/test_auth.py
@@ -0,0 +1,64 @@
+# -*- coding: utf-8 -*-
+# test_zmq_components.py
+# Copyright (C) 2014 LEAP
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see <http://www.gnu.org/licenses/>.
+"""
+Tests for the auth module.
+"""
+import os
+
+from twisted.trial import unittest
+from txzmq import ZmqFactory
+
+from leap.common.events import auth
+from leap.common.testing.basetest import BaseLeapTest
+from leap.common.zmq_utils import PUBLIC_KEYS_PREFIX
+from leap.common.zmq_utils import maybe_create_and_get_certificates
+
+from txzmq.test import _wait
+
+
+class ZmqAuthTestCase(unittest.TestCase, BaseLeapTest):
+
+ def setUp(self):
+ self.setUpEnv(launch_events_server=False)
+
+ self.factory = ZmqFactory()
+ self._config_prefix = os.path.join(self.tempdir, "leap", "events")
+
+ self.public, self.secret = maybe_create_and_get_certificates(
+ self._config_prefix, 'server')
+
+ self.authenticator = auth.TxAuthenticator(self.factory)
+ self.authenticator.start()
+ self.auth_req = auth.TxAuthenticationRequest(self.factory)
+
+ def tearDown(self):
+ self.factory.shutdown()
+ self.tearDownEnv()
+
+ def test_curve_auth(self):
+ self.auth_req.start()
+ self.auth_req.allow('127.0.0.1')
+ public_keys_dir = os.path.join(self._config_prefix, PUBLIC_KEYS_PREFIX)
+ self.auth_req.configure_curve(domain="*", location=public_keys_dir)
+
+ def check(ignored):
+ authenticator = self.authenticator.authenticator
+ certs = authenticator.certs['*']
+ self.failUnlessEqual(authenticator.whitelist, set([u'127.0.0.1']))
+ self.failUnlessEqual(certs[certs.keys()[0]], True)
+
+ return _wait(0.1).addCallback(check)
diff --git a/src/leap/common/tests/test_events.py b/src/leap/common/events/tests/test_events.py
index 2ad097e..d8435c6 100644
--- a/src/leap/common/tests/test_events.py
+++ b/src/leap/common/events/tests/test_events.py
@@ -14,16 +14,18 @@
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
-
-
+"""
+Tests for the events framework
+"""
import os
import logging
-import time
from twisted.internet.reactor import callFromThread
from twisted.trial import unittest
from twisted.internet import defer
+from txzmq import ZmqFactory
+
from leap.common.events import server
from leap.common.events import client
from leap.common.events import flags
@@ -40,19 +42,22 @@ class EventsGenericClientTestCase(object):
def setUp(self):
flags.set_events_enabled(True)
+ self.factory = ZmqFactory()
self._server = server.ensure_server(
emit_addr="tcp://127.0.0.1:0",
- reg_addr="tcp://127.0.0.1:0")
+ reg_addr="tcp://127.0.0.1:0",
+ factory=self.factory,
+ enable_curve=False)
+
self._client.configure_client(
emit_addr="tcp://127.0.0.1:%d" % self._server.pull_port,
- reg_addr="tcp://127.0.0.1:%d" % self._server.pub_port)
+ reg_addr="tcp://127.0.0.1:%d" % self._server.pub_port,
+ factory=self.factory, enable_curve=False)
def tearDown(self):
- self._client.shutdown()
- self._server.shutdown()
flags.set_events_enabled(False)
- # wait a bit for sockets to close properly
- time.sleep(0.1)
+ self.factory.shutdown()
+ self._client.instance().reset()
def test_client_register(self):
"""
diff --git a/src/leap/common/events/txclient.py b/src/leap/common/events/txclient.py
index dfd0533..63f12d7 100644
--- a/src/leap/common/events/txclient.py
+++ b/src/leap/common/events/txclient.py
@@ -58,16 +58,19 @@ class EventsTxClient(TxZmqClientComponent, EventsClient):
"""
def __init__(self, emit_addr=EMIT_ADDR, reg_addr=REG_ADDR,
- path_prefix=None):
+ path_prefix=None, factory=None, enable_curve=True):
"""
- Initialize the events server.
+ Initialize the events client.
"""
- TxZmqClientComponent.__init__(self, path_prefix=path_prefix)
+ TxZmqClientComponent.__init__(
+ self, path_prefix=path_prefix, factory=factory,
+ enable_curve=enable_curve)
EventsClient.__init__(self, emit_addr, reg_addr)
# connect SUB first, otherwise we might miss some event sent from this
# same client
self._sub = self._zmq_connect(txzmq.ZmqSubConnection, reg_addr)
self._sub.gotMessage = self._gotMessage
+
self._push = self._zmq_connect(txzmq.ZmqPushConnection, emit_addr)
def _gotMessage(self, msg, tag):
@@ -122,7 +125,6 @@ class EventsTxClient(TxZmqClientComponent, EventsClient):
callback(event, *content)
def shutdown(self):
- TxZmqClientComponent.shutdown(self)
EventsClient.shutdown(self)
diff --git a/src/leap/common/events/zmq_components.py b/src/leap/common/events/zmq_components.py
index 51de02c..c533a74 100644
--- a/src/leap/common/events/zmq_components.py
+++ b/src/leap/common/events/zmq_components.py
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
# zmq.py
-# Copyright (C) 2015 LEAP
+# Copyright (C) 2015, 2016 LEAP
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
@@ -14,60 +14,63 @@
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
-
-
"""
The server for the events mechanism.
"""
-
-
import os
import logging
import txzmq
import re
-import time
from abc import ABCMeta
-# XXX some distros don't package libsodium, so we have to be prepared for
-# absence of zmq.auth
try:
import zmq.auth
- from zmq.auth.thread import ThreadAuthenticator
+ from leap.common.events.auth import TxAuthenticator
+ from leap.common.events.auth import TxAuthenticationRequest
except ImportError:
pass
+from txzmq.connection import ZmqEndpoint, ZmqEndpointType
+
from leap.common.config import flags, get_path_prefix
from leap.common.zmq_utils import zmq_has_curve
from leap.common.zmq_utils import maybe_create_and_get_certificates
from leap.common.zmq_utils import PUBLIC_KEYS_PREFIX
-
logger = logging.getLogger(__name__)
-
ADDRESS_RE = re.compile("^([a-z]+)://([^:]+):?(\d+)?$")
+LOCALHOST_ALLOWED = '127.0.0.1'
+
class TxZmqComponent(object):
"""
A twisted-powered zmq events component.
"""
+ _factory = txzmq.ZmqFactory()
+ _factory.registerForShutdown()
+ _auth = None
__metaclass__ = ABCMeta
_component_type = None
- def __init__(self, path_prefix=None):
+ def __init__(self, path_prefix=None, enable_curve=True, factory=None):
"""
Initialize the txzmq component.
"""
- self._factory = txzmq.ZmqFactory()
- self._factory.registerForShutdown()
if path_prefix is None:
path_prefix = get_path_prefix(flags.STANDALONE)
+ if factory is not None:
+ self._factory = factory
self._config_prefix = os.path.join(path_prefix, "leap", "events")
self._connections = []
+ if enable_curve:
+ self.use_curve = zmq_has_curve()
+ else:
+ self.use_curve = False
@property
def component_type(self):
@@ -77,105 +80,89 @@ class TxZmqComponent(object):
"define a self._component_type!")
return self._component_type
- def _zmq_connect(self, connClass, address):
+ def _zmq_bind(self, connClass, address):
"""
- Connect to an address.
+ Bind to an address.
:param connClass: The connection class to be used.
:type connClass: txzmq.ZmqConnection
- :param address: The address to connect to.
+ :param address: The address to bind to.
:type address: str
- :return: The binded connection.
- :rtype: txzmq.ZmqConnection
+ :return: The binded connection and port.
+ :rtype: (txzmq.ZmqConnection, int)
"""
+ proto, addr, port = ADDRESS_RE.search(address).groups()
+
+ endpoint = ZmqEndpoint(ZmqEndpointType.bind, address)
connection = connClass(self._factory)
- # create and configure socket
- socket = connection.socket
- if zmq_has_curve():
+
+ if self.use_curve:
+ socket = connection.socket
+
public, secret = maybe_create_and_get_certificates(
self._config_prefix, self.component_type)
- server_public_file = os.path.join(
- self._config_prefix, PUBLIC_KEYS_PREFIX, "server.key")
- server_public, _ = zmq.auth.load_certificate(server_public_file)
socket.curve_publickey = public
socket.curve_secretkey = secret
- socket.curve_serverkey = server_public
- socket.connect(address)
- logger.debug("Connected %s to %s." % (connClass, address))
- self._connections.append(connection)
- return connection
+ self._start_authentication(connection.socket)
- def _zmq_bind(self, connClass, address):
+ if proto == 'tcp' and int(port) == 0:
+ connection.endpoints.extend([endpoint])
+ port = connection.socket.bind_to_random_port('tcp://%s' % addr)
+ else:
+ connection.addEndpoints([endpoint])
+
+ return connection, int(port)
+
+ def _zmq_connect(self, connClass, address):
"""
- Bind to an address.
+ Connect to an address.
:param connClass: The connection class to be used.
:type connClass: txzmq.ZmqConnection
- :param address: The address to bind to.
+ :param address: The address to connect to.
:type address: str
- :return: The binded connection and port.
- :rtype: (txzmq.ZmqConnection, int)
+ :return: The binded connection.
+ :rtype: txzmq.ZmqConnection
"""
+ endpoint = ZmqEndpoint(ZmqEndpointType.connect, address)
connection = connClass(self._factory)
- socket = connection.socket
- if zmq_has_curve():
+
+ if self.use_curve:
+ socket = connection.socket
public, secret = maybe_create_and_get_certificates(
self._config_prefix, self.component_type)
+ server_public_file = os.path.join(
+ self._config_prefix, PUBLIC_KEYS_PREFIX, "server.key")
+
+ server_public, _ = zmq.auth.load_certificate(server_public_file)
socket.curve_publickey = public
socket.curve_secretkey = secret
- self._start_thread_auth(connection.socket)
+ socket.curve_serverkey = server_public
- proto, addr, port = ADDRESS_RE.search(address).groups()
+ connection.addEndpoints([endpoint])
+ return connection
- if proto == "tcp":
- if port is None or port is '0':
- params = proto, addr
- port = socket.bind_to_random_port("%s://%s" % params)
- logger.debug("Binded %s to %s://%s." % ((connClass,) + params))
- else:
- params = proto, addr, int(port)
- socket.bind("%s://%s:%d" % params)
- logger.debug(
- "Binded %s to %s://%s:%d." % ((connClass,) + params))
- else:
- params = proto, addr
- socket.bind("%s://%s" % params)
- logger.debug(
- "Binded %s to %s://%s" % ((connClass,) + params))
- self._connections.append(connection)
- return connection, port
-
- def _start_thread_auth(self, socket):
- """
- Start the zmq curve thread authenticator.
+ def _start_authentication(self, socket):
- :param socket: The socket in which to configure the authenticator.
- :type socket: zmq.Socket
- """
- authenticator = ThreadAuthenticator(self._factory.context)
+ if not TxZmqComponent._auth:
+ TxZmqComponent._auth = TxAuthenticator(self._factory)
+ TxZmqComponent._auth.start()
- # Temporary fix until we understand what the problem is
- # See https://leap.se/code/issues/7536
- time.sleep(0.5)
+ auth_req = TxAuthenticationRequest(self._factory)
+ auth_req.start()
+ auth_req.allow(LOCALHOST_ALLOWED)
- authenticator.start()
- # XXX do not hardcode this here.
- authenticator.allow('127.0.0.1')
# tell authenticator to use the certificate in a directory
public_keys_dir = os.path.join(self._config_prefix, PUBLIC_KEYS_PREFIX)
- authenticator.configure_curve(domain="*", location=public_keys_dir)
- socket.curve_server = True # must come before bind
+ auth_req.configure_curve(domain="*", location=public_keys_dir)
+ auth_req.shutdown()
+ TxZmqComponent._auth.shutdown()
- def shutdown(self):
- """
- Shutdown the component.
- """
- logger.debug("Shutting down component %s." % str(self))
- for conn in self._connections:
- conn.shutdown()
- self._factory.shutdown()
+ # This has to be set before binding the socket, that's why this method
+ # has to be called before addEndpoints()
+ socket.curve_server = True
class TxZmqServerComponent(TxZmqComponent):
diff --git a/src/leap/common/service_hooks.py b/src/leap/common/service_hooks.py
new file mode 100644
index 0000000..96e95cc
--- /dev/null
+++ b/src/leap/common/service_hooks.py
@@ -0,0 +1,75 @@
+# -*- coding: utf-8 -*-
+# service_hooks.py
+# Copyright (C) 2016 LEAP
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see <http://www.gnu.org/licenses/>.
+"""
+Hooks for service composition.
+"""
+from collections import defaultdict
+
+from twisted.application.service import IService, Service
+from twisted.python import log
+
+from zope.interface import implementer
+
+
+@implementer(IService)
+class HookableService(Service):
+
+ """
+ This service allows for other services in a Twisted Service tree to be
+ notified whenever a certain kind of hook is triggered.
+
+ During the service composition, one is expected to register
+ a hook name with the name of the service that wants to react to the
+ triggering of the hook. All the services, both hooked and listeners, should
+ be registered against the same parent service.
+
+ Upon the hook being triggered, the method "hook_<name>" will be called with
+ the passed data in the listener service.
+ """
+
+ def register_hook(self, name, listener):
+ if not hasattr(self, 'event_listeners'):
+ self.event_listeners = defaultdict(list)
+ log.msg("Registering hook %s->%s" % (name, listener))
+ self.event_listeners[name].append(listener)
+
+ def trigger_hook(self, name, **data):
+
+ def react_to_hook(listener, name, **kw):
+ try:
+ getattr(listener, 'hook_' + name)(**kw)
+ except AttributeError:
+ raise RuntimeError(
+ "Tried to notify a hook, but the listener service class %s"
+ "has not defined the proper method" % listener.__class__)
+
+ if not hasattr(self, 'event_listeners'):
+ self.event_listeners = defaultdict(list)
+ listeners = self._get_listener_services(name)
+
+ for listener in listeners:
+ react_to_hook(listener, name, **data)
+
+ def _get_sibling_service(self, name):
+ return self.parent.getServiceNamed(name)
+
+ def _get_listener_services(self, hook):
+ if hook in self.event_listeners:
+ service_names = self.event_listeners[hook]
+ services = [
+ self._get_sibling_service(name) for name in service_names]
+ return services
diff --git a/src/leap/common/testing/basetest.py b/src/leap/common/testing/basetest.py
index 3d3cee0..2e84a25 100644
--- a/src/leap/common/testing/basetest.py
+++ b/src/leap/common/testing/basetest.py
@@ -52,7 +52,7 @@ class BaseLeapTest(unittest.TestCase):
cls.tearDownEnv()
@classmethod
- def setUpEnv(cls):
+ def setUpEnv(cls, launch_events_server=True):
"""
Sets up common facilities for testing this TestCase:
- custom PATH and HOME environmental variables
@@ -72,14 +72,15 @@ class BaseLeapTest(unittest.TestCase):
os.environ["PATH"] = bin_tdir
os.environ["HOME"] = cls.tempdir
os.environ["XDG_CONFIG_HOME"] = os.path.join(cls.tempdir, ".config")
- cls._init_events()
+ if launch_events_server:
+ cls._init_events()
@classmethod
def _init_events(cls):
if flags.EVENTS_ENABLED:
cls._server = events_server.ensure_server(
- emit_addr="tcp://127.0.0.1:0",
- reg_addr="tcp://127.0.0.1:0")
+ emit_addr="tcp://127.0.0.1",
+ reg_addr="tcp://127.0.0.1")
events_client.configure_client(
emit_addr="tcp://127.0.0.1:%d" % cls._server.pull_port,
reg_addr="tcp://127.0.0.1:%d" % cls._server.pub_port)
diff --git a/src/leap/common/zmq_utils.py b/src/leap/common/zmq_utils.py
index 0a781de..39a49c7 100644
--- a/src/leap/common/zmq_utils.py
+++ b/src/leap/common/zmq_utils.py
@@ -19,6 +19,7 @@ Utilities to handle ZMQ certificates.
"""
import os
import logging
+import platform
import stat
import shutil
@@ -52,6 +53,10 @@ def zmq_has_curve():
`zmq.auth` module is new in version 14.1
`zmq.has()` is new in version 14.1, new in version libzmq-4.1.
"""
+ if platform.system() == "Windows":
+ # TODO: curve is not working on windows #7919
+ return False
+
zmq_version = zmq.zmq_version_info()
pyzmq_version = zmq.pyzmq_version_info()