#!/usr/bin/python # -*- coding: utf-8 -*- # # Copyright (C) 2014 LEAP # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program. If not, see . # """ This is a privileged helper script for safely running certain commands as root. It should only be called by the Bitmask application. USAGE: bitmask-root firewall stop bitmask-root firewall start GATEWAY1 GATEWAY2 ... bitmask-root openvpn stop bitmask-root openvpn start CONFIG1 CONFIG1 ... All actions return exit code 0 for success, non-zero otherwise. The `openvpn start` action is special: it calls exec on openvpn and replaces the current process. """ # TODO should be tested with python3, which can be the default on some distro. from __future__ import print_function import atexit import os import re import signal import socket import subprocess import sys import time import traceback cmdcheck = subprocess.check_output ## ## CONSTANTS ## SCRIPT = "bitmask-root" NAMESERVER = "10.42.0.1" BITMASK_CHAIN = "bitmask" IP = "/bin/ip" IPTABLES = "/sbin/iptables" IP6TABLES = "/sbin/ip6tables" RESOLVCONF = "/sbin/resolvconf" OPENVPN_USER = "nobody" OPENVPN_GROUP = "nogroup" LEAPOPENVPN = "LEAPOPENVPN" OPENVPN_SYSTEM_BIN = "/usr/sbin/openvpn" # Debian location OPENVPN_LEAP_BIN = "/usr/sbin/leap-openvpn" # installed by bundle """ The path to the script to update resolv.conf """ # XXX We have to check if we have a valid resolvconf, and use # the old resolv-update if not. LEAP_UPDATE_RESOLVCONF_FILE = "/etc/leap/update-resolv-conf" LEAP_RESOLV_UPDATE = "/etc/leap/resolv-update" FIXED_FLAGS = [ "--setenv", "LEAPOPENVPN", "1", "--nobind", "--client", "--dev", "tun", "--tls-client", "--remote-cert-tls", "server", "--management-signal", "--script-security", "1", "--user", "nobody", "--group", "nogroup", ] ALLOWED_FLAGS = { "--remote": ["IP", "NUMBER", "PROTO"], "--tls-cipher": ["CIPHER"], "--cipher": ["CIPHER"], "--auth": ["CIPHER"], "--management": ["DIR", "UNIXSOCKET"], "--management-client-user": ["USER"], "--cert": ["FILE"], "--key": ["FILE"], "--ca": ["FILE"] } PARAM_FORMATS = { "NUMBER": lambda s: re.match("^\d+$", s), "PROTO": lambda s: re.match("^(tcp|udp)$", s), "IP": lambda s: is_valid_address(s), "CIPHER": lambda s: re.match("^[A-Z0-9-]+$", s), "USER": lambda s: re.match( "^[a-zA-Z0-9_\.\@][a-zA-Z0-9_\-\.\@]*\$?$", s), # IEEE Std 1003.1-2001 "FILE": lambda s: os.path.isfile(s), "DIR": lambda s: os.path.isdir(os.path.split(s)[0]), "UNIXSOCKET": lambda s: s == "unix" } DEBUG = os.getenv("DEBUG") TEST = os.getenv("TEST") if DEBUG: import logging formatter = logging.Formatter( "%(asctime)s - %(name)s - %(levelname)s - %(message)s") ch = logging.StreamHandler() ch.setLevel(logging.DEBUG) ch.setFormatter(formatter) logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) logger.addHandler(ch) ## ## UTILITY ## def is_valid_address(value): """ Validate that the passed ip is a valid IP address. :param value: the value to be validated :type value: str :rtype: bool """ try: socket.inet_aton(value) return True except Exception: print("%s: ERROR: MALFORMED IP: %s!" % (SCRIPT, value)) return False def has_system_resolvconf(): """ Return True if resolvconf is found in the system. :rtype: bool """ return os.path.isfile(RESOLVCONF) def has_valid_update_resolvconf(): """ Return True if a valid update-resolv-conf script is found in the system. :rtype: bool """ return os.path.isfile(LEAP_UPDATE_RESOLVCONF_FILE) def has_valid_leap_resolv_update(): """ Return True if a valid resolv-update script is found in the system. :rtype: bool """ return os.path.isfile(LEAP_RESOLV_UPDATE) def split_list(_list, regex): """ Split a list based on a regex: e.g. split_list(["xx", "yy", "x1", "zz"], "^x") => [["xx", "yy"], ["x1", "zz"]] :param _list: the list to be split. :type _list: list :param regex: the regex expression to filter with. :type regex: str :rtype: list """ if not hasattr(regex, "match"): regex = re.compile(regex) result = [] i = 0 if not _list: return result while True: if regex.match(_list[i]): result.append([]) while True: result[-1].append(_list[i]) i += 1 if i >= len(_list) or regex.match(_list[i]): break else: i += 1 if i >= len(_list): break return result def get_process_list(): """ Get a process list by reading `/proc` filesystem. :return: a list of tuples, each containing pid and command string. :rtype: tuple if lists """ res = [] pids = [pid for pid in os.listdir('/proc') if pid.isdigit()] for pid in pids: try: res.append((pid, open( os.path.join( '/proc', pid, 'cmdline'), 'rb').read())) except IOError: # proc has already terminated continue return filter(None, res) class Daemon(object): """ A generic daemon class. """ def __init__(self, pidfile, stdin='/dev/null', stdout='/dev/null', stderr='/dev/null'): self.stdin = stdin self.stdout = stdout self.stderr = stderr self.pidfile = pidfile def daemonize(self): """ Do the UNIX double-fork magic, see Stevens' "Advanced Programming in the UNIX Environment" for details (ISBN 0201563177) http://www.erlenstar.demon.co.uk/unix/faq_2.html#SEC16 """ try: pid = os.fork() if pid > 0: # exit first parent sys.exit(0) except OSError, e: sys.stderr.write( "fork #1 failed: %d (%s)\n" % (e.errno, e.strerror)) sys.exit(1) # decouple from parent environment os.chdir("/") os.setsid() os.umask(0) # do second fork try: pid = os.fork() if pid > 0: # exit from second parent sys.exit(0) except OSError, e: sys.stderr.write( "fork #2 failed: %d (%s)\n" % (e.errno, e.strerror)) sys.exit(1) # redirect standard file descriptors sys.stdout.flush() sys.stderr.flush() si = file(self.stdin, 'r') so = file(self.stdout, 'a+') se = file(self.stderr, 'a+', 0) os.dup2(si.fileno(), sys.stdin.fileno()) os.dup2(so.fileno(), sys.stdout.fileno()) os.dup2(se.fileno(), sys.stderr.fileno()) # write pidfile atexit.register(self.delpid) pid = str(os.getpid()) file(self.pidfile, 'w+').write("%s\n" % pid) def delpid(self): """ Delete the pidfile. """ os.remove(self.pidfile) def start(self, *args): """ Start the daemon. """ # Check for a pidfile to see if the daemon already runs try: pf = file(self.pidfile, 'r') pid = int(pf.read().strip()) pf.close() except IOError: pid = None if pid: message = "pidfile %s already exist. Daemon already running?\n" sys.stderr.write(message % self.pidfile) sys.exit(1) # Start the daemon self.daemonize() self.run(args) def stop(self): """ Stop the daemon. """ # Get the pid from the pidfile try: pf = file(self.pidfile, 'r') pid = int(pf.read().strip()) pf.close() except IOError: pid = None if not pid: message = "pidfile %s does not exist. Daemon not running?\n" sys.stderr.write(message % self.pidfile) return # not an error in a restart # Try killing the daemon process try: while 1: os.kill(pid, signal.SIGTERM) time.sleep(0.1) except OSError, err: err = str(err) if err.find("No such process") > 0: if os.path.exists(self.pidfile): os.remove(self.pidfile) else: print(str(err)) sys.exit(1) def restart(self): """ Restart the daemon. """ self.stop() self.start() def run(self): """ This should be overridden by derived classes. """ def run(command, *args, **options): """ Run an external command. Options: `check`: If True, check the command's output. bail if non-zero. (the default is true unless detach or input is true) `exitcode`: like `check`, but return exitcode instead of bailing. `detach`: If True, run in detached process. `input`: If True, open command for writing stream to, returning the Popen object. """ parts = [command] parts.extend(args) if TEST or DEBUG: print("%s run: %s " (SCRIPT, " ".join(parts))) _check = options.get("check", True) _detach = options.get("detach", False) _input = options.get("input", False) _exitcode = options.get("exitcode", False) if not _check or _detach or _input: if _input: return subprocess.Popen(parts, stdin=subprocess.PIPE) else: # XXX ok with return None ?? subprocess.Popen(parts) else: try: devnull = open('/dev/null', 'w') subprocess.check_call(parts, stdout=devnull, stderr=devnull) return 0 except subprocess.CalledProcessError as exc: if DEBUG: logger.exception(exc) if _exitcode: return exc.returncode else: bail("ERROR: Could not run %s: %s" % (exc.cmd, exc.output), exception=exc) def bail(msg=None, exception=None): """ Abnormal exit. :param msg: optional error message. :type msg: str """ if msg is not None: print("%s: %s" % (SCRIPT, msg)) if exception is not None: traceback.print_exc() exit(1) ## ## OPENVPN ## def get_openvpn_bin(): """ Return the path for either the system openvpn or the one the bundle has put there. """ if os.path.isfile(OPENVPN_SYSTEM_BIN): return OPENVPN_SYSTEM_BIN # the bundle option should be removed from the debian package. if os.path.isfile(OPENVPN_LEAP_BIN): return OPENVPN_LEAP_BIN def parse_openvpn_flags(args): """ Take argument list from the command line and parse it, only allowing some configuration flags. :type args: list """ result = [] try: for flag in split_list(args, "^--"): flag_name = flag[0] if flag_name in ALLOWED_FLAGS: result.append(flag_name) required_params = ALLOWED_FLAGS[flag_name] if required_params: flag_params = flag[1:] if len(flag_params) != len(required_params): print("%s: ERROR: not enough params for %s" % (SCRIPT, flag_name)) return None for param, param_type in zip(flag_params, required_params): if PARAM_FORMATS[param_type](param): result.append(param) else: print("%s: ERROR: Bad argument %s" % (SCRIPT, param)) return None else: print("WARNING: unrecognized openvpn flag %s" % flag_name) return result except Exception as exc: print("%s: ERROR PARSING FLAGS: %s" % (SCRIPT, exc)) if DEBUG: logger.exception(exc) return None def openvpn_start(args): """ Launch openvpn, sanitizing input, and replacing the current process with the openvpn process. :param args: arguments to be passed to openvpn :type args: list """ openvpn_flags = parse_openvpn_flags(args) if openvpn_flags: OPENVPN = get_openvpn_bin() flags = [OPENVPN] + FIXED_FLAGS + openvpn_flags if DEBUG: print("%s: running openvpn with flags:" % (SCRIPT,)) print(flags) # note: first argument to command is ignored, but customarily set to # the command. os.execv(OPENVPN, flags) else: bail('ERROR: could not parse openvpn options') def openvpn_stop(args): """ Stop the openvpn that has likely been launched by bitmask. :param args: arguments to openvpn :type args: list """ plist = get_process_list() OPENVPN_BIN = get_openvpn_bin() found_leap_openvpn = filter( lambda (p, s): s.startswith(OPENVPN_BIN) and LEAPOPENVPN in s, plist) if found_leap_openvpn: pid = found_leap_openvpn[0][0] os.kill(int(pid), signal.SIGTERM) ## ## DNS ## class NameserverSetter(Daemon): """ A daemon that will add leap nameserver inside the tunnel to the system `resolv.conf` """ def run(self, *args): """ Run when daemonized. """ if args: ip_address = args[0] self.set_dns_nameserver(ip_address) def set_dns_nameserver(self, ip_address): """ Add the tunnel DNS server to `resolv.conf` :param ip_address: the ip to add to `resolv.conf` :type ip_address: str """ if os.path.isfile(RESOLVCONF): process = run(RESOLVCONF, "-a", "bitmask", input=True) process.communicate("nameserver %s\n" % ip_address) else: bail("ERROR: package openresolv or resolvconf not installed.") nameserver_setter = NameserverSetter('/tmp/leap-dns-up.pid') class NameserverRestorer(Daemon): """ A daemon that will restore the previous nameservers. """ def run(self): """ Run when daemonized. """ self.restore_dns_nameserver() def restore_dns_nameserver(self): """ Remove tunnel DNS server from `resolv.conf` """ if os.path.isfile(RESOLVCONF): run(RESOLVCONF, "-d", "bitmask") else: print("%s: ERROR: package openresolv " "or resolvconf not installed." % (SCRIPT,)) nameserver_restorer = NameserverRestorer('/tmp/leap-dns-down.pid') ## ## FIREWALL ## def get_gateways(gateways): """ Filter a passed sequence of gateways, returning only the valid ones. :param gateways: a sequence of gateways to filter. :type gateways: iterable :rtype: iterable """ result = filter(is_valid_address, gateways) if not result: bail("ERROR: No valid gateways specified") else: return result def get_default_device(): """ Retrieve the current default network device. :rtype: str """ routes = subprocess.check_output([IP, "route", "show"]) match = re.search("^default .*dev ([^\s]*) .*$", routes, flags=re.M) if match.groups(): return match.group(1) else: bail("Could not find default device") def get_local_network_ipv4(device): """ Get the local ipv4 addres for a given device. :param device: :type device: str """ addresses = cmdcheck([IP, "-o", "address", "show", "dev", device]) match = re.search("^.*inet ([^ ]*) .*$", addresses, flags=re.M) if match.groups(): return match.group(1) else: return None def get_local_network_ipv6(device): """ Get the local ipv6 addres for a given device. :param device: :type device: str """ addresses = cmdcheck([IP, "-o", "address", "show", "dev", device]) match = re.search("^.*inet6 ([^ ]*) .*$", addresses, flags=re.M) if match.groups(): return match.group(1) else: return None def run_iptable_with_check(cmd, *args, **options): """ Run an iptables command checking to see if it should: for --insert: run only if rule does not already exist. for --delete: run only if rule does exist. other commands are run normally. """ if "--insert" in args: check_args = [arg.replace("--insert", "--check") for arg in args] check_code = run(cmd, *check_args, exitcode=True) if check_code != 0: run(cmd, *args, **options) elif "--delete" in args: check_args = [arg.replace("--delete", "--check") for arg in args] check_code = run(cmd, *check_args, exitcode=True) if check_code == 0: run(cmd, *args, **options) else: run(cmd, *args, **options) def iptables(*args, **options): """ Run iptables4 and iptables6. """ ip4tables(*args, **options) ip6tables(*args, **options) def ip4tables(*args, **options): """ Run iptables4 with checks. """ run_iptable_with_check(IPTABLES, *args, **options) def ip6tables(*args, **options): """ Run iptables6 with checks. """ run_iptable_with_check(IP6TABLES, *args, **options) def ipv4_chain_exists(table): """ Check if a given chain exists. :param table: the table to check against :type table: str :rtype: bool """ code = run(IPTABLES, "--list", table, "--numeric", exitcode=True) return code == 0 def ipv6_chain_exists(table): """ Check if a given chain exists. :param table: the table to check against :type table: str :rtype: bool """ code = run(IP6TABLES, "--list", table, "--numeric", exitcode=True) return code == 0 def firewall_start(args): """ Bring up the firewall. :param args: list of gateways, to be sanitized. :type args: list """ default_device = get_default_device() local_network_ipv4 = get_local_network_ipv4(default_device) local_network_ipv6 = get_local_network_ipv6(default_device) gateways = get_gateways(args) # add custom chain "bitmask" if not ipv4_chain_exists(BITMASK_CHAIN): ip4tables("--new-chain", BITMASK_CHAIN) if not ipv6_chain_exists(BITMASK_CHAIN): ip6tables("--new-chain", BITMASK_CHAIN) iptables("--insert", "OUTPUT", "--jump", BITMASK_CHAIN) # reject everything iptables("--insert", BITMASK_CHAIN, "-o", default_device, "--jump", "REJECT") # allow traffic to gateways for gateway in gateways: ip4tables("--insert", BITMASK_CHAIN, "--destination", gateway, "-o", default_device, "--jump", "ACCEPT") # allow traffic to IPs on local network if local_network_ipv4: ip4tables("--insert", BITMASK_CHAIN, "--destination", local_network_ipv4, "-o", default_device, "--jump", "ACCEPT") if local_network_ipv6: ip6tables("--insert", BITMASK_CHAIN, "--destination", local_network_ipv6, "-o", default_device, "--jump", "ACCEPT") # block DNS requests to anyone but the service provider or localhost # when we actually route ipv6, we will need dns rules for it too ip4tables("--insert", BITMASK_CHAIN, "--protocol", "udp", "--dport", "53", "--jump", "REJECT") for allowed_dns in [NAMESERVER, "127.0.0.1", "127.0.1.1"]: ip4tables("--insert", BITMASK_CHAIN, "--protocol", "udp", "--dport", "53", "--destination", allowed_dns, "--jump", "ACCEPT") def firewall_stop(): """ Stop the firewall. """ iptables("--delete", "OUTPUT", "--jump", BITMASK_CHAIN) if ipv4_chain_exists(BITMASK_CHAIN): ip4tables("--flush", BITMASK_CHAIN) ip4tables("--delete-chain", BITMASK_CHAIN) if ipv6_chain_exists(BITMASK_CHAIN): ip6tables("--flush", BITMASK_CHAIN) ip6tables("--delete-chain", BITMASK_CHAIN) ## ## MAIN ## def main(): if len(sys.argv) >= 3: command = "_".join(sys.argv[1:3]) args = sys.argv[3:] if command == "openvpn_start": openvpn_start(args) elif command == "openvpn_stop": openvpn_stop(args) elif command == "firewall_start": try: firewall_start(args) nameserver_setter.start(NAMESERVER) except Exception as ex: nameserver_restorer.start() firewall_stop() bail("ERROR: could not start firewall", ex) elif command == "firewall_stop": try: firewall_stop() nameserver_restorer.start() except Exception as ex: bail("ERROR: could not stop firewall", ex) else: bail("ERROR: No such command") else: bail("ERROR: No such command") if __name__ == "__main__": if DEBUG: logger.debug(" ".join(sys.argv)) main() print("%s: done" % (SCRIPT,)) exit(0)