diff options
Diffstat (limited to 'pkg/linux')
| -rwxr-xr-x | pkg/linux/bitmask-root | 87 | 
1 files changed, 58 insertions, 29 deletions
| diff --git a/pkg/linux/bitmask-root b/pkg/linux/bitmask-root index 5b49a187..4cb214e1 100755 --- a/pkg/linux/bitmask-root +++ b/pkg/linux/bitmask-root @@ -1,4 +1,4 @@ -#!/usr/bin/python2 +#!/usr/bin/python  # -*- coding: utf-8 -*-  #  # Copyright (C) 2014 LEAP @@ -35,6 +35,8 @@ import socket  import sys  import re +cmdcheck = subprocess.check_output +  ##  ## CONSTANTS  ## @@ -71,17 +73,20 @@ ALLOWED_FLAGS = {  PARAM_FORMATS = {      "NUMBER": lambda s: re.match("^\d+$", s), -    "PROTO":  lambda s: re.match("^(tcp|udp)$", s), -    "IP":     lambda s: is_valid_address(s), +    "PROTO": lambda s: re.match("^(tcp|udp)$", s), +    "IP": lambda s: is_valid_address(s),      "CIPHER": lambda s: re.match("^[A-Z0-9-]+$", s), -    "USER":   lambda s: re.match("^[a-zA-Z0-9_\.\@][a-zA-Z0-9_\-\.\@]*\$?$", s), # IEEE Std 1003.1-2001 -    "FILE":   lambda s: os.path.isfile(s) +    "USER": lambda s: re.match( +        "^[a-zA-Z0-9_\.\@][a-zA-Z0-9_\-\.\@]*\$?$", s),  # IEEE Std 1003.1-2001 +    "FILE": lambda s: os.path.isfile(s)  } -DEBUG=os.getenv("DEBUG") + +DEBUG = os.getenv("DEBUG")  if DEBUG:      import logging -    formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") +    formatter = logging.Formatter( +        "%(asctime)s - %(name)s - %(levelname)s - %(message)s")      ch = logging.StreamHandler()      ch.setLevel(logging.DEBUG)      ch.setFormatter(formatter) @@ -94,6 +99,7 @@ if DEBUG:  ## UTILITY  ## +  def is_valid_address(value):      """      Validate that the passed ip is a valid IP address. @@ -109,10 +115,12 @@ def is_valid_address(value):          print "MALFORMED IP: %s!" % value          return False +  def split_list(list, regex):      """      Splits a list based on a regex: -    e.g. split_list(["xx", "yy", "x1", "zz"], "^x") => [["xx", "yy"], ["x1", "zz"]] +    e.g. split_list(["xx", "yy", "x1", "zz"], "^x") => [["xx", "yy"], ["x1", +    "zz"]]      :param list: the list to be split.      :type list: list @@ -140,12 +148,13 @@ def split_list(list, regex):  #def sanify(command, *args):  #    return [command] + [pipes.quote(a) for a in args] +  def run(command, *args, **options):      parts = [command]      parts.extend(args)      if DEBUG:          print "run: " + " ".join(parts) -    if options.get("check", True) == False or options.get("detach", False) == True: +    if not options.get("check", True) or options.get("detach", False):          subprocess.Popen(parts)      else:          try: @@ -153,7 +162,7 @@ def run(command, *args, **options):              subprocess.check_call(parts, stdout=devnull, stderr=devnull)              return 0          except subprocess.CalledProcessError as ex: -            if options.get("exitcode", False) == True: +            if options.get("exitcode", False):                  return ex.returncode              else:                  bail("Could not run %s: %s" % (ex.cmd, ex.output)) @@ -162,15 +171,17 @@ def run(command, *args, **options):  ## OPENVPN  ## +  def parse_openvpn_flags(args):      """ -    takes argument list from the command line and parses it, only allowing some configuration flags. +    takes argument list from the command line and parses it, only allowing some +    configuration flags.      """      result = []      try:          for flag in split_list(args, "^--"):              flag_name = flag[0] -            if ALLOWED_FLAGS.has_key(flag_name): +            if flag_name in ALLOWED_FLAGS:                  result.append(flag_name)                  required_params = ALLOWED_FLAGS[flag_name]                  if len(required_params) > 0: @@ -200,6 +211,7 @@ def openvpn_start(args):      else:          bail('ERROR: could not parse openvpn options') +  def openvpn_stop(args):      print "stop" @@ -207,6 +219,7 @@ def openvpn_stop(args):  ## FIREWALL  ## +  def get_gateways(gateways):      result = [gateway for gateway in gateways if is_valid_address(gateway)]      if not len(result): @@ -214,29 +227,33 @@ def get_gateways(gateways):      else:          return result +  def get_default_device():      routes = subprocess.check_output([IP, "route", "show"])      match = re.search("^default .*dev ([^\s]*) .*$", routes, flags=re.M)      if len(match.groups()) >= 1: -      return match.group(1) +        return match.group(1)      else: -      bail("could not find default device") +        bail("could not find default device") +  def get_local_network_ipv4(device): -    addresses = subprocess.check_output([IP, "-o", "address", "show", "dev", device]) +    addresses = cmdcheck([IP, "-o", "address", "show", "dev", device])      match = re.search("^.*inet ([^ ]*) .*$", addresses, flags=re.M)      if len(match.groups()) >= 1: -      return match.group(1) +        return match.group(1)      else: -      return None +        return None +  def get_local_network_ipv6(device): -    addresses = subprocess.check_output([IP, "-o", "address", "show", "dev", device]) +    addresses = cmdcheck([IP, "-o", "address", "show", "dev", device])      match = re.search("^.*inet6 ([^ ]*) .*$", addresses, flags=re.M)      if len(match.groups()) >= 1: -      return match.group(1) +        return match.group(1)      else: -      return None +        return None +  def run_iptable_with_check(cmd, *args, **options):      """ @@ -258,29 +275,35 @@ def run_iptable_with_check(cmd, *args, **options):      else:          run(cmd, *args, **options) +  def iptables(*args, **options):      ip4tables(*args, **options)      ip6tables(*args, **options) +  def ip4tables(*args, **options):      run_iptable_with_check(IPTABLES, *args, **options) +  def ip6tables(*args, **options):      run_iptable_with_check(IP6TABLES, *args, **options) +  def ipv4_chain_exists(table):      code = run(IPTABLES, "--list", table, "--numeric", exitcode=True)      return code == 0 +  def ipv6_chain_exists(table):      code = run(IP6TABLES, "--list", table, "--numeric", exitcode=True)      return code == 0 +  def firewall_start(args): -    default_device     = get_default_device() +    default_device = get_default_device()      local_network_ipv4 = get_local_network_ipv4(default_device)      local_network_ipv6 = get_local_network_ipv6(default_device) -    gateways           = get_gateways(args) +    gateways = get_gateways(args)      # add custom chain "bitmask"      if not ipv4_chain_exists("bitmask"): @@ -294,18 +317,24 @@ def firewall_start(args):      # allow traffic to gateways      for gateway in gateways: -        ip4tables("--insert", "bitmask", "--destination", gateway, "-o", default_device, "--jump", "ACCEPT") +        ip4tables("--insert", "bitmask", "--destination", gateway, +                  "-o", default_device, "--jump", "ACCEPT")      # allow traffic to IPs on local network      if local_network_ipv4: -        ip4tables("--insert", "bitmask", "--destination", local_network_ipv4, "-o", default_device, "--jump", "ACCEPT") +        ip4tables("--insert", "bitmask", "--destination", local_network_ipv4, +                  "-o", default_device, "--jump", "ACCEPT")      if local_network_ipv6: -        ip6tables("--insert", "bitmask", "--destination", local_network_ipv6, "-o", default_device, "--jump", "ACCEPT") +        ip6tables("--insert", "bitmask", "--destination", local_network_ipv6, +                  "-o", default_device, "--jump", "ACCEPT")      # block DNS requests to anyone but the service provider or localhost -    ip4tables("--insert", "bitmask", "--protocol", "udp", "--dport", "53", "--jump", "REJECT") -    for allowed_dns in gateways + ["127.0.0.1","127.0.1.1"]: -        ip4tables("--insert", "bitmask", "--protocol", "udp", "--dport", "53", "--destination", allowed_dns, "--jump", "ACCEPT") +    ip4tables("--insert", "bitmask", "--protocol", "udp", "--dport", "53", +              "--jump", "REJECT") +    for allowed_dns in gateways + ["127.0.0.1", "127.0.1.1"]: +        ip4tables("--insert", "bitmask", "--protocol", "udp", "--dport", "53", +                  "--destination", allowed_dns, "--jump", "ACCEPT") +  def firewall_stop(args):      iptables("--delete", "OUTPUT", "--jump", "bitmask") @@ -322,6 +351,7 @@ def bail(msg=""):          print(msg)      exit(1) +  def main():      if len(sys.argv) >= 3:          command = "_".join(sys.argv[1:3]) @@ -343,4 +373,3 @@ if __name__ == "__main__":      main()      print "done"      exit(0) - | 
