#!/usr/bin/python # -*- coding: utf-8 -*- # # Copyright (C) 2014 LEAP # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program. If not, see . # """ This is a privileged helper script for safely running certain commands as root. It should only be called by the Bitmask application. USAGE: bitmask-root firewall stop bitmask-root firewall start GATEWAY1 GATEWAY2 ... bitmask-root openvpn stop bitmask-root openvpn start CONFIG1 CONFIG1 ... """ # TODO should be tested with python3, which can be the default on some distro. from __future__ import print_function import os import subprocess import socket import sys import re cmdcheck = subprocess.check_output ## ## CONSTANTS ## OPENVPN = "/usr/sbin/openvpn" IPTABLES = "/sbin/iptables" IP6TABLES = "/sbin/ip6tables" UPDATE_RESOLV_CONF = "/etc/openvpn/update-resolv-conf" FIXED_FLAGS = [ "--setenv", "LEAPOPENVPN", "1", "--nobind", "--client", "--dev", "tun", "--tls-client", "--remote-cert-tls", "server", "--management-signal", "--management", "/tmp/openvpn.socket", "unix", "--up", UPDATE_RESOLV_CONF, "--down", UPDATE_RESOLV_CONF, "--script-security", "2" ] ALLOWED_FLAGS = { "--remote": ["IP", "NUMBER", "PROTO"], "--tls-cipher": ["CIPHER"], "--cipher": ["CIPHER"], "--auth": ["CIPHER"], "--management-client-user": ["USER"], "--cert": ["FILE"], "--key": ["FILE"], "--ca": ["FILE"] } PARAM_FORMATS = { "NUMBER": lambda s: re.match("^\d+$", s), "PROTO": lambda s: re.match("^(tcp|udp)$", s), "IP": lambda s: is_valid_address(s), "CIPHER": lambda s: re.match("^[A-Z0-9-]+$", s), "USER": lambda s: re.match( "^[a-zA-Z0-9_\.\@][a-zA-Z0-9_\-\.\@]*\$?$", s), # IEEE Std 1003.1-2001 "FILE": lambda s: os.path.isfile(s) } DEBUG = os.getenv("DEBUG") if DEBUG: import logging formatter = logging.Formatter( "%(asctime)s - %(name)s - %(levelname)s - %(message)s") ch = logging.StreamHandler() ch.setLevel(logging.DEBUG) ch.setFormatter(formatter) logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) logger.addHandler(ch) logger.debug(" ".join(sys.argv)) ## ## UTILITY ## def is_valid_address(value): """ Validate that the passed ip is a valid IP address. :param value: the value to be validated :type value: str :rtype: bool """ try: socket.inet_aton(value) return True except Exception: print "MALFORMED IP: %s!" % value return False def split_list(list, regex): """ Splits a list based on a regex: e.g. split_list(["xx", "yy", "x1", "zz"], "^x") => [["xx", "yy"], ["x1", "zz"]] :param list: the list to be split. :type list: list :rtype: list """ if not hasattr(regex, "match"): regex = re.compile(regex) result = [] i = 0 while True: if regex.match(list[i]): result.append([]) while True: result[-1].append(list[i]) i += 1 if i >= len(list) or regex.match(list[i]): break else: i += 1 if i >= len(list): break return result # i think this is not needed with shell=False #def sanify(command, *args): # return [command] + [pipes.quote(a) for a in args] def run(command, *args, **options): parts = [command] parts.extend(args) if DEBUG: print "run: " + " ".join(parts) if not options.get("check", True) or options.get("detach", False): subprocess.Popen(parts) else: try: devnull = open('/dev/null', 'w') subprocess.check_call(parts, stdout=devnull, stderr=devnull) return 0 except subprocess.CalledProcessError as ex: if options.get("exitcode", False): return ex.returncode else: bail("Could not run %s: %s" % (ex.cmd, ex.output)) ## ## OPENVPN ## def parse_openvpn_flags(args): """ takes argument list from the command line and parses it, only allowing some configuration flags. """ result = [] try: for flag in split_list(args, "^--"): flag_name = flag[0] if flag_name in ALLOWED_FLAGS: result.append(flag_name) required_params = ALLOWED_FLAGS[flag_name] if len(required_params) > 0: flag_params = flag[1:] if len(flag_params) != len(required_params): print "ERROR: not enough params for %s" % flag_name return None for param, param_type in zip(flag_params, required_params): if PARAM_FORMATS[param_type](param): result.append(param) else: print "ERROR: Bad argument %s" % param return None else: print "WARNING: unrecognized openvpn flag %s" % flag_name return result except Exception as ex: print ex return None def openvpn_start(args): openvpn_flags = parse_openvpn_flags(args) if openvpn_flags: flags = FIXED_FLAGS + openvpn_flags run(OPENVPN, *flags, detach=True) else: bail('ERROR: could not parse openvpn options') def openvpn_stop(args): print "stop" ## ## FIREWALL ## def get_gateways(gateways): result = [gateway for gateway in gateways if is_valid_address(gateway)] if not len(result): bail("No valid gateways specified") else: return result def get_default_device(): routes = subprocess.check_output([IP, "route", "show"]) match = re.search("^default .*dev ([^\s]*) .*$", routes, flags=re.M) if len(match.groups()) >= 1: return match.group(1) else: bail("could not find default device") def get_local_network_ipv4(device): addresses = cmdcheck([IP, "-o", "address", "show", "dev", device]) match = re.search("^.*inet ([^ ]*) .*$", addresses, flags=re.M) if len(match.groups()) >= 1: return match.group(1) else: return None def get_local_network_ipv6(device): addresses = cmdcheck([IP, "-o", "address", "show", "dev", device]) match = re.search("^.*inet6 ([^ ]*) .*$", addresses, flags=re.M) if len(match.groups()) >= 1: return match.group(1) else: return None def run_iptable_with_check(cmd, *args, **options): """ runs an iptables command checking to see if it should: for --insert: run only if rule does not already exist. for --delete: run only if rule does exist. other commands are run normally. """ if "--insert" in args: check_args = [arg.replace("--insert", "--check") for arg in args] check_code = run(cmd, *check_args, exitcode=True) if check_code != 0: run(cmd, *args, **options) elif "--delete" in args: check_args = [arg.replace("--delete", "--check") for arg in args] check_code = run(cmd, *check_args, exitcode=True) if check_code == 0: run(cmd, *args, **options) else: run(cmd, *args, **options) def iptables(*args, **options): ip4tables(*args, **options) ip6tables(*args, **options) def ip4tables(*args, **options): run_iptable_with_check(IPTABLES, *args, **options) def ip6tables(*args, **options): run_iptable_with_check(IP6TABLES, *args, **options) def ipv4_chain_exists(table): code = run(IPTABLES, "--list", table, "--numeric", exitcode=True) return code == 0 def ipv6_chain_exists(table): code = run(IP6TABLES, "--list", table, "--numeric", exitcode=True) return code == 0 def firewall_start(args): default_device = get_default_device() local_network_ipv4 = get_local_network_ipv4(default_device) local_network_ipv6 = get_local_network_ipv6(default_device) gateways = get_gateways(args) # add custom chain "bitmask" if not ipv4_chain_exists("bitmask"): ip4tables("--new-chain", "bitmask") if not ipv6_chain_exists("bitmask"): ip6tables("--new-chain", "bitmask") iptables("--insert", "OUTPUT", "--jump", "bitmask") # reject everything iptables("--insert", "bitmask", "-o", default_device, "--jump", "REJECT") # allow traffic to gateways for gateway in gateways: ip4tables("--insert", "bitmask", "--destination", gateway, "-o", default_device, "--jump", "ACCEPT") # allow traffic to IPs on local network if local_network_ipv4: ip4tables("--insert", "bitmask", "--destination", local_network_ipv4, "-o", default_device, "--jump", "ACCEPT") if local_network_ipv6: ip6tables("--insert", "bitmask", "--destination", local_network_ipv6, "-o", default_device, "--jump", "ACCEPT") # block DNS requests to anyone but the service provider or localhost ip4tables("--insert", "bitmask", "--protocol", "udp", "--dport", "53", "--jump", "REJECT") for allowed_dns in gateways + ["127.0.0.1", "127.0.1.1"]: ip4tables("--insert", "bitmask", "--protocol", "udp", "--dport", "53", "--destination", allowed_dns, "--jump", "ACCEPT") def firewall_stop(args): iptables("--delete", "OUTPUT", "--jump", "bitmask") if ipv4_chain_exists("bitmask"): ip4tables("--flush", "bitmask") ip4tables("--delete-chain", "bitmask") if ipv6_chain_exists("bitmask"): ip6tables("--flush", "bitmask") ip6tables("--delete-chain", "bitmask") def bail(msg=""): if msg: print(msg) exit(1) def main(): if len(sys.argv) >= 3: command = "_".join(sys.argv[1:3]) args = sys.argv[3:] if command == "openvpn_start": openvpn_start(args) elif command == "openvpn_stop": openvpn_stop(args) elif command == "firewall_start": firewall_start(args) elif command == "firewall_stop": firewall_stop(args) else: bail("no such command") else: bail("no such command") if __name__ == "__main__": main() print "done" exit(0)