diff options
Diffstat (limited to 'client/src')
-rw-r--r-- | client/src/leap/soledad/client/_version.py | 490 | ||||
-rwxr-xr-x | client/src/taskthread/__init__.py | 296 | ||||
-rwxr-xr-x | client/src/taskthread/tests/__init__.py | 13 | ||||
-rwxr-xr-x | client/src/taskthread/tests/unit/__init__.py | 13 | ||||
-rwxr-xr-x | client/src/taskthread/tests/unit/test_taskthread.py | 315 |
5 files changed, 646 insertions, 481 deletions
diff --git a/client/src/leap/soledad/client/_version.py b/client/src/leap/soledad/client/_version.py index 3ee3f81b..23749c7c 100644 --- a/client/src/leap/soledad/client/_version.py +++ b/client/src/leap/soledad/client/_version.py @@ -1,484 +1,12 @@ +# This file was generated by the `freeze_debianver` command in setup.py +# Using 'versioneer.py' (0.7+) from +# revision-control system data, or from the parent directory name of an +# unpacked source archive. Distribution tarballs contain a pre-generated copy +# of this file. -# This file helps to compute a version number in source trees obtained from -# git-archive tarball (such as those provided by githubs download-from-tag -# feature). Distribution tarballs (built by setup.py sdist) and build -# directories (produced by setup.py build) will contain a much shorter file -# that just contains the computed version number. +version_version = '0.8.0alpha2' +version_full = 'aa6a34bc4ac5962dacaa5908778e444fe5aae3d7' -# This file is released into the public domain. Generated by -# versioneer-0.16 (https://github.com/warner/python-versioneer) -"""Git implementation of _version.py.""" - -import errno -import os -import re -import subprocess -import sys - - -def get_keywords(): - """Get the keywords needed to look up the version information.""" - # these strings will be replaced by git during git-archive. - # setup.py/versioneer.py will grep for the variable names, so they must - # each be defined on a line of their own. _version.py will just call - # get_keywords(). - git_refnames = "$Format:%d$" - git_full = "$Format:%H$" - keywords = {"refnames": git_refnames, "full": git_full} - return keywords - - -class VersioneerConfig: - """Container for Versioneer configuration parameters.""" - - -def get_config(): - """Create, populate and return the VersioneerConfig() object.""" - # these strings are filled in when 'setup.py versioneer' creates - # _version.py - cfg = VersioneerConfig() - cfg.VCS = "git" - cfg.style = "pep440" - cfg.tag_prefix = "" - cfg.parentdir_prefix = "None" - cfg.versionfile_source = "src/leap/soledad/client/_version.py" - cfg.verbose = False - return cfg - - -class NotThisMethod(Exception): - """Exception raised if a method is not valid for the current scenario.""" - - -LONG_VERSION_PY = {} -HANDLERS = {} - - -def register_vcs_handler(vcs, method): # decorator - """Decorator to mark a method as the handler for a particular VCS.""" - def decorate(f): - """Store f in HANDLERS[vcs][method].""" - if vcs not in HANDLERS: - HANDLERS[vcs] = {} - HANDLERS[vcs][method] = f - return f - return decorate - - -def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False): - """Call the given command(s).""" - assert isinstance(commands, list) - p = None - for c in commands: - try: - dispcmd = str([c] + args) - # remember shell=False, so use git.cmd on windows, not just git - p = subprocess.Popen([c] + args, cwd=cwd, stdout=subprocess.PIPE, - stderr=(subprocess.PIPE if hide_stderr - else None)) - break - except EnvironmentError: - e = sys.exc_info()[1] - if e.errno == errno.ENOENT: - continue - if verbose: - print("unable to run %s" % dispcmd) - print(e) - return None - else: - if verbose: - print("unable to find command, tried %s" % (commands,)) - return None - stdout = p.communicate()[0].strip() - if sys.version_info[0] >= 3: - stdout = stdout.decode() - if p.returncode != 0: - if verbose: - print("unable to run %s (error)" % dispcmd) - return None - return stdout - - -def versions_from_parentdir(parentdir_prefix, root, verbose): - """Try to determine the version from the parent directory name. - - Source tarballs conventionally unpack into a directory that includes - both the project name and a version string. - """ - dirname = os.path.basename(root) - if not dirname.startswith(parentdir_prefix): - if verbose: - print("guessing rootdir is '%s', but '%s' doesn't start with " - "prefix '%s'" % (root, dirname, parentdir_prefix)) - raise NotThisMethod("rootdir doesn't start with parentdir_prefix") - return {"version": dirname[len(parentdir_prefix):], - "full-revisionid": None, - "dirty": False, "error": None} - - -@register_vcs_handler("git", "get_keywords") -def git_get_keywords(versionfile_abs): - """Extract version information from the given file.""" - # the code embedded in _version.py can just fetch the value of these - # keywords. When used from setup.py, we don't want to import _version.py, - # so we do it with a regexp instead. This function is not used from - # _version.py. - keywords = {} - try: - f = open(versionfile_abs, "r") - for line in f.readlines(): - if line.strip().startswith("git_refnames ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["refnames"] = mo.group(1) - if line.strip().startswith("git_full ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["full"] = mo.group(1) - f.close() - except EnvironmentError: - pass - return keywords - - -@register_vcs_handler("git", "keywords") -def git_versions_from_keywords(keywords, tag_prefix, verbose): - """Get version information from git keywords.""" - if not keywords: - raise NotThisMethod("no keywords at all, weird") - refnames = keywords["refnames"].strip() - if refnames.startswith("$Format"): - if verbose: - print("keywords are unexpanded, not using") - raise NotThisMethod("unexpanded keywords, not a git-archive tarball") - refs = set([r.strip() for r in refnames.strip("()").split(",")]) - # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of - # just "foo-1.0". If we see a "tag: " prefix, prefer those. - TAG = "tag: " - tags = set([r[len(TAG):] for r in refs if r.startswith(TAG)]) - if not tags: - # Either we're using git < 1.8.3, or there really are no tags. We use - # a heuristic: assume all version tags have a digit. The old git %d - # expansion behaves like git log --decorate=short and strips out the - # refs/heads/ and refs/tags/ prefixes that would let us distinguish - # between branches and tags. By ignoring refnames without digits, we - # filter out many common branch names like "release" and - # "stabilization", as well as "HEAD" and "master". - tags = set([r for r in refs if re.search(r'\d', r)]) - if verbose: - print("discarding '%s', no digits" % ",".join(refs-tags)) - if verbose: - print("likely tags: %s" % ",".join(sorted(tags))) - for ref in sorted(tags): - # sorting will prefer e.g. "2.0" over "2.0rc1" - if ref.startswith(tag_prefix): - r = ref[len(tag_prefix):] - if verbose: - print("picking %s" % r) - return {"version": r, - "full-revisionid": keywords["full"].strip(), - "dirty": False, "error": None - } - # no suitable tags, so version is "0+unknown", but full hex is still there - if verbose: - print("no suitable tags, using unknown + full revision id") - return {"version": "0+unknown", - "full-revisionid": keywords["full"].strip(), - "dirty": False, "error": "no suitable tags"} - - -@register_vcs_handler("git", "pieces_from_vcs") -def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): - """Get version from 'git describe' in the root of the source tree. - - This only gets called if the git-archive 'subst' keywords were *not* - expanded, and _version.py hasn't already been rewritten with a short - version string, meaning we're inside a checked out source tree. - """ - if not os.path.exists(os.path.join(root, ".git")): - if verbose: - print("no .git in %s" % root) - raise NotThisMethod("no .git directory") - - GITS = ["git"] - if sys.platform == "win32": - GITS = ["git.cmd", "git.exe"] - # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] - # if there isn't one, this yields HEX[-dirty] (no NUM) - describe_out = run_command(GITS, ["describe", "--tags", "--dirty", - "--always", "--long", - "--match", "%s*" % tag_prefix], - cwd=root) - # --long was added in git-1.5.5 - if describe_out is None: - raise NotThisMethod("'git describe' failed") - describe_out = describe_out.strip() - full_out = run_command(GITS, ["rev-parse", "HEAD"], cwd=root) - if full_out is None: - raise NotThisMethod("'git rev-parse' failed") - full_out = full_out.strip() - - pieces = {} - pieces["long"] = full_out - pieces["short"] = full_out[:7] # maybe improved later - pieces["error"] = None - - # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty] - # TAG might have hyphens. - git_describe = describe_out - - # look for -dirty suffix - dirty = git_describe.endswith("-dirty") - pieces["dirty"] = dirty - if dirty: - git_describe = git_describe[:git_describe.rindex("-dirty")] - - # now we have TAG-NUM-gHEX or HEX - - if "-" in git_describe: - # TAG-NUM-gHEX - mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe) - if not mo: - # unparseable. Maybe git-describe is misbehaving? - pieces["error"] = ("unable to parse git-describe output: '%s'" - % describe_out) - return pieces - - # tag - full_tag = mo.group(1) - if not full_tag.startswith(tag_prefix): - if verbose: - fmt = "tag '%s' doesn't start with prefix '%s'" - print(fmt % (full_tag, tag_prefix)) - pieces["error"] = ("tag '%s' doesn't start with prefix '%s'" - % (full_tag, tag_prefix)) - return pieces - pieces["closest-tag"] = full_tag[len(tag_prefix):] - - # distance: number of commits since tag - pieces["distance"] = int(mo.group(2)) - - # commit: short hex revision ID - pieces["short"] = mo.group(3) - - else: - # HEX: no tags - pieces["closest-tag"] = None - count_out = run_command(GITS, ["rev-list", "HEAD", "--count"], - cwd=root) - pieces["distance"] = int(count_out) # total number of commits - - return pieces - - -def plus_or_dot(pieces): - """Return a + if we don't already have one, else return a .""" - if "+" in pieces.get("closest-tag", ""): - return "." - return "+" - - -def render_pep440(pieces): - """Build up version string, with post-release "local version identifier". - - Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you - get a tagged build and then dirty it, you'll get TAG+0.gHEX.dirty - - Exceptions: - 1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += plus_or_dot(pieces) - rendered += "%d.g%s" % (pieces["distance"], pieces["short"]) - if pieces["dirty"]: - rendered += ".dirty" - else: - # exception #1 - rendered = "0+untagged.%d.g%s" % (pieces["distance"], - pieces["short"]) - if pieces["dirty"]: - rendered += ".dirty" - return rendered - - -def render_pep440_pre(pieces): - """TAG[.post.devDISTANCE] -- No -dirty. - - Exceptions: - 1: no tags. 0.post.devDISTANCE - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"]: - rendered += ".post.dev%d" % pieces["distance"] - else: - # exception #1 - rendered = "0.post.dev%d" % pieces["distance"] - return rendered - - -def render_pep440_post(pieces): - """TAG[.postDISTANCE[.dev0]+gHEX] . - - The ".dev0" means dirty. Note that .dev0 sorts backwards - (a dirty tree will appear "older" than the corresponding clean one), - but you shouldn't be releasing software with -dirty anyways. - - Exceptions: - 1: no tags. 0.postDISTANCE[.dev0] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += ".post%d" % pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - rendered += plus_or_dot(pieces) - rendered += "g%s" % pieces["short"] - else: - # exception #1 - rendered = "0.post%d" % pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - rendered += "+g%s" % pieces["short"] - return rendered - - -def render_pep440_old(pieces): - """TAG[.postDISTANCE[.dev0]] . - - The ".dev0" means dirty. - - Eexceptions: - 1: no tags. 0.postDISTANCE[.dev0] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += ".post%d" % pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - else: - # exception #1 - rendered = "0.post%d" % pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - return rendered - - -def render_git_describe(pieces): - """TAG[-DISTANCE-gHEX][-dirty]. - - Like 'git describe --tags --dirty --always'. - - Exceptions: - 1: no tags. HEX[-dirty] (note: no 'g' prefix) - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"]: - rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) - else: - # exception #1 - rendered = pieces["short"] - if pieces["dirty"]: - rendered += "-dirty" - return rendered - - -def render_git_describe_long(pieces): - """TAG-DISTANCE-gHEX[-dirty]. - - Like 'git describe --tags --dirty --always -long'. - The distance/hash is unconditional. - - Exceptions: - 1: no tags. HEX[-dirty] (note: no 'g' prefix) - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) - else: - # exception #1 - rendered = pieces["short"] - if pieces["dirty"]: - rendered += "-dirty" - return rendered - - -def render(pieces, style): - """Render the given version pieces into the requested style.""" - if pieces["error"]: - return {"version": "unknown", - "full-revisionid": pieces.get("long"), - "dirty": None, - "error": pieces["error"]} - - if not style or style == "default": - style = "pep440" # the default - - if style == "pep440": - rendered = render_pep440(pieces) - elif style == "pep440-pre": - rendered = render_pep440_pre(pieces) - elif style == "pep440-post": - rendered = render_pep440_post(pieces) - elif style == "pep440-old": - rendered = render_pep440_old(pieces) - elif style == "git-describe": - rendered = render_git_describe(pieces) - elif style == "git-describe-long": - rendered = render_git_describe_long(pieces) - else: - raise ValueError("unknown style '%s'" % style) - - return {"version": rendered, "full-revisionid": pieces["long"], - "dirty": pieces["dirty"], "error": None} - - -def get_versions(): - """Get version information or return default if unable to do so.""" - # I am in _version.py, which lives at ROOT/VERSIONFILE_SOURCE. If we have - # __file__, we can work backwards from there to the root. Some - # py2exe/bbfreeze/non-CPython implementations don't do __file__, in which - # case we can only use expanded keywords. - - cfg = get_config() - verbose = cfg.verbose - - try: - return git_versions_from_keywords(get_keywords(), cfg.tag_prefix, - verbose) - except NotThisMethod: - pass - - try: - root = os.path.realpath(__file__) - # versionfile_source is the relative path from the top of the source - # tree (where the .git directory might live) to this file. Invert - # this to find the root from __file__. - for i in cfg.versionfile_source.split('/'): - root = os.path.dirname(root) - except NameError: - return {"version": "0+unknown", "full-revisionid": None, - "dirty": None, - "error": "unable to find root of source tree"} - - try: - pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose) - return render(pieces, cfg.style) - except NotThisMethod: - pass - - try: - if cfg.parentdir_prefix: - return versions_from_parentdir(cfg.parentdir_prefix, root, verbose) - except NotThisMethod: - pass - - return {"version": "0+unknown", "full-revisionid": None, - "dirty": None, - "error": "unable to compute version"} +def get_versions(default={}, verbose=False): + return {'version': version_version, 'full': version_full} diff --git a/client/src/taskthread/__init__.py b/client/src/taskthread/__init__.py new file mode 100755 index 00000000..a734a829 --- /dev/null +++ b/client/src/taskthread/__init__.py @@ -0,0 +1,296 @@ +# Copyright 2013 Hewlett-Packard Development Company, L.P. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +import logging +import threading + +__version__ = '1.4' + + +logger = logging.getLogger(__name__) + + +class TaskInProcessException(BaseException): + pass + + +class TaskThread(threading.Thread): + """ + A thread object that repeats a task. + + Usage example:: + + from taskthread import TaskThread + + import time + + def my_task(*args, **kwargs): + print args, kwargs + + task_thread = TaskThread(my_task) + task_thread.start() + for i in xrange(10): + task_thread.run_task() + task_thread.join_task() + task_thread.join() + + .. note:: If :py:meth:`~TaskThread.run_task` is + invoked while run_task is in progress, + :py:class:`~.TaskInProcessException` will + be raised. + + :param task: + A ``function``. This param is the task to execute when + run_task is called. + :param event: + A ``threading.Event``. This event will be set when run_task + is called. The default value is a new event, but may be + specified for testing purposes. + """ + + daemon = True + ''' + Threads marked as daemon will be terminated. + ''' + def __init__(self, task, event=threading.Event(), + *args, **kwargs): + super(TaskThread, self).__init__() + self.task = task + self.task_event = event + self.running = True + self.running_lock = threading.Lock() + self.in_task = False + self.task_complete = threading.Event() + self.args = args + self.kwargs = kwargs + + def run(self): + """ + Called by threading.Thread, this runs in the new thread. + """ + while True: + self.task_event.wait() + if not self.running: + logger.debug("TaskThread exiting") + return + logger.debug("TaskThread starting task") + with self.running_lock: + self.task_event.clear() + self.task_complete.clear() + self.task(*self.args, **self.kwargs) + with self.running_lock: + self.in_task = False + self.task_complete.set() + + def run_task(self, *args, **kwargs): + """ + Run an instance of the task. + + :param args: + The arguments to pass to the task. + + :param kwargs: + The keyword arguments to pass to the task. + """ + # Don't allow this call if the thread is currently + # in a task. + with self.running_lock: + if self.in_task: + raise TaskInProcessException() + self.in_task = True + logger.debug("Waking up the thread") + self.args = args + self.kwargs = kwargs + # Wake up the thread to do it's thing + self.task_event.set() + + def join_task(self, time_out): + """ + Wait for the currently running task to complete. + + :param time_out: + An ``int``. The amount of time to wait for the + task to finish. + """ + with self.running_lock: + if not self.in_task: + return + + success = self.task_complete.wait(time_out) + if success: + self.task_complete.clear() + return success + + def join(self, timeout=None): + """ + Wait for the task to finish + """ + self.running = False + self.task_event.set() + super(TaskThread, self).join(timeout=timeout) + + +class TimerTask(object): + """ + An object that executes a commit function at a given interval. + This class is driven by a TaskThread. A new TaskThread will be + created the first time :py:meth:`.~start` is called. All + subsequent calls to start will reuse the same thread. + + Usage example:: + + from taskthread import TimerTask + import time + + count = 0 + def get_count(): + return count + def execute(): + print "Count: %d" % count + + task = TimerTask(execute, + timeout=10, + count_fcn=get_count, + threshold=1) + + task.start() + + for i in xrange(100000): + count += 1 + time.sleep(1) + task.stop() + count = 0 + task.start() + for i in xrange(100000): + count += 1 + time.sleep(1) + task.shutdown() + + :param execute_fcn: + A `function`. This function will be executed on each time interval. + + :param delay: + An `int`. The delay in **seconds** invocations of + `execute_fcn`. Default: `10`. + + :param count_fcn: + A `function`. This function returns a current count. If the count + has not changed more the `threshold` since the last invocation of + `execute_fcn`, `execute_fcn` will not be called again. If not + specified, `execute_fcn` will be called each time the timer fires. + **Optional**. If count_fcn is specified, ``threshold`` is required. + + :param threshold: + An `int`. Specifies the minimum delta in `count_fcn` that must be + met for `execute_fcn` to be invoked. **Optional**. Must be + specified in conjunction with `count_fcn`. + + """ + def __init__(self, execute_fcn, delay=10, count_fcn=None, threshold=None): + self.running = True + self.execute_fcn = execute_fcn + self.last_count = 0 + self.event = threading.Event() + self.delay = delay + self.thread = None + self.running_lock = threading.RLock() + if bool(threshold) != bool(count_fcn): + raise ValueError("Must specify threshold " + "and count_fcn, or neither") + + self.count_fcn = count_fcn + self.threshold = threshold + + def start(self): + """ + Start the task. This starts a :py:class:`.~TaskThread`, and starts + running run_threshold_timer on the thread. + + """ + if not self.thread: + logger.debug('Starting up the taskthread') + self.thread = TaskThread(self._run_threshold_timer) + self.thread.start() + + if self.threshold: + self.last_count = 0 + + logger.debug('Running the task') + self.running = True + self.thread.run_task() + + def stop(self): + """ + Stop the task, leaving the :py:class:`.~TaskThread` running + so start can be called again. + + """ + logger.debug('Stopping the task') + wait = False + with self.running_lock: + if self.running: + wait = True + self.running = False + if wait: + self.event.set() + self.thread.join_task(2) + + def shutdown(self): + """ + Close down the task thread and stop the task if it is running. + + """ + logger.debug('Shutting down the task') + self.stop() + self.thread.join(2) + + def _exec_if_threshold_met(self): + new_count = self.count_fcn() + logger.debug('new_count: %d' % new_count) + if new_count >= self.last_count + self.threshold: + self.execute_fcn() + self.last_count = new_count + + def _exec(self): + if self.count_fcn: + self._exec_if_threshold_met() + else: + self.execute_fcn() + + def _wait(self): + self.event.wait(timeout=self.delay) + self.event.clear() + logger.debug('Task woken up') + + def _exit_loop(self): + """ + If self.running is false, it means the task should shut down. + """ + exit_loop = False + with self.running_lock: + if not self.running: + exit_loop = True + logger.debug('Task shutting down') + return exit_loop + + def _run_threshold_timer(self): + """ + Main loop of the timer task + + """ + logger.debug('In Task') + while True: + self._wait() + if self._exit_loop(): + return + self._exec() diff --git a/client/src/taskthread/tests/__init__.py b/client/src/taskthread/tests/__init__.py new file mode 100755 index 00000000..92bd912f --- /dev/null +++ b/client/src/taskthread/tests/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2013 Hewlett-Packard Development Company, L.P. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. diff --git a/client/src/taskthread/tests/unit/__init__.py b/client/src/taskthread/tests/unit/__init__.py new file mode 100755 index 00000000..92bd912f --- /dev/null +++ b/client/src/taskthread/tests/unit/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2013 Hewlett-Packard Development Company, L.P. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. diff --git a/client/src/taskthread/tests/unit/test_taskthread.py b/client/src/taskthread/tests/unit/test_taskthread.py new file mode 100755 index 00000000..82565922 --- /dev/null +++ b/client/src/taskthread/tests/unit/test_taskthread.py @@ -0,0 +1,315 @@ +# -*- coding: utf-8 -*- +# Copyright 2013 Hewlett-Packard Development Company, L.P. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License.:w + + +import threading +import unittest2 as unittest + +from mock import Mock, patch + +from taskthread import TaskThread, TaskInProcessException, TimerTask + +forever_event = threading.Event() + + +def forever_function(*args, **kwargs): + forever_event.wait() + forever_event.clear() + + +class TaskThreadTestCase(unittest.TestCase): + """ + Tests for :py:class:`.TaskThread`. + """ + + def test___init__(self): + """ + Test the __init__ method. It doesn't really do much. + """ + task_thread = TaskThread(forever_function) + self.assertEqual(forever_function, task_thread.task) + + def test_run_not_running(self): + """ + Verifies that thread will shut down when running is false + """ + event = Mock() + event.wait = Mock(side_effect=[True]) + event.clear = Mock(side_effect=Exception("Should never be called")) + task_thread = TaskThread(forever_function, + event=event) + task_thread.running = False + task_thread.run() + event.wait.assert_called_once_with() + + def test_run_executes_task(self): + event = Mock() + event.wait = Mock(side_effect=[True, True]) + + def stop_iteration(*args, **kwargs): + args[0].running = False + + task_thread = TaskThread(stop_iteration, + event=event) + + task_thread.args = [task_thread] + task_thread.kwargs = {'a': 2} + task_thread.in_task = True + task_thread.run() + self.assertEqual(False, task_thread.in_task) + + def test_run_task(self): + event = Mock() + task_thread = TaskThread(forever_function, + event=event) + args = [1] + kwargs = {'a': 1} + + task_thread.run_task(*args, **kwargs) + self.assertEqual(tuple(args), task_thread.args) + self.assertEqual(kwargs, task_thread.kwargs) + event.set.assert_called_once_with() + + def test_run_task_task_in_progress(self): + event = Mock() + task_thread = TaskThread(forever_function, + event=event) + task_thread.in_task = True + self.assertRaises(TaskInProcessException, task_thread.run_task) + + def test_join_task(self): + task_thread = TaskThread(forever_function) + task_thread.in_task = True + task_thread.task_complete = Mock() + task_thread.task_complete.wait = Mock(side_effect=[True]) + success = task_thread.join_task(1) + self.assertTrue(success) + + def test_join_task_not_running(self): + task_thread = TaskThread(forever_function) + task_thread.task_complete = Mock() + task_thread.wait =\ + Mock(side_effect=Exception("Should never be called")) + task_thread.join_task(1) + + def test_join(self): + task_thread = TaskThread(forever_function) + task_thread.start() + task_thread.run_task() + # Set the event so the task completes + forever_event.set() + task_thread.join_task(1) + task_thread.join(1) + + def test_execute_multiple_tasks(self): + task_thread = TaskThread(forever_function) + task_thread.start() + task_thread.run_task() + # Set the event so the task completes + forever_event.set() + task_thread.join_task(1) + forever_event.set() + task_thread.join_task(1) + task_thread.join(1) + + +def my_func(): + pass + + +class TimerTaskTestCase(unittest.TestCase): + + def test___int__(self): + + task = TimerTask(my_func, + delay=100) + self.assertEqual(my_func, task.execute_fcn) + self.assertEqual(100, task.delay) + self.assertIsNone(task.count_fcn) + self.assertIsNone(task.threshold) + + def test___int__raises(self): + self.assertRaises(ValueError, TimerTask.__init__, + TimerTask(None), + my_func(), + count_fcn=Mock()) + + self.assertRaises(ValueError, TimerTask.__init__, + TimerTask(None), + my_func(), + threshold=Mock()) + + @patch('taskthread.TaskThread') + def test_start(self, TaskThreadMock): + task = TimerTask(my_func) + thread = TaskThreadMock.return_value + + task.start() + self.assertTrue(task.running) + self.assertEqual(thread, task.thread) + thread.start.assert_called_once_with() + thread.run_task.assert_called_once_with() + + @patch('taskthread.TaskThread') + def test_start_restarts(self, TaskThreadMock): + task = TimerTask(my_func, threshold=1, count_fcn=Mock()) + thread = TaskThreadMock.return_value + task.last_count = 1 + task.thread = thread + + task.start() + self.assertEqual(0, task.last_count) + self.assertEqual(0, thread.start.called) + thread.run_task.assert_called_once_with() + + @patch('taskthread.TaskThread') + def test_stop(self, TaskThreadMock): + running_lock = Mock() + running_lock.__enter__ = Mock() + running_lock.__exit__ = Mock() + task = TimerTask(my_func) + task.thread = TaskThreadMock.return_value + task.running = True + task.event = Mock() + task.running_lock = running_lock + + task.stop() + + self.assertEqual(False, task.running) + self.assertEqual(1, task.event.set.called) + running_lock.__enter__.assert_called_once_with() + running_lock.__exit__.assert_called_once_with(None, None, None) + task.thread.join_task.assert_called_once_with(2) + + @patch('taskthread.TaskThread') + def test_stop_not_running(self, TaskThreadMock): + task = TimerTask(my_func) + task.thread = TaskThreadMock.return_value + task.running = False + task.event = Mock() + + task.stop() + + self.assertEqual(False, task.running) + self.assertEqual(0, task.event.set.called) + self.assertEqual(0, task.thread.join_task.called) + + @patch('taskthread.TaskThread') + def test_shutdown(self, TaskThreadMock): + task = TimerTask(my_func) + task.thread = TaskThreadMock.return_value + task.running = False + task.shutdown() + task.thread.join.assert_called_once_with(2) + + def test__exec_if_threshold_met(self): + self.called = False + + def exec_fcn(): + self.called = True + + def count_fcn(): + return 10 + + task = TimerTask(exec_fcn, count_fcn=count_fcn, threshold=1) + task.last_count = 9 + task._exec_if_threshold_met() + self.assertTrue(self.called) + self.assertEqual(10, task.last_count) + + def test__exec_if_threshold_met_not_met(self): + + def exec_fcn(): + raise Exception("This shouldn't happen!!") + + def count_fcn(): + return 10 + + task = TimerTask(exec_fcn, count_fcn=count_fcn, threshold=10) + task.last_count = 9 + task._exec_if_threshold_met() + self.assertEqual(9, task.last_count) + + def test__exec(self): + self.called = False + + def exec_fcn(): + self.called = True + + task = TimerTask(exec_fcn) + task._exec() + self.assertTrue(self.called) + + def test__exec_threshold(self): + self.called = False + + def exec_fcn(): + self.called = True + + def count_fcn(): + return 1 + + task = TimerTask(exec_fcn, count_fcn=count_fcn, threshold=1) + task._exec() + self.assertTrue(self.called) + + @patch('threading.Event') + def test__wait(self, event_mock): + task = TimerTask(my_func) + event = event_mock.return_value + + task._wait() + event.wait.assert_called_once_with(timeout=task.delay) + self.assertEqual(1, event.clear.called) + + @patch('threading.RLock') + def test__exit_loop(self, mock_rlock): + task = TimerTask(my_func) + task.running = False + lock = mock_rlock.return_value + lock.__enter__ = Mock() + lock.__exit__ = Mock() + self.assertTrue(task._exit_loop()) + self.assertEqual(1, lock.__enter__.called) + lock.__exit__.assert_called_once_with(None, None, None) + + @patch('threading.RLock') + def test__exit_loop_running(self, mock_rlock): + lock = mock_rlock.return_value + lock.__enter__ = Mock() + lock.__exit__ = Mock() + task = TimerTask(my_func) + task.running = True + self.assertFalse(task._exit_loop()) + self.assertEqual(1, lock.__enter__.called) + lock.__exit__.assert_called_once_with(None, None, None) + + @patch('threading.RLock') + @patch('threading.Event') + def test__run_threshold_timer(self, event_mock, rlock_mock): + self.task = None + event = event_mock.return_value + lock = rlock_mock.return_value + lock.__enter__ = Mock() + lock.__exit__ = Mock() + + def exec_fcn(): + self.task.running = False + + self.task = TimerTask(exec_fcn) + self.task._run_threshold_timer() + + self.assertFalse(self.task.running) + self.assertEqual(2, event.wait.call_count) |