diff options
-rwxr-xr-x | pkg/linux/bitmask-root | 87 |
1 files changed, 58 insertions, 29 deletions
diff --git a/pkg/linux/bitmask-root b/pkg/linux/bitmask-root index 5b49a187..4cb214e1 100755 --- a/pkg/linux/bitmask-root +++ b/pkg/linux/bitmask-root @@ -1,4 +1,4 @@ -#!/usr/bin/python2 +#!/usr/bin/python # -*- coding: utf-8 -*- # # Copyright (C) 2014 LEAP @@ -35,6 +35,8 @@ import socket import sys import re +cmdcheck = subprocess.check_output + ## ## CONSTANTS ## @@ -71,17 +73,20 @@ ALLOWED_FLAGS = { PARAM_FORMATS = { "NUMBER": lambda s: re.match("^\d+$", s), - "PROTO": lambda s: re.match("^(tcp|udp)$", s), - "IP": lambda s: is_valid_address(s), + "PROTO": lambda s: re.match("^(tcp|udp)$", s), + "IP": lambda s: is_valid_address(s), "CIPHER": lambda s: re.match("^[A-Z0-9-]+$", s), - "USER": lambda s: re.match("^[a-zA-Z0-9_\.\@][a-zA-Z0-9_\-\.\@]*\$?$", s), # IEEE Std 1003.1-2001 - "FILE": lambda s: os.path.isfile(s) + "USER": lambda s: re.match( + "^[a-zA-Z0-9_\.\@][a-zA-Z0-9_\-\.\@]*\$?$", s), # IEEE Std 1003.1-2001 + "FILE": lambda s: os.path.isfile(s) } -DEBUG=os.getenv("DEBUG") + +DEBUG = os.getenv("DEBUG") if DEBUG: import logging - formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") + formatter = logging.Formatter( + "%(asctime)s - %(name)s - %(levelname)s - %(message)s") ch = logging.StreamHandler() ch.setLevel(logging.DEBUG) ch.setFormatter(formatter) @@ -94,6 +99,7 @@ if DEBUG: ## UTILITY ## + def is_valid_address(value): """ Validate that the passed ip is a valid IP address. @@ -109,10 +115,12 @@ def is_valid_address(value): print "MALFORMED IP: %s!" % value return False + def split_list(list, regex): """ Splits a list based on a regex: - e.g. split_list(["xx", "yy", "x1", "zz"], "^x") => [["xx", "yy"], ["x1", "zz"]] + e.g. split_list(["xx", "yy", "x1", "zz"], "^x") => [["xx", "yy"], ["x1", + "zz"]] :param list: the list to be split. :type list: list @@ -140,12 +148,13 @@ def split_list(list, regex): #def sanify(command, *args): # return [command] + [pipes.quote(a) for a in args] + def run(command, *args, **options): parts = [command] parts.extend(args) if DEBUG: print "run: " + " ".join(parts) - if options.get("check", True) == False or options.get("detach", False) == True: + if not options.get("check", True) or options.get("detach", False): subprocess.Popen(parts) else: try: @@ -153,7 +162,7 @@ def run(command, *args, **options): subprocess.check_call(parts, stdout=devnull, stderr=devnull) return 0 except subprocess.CalledProcessError as ex: - if options.get("exitcode", False) == True: + if options.get("exitcode", False): return ex.returncode else: bail("Could not run %s: %s" % (ex.cmd, ex.output)) @@ -162,15 +171,17 @@ def run(command, *args, **options): ## OPENVPN ## + def parse_openvpn_flags(args): """ - takes argument list from the command line and parses it, only allowing some configuration flags. + takes argument list from the command line and parses it, only allowing some + configuration flags. """ result = [] try: for flag in split_list(args, "^--"): flag_name = flag[0] - if ALLOWED_FLAGS.has_key(flag_name): + if flag_name in ALLOWED_FLAGS: result.append(flag_name) required_params = ALLOWED_FLAGS[flag_name] if len(required_params) > 0: @@ -200,6 +211,7 @@ def openvpn_start(args): else: bail('ERROR: could not parse openvpn options') + def openvpn_stop(args): print "stop" @@ -207,6 +219,7 @@ def openvpn_stop(args): ## FIREWALL ## + def get_gateways(gateways): result = [gateway for gateway in gateways if is_valid_address(gateway)] if not len(result): @@ -214,29 +227,33 @@ def get_gateways(gateways): else: return result + def get_default_device(): routes = subprocess.check_output([IP, "route", "show"]) match = re.search("^default .*dev ([^\s]*) .*$", routes, flags=re.M) if len(match.groups()) >= 1: - return match.group(1) + return match.group(1) else: - bail("could not find default device") + bail("could not find default device") + def get_local_network_ipv4(device): - addresses = subprocess.check_output([IP, "-o", "address", "show", "dev", device]) + addresses = cmdcheck([IP, "-o", "address", "show", "dev", device]) match = re.search("^.*inet ([^ ]*) .*$", addresses, flags=re.M) if len(match.groups()) >= 1: - return match.group(1) + return match.group(1) else: - return None + return None + def get_local_network_ipv6(device): - addresses = subprocess.check_output([IP, "-o", "address", "show", "dev", device]) + addresses = cmdcheck([IP, "-o", "address", "show", "dev", device]) match = re.search("^.*inet6 ([^ ]*) .*$", addresses, flags=re.M) if len(match.groups()) >= 1: - return match.group(1) + return match.group(1) else: - return None + return None + def run_iptable_with_check(cmd, *args, **options): """ @@ -258,29 +275,35 @@ def run_iptable_with_check(cmd, *args, **options): else: run(cmd, *args, **options) + def iptables(*args, **options): ip4tables(*args, **options) ip6tables(*args, **options) + def ip4tables(*args, **options): run_iptable_with_check(IPTABLES, *args, **options) + def ip6tables(*args, **options): run_iptable_with_check(IP6TABLES, *args, **options) + def ipv4_chain_exists(table): code = run(IPTABLES, "--list", table, "--numeric", exitcode=True) return code == 0 + def ipv6_chain_exists(table): code = run(IP6TABLES, "--list", table, "--numeric", exitcode=True) return code == 0 + def firewall_start(args): - default_device = get_default_device() + default_device = get_default_device() local_network_ipv4 = get_local_network_ipv4(default_device) local_network_ipv6 = get_local_network_ipv6(default_device) - gateways = get_gateways(args) + gateways = get_gateways(args) # add custom chain "bitmask" if not ipv4_chain_exists("bitmask"): @@ -294,18 +317,24 @@ def firewall_start(args): # allow traffic to gateways for gateway in gateways: - ip4tables("--insert", "bitmask", "--destination", gateway, "-o", default_device, "--jump", "ACCEPT") + ip4tables("--insert", "bitmask", "--destination", gateway, + "-o", default_device, "--jump", "ACCEPT") # allow traffic to IPs on local network if local_network_ipv4: - ip4tables("--insert", "bitmask", "--destination", local_network_ipv4, "-o", default_device, "--jump", "ACCEPT") + ip4tables("--insert", "bitmask", "--destination", local_network_ipv4, + "-o", default_device, "--jump", "ACCEPT") if local_network_ipv6: - ip6tables("--insert", "bitmask", "--destination", local_network_ipv6, "-o", default_device, "--jump", "ACCEPT") + ip6tables("--insert", "bitmask", "--destination", local_network_ipv6, + "-o", default_device, "--jump", "ACCEPT") # block DNS requests to anyone but the service provider or localhost - ip4tables("--insert", "bitmask", "--protocol", "udp", "--dport", "53", "--jump", "REJECT") - for allowed_dns in gateways + ["127.0.0.1","127.0.1.1"]: - ip4tables("--insert", "bitmask", "--protocol", "udp", "--dport", "53", "--destination", allowed_dns, "--jump", "ACCEPT") + ip4tables("--insert", "bitmask", "--protocol", "udp", "--dport", "53", + "--jump", "REJECT") + for allowed_dns in gateways + ["127.0.0.1", "127.0.1.1"]: + ip4tables("--insert", "bitmask", "--protocol", "udp", "--dport", "53", + "--destination", allowed_dns, "--jump", "ACCEPT") + def firewall_stop(args): iptables("--delete", "OUTPUT", "--jump", "bitmask") @@ -322,6 +351,7 @@ def bail(msg=""): print(msg) exit(1) + def main(): if len(sys.argv) >= 3: command = "_".join(sys.argv[1:3]) @@ -343,4 +373,3 @@ if __name__ == "__main__": main() print "done" exit(0) - |