#!/usr/bin/python # -*- coding: utf-8 -*- # # Copyright (C) 2014 LEAP # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program. If not, see . # """ This is a privileged helper script for safely running certain commands as root. It should only be called by the Bitmask application. USAGE: bitmask-root firewall stop bitmask-root firewall start [restart] GATEWAY1 GATEWAY2 ... bitmask-root openvpn stop bitmask-root openvpn start CONFIG1 CONFIG1 ... All actions return exit code 0 for success, non-zero otherwise. The `openvpn start` action is special: it calls exec on openvpn and replaces the current process. If the `restart` parameter is passed, the firewall will not be teared down in the case of an error during launch. """ # TODO should be tested with python3, which can be the default on some distro. from __future__ import print_function import atexit import os import re import signal import socket import syslog import subprocess import sys import time import traceback cmdcheck = subprocess.check_output ## ## CONSTANTS ## VERSION = "1" SCRIPT = "bitmask-root" NAMESERVER = "10.42.0.1" BITMASK_CHAIN = "bitmask" IP = "/bin/ip" IPTABLES = "/sbin/iptables" IP6TABLES = "/sbin/ip6tables" RESOLVCONF_SYSTEM_BIN = "/sbin/resolvconf" RESOLVCONF_LEAP_BIN = "/usr/local/sbin/leap-resolvconf" OPENVPN_USER = "nobody" OPENVPN_GROUP = "nogroup" LEAPOPENVPN = "LEAPOPENVPN" OPENVPN_SYSTEM_BIN = "/usr/sbin/openvpn" # Debian location OPENVPN_LEAP_BIN = "/usr/local/sbin/leap-openvpn" # installed by bundle """ The path to the script to update resolv.conf """ # XXX We have to check if we have a valid resolvconf, and use # the old resolv-update if not. LEAP_UPDATE_RESOLVCONF_FILE = "/etc/leap/update-resolv-conf" LEAP_RESOLV_UPDATE = "/etc/leap/resolv-update" FIXED_FLAGS = [ "--setenv", "LEAPOPENVPN", "1", "--nobind", "--client", "--dev", "tun", "--tls-client", "--remote-cert-tls", "server", "--management-signal", "--script-security", "1", "--user", "nobody", "--group", "nogroup", "--remap-usr1", "SIGTERM", ] ALLOWED_FLAGS = { "--remote": ["IP", "NUMBER", "PROTO"], "--tls-cipher": ["CIPHER"], "--cipher": ["CIPHER"], "--auth": ["CIPHER"], "--management": ["DIR", "UNIXSOCKET"], "--management-client-user": ["USER"], "--cert": ["FILE"], "--key": ["FILE"], "--ca": ["FILE"] } PARAM_FORMATS = { "NUMBER": lambda s: re.match("^\d+$", s), "PROTO": lambda s: re.match("^(tcp|udp)$", s), "IP": lambda s: is_valid_address(s), "CIPHER": lambda s: re.match("^[A-Z0-9-]+$", s), "USER": lambda s: re.match( "^[a-zA-Z0-9_\.\@][a-zA-Z0-9_\-\.\@]*\$?$", s), # IEEE Std 1003.1-2001 "FILE": lambda s: os.path.isfile(s), "DIR": lambda s: os.path.isdir(os.path.split(s)[0]), "UNIXSOCKET": lambda s: s == "unix" } DEBUG = os.getenv("DEBUG") TEST = os.getenv("TEST") if DEBUG: import logging formatter = logging.Formatter( "%(asctime)s - %(name)s - %(levelname)s - %(message)s") ch = logging.StreamHandler() ch.setLevel(logging.DEBUG) ch.setFormatter(formatter) logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) logger.addHandler(ch) syslog.openlog(SCRIPT) ## ## UTILITY ## def is_valid_address(value): """ Validate that the passed ip is a valid IP address. :param value: the value to be validated :type value: str :rtype: bool """ try: socket.inet_aton(value) return True except Exception: log("%s: ERROR: MALFORMED IP: %s!" % (SCRIPT, value)) return False def has_system_resolvconf(): """ Return True if resolvconf is found in the system. :rtype: bool """ return os.path.isfile(RESOLVCONF) def has_valid_update_resolvconf(): """ Return True if a valid update-resolv-conf script is found in the system. :rtype: bool """ return os.path.isfile(LEAP_UPDATE_RESOLVCONF_FILE) def has_valid_leap_resolv_update(): """ Return True if a valid resolv-update script is found in the system. :rtype: bool """ return os.path.isfile(LEAP_RESOLV_UPDATE) def split_list(_list, regex): """ Split a list based on a regex: e.g. split_list(["xx", "yy", "x1", "zz"], "^x") => [["xx", "yy"], ["x1", "zz"]] :param _list: the list to be split. :type _list: list :param regex: the regex expression to filter with. :type regex: str :rtype: list """ if not hasattr(regex, "match"): regex = re.compile(regex) result = [] i = 0 if not _list: return result while True: if regex.match(_list[i]): result.append([]) while True: result[-1].append(_list[i]) i += 1 if i >= len(_list) or regex.match(_list[i]): break else: i += 1 if i >= len(_list): break return result def get_process_list(): """ Get a process list by reading `/proc` filesystem. :return: a list of tuples, each containing pid and command string. :rtype: tuple if lists """ res = [] pids = [pid for pid in os.listdir('/proc') if pid.isdigit()] for pid in pids: try: res.append((pid, open( os.path.join( '/proc', pid, 'cmdline'), 'rb').read())) except IOError: # proc has already terminated continue return filter(None, res) class Daemon(object): """ A generic daemon class. """ def __init__(self, pidfile, stdin='/dev/null', stdout='/dev/null', stderr='/dev/null'): self.stdin = stdin self.stdout = stdout self.stderr = stderr self.pidfile = pidfile def daemonize(self): """ Do the UNIX double-fork magic, see Stevens' "Advanced Programming in the UNIX Environment" for details (ISBN 0201563177) http://www.erlenstar.demon.co.uk/unix/faq_2.html#SEC16 """ try: pid = os.fork() if pid > 0: # exit first parent sys.exit(0) except OSError, e: sys.stderr.write( "fork #1 failed: %d (%s)\n" % (e.errno, e.strerror)) sys.exit(1) # decouple from parent environment os.chdir("/") os.setsid() os.umask(0) # do second fork try: pid = os.fork() if pid > 0: # exit from second parent sys.exit(0) except OSError, e: sys.stderr.write( "fork #2 failed: %d (%s)\n" % (e.errno, e.strerror)) sys.exit(1) # redirect standard file descriptors sys.stdout.flush() sys.stderr.flush() si = file(self.stdin, 'r') so = file(self.stdout, 'a+') se = file(self.stderr, 'a+', 0) os.dup2(si.fileno(), sys.stdin.fileno()) os.dup2(so.fileno(), sys.stdout.fileno()) os.dup2(se.fileno(), sys.stderr.fileno()) # write pidfile atexit.register(self.delpid) pid = str(os.getpid()) file(self.pidfile, 'w+').write("%s\n" % pid) def delpid(self): """ Delete the pidfile. """ os.remove(self.pidfile) def start(self, *args): """ Start the daemon. """ # Check for a pidfile to see if the daemon already runs try: pf = file(self.pidfile, 'r') pid = int(pf.read().strip()) pf.close() except IOError: pid = None if pid: message = "pidfile %s already exist. Daemon already running?\n" sys.stderr.write(message % self.pidfile) sys.exit(1) # Start the daemon self.daemonize() self.run(args) def stop(self): """ Stop the daemon. """ # Get the pid from the pidfile try: pf = file(self.pidfile, 'r') pid = int(pf.read().strip()) pf.close() except IOError: pid = None if not pid: message = "pidfile %s does not exist. Daemon not running?\n" sys.stderr.write(message % self.pidfile) return # not an error in a restart # Try killing the daemon process try: while 1: os.kill(pid, signal.SIGTERM) time.sleep(0.1) except OSError, err: err = str(err) if err.find("No such process") > 0: if os.path.exists(self.pidfile): os.remove(self.pidfile) else: log(str(err)) sys.exit(1) def restart(self): """ Restart the daemon. """ self.stop() self.start() def run(self): """ This should be overridden by derived classes. """ def run(command, *args, **options): """ Run an external command. Options: `check`: If True, check the command's output. bail if non-zero. (the default is true unless detach or input is true) `exitcode`: like `check`, but return exitcode instead of bailing. `detach`: If True, run in detached process. `input`: If True, open command for writing stream to, returning the Popen object. `throw`: If True, raise an exception if there is an error instead of bailing. """ parts = [command] parts.extend(args) debug("%s run: %s " % (SCRIPT, " ".join(parts))) _check = options.get("check", True) _detach = options.get("detach", False) _input = options.get("input", False) _exitcode = options.get("exitcode", False) _throw = options.get("throw", False) if not (_check or _throw) or _detach or _input: if _input: return subprocess.Popen(parts, stdin=subprocess.PIPE) else: subprocess.Popen(parts) return None else: try: devnull = open('/dev/null', 'w') subprocess.check_call(parts, stdout=devnull, stderr=devnull) return 0 except subprocess.CalledProcessError as exc: if DEBUG: logger.exception(exc) if _exitcode: debug("ERROR: Could not run %s: %s" % (exc.cmd, exc.output), exception=exc) return exc.returncode elif _throw: raise exc else: bail("ERROR: Could not run %s: %s" % (exc.cmd, exc.output), exception=exc) def log(msg=None, exception=None, priority=syslog.LOG_INFO): """ print and log warning message or exception. :param msg: optional error message. :type msg: str :param msg: optional exception. :type msg: Exception :param msg: syslog level :type msg: one of LOG_EMERG, LOG_ALERT, LOG_CRIT, LOG_ERR, LOG_WARNING, LOG_NOTICE, LOG_INFO, LOG_DEBUG """ if msg is not None: print("%s: %s" % (SCRIPT, msg)) syslog.syslog(priority, msg) if exception is not None: if TEST or DEBUG: traceback.print_exc() syslog.syslog(priority, traceback.format_exc()) def debug(msg=None, exception=None): """ Just like log, but is skipped unless DEBUG. Use syslog.LOG_INFO even for debug messages (we don't want to miss them). """ if TEST or DEBUG: log(msg, exception) def bail(msg=None, exception=None): """ abnormal exit. like log(), but exits with error status code. """ log(msg, exception) exit(1) ## ## OPENVPN ## def get_openvpn_bin(): """ Return the path for either the system openvpn or the one the bundle has put there. """ if os.path.isfile(OPENVPN_SYSTEM_BIN): return OPENVPN_SYSTEM_BIN # the bundle option should be removed from the debian package. if os.path.isfile(OPENVPN_LEAP_BIN): return OPENVPN_LEAP_BIN def parse_openvpn_flags(args): """ Take argument list from the command line and parse it, only allowing some configuration flags. :type args: list """ result = [] try: for flag in split_list(args, "^--"): flag_name = flag[0] if flag_name in ALLOWED_FLAGS: result.append(flag_name) required_params = ALLOWED_FLAGS[flag_name] if required_params: flag_params = flag[1:] if len(flag_params) != len(required_params): log("%s: ERROR: not enough params for %s" % (SCRIPT, flag_name)) return None for param, param_type in zip(flag_params, required_params): if PARAM_FORMATS[param_type](param): result.append(param) else: log("%s: ERROR: Bad argument %s" % (SCRIPT, param)) return None else: log("WARNING: unrecognized openvpn flag %s" % flag_name) return result except Exception as exc: log("%s: ERROR PARSING FLAGS: %s" % (SCRIPT, exc)) if DEBUG: logger.exception(exc) return None def openvpn_start(args): """ Launch openvpn, sanitizing input, and replacing the current process with the openvpn process. :param args: arguments to be passed to openvpn :type args: list """ openvpn_flags = parse_openvpn_flags(args) if openvpn_flags: OPENVPN = get_openvpn_bin() flags = [OPENVPN] + FIXED_FLAGS + openvpn_flags if DEBUG: log("%s: running openvpn with flags:" % (SCRIPT,)) log(flags) # note: first argument to command is ignored, but customarily set to # the command. os.execv(OPENVPN, flags) else: bail('ERROR: could not parse openvpn options') def openvpn_stop(args): """ Stop the openvpn that has likely been launched by bitmask. :param args: arguments to openvpn :type args: list """ plist = get_process_list() OPENVPN_BIN = get_openvpn_bin() found_leap_openvpn = filter( lambda (p, s): s.startswith(OPENVPN_BIN) and LEAPOPENVPN in s, plist) if found_leap_openvpn: pid = found_leap_openvpn[0][0] os.kill(int(pid), signal.SIGTERM) ## ## DNS ## def get_resolvconf_bin(): """ Return the path for either the system resolvconf or the one the bundle has put there. """ if os.path.isfile(RESOLVCONF_SYSTEM_BIN): return RESOLVCONF_SYSTEM_BIN # the bundle option should be removed from the debian package. if os.path.isfile(RESOLVCONF_LEAP_BIN): return RESOLVCONF_LEAP_BIN RESOLVCONF = get_resolvconf_bin() class NameserverSetter(Daemon): """ A daemon that will add leap nameserver inside the tunnel to the system `resolv.conf` """ def run(self, *args): """ Run when daemonized. """ if args: ip_address = args[0] self.set_dns_nameserver(ip_address) def set_dns_nameserver(self, ip_address): """ Add the tunnel DNS server to `resolv.conf` :param ip_address: the ip to add to `resolv.conf` :type ip_address: str """ if os.path.isfile(RESOLVCONF): process = run(RESOLVCONF, "-a", "bitmask", input=True) process.communicate("nameserver %s\n" % ip_address) else: bail("ERROR: package openresolv or resolvconf not installed.") nameserver_setter = NameserverSetter('/tmp/leap-dns-up.pid') class NameserverRestorer(Daemon): """ A daemon that will restore the previous nameservers. """ def run(self, *args): """ Run when daemonized. """ self.restore_dns_nameserver() def restore_dns_nameserver(self): """ Remove tunnel DNS server from `resolv.conf` """ if os.path.isfile(RESOLVCONF): run(RESOLVCONF, "-d", "bitmask") else: log("%s: ERROR: package openresolv " "or resolvconf not installed." % (SCRIPT,)) nameserver_restorer = NameserverRestorer('/tmp/leap-dns-down.pid') ## ## FIREWALL ## def get_gateways(gateways): """ Filter a passed sequence of gateways, returning only the valid ones. :param gateways: a sequence of gateways to filter. :type gateways: iterable :rtype: iterable """ result = filter(is_valid_address, gateways) if not result: bail("ERROR: No valid gateways specified") else: return result def get_default_device(): """ Retrieve the current default network device. :rtype: str """ routes = subprocess.check_output([IP, "route", "show"]) match = re.search("^default .*dev ([^\s]*) .*$", routes, flags=re.M) if match and match.groups(): return match.group(1) else: bail("Could not find default device") def get_local_network_ipv4(device): """ Get the local ipv4 addres for a given device. :param device: :type device: str """ addresses = cmdcheck([IP, "-o", "address", "show", "dev", device]) match = re.search("^.*inet ([^ ]*) .*$", addresses, flags=re.M) if match and match.groups(): return match.group(1) else: return None def get_local_network_ipv6(device): """ Get the local ipv6 addres for a given device. :param device: :type device: str """ addresses = cmdcheck([IP, "-o", "address", "show", "dev", device]) match = re.search("^.*inet6 ([^ ]*) .*$", addresses, flags=re.M) if match and match.groups(): return match.group(1) else: return None def run_iptable_with_check(cmd, *args, **options): """ Run an iptables command checking to see if it should: for --append: run only if rule does not already exist. for --insert: run only if rule does not already exist. for --delete: run only if rule does exist. other commands are run normally. """ if "--insert" in args: check_args = [arg.replace("--insert", "--check") for arg in args] check_code = run(cmd, *check_args, exitcode=True) if check_code != 0: run(cmd, *args, **options) elif "--append" in args: check_args = [arg.replace("--append", "--check") for arg in args] check_code = run(cmd, *check_args, exitcode=True) if check_code != 0: run(cmd, *args, **options) elif "--delete" in args: check_args = [arg.replace("--delete", "--check") for arg in args] check_code = run(cmd, *check_args, exitcode=True) if check_code == 0: run(cmd, *args, **options) else: run(cmd, *args, **options) def iptables(*args, **options): """ Run iptables4 and iptables6. """ ip4tables(*args, **options) ip6tables(*args, **options) def ip4tables(*args, **options): """ Run iptables4 with checks. """ run_iptable_with_check(IPTABLES, *args, **options) def ip6tables(*args, **options): """ Run iptables6 with checks. """ run_iptable_with_check(IP6TABLES, *args, **options) # # NOTE: these tests to see if a chain exists might incorrectly return false. # This happens when there is an error in calling `iptables --list bitmask`. # # For this reason, when stopping the firewall, we do not trust the # output of ipvx_chain_exists() but instead always attempt to delete # the chain. # def ipv4_chain_exists(chain): """ Check if a given chain exists. :param chain: the chain to check against :type chain: str :rtype: bool """ code = run(IPTABLES, "--list", chain, "--numeric", exitcode=True) if code == 0: return True elif code == 1: return False else: log("ERROR: Could not determine state of iptable chain") return False def ipv6_chain_exists(chain): """ Check if a given chain exists. :param chain: the chain to check against :type chain: str :rtype: bool """ code = run(IP6TABLES, "--list", chain, "--numeric", exitcode=True) if code == 0: return True elif code == 1: return False else: log("ERROR: Could not determine state of iptable chain") return False def firewall_start(args): """ Bring up the firewall. :param args: list of gateways, to be sanitized. :type args: list """ default_device = get_default_device() local_network_ipv4 = get_local_network_ipv4(default_device) local_network_ipv6 = get_local_network_ipv6(default_device) gateways = get_gateways(args) # add custom chain "bitmask" to front of OUTPUT chain if not ipv4_chain_exists(BITMASK_CHAIN): ip4tables("--new-chain", BITMASK_CHAIN) if not ipv6_chain_exists(BITMASK_CHAIN): ip6tables("--new-chain", BITMASK_CHAIN) iptables("--insert", "OUTPUT", "--jump", BITMASK_CHAIN) # allow DNS over VPN for allowed_dns in [NAMESERVER, "127.0.0.1", "127.0.1.1"]: ip4tables("--append", BITMASK_CHAIN, "--protocol", "udp", "--dport", "53", "--destination", allowed_dns, "--jump", "ACCEPT") # block DNS requests to anyone but the service provider or localhost # (when we actually route ipv6, we will need DNS rules for it too) ip4tables("--append", BITMASK_CHAIN, "--protocol", "udp", "--dport", "53", "--jump", "REJECT") # allow traffic to IPs on local network if local_network_ipv4: ip4tables("--append", BITMASK_CHAIN, "--destination", local_network_ipv4, "-o", default_device, "--jump", "ACCEPT") # allow multicast Simple Service Discovery Protocol ip4tables("--append", BITMASK_CHAIN, "--protocol", "udp", "--destination", "239.255.255.250", "--dport", "1900", "-o", default_device, "--jump", "RETURN") # allow multicast Bonjour/mDNS ip4tables("--append", BITMASK_CHAIN, "--protocol", "udp", "--destination", "224.0.0.251", "--dport", "5353", "-o", default_device, "--jump", "RETURN") if local_network_ipv6: ip6tables("--append", BITMASK_CHAIN, "--destination", local_network_ipv6, "-o", default_device, "--jump", "ACCEPT") # allow multicast Simple Service Discovery Protocol ip6tables("--append", BITMASK_CHAIN, "--protocol", "udp", "--destination", "FF05::C", "--dport", "1900", "-o", default_device, "--jump", "RETURN") # allow multicast Bonjour/mDNS ip6tables("--append", BITMASK_CHAIN, "--protocol", "udp", "--destination", "FF02::FB", "--dport", "5353", "-o", default_device, "--jump", "RETURN") # allow ipv4 traffic to gateways for gateway in gateways: ip4tables("--append", BITMASK_CHAIN, "--destination", gateway, "-o", default_device, "--jump", "ACCEPT") # log rejected packets to syslog if DEBUG: iptables("--append", BITMASK_CHAIN, "-o", default_device, "--jump", "LOG", "--log-prefix", "iptables denied: ", "--log-level", "7") # for now, ensure all other ipv6 packets get rejected (regardless of # device) # (not sure why, but "-p any" doesn't work) ip6tables("--append", BITMASK_CHAIN, "-p", "tcp", "--jump", "REJECT") ip6tables("--append", BITMASK_CHAIN, "-p", "udp", "--jump", "REJECT") # reject all other ipv4 sent over the default device ip4tables("--append", BITMASK_CHAIN, "-o", default_device, "--jump", "REJECT") def firewall_stop(): """ Stop the firewall. Because we really really always want the firewall to be stopped if at all possible, this function is cautious and contains a lot of trys and excepts. If there were any problems, we raise an exception at the end. This allows the calling code to retry stopping the firewall. Stopping the firewall can fail if iptables is being run by another process (only one iptables command can be run at a time). """ ok = True try: iptables("--wait", "--delete", "OUTPUT", "--jump", BITMASK_CHAIN, throw=True) except subprocess.CalledProcessError as exc: log("INFO: not able to remove bitmask firewall from OUTPUT chain (maybe it is already removed?)", exc) ok = False try: ip4tables("--flush", BITMASK_CHAIN, throw=True) ip4tables("--delete-chain", BITMASK_CHAIN, throw=True) except subprocess.CalledProcessError as exc: log("INFO: not able to flush and delete bitmask ipv4 firewall chain (maybe it is already destroyed?)", exc) ok = False try: ip6tables("--flush", BITMASK_CHAIN, throw=True) ip6tables("--delete-chain", BITMASK_CHAIN, throw=True) except subprocess.CalledProcessError as exc: log("INFO: not able to flush and delete bitmask ipv6 firewall chain (maybe it is already destroyed?)", exc) ok = False if not ok: raise Exception("firewall might still be left up. Please try `firewall stop` again.") ## ## MAIN ## def main(): """ Entry point for cmdline execution. """ # TODO use argparse instead. if len(sys.argv) >= 2: command = "_".join(sys.argv[1:3]) args = sys.argv[3:] is_restart = False if args and args[0] == "restart": is_restart = True args.remove('restart') if command == "version": print(VERSION) exit(0) if os.getuid() != 0: bail("ERROR: must be run as root") if command == "openvpn_start": openvpn_start(args) elif command == "openvpn_stop": openvpn_stop(args) elif command == "firewall_start": try: firewall_start(args) nameserver_setter.start(NAMESERVER) except Exception as ex: if not is_restart: nameserver_restorer.start() firewall_stop() bail("ERROR: could not start firewall", ex) elif command == "firewall_stop": try: firewall_stop() nameserver_restorer.start() except Exception as ex: bail("ERROR: could not stop firewall", ex) elif command == "firewall_isup": if ipv4_chain_exists(BITMASK_CHAIN): log("%s: INFO: bitmask firewall is up" % (SCRIPT,)) else: bail("INFO: bitmask firewall is down") else: bail("ERROR: No such command") else: bail("ERROR: No such command") if __name__ == "__main__": debug(" ".join(sys.argv)) main() log("%s: done" % (SCRIPT,)) exit(0)