summaryrefslogtreecommitdiff
path: root/sh.py
diff options
context:
space:
mode:
Diffstat (limited to 'sh.py')
-rw-r--r--sh.py175
1 files changed, 132 insertions, 43 deletions
diff --git a/sh.py b/sh.py
index 54bf92d..0e46f14 100644
--- a/sh.py
+++ b/sh.py
@@ -21,7 +21,7 @@
#===============================================================================
-__version__ = "1.07"
+__version__ = "1.08"
__project_url__ = "https://github.com/amoffat/sh"
@@ -120,6 +120,17 @@ class ErrorReturnCode(Exception):
msg = "\n\n RAN: %r\n\n STDOUT:\n%s\n\n STDERR:\n%s" %\
(full_cmd, tstdout.decode(DEFAULT_ENCODING), tstderr.decode(DEFAULT_ENCODING))
super(ErrorReturnCode, self).__init__(msg)
+
+
+class SignalException(ErrorReturnCode): pass
+
+SIGNALS_THAT_SHOULD_THROW_EXCEPTION = (
+ signal.SIGKILL,
+ signal.SIGSEGV,
+ signal.SIGTERM,
+ signal.SIGINT,
+ signal.SIGQUIT
+)
# we subclass AttributeError because:
@@ -127,7 +138,7 @@ class ErrorReturnCode(Exception):
# https://github.com/amoffat/sh/issues/97#issuecomment-10610629
class CommandNotFound(AttributeError): pass
-rc_exc_regex = re.compile("ErrorReturnCode_(\d+)")
+rc_exc_regex = re.compile("(ErrorReturnCode|SignalException)_(\d+)")
rc_exc_cache = {}
def get_rc_exc(rc):
@@ -135,8 +146,13 @@ def get_rc_exc(rc):
try: return rc_exc_cache[rc]
except KeyError: pass
- name = "ErrorReturnCode_%d" % rc
- exc = type(name, (ErrorReturnCode,), {})
+ if rc > 0:
+ name = "ErrorReturnCode_%d" % rc
+ exc = type(name, (ErrorReturnCode,), {})
+ else:
+ name = "SignalException_%d" % abs(rc)
+ exc = type(name, (SignalException,), {})
+
rc_exc_cache[rc] = exc
return exc
@@ -222,6 +238,13 @@ class RunningCommand(object):
self.cmd = cmd
self.ran = " ".join(cmd)
self.process = None
+
+ # this flag is for whether or not we've handled the exit code (like
+ # by raising an exception). this is necessary because .wait() is called
+ # from multiple places, and wait() triggers the exit code to be
+ # processed. but we don't want to raise multiple exceptions, only
+ # one (if any at all)
+ self._handled_exit_code = False
self.should_wait = True
spawn_process = True
@@ -275,11 +298,18 @@ class RunningCommand(object):
# here we determine if we had an exception, or an error code that we weren't
# expecting to see. if we did, we create and raise an exception
def _handle_exit_code(self, code):
- if code not in self.call_args["ok_code"] and code >= 0: raise get_rc_exc(code)(
- " ".join(self.cmd),
- self.process.stdout,
- self.process.stderr
- )
+ if self._handled_exit_code: return
+ self._handled_exit_code = True
+
+ if code not in self.call_args["ok_code"] and \
+ (code > 0 or -code in SIGNALS_THAT_SHOULD_THROW_EXCEPTION):
+ raise get_rc_exc(code)(
+ " ".join(self.cmd),
+ self.process.stdout,
+ self.process.stderr
+ )
+
+
@property
def stdout(self):
@@ -416,6 +446,7 @@ class Command(object):
"iter_noblock": None,
"ok_code": 0,
"cwd": None,
+ "long_sep": "=",
# this is for programs that expect their input to be from a terminal.
# ssh is one of those programs
@@ -451,24 +482,48 @@ class Command(object):
("piped", "iter", "You cannot iterate when this command is being piped"),
)
+
+ # this method exists because of the need to have some way of letting
+ # manual object instantiation not perform the underscore-to-dash command
+ # conversion that resolve_program uses.
+ #
+ # there are 2 ways to create a Command object. using sh.Command(<program>)
+ # or by using sh.<program>. the method fed into sh.Command must be taken
+ # literally, and so no underscore-dash conversion is performed. the one
+ # for sh.<program> must do the underscore-dash converesion, because we
+ # can't type dashes in method names
@classmethod
- def _create(cls, program):
+ def _create(cls, program, **default_kwargs):
path = resolve_program(program)
if not path: raise CommandNotFound(program)
- return cls(path)
-
+
+ cmd = cls(path)
+ if default_kwargs: cmd = cmd.bake(**default_kwargs)
+
+ return cmd
+
+
def __init__(self, path):
- self._path = which(path)
+ path = which(path)
+ if not path: raise CommandNotFound(path)
+ self._path = path
+
self._partial = False
self._partial_baked_args = []
self._partial_call_args = {}
+ # bugfix for functools.wraps. issue #121
+ self.__name__ = repr(self)
+
+
def __getattribute__(self, name):
# convenience
getattr = partial(object.__getattribute__, self)
+
+ if name.startswith("_"): return getattr(name)
+ if name == "bake": return getattr("bake")
+ if name.endswith("_"): name = name[:-1]
- if name.startswith("_"): return getattr(name)
- if name == "bake": return getattr("bake")
return getattr("bake")(name)
@@ -497,12 +552,44 @@ class Command(object):
return call_args, kwargs
+ # this helper method is for normalizing an argument into a string in the
+ # system's default encoding. we can feed it a number or a string or
+ # whatever
def _format_arg(self, arg):
if IS_PY3: arg = str(arg)
- else: arg = unicode(arg).encode(DEFAULT_ENCODING)
+ else:
+ # if the argument is already unicode, or a number or whatever,
+ # this first call will fail.
+ try: arg = unicode(arg, DEFAULT_ENCODING).encode(DEFAULT_ENCODING)
+ except TypeError: arg = unicode(arg).encode(DEFAULT_ENCODING)
return arg
- def _compile_args(self, args, kwargs):
+
+ def _aggregate_keywords(self, keywords, sep, raw=False):
+ processed = []
+ for k, v in keywords.items():
+ # we're passing a short arg as a kwarg, example:
+ # cut(d="\t")
+ if len(k) == 1:
+ if v is not False:
+ processed.append("-" + k)
+ if v is not True:
+ processed.append(self._format_arg(v))
+
+ # we're doing a long arg
+ else:
+ if not raw: k = k.replace("_", "-")
+
+ if v is True:
+ processed.append("--" + k)
+ elif v is False:
+ pass
+ else:
+ processed.append("--%s%s%s" % (k, sep, self._format_arg(v)))
+ return processed
+
+
+ def _compile_args(self, args, kwargs, sep):
processed_args = []
# aggregate positional args
@@ -512,28 +599,18 @@ class Command(object):
warnings.warn("Empty list passed as an argument to %r. \
If you're using glob.glob(), please use sh.glob() instead." % self.path, stacklevel=3)
for sub_arg in arg: processed_args.append(self._format_arg(sub_arg))
- else: processed_args.append(self._format_arg(arg))
-
+ elif isinstance(arg, dict):
+ processed_args += self._aggregate_keywords(arg, sep, raw=True)
+ else:
+ processed_args.append(self._format_arg(arg))
+
# aggregate the keyword arguments
- for k,v in kwargs.items():
- # we're passing a short arg as a kwarg, example:
- # cut(d="\t")
- if len(k) == 1:
- if v is not False:
- processed_args.append("-"+k)
- if v is not True: processed_args.append(self._format_arg(v))
-
- # we're doing a long arg
- else:
- k = k.replace("_", "-")
-
- if v is True: processed_args.append("--"+k)
- elif v is False: pass
- else: processed_args.append("--%s=%s" % (k, self._format_arg(v)))
+ processed_args += self._aggregate_keywords(kwargs, sep)
return processed_args
+ # TODO needs documentation
def bake(self, *args, **kwargs):
fn = Command(self._path)
fn._partial = True
@@ -550,7 +627,8 @@ If you're using glob.glob(), please use sh.glob() instead." % self.path, stackle
fn._partial_call_args.update(self._partial_call_args)
fn._partial_call_args.update(pruned_call_args)
fn._partial_baked_args.extend(self._partial_baked_args)
- fn._partial_baked_args.extend(self._compile_args(args, kwargs))
+ sep = pruned_call_args.get("long_sep", self._call_args["long_sep"])
+ fn._partial_baked_args.extend(self._compile_args(args, kwargs, sep))
return fn
def __str__(self):
@@ -562,7 +640,7 @@ If you're using glob.glob(), please use sh.glob() instead." % self.path, stackle
except: return False
def __repr__(self):
- return str(self)
+ return "<Command %r>" % str(self)
def __unicode__(self):
baked_args = " ".join(self._partial_baked_args)
@@ -618,7 +696,7 @@ If you're using glob.glob(), please use sh.glob() instead." % self.path, stackle
else:
args.insert(0, first_arg)
- processed_args = self._compile_args(args, kwargs)
+ processed_args = self._compile_args(args, kwargs, call_args["long_sep"])
# makes sure our arguments are broken up correctly
split_args = self._partial_baked_args + processed_args
@@ -1455,9 +1533,10 @@ class StreamBufferer(object):
# the exec() statement used in this file requires the "globals" argument to
# be a dictionary
class Environment(dict):
- def __init__(self, globs):
+ def __init__(self, globs, baked_args={}):
self.globs = globs
-
+ self.baked_args = baked_args
+
def __setitem__(self, k, v):
self.globs[k] = v
@@ -1483,7 +1562,10 @@ Please import sh or import programs individually.")
try: return rc_exc_cache[k]
except KeyError:
m = rc_exc_regex.match(k)
- if m: return get_rc_exc(int(m.group(1)))
+ if m:
+ exit_code = int(m.group(2))
+ if m.group(1) == "SignalException": exit_code = -exit_code
+ return get_rc_exc(exit_code)
# is it a builtin?
try: return getattr(self["__builtins__"], k)
@@ -1505,7 +1587,10 @@ Please import sh or import programs individually.")
if builtin: return builtin
# it must be a command then
- return Command._create(k)
+ # we use _create instead of instantiating the class directly because
+ # _create uses resolve_program, which will automatically do underscore-
+ # to-dash conversions. instantiating directly does not use that
+ return Command._create(k, **self.baked_args)
# methods that begin with "b_" are custom builtins and will override any
@@ -1546,7 +1631,7 @@ def run_repl(env):
# system PATH worth of commands. in this case, we just proxy the
# import lookup to our Environment class
class SelfWrapper(ModuleType):
- def __init__(self, self_module):
+ def __init__(self, self_module, baked_args={}):
# this is super ugly to have to copy attributes like this,
# but it seems to be the only way to make reload() behave
# nicely. if i make these attributes dynamic lookups in
@@ -1558,7 +1643,7 @@ class SelfWrapper(ModuleType):
# if we set this to None. and 3.3 needs a value for __path__
self.__path__ = []
self.self_module = self_module
- self.env = Environment(globals())
+ self.env = Environment(globals(), baked_args)
def __setattr__(self, name, value):
if hasattr(self, "env"): self.env[name] = value
@@ -1568,6 +1653,10 @@ class SelfWrapper(ModuleType):
if name == "env": raise AttributeError
return self.env[name]
+ # accept special keywords argument to define defaults for all operations
+ # that will be processed with given by return SelfWrapper
+ def __call__(self, **kwargs):
+ return SelfWrapper(self.self_module, kwargs)