#!/usr/bin/python2
# -*- coding: utf-8 -*-
#
# Copyright (C) 2014 LEAP
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
"""
This is a privileged helper script for safely running certain commands as root.
It should only be called by the Bitmask application.

USAGE:
  bitmask-root firewall stop
  bitmask-root firewall start GATEWAY1 GATEWAY2 ...
  bitmask-root openvpn stop
  bitmask-root openvpn start CONFIG1 CONFIG1 ...
"""
# TODO should be tested with python3, which can be the default on some distro.

from __future__ import print_function
import os
import subprocess
import socket
import sys
import re

##
## CONSTANTS
##

OPENVPN = "/usr/sbin/openvpn"
IPTABLES = "/sbin/iptables"
IP6TABLES = "/sbin/ip6tables"
UPDATE_RESOLV_CONF = "/etc/openvpn/update-resolv-conf"

FIXED_FLAGS = [
    "--setenv", "LEAPOPENVPN", "1",
    "--nobind",
    "--client",
    "--dev", "tun",
    "--tls-client",
    "--remote-cert-tls", "server",
    "--management-signal",
    "--management", "/tmp/openvpn.socket", "unix",
    "--up", UPDATE_RESOLV_CONF,
    "--down", UPDATE_RESOLV_CONF,
    "--script-security", "2"
]

ALLOWED_FLAGS = {
    "--remote": ["IP", "NUMBER", "PROTO"],
    "--tls-cipher": ["CIPHER"],
    "--cipher": ["CIPHER"],
    "--auth": ["CIPHER"],
    "--management-client-user": ["USER"],
    "--cert": ["FILE"],
    "--key": ["FILE"],
    "--ca": ["FILE"]
}

PARAM_FORMATS = {
    "NUMBER": lambda s: re.match("^\d+$", s),
    "PROTO":  lambda s: re.match("^(tcp|udp)$", s),
    "IP":     lambda s: is_valid_address(s),
    "CIPHER": lambda s: re.match("^[A-Z0-9-]+$", s),
    "USER":   lambda s: re.match("^[a-zA-Z0-9_\.\@][a-zA-Z0-9_\-\.\@]*\$?$", s), # IEEE Std 1003.1-2001
    "FILE":   lambda s: os.path.isfile(s)
}

DEBUG=os.getenv("DEBUG")
if DEBUG:
    import logging
    formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
    ch = logging.StreamHandler()
    ch.setLevel(logging.DEBUG)
    ch.setFormatter(formatter)
    logger = logging.getLogger(__name__)
    logger.setLevel(logging.DEBUG)
    logger.addHandler(ch)
    logger.debug(" ".join(sys.argv))

##
## UTILITY
##

def is_valid_address(value):
    """
    Validate that the passed ip is a valid IP address.

    :param value: the value to be validated
    :type value: str
    :rtype: bool
    """
    try:
        socket.inet_aton(value)
        return True
    except Exception:
        print "MALFORMED IP: %s!" % value
        return False

def split_list(list, regex):
    """
    Splits a list based on a regex:
    e.g. split_list(["xx", "yy", "x1", "zz"], "^x") => [["xx", "yy"], ["x1", "zz"]]

    :param list: the list to be split.
    :type list: list
    :rtype: list
    """
    if not hasattr(regex, "match"):
        regex = re.compile(regex)
    result = []
    i = 0
    while True:
        if regex.match(list[i]):
            result.append([])
            while True:
                result[-1].append(list[i])
                i += 1
                if i >= len(list) or regex.match(list[i]):
                    break
        else:
            i += 1
        if i >= len(list):
            break
    return result

# i think this is not needed with shell=False
#def sanify(command, *args):
#    return [command] + [pipes.quote(a) for a in args]

def run(command, *args, **options):
    parts = [command]
    parts.extend(args)
    if DEBUG:
        print "run: " + " ".join(parts)
    if options.get("check", True) == False or options.get("detach", False) == True:
        subprocess.Popen(parts)
    else:
        try:
            devnull = open('/dev/null', 'w')
            subprocess.check_call(parts, stdout=devnull, stderr=devnull)
            return 0
        except subprocess.CalledProcessError as ex:
            if options.get("exitcode", False) == True:
                return ex.returncode
            else:
                bail("Could not run %s: %s" % (ex.cmd, ex.output))

##
## OPENVPN
##

def parse_openvpn_flags(args):
    """
    takes argument list from the command line and parses it, only allowing some configuration flags.
    """
    result = []
    try:
        for flag in split_list(args, "^--"):
            flag_name = flag[0]
            if ALLOWED_FLAGS.has_key(flag_name):
                result.append(flag_name)
                required_params = ALLOWED_FLAGS[flag_name]
                if len(required_params) > 0:
                    flag_params = flag[1:]
                    if len(flag_params) != len(required_params):
                        print "ERROR: not enough params for %s" % flag_name
                        return None
                    for param, param_type in zip(flag_params, required_params):
                        if PARAM_FORMATS[param_type](param):
                            result.append(param)
                        else:
                            print "ERROR: Bad argument %s" % param
                            return None
            else:
                print "WARNING: unrecognized openvpn flag %s" % flag_name
        return result
    except Exception as ex:
        print ex
        return None


def openvpn_start(args):
    openvpn_flags = parse_openvpn_flags(args)
    if openvpn_flags:
        flags = FIXED_FLAGS + openvpn_flags
        run(OPENVPN, *flags, detach=True)
    else:
        bail('ERROR: could not parse openvpn options')

def openvpn_stop(args):
    print "stop"

##
## FIREWALL
##

def get_gateways(gateways):
    result = [gateway for gateway in gateways if is_valid_address(gateway)]
    if not len(result):
        bail("No valid gateways specified")
    else:
        return result

def get_default_device():
    routes = subprocess.check_output([IP, "route", "show"])
    match = re.search("^default .*dev ([^\s]*) .*$", routes, flags=re.M)
    if len(match.groups()) >= 1:
      return match.group(1)
    else:
      bail("could not find default device")

def get_local_network_ipv4(device):
    addresses = subprocess.check_output([IP, "-o", "address", "show", "dev", device])
    match = re.search("^.*inet ([^ ]*) .*$", addresses, flags=re.M)
    if len(match.groups()) >= 1:
      return match.group(1)
    else:
      return None

def get_local_network_ipv6(device):
    addresses = subprocess.check_output([IP, "-o", "address", "show", "dev", device])
    match = re.search("^.*inet6 ([^ ]*) .*$", addresses, flags=re.M)
    if len(match.groups()) >= 1:
      return match.group(1)
    else:
      return None

def run_iptable_with_check(cmd, *args, **options):
    """
    runs an iptables command checking to see if it should:
      for --insert: run only if rule does not already exist.
      for --delete: run only if rule does exist.
    other commands are run normally.
    """
    if "--insert" in args:
        check_args = [arg.replace("--insert", "--check") for arg in args]
        check_code = run(cmd, *check_args, exitcode=True)
        if check_code != 0:
            run(cmd, *args, **options)
    elif "--delete" in args:
        check_args = [arg.replace("--delete", "--check") for arg in args]
        check_code = run(cmd, *check_args, exitcode=True)
        if check_code == 0:
            run(cmd, *args, **options)
    else:
        run(cmd, *args, **options)

def iptables(*args, **options):
    ip4tables(*args, **options)
    ip6tables(*args, **options)

def ip4tables(*args, **options):
    run_iptable_with_check(IPTABLES, *args, **options)

def ip6tables(*args, **options):
    run_iptable_with_check(IP6TABLES, *args, **options)

def ipv4_chain_exists(table):
    code = run(IPTABLES, "--list", table, "--numeric", exitcode=True)
    return code == 0

def ipv6_chain_exists(table):
    code = run(IP6TABLES, "--list", table, "--numeric", exitcode=True)
    return code == 0

def firewall_start(args):
    default_device     = get_default_device()
    local_network_ipv4 = get_local_network_ipv4(default_device)
    local_network_ipv6 = get_local_network_ipv6(default_device)
    gateways           = get_gateways(args)

    # add custom chain "bitmask"
    if not ipv4_chain_exists("bitmask"):
        ip4tables("--new-chain", "bitmask")
    if not ipv6_chain_exists("bitmask"):
        ip6tables("--new-chain", "bitmask")
    iptables("--insert", "OUTPUT", "--jump", "bitmask")

    # reject everything
    iptables("--insert", "bitmask", "-o", default_device, "--jump", "REJECT")

    # allow traffic to gateways
    for gateway in gateways:
        ip4tables("--insert", "bitmask", "--destination", gateway, "-o", default_device, "--jump", "ACCEPT")

    # allow traffic to IPs on local network
    if local_network_ipv4:
        ip4tables("--insert", "bitmask", "--destination", local_network_ipv4, "-o", default_device, "--jump", "ACCEPT")
    if local_network_ipv6:
        ip6tables("--insert", "bitmask", "--destination", local_network_ipv6, "-o", default_device, "--jump", "ACCEPT")

    # block DNS requests to anyone but the service provider or localhost
    ip4tables("--insert", "bitmask", "--protocol", "udp", "--dport", "53", "--jump", "REJECT")
    for allowed_dns in gateways + ["127.0.0.1","127.0.1.1"]:
        ip4tables("--insert", "bitmask", "--protocol", "udp", "--dport", "53", "--destination", allowed_dns, "--jump", "ACCEPT")

def firewall_stop(args):
    iptables("--delete", "OUTPUT", "--jump", "bitmask")
    if ipv4_chain_exists("bitmask"):
        ip4tables("--flush", "bitmask")
        ip4tables("--delete-chain", "bitmask")
    if ipv6_chain_exists("bitmask"):
        ip6tables("--flush", "bitmask")
        ip6tables("--delete-chain", "bitmask")


def bail(msg=""):
    if msg:
        print(msg)
    exit(1)

def main():
    if len(sys.argv) >= 3:
        command = "_".join(sys.argv[1:3])
        args = sys.argv[3:]
        if command == "openvpn_start":
            openvpn_start(args)
        elif command == "openvpn_stop":
            openvpn_stop(args)
        elif command == "firewall_start":
            firewall_start(args)
        elif command == "firewall_stop":
            firewall_stop(args)
        else:
            bail("no such command")
    else:
        bail("no such command")

if __name__ == "__main__":
    main()
    print "done"
    exit(0)