diff options
| author | Kali Kaneko <kali@leap.se> | 2014-05-05 09:23:11 -0500 | 
|---|---|---|
| committer | Kali Kaneko <kali@leap.se> | 2014-05-12 11:24:43 -0500 | 
| commit | 7dd7d8dac61db9623ae97fc9669eaac693b9a3ee (patch) | |
| tree | 2bf9332fe60b18502a424f5eae5b415b7030e6e6 | |
| parent | 65688daee1d10163d82970426467aa4fed6359b1 (diff) | |
bitmask-root wrapper improvements
* add missing constant for ip command
* use all prints as functions
* add missing docstrings
* add alternatives for openvpn bin and resolvconf script
* use random dirs for management socket
* use exec to spawn openvpn
* make bitmask chain constant
* add script name in stdout lines
| -rwxr-xr-x | pkg/linux/bitmask-root | 406 | 
1 files changed, 328 insertions, 78 deletions
| diff --git a/pkg/linux/bitmask-root b/pkg/linux/bitmask-root index 4cb214e1..b9a7acbc 100755 --- a/pkg/linux/bitmask-root +++ b/pkg/linux/bitmask-root @@ -25,15 +25,25 @@ USAGE:    bitmask-root firewall start GATEWAY1 GATEWAY2 ...    bitmask-root openvpn stop    bitmask-root openvpn start CONFIG1 CONFIG1 ... + +All actions return exit code 0 for success, non-zero otherwise. + +The `openvpn start` action is special: it calls exec on openvpn and replaces +the current process.  """  # TODO should be tested with python3, which can be the default on some distro.  from __future__ import print_function  import os +import re  import subprocess  import socket  import sys -import re +import traceback + + +# XXX not standard +import psutil  cmdcheck = subprocess.check_output @@ -41,10 +51,29 @@ cmdcheck = subprocess.check_output  ## CONSTANTS  ## -OPENVPN = "/usr/sbin/openvpn" +SCRIPT = "bitmask-root" +NAMESERVER = "10.42.0.1" +BITMASK_CHAIN = "bitmask" + +IP = "/bin/ip"  IPTABLES = "/sbin/iptables"  IP6TABLES = "/sbin/ip6tables" -UPDATE_RESOLV_CONF = "/etc/openvpn/update-resolv-conf" +RESOLVCONF = "/sbin/resolvconf" +OPENVPN_USER = "nobody" +OPENVPN_GROUP = "nogroup" + +LEAPOPENVPN = "LEAPOPENVPN" +OPENVPN_SYSTEM_BIN = "/usr/sbin/openvpn"  # Debian location +OPENVPN_LEAP_BIN = "/usr/sbin/leap-openvpn"  # installed by bundle + + +""" +The path to the script to update resolv.conf +""" +# XXX We have to check if we have a valid resolvconf, and use +# the old resolv-update if not. +LEAP_UPDATE_RESOLVCONF_FILE = "/etc/leap/update-resolv-conf" +LEAP_RESOLV_UPDATE = "/etc/leap/resolv-update"  FIXED_FLAGS = [      "--setenv", "LEAPOPENVPN", "1", @@ -54,17 +83,20 @@ FIXED_FLAGS = [      "--tls-client",      "--remote-cert-tls", "server",      "--management-signal", -    "--management", "/tmp/openvpn.socket", "unix", -    "--up", UPDATE_RESOLV_CONF, -    "--down", UPDATE_RESOLV_CONF, -    "--script-security", "2" +    "--management", MANAGEMENT_SOCKET, "unix", +    "--script-security", "1" +    "--user", "nobody", +    "--group", "nogroup",  ] +#    "--management", MANAGEMENT_SOCKET, "unix", +  ALLOWED_FLAGS = {      "--remote": ["IP", "NUMBER", "PROTO"],      "--tls-cipher": ["CIPHER"],      "--cipher": ["CIPHER"],      "--auth": ["CIPHER"], +    "--management": ["DIR", "UNIXSOCKET"],      "--management-client-user": ["USER"],      "--cert": ["FILE"],      "--key": ["FILE"], @@ -78,11 +110,15 @@ PARAM_FORMATS = {      "CIPHER": lambda s: re.match("^[A-Z0-9-]+$", s),      "USER": lambda s: re.match(          "^[a-zA-Z0-9_\.\@][a-zA-Z0-9_\-\.\@]*\$?$", s),  # IEEE Std 1003.1-2001 -    "FILE": lambda s: os.path.isfile(s) +    "FILE": lambda s: os.path.isfile(s), +    "DIR": lambda s: os.path.isdir(os.path.split(s)[0]), +    "UNIXSOCKET": lambda s: s == "unix"  }  DEBUG = os.getenv("DEBUG") +TEST = os.getenv("TEST") +  if DEBUG:      import logging      formatter = logging.Formatter( @@ -93,7 +129,6 @@ if DEBUG:      logger = logging.getLogger(__name__)      logger.setLevel(logging.DEBUG)      logger.addHandler(ch) -    logger.debug(" ".join(sys.argv))  ##  ## UTILITY @@ -112,70 +147,152 @@ def is_valid_address(value):          socket.inet_aton(value)          return True      except Exception: -        print "MALFORMED IP: %s!" % value +        print("%s: ERROR: MALFORMED IP: %s!" % (SCRIPT, value))          return False -def split_list(list, regex): +def has_system_resolvconf(): +    """ +    Return True if resolvconf is found in the system. + +    :rtype: bool +    """ +    return os.path.isfile(RESOLVCONF) + + +def has_valid_update_resolvconf(): +    """ +    Return True if a valid update-resolv-conf script is found in the system. + +    :rtype: bool +    """ +    return os.path.isfile(LEAP_UPDATE_RESOLVCONF_FILE) + + +def has_valid_leap_resolv_update(): +    """ +    Return True if a valid resolv-update script is found in the system. + +    :rtype: bool +    """ +    return os.path.isfile(LEAP_RESOLV_UPDATE) + + +def split_list(_list, regex):      """ -    Splits a list based on a regex: +    Split a list based on a regex:      e.g. split_list(["xx", "yy", "x1", "zz"], "^x") => [["xx", "yy"], ["x1",      "zz"]] -    :param list: the list to be split. -    :type list: list +    :param _list: the list to be split. +    :type _list: list +    :param regex: the regex expression to filter with. +    :type regex: str +      :rtype: list      """      if not hasattr(regex, "match"):          regex = re.compile(regex)      result = []      i = 0 +    if not _list: +        return result      while True: -        if regex.match(list[i]): +        if regex.match(_list[i]):              result.append([])              while True: -                result[-1].append(list[i]) +                result[-1].append(_list[i])                  i += 1 -                if i >= len(list) or regex.match(list[i]): +                if i >= len(_list) or regex.match(_list[i]):                      break          else:              i += 1 -        if i >= len(list): +        if i >= len(_list):              break      return result -# i think this is not needed with shell=False -#def sanify(command, *args): -#    return [command] + [pipes.quote(a) for a in args] -  def run(command, *args, **options): +    """ +    Run an external command. + +    Options: + +      `check`: If True, check the command's output. bail if non-zero. (the +               default is true unless detach or input is true) +      `exitcode`: like `check`, but return exitcode instead of bailing. +      `detach`: If True, run in detached process. +      `input`: If True, open command for writing stream to, returning the Popen +               object. +    """      parts = [command]      parts.extend(args) -    if DEBUG: -        print "run: " + " ".join(parts) -    if not options.get("check", True) or options.get("detach", False): -        subprocess.Popen(parts) +    if TEST or DEBUG: +        print("%s run: %s " (SCRIPT, " ".join(parts))) + +    _check = options.get("check", True) +    _detach = options.get("detach", False) +    _input = options.get("input", False) +    _exitcode = options.get("exitcode", False) + +    if not _check or _detach or _input: +        if _input: +            return subprocess.Popen(parts, stdin=subprocess.PIPE) +        else: +            # XXX ok with return None ?? +            subprocess.Popen(parts)      else:          try:              devnull = open('/dev/null', 'w')              subprocess.check_call(parts, stdout=devnull, stderr=devnull)              return 0 -        except subprocess.CalledProcessError as ex: -            if options.get("exitcode", False): -                return ex.returncode +        except subprocess.CalledProcessError as exc: +            if DEBUG: +                logger.exception(exc) +            if _exitcode: +                return exc.returncode              else: -                bail("Could not run %s: %s" % (ex.cmd, ex.output)) +                bail("ERROR: Could not run %s: %s" % (exc.cmd, exc.output), +                     exception=exc) + + +def bail(msg=None, exception=None): +    """ +    Abnormal exit. + +    :param msg: optional error message. +    :type msg: str +    """ +    if msg is not None: +        print("%s: %s" % (SCRIPT, msg)) +    if exception is not None: +        traceback.print_exc() +    exit(1)  ##  ## OPENVPN  ## +def get_openvpn_bin(): +    """ +    Return the path for either the system openvpn or the one the +    bundle has put there. +    """ +    if os.path.isfile(OPENVPN_SYSTEM_BIN): +        return OPENVPN_SYSTEM_BIN + +    # the bundle option should be removed from the debian package. +    if os.path.isfile(OPENVPN_LEAP_BIN): +        return OPENVPN_LEAP_BIN + +  def parse_openvpn_flags(args):      """ -    takes argument list from the command line and parses it, only allowing some +    Take argument list from the command line and parse it, only allowing some      configuration flags. + +    :type args: list      """      result = []      try: @@ -184,36 +301,95 @@ def parse_openvpn_flags(args):              if flag_name in ALLOWED_FLAGS:                  result.append(flag_name)                  required_params = ALLOWED_FLAGS[flag_name] -                if len(required_params) > 0: +                if required_params:                      flag_params = flag[1:]                      if len(flag_params) != len(required_params): -                        print "ERROR: not enough params for %s" % flag_name +                        print("%s: ERROR: not enough params for %s" % +                              (SCRIPT, flag_name))                          return None                      for param, param_type in zip(flag_params, required_params):                          if PARAM_FORMATS[param_type](param):                              result.append(param)                          else: -                            print "ERROR: Bad argument %s" % param +                            print("%s: ERROR: Bad argument %s" % +                                  (SCRIPT, param))                              return None              else: -                print "WARNING: unrecognized openvpn flag %s" % flag_name +                print("WARNING: unrecognized openvpn flag %s" % flag_name)          return result -    except Exception as ex: -        print ex +    except Exception as exc: +        print("%s: ERROR PARSING FLAGS: %s" % (SCRIPT, exc)) +        if DEBUG: +            logger.exception(exc)          return None  def openvpn_start(args): +    """ +    Launch openvpn, sanitizing input, and replacing the current process with +    the openvpn process. + +    :param args: arguments to be passed to openvpn +    :type args: list +    """      openvpn_flags = parse_openvpn_flags(args)      if openvpn_flags: -        flags = FIXED_FLAGS + openvpn_flags -        run(OPENVPN, *flags, detach=True) +        OPENVPN = get_openvpn_bin() +        flags = [OPENVPN] + FIXED_FLAGS + openvpn_flags +        if DEBUG: +            print("%s: running openvpn with flags:" % (SCRIPT,)) +            print(flags) +        # note: first argument to command is ignored, but customarily set to +        # the command. +        os.execv(OPENVPN, flags)      else:          bail('ERROR: could not parse openvpn options')  def openvpn_stop(args): -    print "stop" +    """ +    Stop openvpn. + +    :param args: arguments to openvpn +    :type args: list +    """ +    # XXX this deps on psutil, which is not there in the bundle +    # case. We could try to manually parse proc system. +    for proc in psutil.process_iter(): +        if LEAPOPENVPN in proc.cmdline: +            # FIXME naive approach. this will kill try to kill *anythin*, we +            # should check that the command is openvpn. -- kali +            proc.terminate() + +## +## DNS +## + + +def set_dns_nameserver(ip_address): +    """ +    Add the tunnel DNS server to `resolv.conf` + +    :param ip_address: the ip to add to `resolv.conf` +    :type ip_address: str +    """ +    if os.path.isfile(RESOLVCONF): +        process = run(RESOLVCONF, "-a", "bitmask", input=True) +        process.communicate("nameserver %s\n" % ip_address) +    else: +        bail("ERROR: package openresolv or resolvconf not installed.") + + +def restore_dns_nameserver(): +    """ +    Remove tunnel DNS server from `resolv.conf` +    """ +    if os.path.isfile(RESOLVCONF): +        run(RESOLVCONF, "-d", "bitmask") +    else: +        print("%s: ERROR: package openresolv or resolvconf not installed." % +              (SCRIPT,)) +  ##  ## FIREWALL @@ -221,35 +397,59 @@ def openvpn_stop(args):  def get_gateways(gateways): -    result = [gateway for gateway in gateways if is_valid_address(gateway)] -    if not len(result): -        bail("No valid gateways specified") +    """ +    Filter a passed sequence of gateways, returning only the valid ones. + +    :param gateways: a sequence of gateways to filter. +    :type gateways: iterable +    :rtype: iterable +    """ +    result = filter(is_valid_address, gateways) +    if not result: +        bail("ERROR: No valid gateways specified")      else:          return result  def get_default_device(): +    """ +    Retrieve the current default network device. + +    :rtype: str +    """      routes = subprocess.check_output([IP, "route", "show"])      match = re.search("^default .*dev ([^\s]*) .*$", routes, flags=re.M) -    if len(match.groups()) >= 1: +    if match.groups():          return match.group(1)      else: -        bail("could not find default device") +        bail("Could not find default device")  def get_local_network_ipv4(device): +    """ +    Get the local ipv4 addres for a given device. + +    :param device: +    :type device: str +    """      addresses = cmdcheck([IP, "-o", "address", "show", "dev", device])      match = re.search("^.*inet ([^ ]*) .*$", addresses, flags=re.M) -    if len(match.groups()) >= 1: +    if match.groups():          return match.group(1)      else:          return None  def get_local_network_ipv6(device): +    """ +    Get the local ipv6 addres for a given device. + +    :param device: +    :type device: str +    """      addresses = cmdcheck([IP, "-o", "address", "show", "dev", device])      match = re.search("^.*inet6 ([^ ]*) .*$", addresses, flags=re.M) -    if len(match.groups()) >= 1: +    if match.groups():          return match.group(1)      else:          return None @@ -257,7 +457,7 @@ def get_local_network_ipv6(device):  def run_iptable_with_check(cmd, *args, **options):      """ -    runs an iptables command checking to see if it should: +    Run an iptables command checking to see if it should:        for --insert: run only if rule does not already exist.        for --delete: run only if rule does exist.      other commands are run normally. @@ -277,99 +477,149 @@ def run_iptable_with_check(cmd, *args, **options):  def iptables(*args, **options): +    """ +    Run iptables4 and iptables6. +    """      ip4tables(*args, **options)      ip6tables(*args, **options)  def ip4tables(*args, **options): +    """ +    Run iptables4 with checks. +    """      run_iptable_with_check(IPTABLES, *args, **options)  def ip6tables(*args, **options): +    """ +    Run iptables6 with checks. +    """      run_iptable_with_check(IP6TABLES, *args, **options)  def ipv4_chain_exists(table): +    """ +    Check if a given chain exists. + +    :param table: the table to check against +    :type table: str +    :rtype: bool +    """      code = run(IPTABLES, "--list", table, "--numeric", exitcode=True)      return code == 0  def ipv6_chain_exists(table): +    """ +    Check if a given chain exists. + +    :param table: the table to check against +    :type table: str +    :rtype: bool +    """      code = run(IP6TABLES, "--list", table, "--numeric", exitcode=True)      return code == 0  def firewall_start(args): +    """ +    Bring up the firewall. + +    :param args: list of gateways, to be sanitized. +    :type args: list +    """      default_device = get_default_device()      local_network_ipv4 = get_local_network_ipv4(default_device)      local_network_ipv6 = get_local_network_ipv6(default_device)      gateways = get_gateways(args)      # add custom chain "bitmask" -    if not ipv4_chain_exists("bitmask"): -        ip4tables("--new-chain", "bitmask") -    if not ipv6_chain_exists("bitmask"): -        ip6tables("--new-chain", "bitmask") -    iptables("--insert", "OUTPUT", "--jump", "bitmask") +    if not ipv4_chain_exists(BITMASK_CHAIN): +        ip4tables("--new-chain", BITMASK_CHAIN) +    if not ipv6_chain_exists(BITMASK_CHAIN): +        ip6tables("--new-chain", BITMASK_CHAIN) +    iptables("--insert", "OUTPUT", "--jump", BITMASK_CHAIN)      # reject everything -    iptables("--insert", "bitmask", "-o", default_device, "--jump", "REJECT") +    iptables("--insert", BITMASK_CHAIN, "-o", default_device, +             "--jump", "REJECT")      # allow traffic to gateways      for gateway in gateways: -        ip4tables("--insert", "bitmask", "--destination", gateway, +        ip4tables("--insert", BITMASK_CHAIN, "--destination", gateway,                    "-o", default_device, "--jump", "ACCEPT")      # allow traffic to IPs on local network      if local_network_ipv4: -        ip4tables("--insert", "bitmask", "--destination", local_network_ipv4, -                  "-o", default_device, "--jump", "ACCEPT") +        ip4tables("--insert", BITMASK_CHAIN, +                  "--destination", local_network_ipv4, "-o", default_device, +                  "--jump", "ACCEPT")      if local_network_ipv6: -        ip6tables("--insert", "bitmask", "--destination", local_network_ipv6, -                  "-o", default_device, "--jump", "ACCEPT") +        ip6tables("--insert", BITMASK_CHAIN, +                  "--destination", local_network_ipv6, "-o", default_device, +                  "--jump", "ACCEPT")      # block DNS requests to anyone but the service provider or localhost -    ip4tables("--insert", "bitmask", "--protocol", "udp", "--dport", "53", +    ip4tables("--insert", BITMASK_CHAIN, "--protocol", "udp", "--dport", "53",                "--jump", "REJECT")      for allowed_dns in gateways + ["127.0.0.1", "127.0.1.1"]:          ip4tables("--insert", "bitmask", "--protocol", "udp", "--dport", "53",                    "--destination", allowed_dns, "--jump", "ACCEPT") -def firewall_stop(args): -    iptables("--delete", "OUTPUT", "--jump", "bitmask") -    if ipv4_chain_exists("bitmask"): -        ip4tables("--flush", "bitmask") -        ip4tables("--delete-chain", "bitmask") -    if ipv6_chain_exists("bitmask"): -        ip6tables("--flush", "bitmask") -        ip6tables("--delete-chain", "bitmask") - +def firewall_stop(): +    """ +    Stop the firewall. +    """ +    iptables("--delete", "OUTPUT", "--jump", BITMASK_CHAIN) +    if ipv4_chain_exists(BITMASK_CHAIN): +        ip4tables("--flush", BITMASK_CHAIN) +        ip4tables("--delete-chain", BITMASK_CHAIN) +    if ipv6_chain_exists(BITMASK_CHAIN): +        ip6tables("--flush", BITMASK_CHAIN) +        ip6tables("--delete-chain", BITMASK_CHAIN) -def bail(msg=""): -    if msg: -        print(msg) -    exit(1) +## +## MAIN +##  def main():      if len(sys.argv) >= 3:          command = "_".join(sys.argv[1:3])          args = sys.argv[3:] +          if command == "openvpn_start":              openvpn_start(args) +          elif command == "openvpn_stop":              openvpn_stop(args) +          elif command == "firewall_start": -            firewall_start(args) +            try: +                firewall_start(args) +                set_dns_nameserver(NAMESERVER) +            except Exception as ex: +                restore_dns_nameserver() +                firewall_stop() +                bail("ERROR: could not start firewall", ex) +          elif command == "firewall_stop": -            firewall_stop(args) +            try: +                restore_dns_nameserver() +                firewall_stop() +            except Exception as ex: +                bail("ERROR: could not stop firewall", ex) +          else: -            bail("no such command") +            bail("ERROR: No such command")      else: -        bail("no such command") +        bail("ERROR: No such command")  if __name__ == "__main__": +    if DEBUG: +        logger.debug(" ".join(sys.argv))      main() -    print "done" +    print("%s: done" % (SCRIPT,))      exit(0) | 
