diff options
Diffstat (limited to 'pkg/linux/bitmask-root')
-rwxr-xr-x | pkg/linux/bitmask-root | 406 |
1 files changed, 328 insertions, 78 deletions
diff --git a/pkg/linux/bitmask-root b/pkg/linux/bitmask-root index 4cb214e1..b9a7acbc 100755 --- a/pkg/linux/bitmask-root +++ b/pkg/linux/bitmask-root @@ -25,15 +25,25 @@ USAGE: bitmask-root firewall start GATEWAY1 GATEWAY2 ... bitmask-root openvpn stop bitmask-root openvpn start CONFIG1 CONFIG1 ... + +All actions return exit code 0 for success, non-zero otherwise. + +The `openvpn start` action is special: it calls exec on openvpn and replaces +the current process. """ # TODO should be tested with python3, which can be the default on some distro. from __future__ import print_function import os +import re import subprocess import socket import sys -import re +import traceback + + +# XXX not standard +import psutil cmdcheck = subprocess.check_output @@ -41,10 +51,29 @@ cmdcheck = subprocess.check_output ## CONSTANTS ## -OPENVPN = "/usr/sbin/openvpn" +SCRIPT = "bitmask-root" +NAMESERVER = "10.42.0.1" +BITMASK_CHAIN = "bitmask" + +IP = "/bin/ip" IPTABLES = "/sbin/iptables" IP6TABLES = "/sbin/ip6tables" -UPDATE_RESOLV_CONF = "/etc/openvpn/update-resolv-conf" +RESOLVCONF = "/sbin/resolvconf" +OPENVPN_USER = "nobody" +OPENVPN_GROUP = "nogroup" + +LEAPOPENVPN = "LEAPOPENVPN" +OPENVPN_SYSTEM_BIN = "/usr/sbin/openvpn" # Debian location +OPENVPN_LEAP_BIN = "/usr/sbin/leap-openvpn" # installed by bundle + + +""" +The path to the script to update resolv.conf +""" +# XXX We have to check if we have a valid resolvconf, and use +# the old resolv-update if not. +LEAP_UPDATE_RESOLVCONF_FILE = "/etc/leap/update-resolv-conf" +LEAP_RESOLV_UPDATE = "/etc/leap/resolv-update" FIXED_FLAGS = [ "--setenv", "LEAPOPENVPN", "1", @@ -54,17 +83,20 @@ FIXED_FLAGS = [ "--tls-client", "--remote-cert-tls", "server", "--management-signal", - "--management", "/tmp/openvpn.socket", "unix", - "--up", UPDATE_RESOLV_CONF, - "--down", UPDATE_RESOLV_CONF, - "--script-security", "2" + "--management", MANAGEMENT_SOCKET, "unix", + "--script-security", "1" + "--user", "nobody", + "--group", "nogroup", ] +# "--management", MANAGEMENT_SOCKET, "unix", + ALLOWED_FLAGS = { "--remote": ["IP", "NUMBER", "PROTO"], "--tls-cipher": ["CIPHER"], "--cipher": ["CIPHER"], "--auth": ["CIPHER"], + "--management": ["DIR", "UNIXSOCKET"], "--management-client-user": ["USER"], "--cert": ["FILE"], "--key": ["FILE"], @@ -78,11 +110,15 @@ PARAM_FORMATS = { "CIPHER": lambda s: re.match("^[A-Z0-9-]+$", s), "USER": lambda s: re.match( "^[a-zA-Z0-9_\.\@][a-zA-Z0-9_\-\.\@]*\$?$", s), # IEEE Std 1003.1-2001 - "FILE": lambda s: os.path.isfile(s) + "FILE": lambda s: os.path.isfile(s), + "DIR": lambda s: os.path.isdir(os.path.split(s)[0]), + "UNIXSOCKET": lambda s: s == "unix" } DEBUG = os.getenv("DEBUG") +TEST = os.getenv("TEST") + if DEBUG: import logging formatter = logging.Formatter( @@ -93,7 +129,6 @@ if DEBUG: logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) logger.addHandler(ch) - logger.debug(" ".join(sys.argv)) ## ## UTILITY @@ -112,70 +147,152 @@ def is_valid_address(value): socket.inet_aton(value) return True except Exception: - print "MALFORMED IP: %s!" % value + print("%s: ERROR: MALFORMED IP: %s!" % (SCRIPT, value)) return False -def split_list(list, regex): +def has_system_resolvconf(): + """ + Return True if resolvconf is found in the system. + + :rtype: bool + """ + return os.path.isfile(RESOLVCONF) + + +def has_valid_update_resolvconf(): + """ + Return True if a valid update-resolv-conf script is found in the system. + + :rtype: bool + """ + return os.path.isfile(LEAP_UPDATE_RESOLVCONF_FILE) + + +def has_valid_leap_resolv_update(): + """ + Return True if a valid resolv-update script is found in the system. + + :rtype: bool + """ + return os.path.isfile(LEAP_RESOLV_UPDATE) + + +def split_list(_list, regex): """ - Splits a list based on a regex: + Split a list based on a regex: e.g. split_list(["xx", "yy", "x1", "zz"], "^x") => [["xx", "yy"], ["x1", "zz"]] - :param list: the list to be split. - :type list: list + :param _list: the list to be split. + :type _list: list + :param regex: the regex expression to filter with. + :type regex: str + :rtype: list """ if not hasattr(regex, "match"): regex = re.compile(regex) result = [] i = 0 + if not _list: + return result while True: - if regex.match(list[i]): + if regex.match(_list[i]): result.append([]) while True: - result[-1].append(list[i]) + result[-1].append(_list[i]) i += 1 - if i >= len(list) or regex.match(list[i]): + if i >= len(_list) or regex.match(_list[i]): break else: i += 1 - if i >= len(list): + if i >= len(_list): break return result -# i think this is not needed with shell=False -#def sanify(command, *args): -# return [command] + [pipes.quote(a) for a in args] - def run(command, *args, **options): + """ + Run an external command. + + Options: + + `check`: If True, check the command's output. bail if non-zero. (the + default is true unless detach or input is true) + `exitcode`: like `check`, but return exitcode instead of bailing. + `detach`: If True, run in detached process. + `input`: If True, open command for writing stream to, returning the Popen + object. + """ parts = [command] parts.extend(args) - if DEBUG: - print "run: " + " ".join(parts) - if not options.get("check", True) or options.get("detach", False): - subprocess.Popen(parts) + if TEST or DEBUG: + print("%s run: %s " (SCRIPT, " ".join(parts))) + + _check = options.get("check", True) + _detach = options.get("detach", False) + _input = options.get("input", False) + _exitcode = options.get("exitcode", False) + + if not _check or _detach or _input: + if _input: + return subprocess.Popen(parts, stdin=subprocess.PIPE) + else: + # XXX ok with return None ?? + subprocess.Popen(parts) else: try: devnull = open('/dev/null', 'w') subprocess.check_call(parts, stdout=devnull, stderr=devnull) return 0 - except subprocess.CalledProcessError as ex: - if options.get("exitcode", False): - return ex.returncode + except subprocess.CalledProcessError as exc: + if DEBUG: + logger.exception(exc) + if _exitcode: + return exc.returncode else: - bail("Could not run %s: %s" % (ex.cmd, ex.output)) + bail("ERROR: Could not run %s: %s" % (exc.cmd, exc.output), + exception=exc) + + +def bail(msg=None, exception=None): + """ + Abnormal exit. + + :param msg: optional error message. + :type msg: str + """ + if msg is not None: + print("%s: %s" % (SCRIPT, msg)) + if exception is not None: + traceback.print_exc() + exit(1) ## ## OPENVPN ## +def get_openvpn_bin(): + """ + Return the path for either the system openvpn or the one the + bundle has put there. + """ + if os.path.isfile(OPENVPN_SYSTEM_BIN): + return OPENVPN_SYSTEM_BIN + + # the bundle option should be removed from the debian package. + if os.path.isfile(OPENVPN_LEAP_BIN): + return OPENVPN_LEAP_BIN + + def parse_openvpn_flags(args): """ - takes argument list from the command line and parses it, only allowing some + Take argument list from the command line and parse it, only allowing some configuration flags. + + :type args: list """ result = [] try: @@ -184,36 +301,95 @@ def parse_openvpn_flags(args): if flag_name in ALLOWED_FLAGS: result.append(flag_name) required_params = ALLOWED_FLAGS[flag_name] - if len(required_params) > 0: + if required_params: flag_params = flag[1:] if len(flag_params) != len(required_params): - print "ERROR: not enough params for %s" % flag_name + print("%s: ERROR: not enough params for %s" % + (SCRIPT, flag_name)) return None for param, param_type in zip(flag_params, required_params): if PARAM_FORMATS[param_type](param): result.append(param) else: - print "ERROR: Bad argument %s" % param + print("%s: ERROR: Bad argument %s" % + (SCRIPT, param)) return None else: - print "WARNING: unrecognized openvpn flag %s" % flag_name + print("WARNING: unrecognized openvpn flag %s" % flag_name) return result - except Exception as ex: - print ex + except Exception as exc: + print("%s: ERROR PARSING FLAGS: %s" % (SCRIPT, exc)) + if DEBUG: + logger.exception(exc) return None def openvpn_start(args): + """ + Launch openvpn, sanitizing input, and replacing the current process with + the openvpn process. + + :param args: arguments to be passed to openvpn + :type args: list + """ openvpn_flags = parse_openvpn_flags(args) if openvpn_flags: - flags = FIXED_FLAGS + openvpn_flags - run(OPENVPN, *flags, detach=True) + OPENVPN = get_openvpn_bin() + flags = [OPENVPN] + FIXED_FLAGS + openvpn_flags + if DEBUG: + print("%s: running openvpn with flags:" % (SCRIPT,)) + print(flags) + # note: first argument to command is ignored, but customarily set to + # the command. + os.execv(OPENVPN, flags) else: bail('ERROR: could not parse openvpn options') def openvpn_stop(args): - print "stop" + """ + Stop openvpn. + + :param args: arguments to openvpn + :type args: list + """ + # XXX this deps on psutil, which is not there in the bundle + # case. We could try to manually parse proc system. + for proc in psutil.process_iter(): + if LEAPOPENVPN in proc.cmdline: + # FIXME naive approach. this will kill try to kill *anythin*, we + # should check that the command is openvpn. -- kali + proc.terminate() + +## +## DNS +## + + +def set_dns_nameserver(ip_address): + """ + Add the tunnel DNS server to `resolv.conf` + + :param ip_address: the ip to add to `resolv.conf` + :type ip_address: str + """ + if os.path.isfile(RESOLVCONF): + process = run(RESOLVCONF, "-a", "bitmask", input=True) + process.communicate("nameserver %s\n" % ip_address) + else: + bail("ERROR: package openresolv or resolvconf not installed.") + + +def restore_dns_nameserver(): + """ + Remove tunnel DNS server from `resolv.conf` + """ + if os.path.isfile(RESOLVCONF): + run(RESOLVCONF, "-d", "bitmask") + else: + print("%s: ERROR: package openresolv or resolvconf not installed." % + (SCRIPT,)) + ## ## FIREWALL @@ -221,35 +397,59 @@ def openvpn_stop(args): def get_gateways(gateways): - result = [gateway for gateway in gateways if is_valid_address(gateway)] - if not len(result): - bail("No valid gateways specified") + """ + Filter a passed sequence of gateways, returning only the valid ones. + + :param gateways: a sequence of gateways to filter. + :type gateways: iterable + :rtype: iterable + """ + result = filter(is_valid_address, gateways) + if not result: + bail("ERROR: No valid gateways specified") else: return result def get_default_device(): + """ + Retrieve the current default network device. + + :rtype: str + """ routes = subprocess.check_output([IP, "route", "show"]) match = re.search("^default .*dev ([^\s]*) .*$", routes, flags=re.M) - if len(match.groups()) >= 1: + if match.groups(): return match.group(1) else: - bail("could not find default device") + bail("Could not find default device") def get_local_network_ipv4(device): + """ + Get the local ipv4 addres for a given device. + + :param device: + :type device: str + """ addresses = cmdcheck([IP, "-o", "address", "show", "dev", device]) match = re.search("^.*inet ([^ ]*) .*$", addresses, flags=re.M) - if len(match.groups()) >= 1: + if match.groups(): return match.group(1) else: return None def get_local_network_ipv6(device): + """ + Get the local ipv6 addres for a given device. + + :param device: + :type device: str + """ addresses = cmdcheck([IP, "-o", "address", "show", "dev", device]) match = re.search("^.*inet6 ([^ ]*) .*$", addresses, flags=re.M) - if len(match.groups()) >= 1: + if match.groups(): return match.group(1) else: return None @@ -257,7 +457,7 @@ def get_local_network_ipv6(device): def run_iptable_with_check(cmd, *args, **options): """ - runs an iptables command checking to see if it should: + Run an iptables command checking to see if it should: for --insert: run only if rule does not already exist. for --delete: run only if rule does exist. other commands are run normally. @@ -277,99 +477,149 @@ def run_iptable_with_check(cmd, *args, **options): def iptables(*args, **options): + """ + Run iptables4 and iptables6. + """ ip4tables(*args, **options) ip6tables(*args, **options) def ip4tables(*args, **options): + """ + Run iptables4 with checks. + """ run_iptable_with_check(IPTABLES, *args, **options) def ip6tables(*args, **options): + """ + Run iptables6 with checks. + """ run_iptable_with_check(IP6TABLES, *args, **options) def ipv4_chain_exists(table): + """ + Check if a given chain exists. + + :param table: the table to check against + :type table: str + :rtype: bool + """ code = run(IPTABLES, "--list", table, "--numeric", exitcode=True) return code == 0 def ipv6_chain_exists(table): + """ + Check if a given chain exists. + + :param table: the table to check against + :type table: str + :rtype: bool + """ code = run(IP6TABLES, "--list", table, "--numeric", exitcode=True) return code == 0 def firewall_start(args): + """ + Bring up the firewall. + + :param args: list of gateways, to be sanitized. + :type args: list + """ default_device = get_default_device() local_network_ipv4 = get_local_network_ipv4(default_device) local_network_ipv6 = get_local_network_ipv6(default_device) gateways = get_gateways(args) # add custom chain "bitmask" - if not ipv4_chain_exists("bitmask"): - ip4tables("--new-chain", "bitmask") - if not ipv6_chain_exists("bitmask"): - ip6tables("--new-chain", "bitmask") - iptables("--insert", "OUTPUT", "--jump", "bitmask") + if not ipv4_chain_exists(BITMASK_CHAIN): + ip4tables("--new-chain", BITMASK_CHAIN) + if not ipv6_chain_exists(BITMASK_CHAIN): + ip6tables("--new-chain", BITMASK_CHAIN) + iptables("--insert", "OUTPUT", "--jump", BITMASK_CHAIN) # reject everything - iptables("--insert", "bitmask", "-o", default_device, "--jump", "REJECT") + iptables("--insert", BITMASK_CHAIN, "-o", default_device, + "--jump", "REJECT") # allow traffic to gateways for gateway in gateways: - ip4tables("--insert", "bitmask", "--destination", gateway, + ip4tables("--insert", BITMASK_CHAIN, "--destination", gateway, "-o", default_device, "--jump", "ACCEPT") # allow traffic to IPs on local network if local_network_ipv4: - ip4tables("--insert", "bitmask", "--destination", local_network_ipv4, - "-o", default_device, "--jump", "ACCEPT") + ip4tables("--insert", BITMASK_CHAIN, + "--destination", local_network_ipv4, "-o", default_device, + "--jump", "ACCEPT") if local_network_ipv6: - ip6tables("--insert", "bitmask", "--destination", local_network_ipv6, - "-o", default_device, "--jump", "ACCEPT") + ip6tables("--insert", BITMASK_CHAIN, + "--destination", local_network_ipv6, "-o", default_device, + "--jump", "ACCEPT") # block DNS requests to anyone but the service provider or localhost - ip4tables("--insert", "bitmask", "--protocol", "udp", "--dport", "53", + ip4tables("--insert", BITMASK_CHAIN, "--protocol", "udp", "--dport", "53", "--jump", "REJECT") for allowed_dns in gateways + ["127.0.0.1", "127.0.1.1"]: ip4tables("--insert", "bitmask", "--protocol", "udp", "--dport", "53", "--destination", allowed_dns, "--jump", "ACCEPT") -def firewall_stop(args): - iptables("--delete", "OUTPUT", "--jump", "bitmask") - if ipv4_chain_exists("bitmask"): - ip4tables("--flush", "bitmask") - ip4tables("--delete-chain", "bitmask") - if ipv6_chain_exists("bitmask"): - ip6tables("--flush", "bitmask") - ip6tables("--delete-chain", "bitmask") - +def firewall_stop(): + """ + Stop the firewall. + """ + iptables("--delete", "OUTPUT", "--jump", BITMASK_CHAIN) + if ipv4_chain_exists(BITMASK_CHAIN): + ip4tables("--flush", BITMASK_CHAIN) + ip4tables("--delete-chain", BITMASK_CHAIN) + if ipv6_chain_exists(BITMASK_CHAIN): + ip6tables("--flush", BITMASK_CHAIN) + ip6tables("--delete-chain", BITMASK_CHAIN) -def bail(msg=""): - if msg: - print(msg) - exit(1) +## +## MAIN +## def main(): if len(sys.argv) >= 3: command = "_".join(sys.argv[1:3]) args = sys.argv[3:] + if command == "openvpn_start": openvpn_start(args) + elif command == "openvpn_stop": openvpn_stop(args) + elif command == "firewall_start": - firewall_start(args) + try: + firewall_start(args) + set_dns_nameserver(NAMESERVER) + except Exception as ex: + restore_dns_nameserver() + firewall_stop() + bail("ERROR: could not start firewall", ex) + elif command == "firewall_stop": - firewall_stop(args) + try: + restore_dns_nameserver() + firewall_stop() + except Exception as ex: + bail("ERROR: could not stop firewall", ex) + else: - bail("no such command") + bail("ERROR: No such command") else: - bail("no such command") + bail("ERROR: No such command") if __name__ == "__main__": + if DEBUG: + logger.debug(" ".join(sys.argv)) main() - print "done" + print("%s: done" % (SCRIPT,)) exit(0) |