diff options
author | Tomás Touceda <chiiph@leap.se> | 2014-05-16 15:50:41 -0300 |
---|---|---|
committer | Tomás Touceda <chiiph@leap.se> | 2014-05-16 15:50:41 -0300 |
commit | 2bb3e9200adf7fcc1781ba7ff119f6dd6d09f7dc (patch) | |
tree | 400ac91ea22e121344b123ea6a0ac1ba6769120d /pkg/linux/bitmask-root | |
parent | f995ff502089bcb7ef6a0fe71c692ee7274d45cb (diff) | |
parent | dfbe8c4f0158366e91ea5118e5aa68c07d28ddbf (diff) |
Merge remote-tracking branch 'kali/linux-firewall-root-py-2' into develop
Diffstat (limited to 'pkg/linux/bitmask-root')
-rwxr-xr-x | pkg/linux/bitmask-root | 829 |
1 files changed, 829 insertions, 0 deletions
diff --git a/pkg/linux/bitmask-root b/pkg/linux/bitmask-root new file mode 100755 index 00000000..136fd6a4 --- /dev/null +++ b/pkg/linux/bitmask-root @@ -0,0 +1,829 @@ +#!/usr/bin/python +# -*- coding: utf-8 -*- +# +# Copyright (C) 2014 LEAP +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. +# +""" +This is a privileged helper script for safely running certain commands as root. +It should only be called by the Bitmask application. + +USAGE: + bitmask-root firewall stop + bitmask-root firewall start GATEWAY1 GATEWAY2 ... + bitmask-root openvpn stop + bitmask-root openvpn start CONFIG1 CONFIG1 ... + +All actions return exit code 0 for success, non-zero otherwise. + +The `openvpn start` action is special: it calls exec on openvpn and replaces +the current process. +""" +# TODO should be tested with python3, which can be the default on some distro. +from __future__ import print_function +import atexit +import os +import re +import signal +import socket +import subprocess +import sys +import time +import traceback + + +cmdcheck = subprocess.check_output + +## +## CONSTANTS +## + +SCRIPT = "bitmask-root" +NAMESERVER = "10.42.0.1" +BITMASK_CHAIN = "bitmask" + +IP = "/bin/ip" +IPTABLES = "/sbin/iptables" +IP6TABLES = "/sbin/ip6tables" + +RESOLVCONF_SYSTEM_BIN = "/sbin/resolvconf" +RESOLVCONF_LEAP_BIN = "/usr/local/sbin/leap-resolvconf" + +OPENVPN_USER = "nobody" +OPENVPN_GROUP = "nogroup" +LEAPOPENVPN = "LEAPOPENVPN" +OPENVPN_SYSTEM_BIN = "/usr/sbin/openvpn" # Debian location +OPENVPN_LEAP_BIN = "/usr/sbin/leap-openvpn" # installed by bundle + + +""" +The path to the script to update resolv.conf +""" +# XXX We have to check if we have a valid resolvconf, and use +# the old resolv-update if not. +LEAP_UPDATE_RESOLVCONF_FILE = "/etc/leap/update-resolv-conf" +LEAP_RESOLV_UPDATE = "/etc/leap/resolv-update" + +FIXED_FLAGS = [ + "--setenv", "LEAPOPENVPN", "1", + "--nobind", + "--client", + "--dev", "tun", + "--tls-client", + "--remote-cert-tls", "server", + "--management-signal", + "--script-security", "1", + "--user", "nobody", + "--group", "nogroup", + "--remap-usr1", "SIGTERM", +] + +ALLOWED_FLAGS = { + "--remote": ["IP", "NUMBER", "PROTO"], + "--tls-cipher": ["CIPHER"], + "--cipher": ["CIPHER"], + "--auth": ["CIPHER"], + "--management": ["DIR", "UNIXSOCKET"], + "--management-client-user": ["USER"], + "--cert": ["FILE"], + "--key": ["FILE"], + "--ca": ["FILE"] +} + +PARAM_FORMATS = { + "NUMBER": lambda s: re.match("^\d+$", s), + "PROTO": lambda s: re.match("^(tcp|udp)$", s), + "IP": lambda s: is_valid_address(s), + "CIPHER": lambda s: re.match("^[A-Z0-9-]+$", s), + "USER": lambda s: re.match( + "^[a-zA-Z0-9_\.\@][a-zA-Z0-9_\-\.\@]*\$?$", s), # IEEE Std 1003.1-2001 + "FILE": lambda s: os.path.isfile(s), + "DIR": lambda s: os.path.isdir(os.path.split(s)[0]), + "UNIXSOCKET": lambda s: s == "unix" +} + + +DEBUG = os.getenv("DEBUG") +TEST = os.getenv("TEST") + +if DEBUG: + import logging + formatter = logging.Formatter( + "%(asctime)s - %(name)s - %(levelname)s - %(message)s") + ch = logging.StreamHandler() + ch.setLevel(logging.DEBUG) + ch.setFormatter(formatter) + logger = logging.getLogger(__name__) + logger.setLevel(logging.DEBUG) + logger.addHandler(ch) + +## +## UTILITY +## + + +def is_valid_address(value): + """ + Validate that the passed ip is a valid IP address. + + :param value: the value to be validated + :type value: str + :rtype: bool + """ + try: + socket.inet_aton(value) + return True + except Exception: + print("%s: ERROR: MALFORMED IP: %s!" % (SCRIPT, value)) + return False + + +def has_system_resolvconf(): + """ + Return True if resolvconf is found in the system. + + :rtype: bool + """ + return os.path.isfile(RESOLVCONF) + + +def has_valid_update_resolvconf(): + """ + Return True if a valid update-resolv-conf script is found in the system. + + :rtype: bool + """ + return os.path.isfile(LEAP_UPDATE_RESOLVCONF_FILE) + + +def has_valid_leap_resolv_update(): + """ + Return True if a valid resolv-update script is found in the system. + + :rtype: bool + """ + return os.path.isfile(LEAP_RESOLV_UPDATE) + + +def split_list(_list, regex): + """ + Split a list based on a regex: + e.g. split_list(["xx", "yy", "x1", "zz"], "^x") => [["xx", "yy"], ["x1", + "zz"]] + + :param _list: the list to be split. + :type _list: list + :param regex: the regex expression to filter with. + :type regex: str + + :rtype: list + """ + if not hasattr(regex, "match"): + regex = re.compile(regex) + result = [] + i = 0 + if not _list: + return result + while True: + if regex.match(_list[i]): + result.append([]) + while True: + result[-1].append(_list[i]) + i += 1 + if i >= len(_list) or regex.match(_list[i]): + break + else: + i += 1 + if i >= len(_list): + break + return result + + +def get_process_list(): + """ + Get a process list by reading `/proc` filesystem. + + :return: a list of tuples, each containing pid and command string. + :rtype: tuple if lists + """ + res = [] + pids = [pid for pid in os.listdir('/proc') if pid.isdigit()] + + for pid in pids: + try: + res.append((pid, open( + os.path.join( + '/proc', pid, 'cmdline'), 'rb').read())) + except IOError: # proc has already terminated + continue + return filter(None, res) + + +class Daemon(object): + """ + A generic daemon class. + """ + def __init__(self, pidfile, stdin='/dev/null', + stdout='/dev/null', stderr='/dev/null'): + self.stdin = stdin + self.stdout = stdout + self.stderr = stderr + self.pidfile = pidfile + + def daemonize(self): + """ + Do the UNIX double-fork magic, see Stevens' "Advanced + Programming in the UNIX Environment" for details (ISBN 0201563177) + http://www.erlenstar.demon.co.uk/unix/faq_2.html#SEC16 + """ + try: + pid = os.fork() + if pid > 0: + # exit first parent + sys.exit(0) + except OSError, e: + sys.stderr.write( + "fork #1 failed: %d (%s)\n" % (e.errno, e.strerror)) + sys.exit(1) + + # decouple from parent environment + os.chdir("/") + os.setsid() + os.umask(0) + + # do second fork + try: + pid = os.fork() + if pid > 0: + # exit from second parent + sys.exit(0) + except OSError, e: + sys.stderr.write( + "fork #2 failed: %d (%s)\n" % (e.errno, e.strerror)) + sys.exit(1) + + # redirect standard file descriptors + sys.stdout.flush() + sys.stderr.flush() + si = file(self.stdin, 'r') + so = file(self.stdout, 'a+') + se = file(self.stderr, 'a+', 0) + os.dup2(si.fileno(), sys.stdin.fileno()) + os.dup2(so.fileno(), sys.stdout.fileno()) + os.dup2(se.fileno(), sys.stderr.fileno()) + + # write pidfile + atexit.register(self.delpid) + pid = str(os.getpid()) + file(self.pidfile, 'w+').write("%s\n" % pid) + + def delpid(self): + """ + Delete the pidfile. + """ + os.remove(self.pidfile) + + def start(self, *args): + """ + Start the daemon. + """ + # Check for a pidfile to see if the daemon already runs + try: + pf = file(self.pidfile, 'r') + pid = int(pf.read().strip()) + pf.close() + except IOError: + pid = None + + if pid: + message = "pidfile %s already exist. Daemon already running?\n" + sys.stderr.write(message % self.pidfile) + sys.exit(1) + + # Start the daemon + self.daemonize() + self.run(args) + + def stop(self): + """ + Stop the daemon. + """ + # Get the pid from the pidfile + try: + pf = file(self.pidfile, 'r') + pid = int(pf.read().strip()) + pf.close() + except IOError: + pid = None + + if not pid: + message = "pidfile %s does not exist. Daemon not running?\n" + sys.stderr.write(message % self.pidfile) + return # not an error in a restart + + # Try killing the daemon process + try: + while 1: + os.kill(pid, signal.SIGTERM) + time.sleep(0.1) + except OSError, err: + err = str(err) + if err.find("No such process") > 0: + if os.path.exists(self.pidfile): + os.remove(self.pidfile) + else: + print(str(err)) + sys.exit(1) + + def restart(self): + """ + Restart the daemon. + """ + self.stop() + self.start() + + def run(self): + """ + This should be overridden by derived classes. + """ + + +def run(command, *args, **options): + """ + Run an external command. + + Options: + + `check`: If True, check the command's output. bail if non-zero. (the + default is true unless detach or input is true) + `exitcode`: like `check`, but return exitcode instead of bailing. + `detach`: If True, run in detached process. + `input`: If True, open command for writing stream to, returning the Popen + object. + """ + parts = [command] + parts.extend(args) + if TEST or DEBUG: + print("%s run: %s " (SCRIPT, " ".join(parts))) + + _check = options.get("check", True) + _detach = options.get("detach", False) + _input = options.get("input", False) + _exitcode = options.get("exitcode", False) + + if not _check or _detach or _input: + if _input: + return subprocess.Popen(parts, stdin=subprocess.PIPE) + else: + # XXX ok with return None ?? + subprocess.Popen(parts) + else: + try: + devnull = open('/dev/null', 'w') + subprocess.check_call(parts, stdout=devnull, stderr=devnull) + return 0 + except subprocess.CalledProcessError as exc: + if DEBUG: + logger.exception(exc) + if _exitcode: + return exc.returncode + else: + bail("ERROR: Could not run %s: %s" % (exc.cmd, exc.output), + exception=exc) + + +def bail(msg=None, exception=None): + """ + Abnormal exit. + + :param msg: optional error message. + :type msg: str + """ + if msg is not None: + print("%s: %s" % (SCRIPT, msg)) + if exception is not None: + traceback.print_exc() + exit(1) + +## +## OPENVPN +## + + +def get_openvpn_bin(): + """ + Return the path for either the system openvpn or the one the + bundle has put there. + """ + if os.path.isfile(OPENVPN_SYSTEM_BIN): + return OPENVPN_SYSTEM_BIN + + # the bundle option should be removed from the debian package. + if os.path.isfile(OPENVPN_LEAP_BIN): + return OPENVPN_LEAP_BIN + + +def parse_openvpn_flags(args): + """ + Take argument list from the command line and parse it, only allowing some + configuration flags. + + :type args: list + """ + result = [] + try: + for flag in split_list(args, "^--"): + flag_name = flag[0] + if flag_name in ALLOWED_FLAGS: + result.append(flag_name) + required_params = ALLOWED_FLAGS[flag_name] + if required_params: + flag_params = flag[1:] + if len(flag_params) != len(required_params): + print("%s: ERROR: not enough params for %s" % + (SCRIPT, flag_name)) + return None + for param, param_type in zip(flag_params, required_params): + if PARAM_FORMATS[param_type](param): + result.append(param) + else: + print("%s: ERROR: Bad argument %s" % + (SCRIPT, param)) + return None + else: + print("WARNING: unrecognized openvpn flag %s" % flag_name) + return result + except Exception as exc: + print("%s: ERROR PARSING FLAGS: %s" % (SCRIPT, exc)) + if DEBUG: + logger.exception(exc) + return None + + +def openvpn_start(args): + """ + Launch openvpn, sanitizing input, and replacing the current process with + the openvpn process. + + :param args: arguments to be passed to openvpn + :type args: list + """ + openvpn_flags = parse_openvpn_flags(args) + if openvpn_flags: + OPENVPN = get_openvpn_bin() + flags = [OPENVPN] + FIXED_FLAGS + openvpn_flags + if DEBUG: + print("%s: running openvpn with flags:" % (SCRIPT,)) + print(flags) + # note: first argument to command is ignored, but customarily set to + # the command. + os.execv(OPENVPN, flags) + else: + bail('ERROR: could not parse openvpn options') + + +def openvpn_stop(args): + """ + Stop the openvpn that has likely been launched by bitmask. + + :param args: arguments to openvpn + :type args: list + """ + plist = get_process_list() + OPENVPN_BIN = get_openvpn_bin() + found_leap_openvpn = filter( + lambda (p, s): s.startswith(OPENVPN_BIN) and LEAPOPENVPN in s, + plist) + + if found_leap_openvpn: + pid = found_leap_openvpn[0][0] + os.kill(int(pid), signal.SIGTERM) + +## +## DNS +## + + +def get_resolvconf_bin(): + """ + Return the path for either the system resolvconf or the one the + bundle has put there. + """ + if os.path.isfile(RESOLVCONF_SYSTEM_BIN): + return RESOLVCONF_SYSTEM_BIN + + # the bundle option should be removed from the debian package. + if os.path.isfile(RESOLVCONF_LEAP_BIN): + return RESOLVCONF_LEAP_BIN + +RESOLVCONF = get_resolvconf_bin() + + +class NameserverSetter(Daemon): + """ + A daemon that will add leap nameserver inside the tunnel + to the system `resolv.conf` + """ + + def run(self, *args): + """ + Run when daemonized. + """ + if args: + ip_address = args[0] + self.set_dns_nameserver(ip_address) + + def set_dns_nameserver(self, ip_address): + """ + Add the tunnel DNS server to `resolv.conf` + + :param ip_address: the ip to add to `resolv.conf` + :type ip_address: str + """ + if os.path.isfile(RESOLVCONF): + process = run(RESOLVCONF, "-a", "bitmask", input=True) + process.communicate("nameserver %s\n" % ip_address) + else: + bail("ERROR: package openresolv or resolvconf not installed.") + +nameserver_setter = NameserverSetter('/tmp/leap-dns-up.pid') + + +class NameserverRestorer(Daemon): + """ + A daemon that will restore the previous nameservers. + """ + + def run(self): + """ + Run when daemonized. + """ + self.restore_dns_nameserver() + + def restore_dns_nameserver(self): + """ + Remove tunnel DNS server from `resolv.conf` + """ + if os.path.isfile(RESOLVCONF): + run(RESOLVCONF, "-d", "bitmask") + else: + print("%s: ERROR: package openresolv " + "or resolvconf not installed." % + (SCRIPT,)) + +nameserver_restorer = NameserverRestorer('/tmp/leap-dns-down.pid') + + +## +## FIREWALL +## + + +def get_gateways(gateways): + """ + Filter a passed sequence of gateways, returning only the valid ones. + + :param gateways: a sequence of gateways to filter. + :type gateways: iterable + :rtype: iterable + """ + result = filter(is_valid_address, gateways) + if not result: + bail("ERROR: No valid gateways specified") + else: + return result + + +def get_default_device(): + """ + Retrieve the current default network device. + + :rtype: str + """ + routes = subprocess.check_output([IP, "route", "show"]) + match = re.search("^default .*dev ([^\s]*) .*$", routes, flags=re.M) + if match.groups(): + return match.group(1) + else: + bail("Could not find default device") + + +def get_local_network_ipv4(device): + """ + Get the local ipv4 addres for a given device. + + :param device: + :type device: str + """ + addresses = cmdcheck([IP, "-o", "address", "show", "dev", device]) + match = re.search("^.*inet ([^ ]*) .*$", addresses, flags=re.M) + if match.groups(): + return match.group(1) + else: + return None + + +def get_local_network_ipv6(device): + """ + Get the local ipv6 addres for a given device. + + :param device: + :type device: str + """ + addresses = cmdcheck([IP, "-o", "address", "show", "dev", device]) + match = re.search("^.*inet6 ([^ ]*) .*$", addresses, flags=re.M) + if match.groups(): + return match.group(1) + else: + return None + + +def run_iptable_with_check(cmd, *args, **options): + """ + Run an iptables command checking to see if it should: + for --insert: run only if rule does not already exist. + for --delete: run only if rule does exist. + other commands are run normally. + """ + if "--insert" in args: + check_args = [arg.replace("--insert", "--check") for arg in args] + check_code = run(cmd, *check_args, exitcode=True) + if check_code != 0: + run(cmd, *args, **options) + elif "--delete" in args: + check_args = [arg.replace("--delete", "--check") for arg in args] + check_code = run(cmd, *check_args, exitcode=True) + if check_code == 0: + run(cmd, *args, **options) + else: + run(cmd, *args, **options) + + +def iptables(*args, **options): + """ + Run iptables4 and iptables6. + """ + ip4tables(*args, **options) + ip6tables(*args, **options) + + +def ip4tables(*args, **options): + """ + Run iptables4 with checks. + """ + run_iptable_with_check(IPTABLES, *args, **options) + + +def ip6tables(*args, **options): + """ + Run iptables6 with checks. + """ + run_iptable_with_check(IP6TABLES, *args, **options) + + +def ipv4_chain_exists(table): + """ + Check if a given chain exists. + + :param table: the table to check against + :type table: str + :rtype: bool + """ + code = run(IPTABLES, "--list", table, "--numeric", exitcode=True) + return code == 0 + + +def ipv6_chain_exists(table): + """ + Check if a given chain exists. + + :param table: the table to check against + :type table: str + :rtype: bool + """ + code = run(IP6TABLES, "--list", table, "--numeric", exitcode=True) + return code == 0 + + +def firewall_start(args): + """ + Bring up the firewall. + + :param args: list of gateways, to be sanitized. + :type args: list + """ + default_device = get_default_device() + local_network_ipv4 = get_local_network_ipv4(default_device) + local_network_ipv6 = get_local_network_ipv6(default_device) + gateways = get_gateways(args) + + # add custom chain "bitmask" + if not ipv4_chain_exists(BITMASK_CHAIN): + ip4tables("--new-chain", BITMASK_CHAIN) + if not ipv6_chain_exists(BITMASK_CHAIN): + ip6tables("--new-chain", BITMASK_CHAIN) + iptables("--insert", "OUTPUT", "--jump", BITMASK_CHAIN) + + # reject everything + iptables("--insert", BITMASK_CHAIN, "-o", default_device, + "--jump", "REJECT") + + # allow traffic to gateways + for gateway in gateways: + ip4tables("--insert", BITMASK_CHAIN, "--destination", gateway, + "-o", default_device, "--jump", "ACCEPT") + + # allow traffic to IPs on local network + if local_network_ipv4: + ip4tables("--insert", BITMASK_CHAIN, + "--destination", local_network_ipv4, "-o", default_device, + "--jump", "ACCEPT") + if local_network_ipv6: + ip6tables("--insert", BITMASK_CHAIN, + "--destination", local_network_ipv6, "-o", default_device, + "--jump", "ACCEPT") + + # block DNS requests to anyone but the service provider or localhost + # when we actually route ipv6, we will need dns rules for it too + ip4tables("--insert", BITMASK_CHAIN, "--protocol", "udp", "--dport", "53", + "--jump", "REJECT") + + for allowed_dns in [NAMESERVER, "127.0.0.1", "127.0.1.1"]: + ip4tables("--insert", BITMASK_CHAIN, "--protocol", "udp", + "--dport", "53", "--destination", allowed_dns, + "--jump", "ACCEPT") + + +def firewall_stop(): + """ + Stop the firewall. + """ + iptables("--delete", "OUTPUT", "--jump", BITMASK_CHAIN) + if ipv4_chain_exists(BITMASK_CHAIN): + ip4tables("--flush", BITMASK_CHAIN) + ip4tables("--delete-chain", BITMASK_CHAIN) + if ipv6_chain_exists(BITMASK_CHAIN): + ip6tables("--flush", BITMASK_CHAIN) + ip6tables("--delete-chain", BITMASK_CHAIN) + +## +## MAIN +## + + +def main(): + if len(sys.argv) >= 3: + command = "_".join(sys.argv[1:3]) + args = sys.argv[3:] + + if command == "openvpn_start": + openvpn_start(args) + + elif command == "openvpn_stop": + openvpn_stop(args) + + elif command == "firewall_start": + try: + firewall_start(args) + nameserver_setter.start(NAMESERVER) + except Exception as ex: + nameserver_restorer.start() + firewall_stop() + bail("ERROR: could not start firewall", ex) + + elif command == "firewall_stop": + try: + firewall_stop() + nameserver_restorer.start() + except Exception as ex: + bail("ERROR: could not stop firewall", ex) + + elif command == "firewall_isup": + if ipv4_chain_exists(BITMASK_CHAIN): + print("%s: INFO: bitmask firewall is up" % (SCRIPT,)) + else: + bail("INFO: bitmask firewall is down") + + else: + bail("ERROR: No such command") + else: + bail("ERROR: No such command") + +if __name__ == "__main__": + if DEBUG: + logger.debug(" ".join(sys.argv)) + main() + print("%s: done" % (SCRIPT,)) + exit(0) |