#!/usr/bin/python
# -*- coding: utf-8 -*-
#
# Copyright (C) 2014 LEAP
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
"""
This is a privileged helper script for safely running certain commands as root.
It should only be called by the Bitmask application.

USAGE:
  bitmask-root firewall stop
  bitmask-root firewall start GATEWAY1 GATEWAY2 ...
  bitmask-root openvpn stop
  bitmask-root openvpn start CONFIG1 CONFIG1 ...

All actions return exit code 0 for success, non-zero otherwise.

The `openvpn start` action is special: it calls exec on openvpn and replaces
the current process.
"""
# TODO should be tested with python3, which can be the default on some distro.
from __future__ import print_function
import atexit
import os
import re
import signal
import socket
import subprocess
import sys
import time
import traceback


cmdcheck = subprocess.check_output

##
## CONSTANTS
##

SCRIPT = "bitmask-root"
NAMESERVER = "10.42.0.1"
BITMASK_CHAIN = "bitmask"

IP = "/bin/ip"
IPTABLES = "/sbin/iptables"
IP6TABLES = "/sbin/ip6tables"

RESOLVCONF_SYSTEM_BIN = "/sbin/resolvconf"
RESOLVCONF_LEAP_BIN = "/usr/local/sbin/leap-resolvconf"

OPENVPN_USER = "nobody"
OPENVPN_GROUP = "nogroup"
LEAPOPENVPN = "LEAPOPENVPN"
OPENVPN_SYSTEM_BIN = "/usr/sbin/openvpn"  # Debian location
OPENVPN_LEAP_BIN = "/usr/sbin/leap-openvpn"  # installed by bundle


"""
The path to the script to update resolv.conf
"""
# XXX We have to check if we have a valid resolvconf, and use
# the old resolv-update if not.
LEAP_UPDATE_RESOLVCONF_FILE = "/etc/leap/update-resolv-conf"
LEAP_RESOLV_UPDATE = "/etc/leap/resolv-update"

FIXED_FLAGS = [
    "--setenv", "LEAPOPENVPN", "1",
    "--nobind",
    "--client",
    "--dev", "tun",
    "--tls-client",
    "--remote-cert-tls", "server",
    "--management-signal",
    "--script-security", "1",
    "--user", "nobody",
    "--group", "nogroup",
    "--remap-usr1", "SIGTERM",
]

ALLOWED_FLAGS = {
    "--remote": ["IP", "NUMBER", "PROTO"],
    "--tls-cipher": ["CIPHER"],
    "--cipher": ["CIPHER"],
    "--auth": ["CIPHER"],
    "--management": ["DIR", "UNIXSOCKET"],
    "--management-client-user": ["USER"],
    "--cert": ["FILE"],
    "--key": ["FILE"],
    "--ca": ["FILE"]
}

PARAM_FORMATS = {
    "NUMBER": lambda s: re.match("^\d+$", s),
    "PROTO": lambda s: re.match("^(tcp|udp)$", s),
    "IP": lambda s: is_valid_address(s),
    "CIPHER": lambda s: re.match("^[A-Z0-9-]+$", s),
    "USER": lambda s: re.match(
        "^[a-zA-Z0-9_\.\@][a-zA-Z0-9_\-\.\@]*\$?$", s),  # IEEE Std 1003.1-2001
    "FILE": lambda s: os.path.isfile(s),
    "DIR": lambda s: os.path.isdir(os.path.split(s)[0]),
    "UNIXSOCKET": lambda s: s == "unix"
}


DEBUG = os.getenv("DEBUG")
TEST = os.getenv("TEST")

if DEBUG:
    import logging
    formatter = logging.Formatter(
        "%(asctime)s - %(name)s - %(levelname)s - %(message)s")
    ch = logging.StreamHandler()
    ch.setLevel(logging.DEBUG)
    ch.setFormatter(formatter)
    logger = logging.getLogger(__name__)
    logger.setLevel(logging.DEBUG)
    logger.addHandler(ch)

##
## UTILITY
##


def is_valid_address(value):
    """
    Validate that the passed ip is a valid IP address.

    :param value: the value to be validated
    :type value: str
    :rtype: bool
    """
    try:
        socket.inet_aton(value)
        return True
    except Exception:
        print("%s: ERROR: MALFORMED IP: %s!" % (SCRIPT, value))
        return False


def has_system_resolvconf():
    """
    Return True if resolvconf is found in the system.

    :rtype: bool
    """
    return os.path.isfile(RESOLVCONF)


def has_valid_update_resolvconf():
    """
    Return True if a valid update-resolv-conf script is found in the system.

    :rtype: bool
    """
    return os.path.isfile(LEAP_UPDATE_RESOLVCONF_FILE)


def has_valid_leap_resolv_update():
    """
    Return True if a valid resolv-update script is found in the system.

    :rtype: bool
    """
    return os.path.isfile(LEAP_RESOLV_UPDATE)


def split_list(_list, regex):
    """
    Split a list based on a regex:
    e.g. split_list(["xx", "yy", "x1", "zz"], "^x") => [["xx", "yy"], ["x1",
    "zz"]]

    :param _list: the list to be split.
    :type _list: list
    :param regex: the regex expression to filter with.
    :type regex: str

    :rtype: list
    """
    if not hasattr(regex, "match"):
        regex = re.compile(regex)
    result = []
    i = 0
    if not _list:
        return result
    while True:
        if regex.match(_list[i]):
            result.append([])
            while True:
                result[-1].append(_list[i])
                i += 1
                if i >= len(_list) or regex.match(_list[i]):
                    break
        else:
            i += 1
        if i >= len(_list):
            break
    return result


def get_process_list():
    """
    Get a process list by reading `/proc` filesystem.

    :return: a list of tuples, each containing pid and command string.
    :rtype: tuple if lists
    """
    res = []
    pids = [pid for pid in os.listdir('/proc') if pid.isdigit()]

    for pid in pids:
        try:
            res.append((pid, open(
                os.path.join(
                    '/proc', pid, 'cmdline'), 'rb').read()))
        except IOError:  # proc has already terminated
            continue
    return filter(None, res)


class Daemon(object):
    """
    A generic daemon class.
    """
    def __init__(self, pidfile, stdin='/dev/null',
                 stdout='/dev/null', stderr='/dev/null'):
        self.stdin = stdin
        self.stdout = stdout
        self.stderr = stderr
        self.pidfile = pidfile

    def daemonize(self):
        """
        Do the UNIX double-fork magic, see Stevens' "Advanced
        Programming in the UNIX Environment" for details (ISBN 0201563177)
        http://www.erlenstar.demon.co.uk/unix/faq_2.html#SEC16
        """
        try:
            pid = os.fork()
            if pid > 0:
                # exit first parent
                sys.exit(0)
        except OSError, e:
            sys.stderr.write(
                "fork #1 failed: %d (%s)\n" % (e.errno, e.strerror))
            sys.exit(1)

        # decouple from parent environment
        os.chdir("/")
        os.setsid()
        os.umask(0)

        # do second fork
        try:
            pid = os.fork()
            if pid > 0:
                # exit from second parent
                sys.exit(0)
        except OSError, e:
            sys.stderr.write(
                "fork #2 failed: %d (%s)\n" % (e.errno, e.strerror))
            sys.exit(1)

        # redirect standard file descriptors
        sys.stdout.flush()
        sys.stderr.flush()
        si = file(self.stdin, 'r')
        so = file(self.stdout, 'a+')
        se = file(self.stderr, 'a+', 0)
        os.dup2(si.fileno(), sys.stdin.fileno())
        os.dup2(so.fileno(), sys.stdout.fileno())
        os.dup2(se.fileno(), sys.stderr.fileno())

        # write pidfile
        atexit.register(self.delpid)
        pid = str(os.getpid())
        file(self.pidfile, 'w+').write("%s\n" % pid)

    def delpid(self):
        """
        Delete the pidfile.
        """
        os.remove(self.pidfile)

    def start(self, *args):
        """
        Start the daemon.
        """
        # Check for a pidfile to see if the daemon already runs
        try:
            pf = file(self.pidfile, 'r')
            pid = int(pf.read().strip())
            pf.close()
        except IOError:
            pid = None

        if pid:
            message = "pidfile %s already exist. Daemon already running?\n"
            sys.stderr.write(message % self.pidfile)
            sys.exit(1)

        # Start the daemon
        self.daemonize()
        self.run(args)

    def stop(self):
        """
        Stop the daemon.
        """
        # Get the pid from the pidfile
        try:
            pf = file(self.pidfile, 'r')
            pid = int(pf.read().strip())
            pf.close()
        except IOError:
            pid = None

        if not pid:
            message = "pidfile %s does not exist. Daemon not running?\n"
            sys.stderr.write(message % self.pidfile)
            return  # not an error in a restart

        # Try killing the daemon process
        try:
            while 1:
                os.kill(pid, signal.SIGTERM)
                time.sleep(0.1)
        except OSError, err:
            err = str(err)
            if err.find("No such process") > 0:
                if os.path.exists(self.pidfile):
                    os.remove(self.pidfile)
            else:
                print(str(err))
                sys.exit(1)

    def restart(self):
        """
        Restart the daemon.
        """
        self.stop()
        self.start()

    def run(self):
        """
        This should  be overridden by derived classes.
        """


def run(command, *args, **options):
    """
    Run an external command.

    Options:

      `check`: If True, check the command's output. bail if non-zero. (the
               default is true unless detach or input is true)
      `exitcode`: like `check`, but return exitcode instead of bailing.
      `detach`: If True, run in detached process.
      `input`: If True, open command for writing stream to, returning the Popen
               object.
    """
    parts = [command]
    parts.extend(args)
    if TEST or DEBUG:
        print("%s run: %s " (SCRIPT, " ".join(parts)))

    _check = options.get("check", True)
    _detach = options.get("detach", False)
    _input = options.get("input", False)
    _exitcode = options.get("exitcode", False)

    if not _check or _detach or _input:
        if _input:
            return subprocess.Popen(parts, stdin=subprocess.PIPE)
        else:
            # XXX ok with return None ??
            subprocess.Popen(parts)
    else:
        try:
            devnull = open('/dev/null', 'w')
            subprocess.check_call(parts, stdout=devnull, stderr=devnull)
            return 0
        except subprocess.CalledProcessError as exc:
            if DEBUG:
                logger.exception(exc)
            if _exitcode:
                return exc.returncode
            else:
                bail("ERROR: Could not run %s: %s" % (exc.cmd, exc.output),
                     exception=exc)


def bail(msg=None, exception=None):
    """
    Abnormal exit.

    :param msg: optional error message.
    :type msg: str
    """
    if msg is not None:
        print("%s: %s" % (SCRIPT, msg))
    if exception is not None:
        traceback.print_exc()
    exit(1)

##
## OPENVPN
##


def get_openvpn_bin():
    """
    Return the path for either the system openvpn or the one the
    bundle has put there.
    """
    if os.path.isfile(OPENVPN_SYSTEM_BIN):
        return OPENVPN_SYSTEM_BIN

    # the bundle option should be removed from the debian package.
    if os.path.isfile(OPENVPN_LEAP_BIN):
        return OPENVPN_LEAP_BIN


def parse_openvpn_flags(args):
    """
    Take argument list from the command line and parse it, only allowing some
    configuration flags.

    :type args: list
    """
    result = []
    try:
        for flag in split_list(args, "^--"):
            flag_name = flag[0]
            if flag_name in ALLOWED_FLAGS:
                result.append(flag_name)
                required_params = ALLOWED_FLAGS[flag_name]
                if required_params:
                    flag_params = flag[1:]
                    if len(flag_params) != len(required_params):
                        print("%s: ERROR: not enough params for %s" %
                              (SCRIPT, flag_name))
                        return None
                    for param, param_type in zip(flag_params, required_params):
                        if PARAM_FORMATS[param_type](param):
                            result.append(param)
                        else:
                            print("%s: ERROR: Bad argument %s" %
                                  (SCRIPT, param))
                            return None
            else:
                print("WARNING: unrecognized openvpn flag %s" % flag_name)
        return result
    except Exception as exc:
        print("%s: ERROR PARSING FLAGS: %s" % (SCRIPT, exc))
        if DEBUG:
            logger.exception(exc)
        return None


def openvpn_start(args):
    """
    Launch openvpn, sanitizing input, and replacing the current process with
    the openvpn process.

    :param args: arguments to be passed to openvpn
    :type args: list
    """
    openvpn_flags = parse_openvpn_flags(args)
    if openvpn_flags:
        OPENVPN = get_openvpn_bin()
        flags = [OPENVPN] + FIXED_FLAGS + openvpn_flags
        if DEBUG:
            print("%s: running openvpn with flags:" % (SCRIPT,))
            print(flags)
        # note: first argument to command is ignored, but customarily set to
        # the command.
        os.execv(OPENVPN, flags)
    else:
        bail('ERROR: could not parse openvpn options')


def openvpn_stop(args):
    """
    Stop the openvpn that has likely been launched by bitmask.

    :param args: arguments to openvpn
    :type args: list
    """
    plist = get_process_list()
    OPENVPN_BIN = get_openvpn_bin()
    found_leap_openvpn = filter(
        lambda (p, s): s.startswith(OPENVPN_BIN) and LEAPOPENVPN in s,
        plist)

    if found_leap_openvpn:
        pid = found_leap_openvpn[0][0]
        os.kill(int(pid), signal.SIGTERM)

##
## DNS
##


def get_resolvconf_bin():
    """
    Return the path for either the system resolvconf or the one the
    bundle has put there.
    """
    if os.path.isfile(RESOLVCONF_SYSTEM_BIN):
        return RESOLVCONF_SYSTEM_BIN

    # the bundle option should be removed from the debian package.
    if os.path.isfile(RESOLVCONF_LEAP_BIN):
        return RESOLVCONF_LEAP_BIN

RESOLVCONF = get_resolvconf_bin()


class NameserverSetter(Daemon):
    """
    A daemon that will add leap nameserver inside the tunnel
    to the system `resolv.conf`
    """

    def run(self, *args):
        """
        Run when daemonized.
        """
        if args:
            ip_address = args[0]
            self.set_dns_nameserver(ip_address)

    def set_dns_nameserver(self, ip_address):
        """
        Add the tunnel DNS server to `resolv.conf`

        :param ip_address: the ip to add to `resolv.conf`
        :type ip_address: str
        """
        if os.path.isfile(RESOLVCONF):
            process = run(RESOLVCONF, "-a", "bitmask", input=True)
            process.communicate("nameserver %s\n" % ip_address)
        else:
            bail("ERROR: package openresolv or resolvconf not installed.")

nameserver_setter = NameserverSetter('/tmp/leap-dns-up.pid')


class NameserverRestorer(Daemon):
    """
    A daemon that will restore the previous nameservers.
    """

    def run(self, *args):
        """
        Run when daemonized.
        """
        self.restore_dns_nameserver()

    def restore_dns_nameserver(self):
        """
        Remove tunnel DNS server from `resolv.conf`
        """
        if os.path.isfile(RESOLVCONF):
            run(RESOLVCONF, "-d", "bitmask")
        else:
            print("%s: ERROR: package openresolv "
                  "or resolvconf not installed." %
                  (SCRIPT,))

nameserver_restorer = NameserverRestorer('/tmp/leap-dns-down.pid')


##
## FIREWALL
##


def get_gateways(gateways):
    """
    Filter a passed sequence of gateways, returning only the valid ones.

    :param gateways: a sequence of gateways to filter.
    :type gateways: iterable
    :rtype: iterable
    """
    result = filter(is_valid_address, gateways)
    if not result:
        bail("ERROR: No valid gateways specified")
    else:
        return result


def get_default_device():
    """
    Retrieve the current default network device.

    :rtype: str
    """
    routes = subprocess.check_output([IP, "route", "show"])
    match = re.search("^default .*dev ([^\s]*) .*$", routes, flags=re.M)
    if match.groups():
        return match.group(1)
    else:
        bail("Could not find default device")


def get_local_network_ipv4(device):
    """
    Get the local ipv4 addres for a given device.

    :param device:
    :type device: str
    """
    addresses = cmdcheck([IP, "-o", "address", "show", "dev", device])
    match = re.search("^.*inet ([^ ]*) .*$", addresses, flags=re.M)
    if match.groups():
        return match.group(1)
    else:
        return None


def get_local_network_ipv6(device):
    """
    Get the local ipv6 addres for a given device.

    :param device:
    :type device: str
    """
    addresses = cmdcheck([IP, "-o", "address", "show", "dev", device])
    match = re.search("^.*inet6 ([^ ]*) .*$", addresses, flags=re.M)
    if match.groups():
        return match.group(1)
    else:
        return None


def run_iptable_with_check(cmd, *args, **options):
    """
    Run an iptables command checking to see if it should:
      for --append: run only if rule does not already exist.
      for --insert: run only if rule does not already exist.
      for --delete: run only if rule does exist.
    other commands are run normally.
    """
    if "--insert" in args:
        check_args = [arg.replace("--insert", "--check") for arg in args]
        check_code = run(cmd, *check_args, exitcode=True)
        if check_code != 0:
            run(cmd, *args, **options)
    elif "--append" in args:
        check_args = [arg.replace("--append", "--check") for arg in args]
        check_code = run(cmd, *check_args, exitcode=True)
        if check_code != 0:
            run(cmd, *args, **options)
    elif "--delete" in args:
        check_args = [arg.replace("--delete", "--check") for arg in args]
        check_code = run(cmd, *check_args, exitcode=True)
        if check_code == 0:
            run(cmd, *args, **options)
    else:
        run(cmd, *args, **options)


def iptables(*args, **options):
    """
    Run iptables4 and iptables6.
    """
    ip4tables(*args, **options)
    ip6tables(*args, **options)


def ip4tables(*args, **options):
    """
    Run iptables4 with checks.
    """
    run_iptable_with_check(IPTABLES, *args, **options)


def ip6tables(*args, **options):
    """
    Run iptables6 with checks.
    """
    run_iptable_with_check(IP6TABLES, *args, **options)


def ipv4_chain_exists(table):
    """
    Check if a given chain exists.

    :param table: the table to check against
    :type table: str
    :rtype: bool
    """
    code = run(IPTABLES, "--list", table, "--numeric", exitcode=True)
    return code == 0


def ipv6_chain_exists(table):
    """
    Check if a given chain exists.

    :param table: the table to check against
    :type table: str
    :rtype: bool
    """
    code = run(IP6TABLES, "--list", table, "--numeric", exitcode=True)
    return code == 0


def firewall_start(args):
    """
    Bring up the firewall.

    :param args: list of gateways, to be sanitized.
    :type args: list
    """
    default_device = get_default_device()
    local_network_ipv4 = get_local_network_ipv4(default_device)
    local_network_ipv6 = get_local_network_ipv6(default_device)
    gateways = get_gateways(args)

    # add custom chain "bitmask" to front of OUTPUT chain
    if not ipv4_chain_exists(BITMASK_CHAIN):
        ip4tables("--new-chain", BITMASK_CHAIN)
    if not ipv6_chain_exists(BITMASK_CHAIN):
        ip6tables("--new-chain", BITMASK_CHAIN)
    iptables("--insert", "OUTPUT", "--jump", BITMASK_CHAIN)

    # allow DNS over VPN
    for allowed_dns in [NAMESERVER, "127.0.0.1", "127.0.1.1"]:
        ip4tables("--append", BITMASK_CHAIN, "--protocol", "udp",
                  "--dport", "53", "--destination", allowed_dns,
                  "--jump", "ACCEPT")

    # block DNS requests to anyone but the service provider or localhost
    # (when we actually route ipv6, we will need DNS rules for it too)
    ip4tables("--append", BITMASK_CHAIN, "--protocol", "udp", "--dport", "53",
              "--jump", "REJECT")

    # allow traffic to IPs on local network
    if local_network_ipv4:
        ip4tables("--append", BITMASK_CHAIN,
                  "--destination", local_network_ipv4, "-o", default_device,
                  "--jump", "ACCEPT")
        # allow multicast Simple Service Discovery Protocol
        ip4tables("--append", BITMASK_CHAIN,
                  "--protocol", "udp", "--destination", "239.255.255.250", "--dport", "1900",
                  "-o", default_device, "--jump", "RETURN")
        # allow multicast Bonjour/mDNS
        ip4tables("--append", BITMASK_CHAIN,
                  "--protocol", "udp", "--destination", "224.0.0.251", "--dport", "5353",
                  "-o", default_device, "--jump", "RETURN")
    if local_network_ipv6:
        ip6tables("--append", BITMASK_CHAIN,
                  "--destination", local_network_ipv6, "-o", default_device,
                  "--jump", "ACCEPT")
        # allow multicast Simple Service Discovery Protocol
        ip6tables("--append", BITMASK_CHAIN,
                  "--protocol", "udp", "--destination", "FF05::C", "--dport", "1900",
                  "-o", default_device, "--jump", "RETURN")
        # allow multicast Bonjour/mDNS
        ip6tables("--append", BITMASK_CHAIN,
                  "--protocol", "udp", "--destination", "FF02::FB", "--dport", "5353",
                  "-o", default_device, "--jump", "RETURN")

    # allow ipv4 traffic to gateways
    for gateway in gateways:
        ip4tables("--append", BITMASK_CHAIN, "--destination", gateway,
                  "-o", default_device, "--jump", "ACCEPT")

    # log rejected packets to syslog
    if DEBUG:
        iptables("--append", BITMASK_CHAIN, "-o", default_device,
                 "--jump", "LOG", "--log-prefix", "iptables denied: ", "--log-level", "7")

    # for now, ensure all other ipv6 packets get rejected (regardless of device)
    # (not sure why, but "-p any" doesn't work)
    ip6tables("--append", BITMASK_CHAIN, "-p", "tcp", "--jump", "REJECT")
    ip6tables("--append", BITMASK_CHAIN, "-p", "udp", "--jump", "REJECT")

    # reject all other ipv4 sent over the default device
    ip4tables("--append", BITMASK_CHAIN, "-o", default_device, "--jump", "REJECT")

def firewall_stop():
    """
    Stop the firewall.
    """
    iptables("--delete", "OUTPUT", "--jump", BITMASK_CHAIN)
    if ipv4_chain_exists(BITMASK_CHAIN):
        ip4tables("--flush", BITMASK_CHAIN)
        ip4tables("--delete-chain", BITMASK_CHAIN)
    if ipv6_chain_exists(BITMASK_CHAIN):
        ip6tables("--flush", BITMASK_CHAIN)
        ip6tables("--delete-chain", BITMASK_CHAIN)

##
## MAIN
##


def main():
    if len(sys.argv) >= 3:
        command = "_".join(sys.argv[1:3])
        args = sys.argv[3:]

        if command == "openvpn_start":
            openvpn_start(args)

        elif command == "openvpn_stop":
            openvpn_stop(args)

        elif command == "firewall_start":
            try:
                firewall_start(args)
                nameserver_setter.start(NAMESERVER)
            except Exception as ex:
                nameserver_restorer.start()
                firewall_stop()
                bail("ERROR: could not start firewall", ex)

        elif command == "firewall_stop":
            try:
                firewall_stop()
                nameserver_restorer.start()
            except Exception as ex:
                bail("ERROR: could not stop firewall", ex)

        elif command == "firewall_isup":
            if ipv4_chain_exists(BITMASK_CHAIN):
                print("%s: INFO: bitmask firewall is up" % (SCRIPT,))
            else:
                bail("INFO: bitmask firewall is down")

        else:
            bail("ERROR: No such command")
    else:
        bail("ERROR: No such command")

if __name__ == "__main__":
    if DEBUG:
        logger.debug(" ".join(sys.argv))
    main()
    print("%s: done" % (SCRIPT,))
    exit(0)