#!/usr/bin/python2.7
# -*- coding: utf-8 -*-
#
# Copyright (C) 2014 LEAP
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
"""
This is a privileged helper script for safely running certain commands as root.
It should only be called by the Bitmask application.

USAGE:
  bitmask-root firewall stop
  bitmask-root firewall start [restart] GATEWAY1 GATEWAY2 ...
  bitmask-root openvpn stop
  bitmask-root openvpn start CONFIG1 CONFIG1 ...
  bitmask-root fw-email stop
  bitmask-root fw-email start uid

All actions return exit code 0 for success, non-zero otherwise.

The `openvpn start` action is special: it calls exec on openvpn and replaces
the current process. If the `restart` parameter is passed, the firewall will
not be teared down in the case of an error during launch.
"""
# TODO should be tested with python3, which can be the default on some distro.
from __future__ import print_function
import os
import re
import signal
import socket
import syslog
import subprocess
import sys
import traceback

cmdcheck = subprocess.check_output

#
# CONSTANTS
#

VERSION = "3"
SCRIPT = "bitmask-root"
NAMESERVER = "10.42.0.1"
BITMASK_CHAIN = "bitmask"
BITMASK_CHAIN_NAT_OUT = "bitmask"
BITMASK_CHAIN_NAT_POST = "bitmask_postrouting"
BITMASK_CHAIN_EMAIL = "bitmask_email"
BITMASK_CHAIN_EMAIL_OUT = "bitmask_email_output"
LOCAL_INTERFACE = "lo"
IMAP_PORT = "1984"
SMTP_PORT = "2013"

IP = "/bin/ip"
IPTABLES = "/sbin/iptables"
IP6TABLES = "/sbin/ip6tables"

OPENVPN_USER = "nobody"
OPENVPN_GROUP = "nogroup"
LEAPOPENVPN = "LEAPOPENVPN"
OPENVPN_SYSTEM_BIN = "/usr/sbin/openvpn"  # Debian location
OPENVPN_LEAP_BIN = "/usr/local/sbin/leap-openvpn"  # installed by bundle

FIXED_FLAGS = [
    "--setenv", "LEAPOPENVPN", "1",
    "--nobind",
    "--client",
    "--dev", "tun",
    "--tls-client",
    "--remote-cert-tls", "server",
    "--management-signal",
    "--script-security", "1",
    "--user", "nobody",
    "--group", "nogroup",
    "--remap-usr1", "SIGTERM",
]

ALLOWED_FLAGS = {
    "--remote": ["IP", "NUMBER", "PROTO"],
    "--tls-cipher": ["CIPHER"],
    "--cipher": ["CIPHER"],
    "--auth": ["CIPHER"],
    "--management": ["DIR", "UNIXSOCKET"],
    "--management-client-user": ["USER"],
    "--cert": ["FILE"],
    "--key": ["FILE"],
    "--ca": ["FILE"]
}

PARAM_FORMATS = {
    "NUMBER": lambda s: re.match("^\d+$", s),
    "PROTO": lambda s: re.match("^(tcp|udp)$", s),
    "IP": lambda s: is_valid_address(s),
    "CIPHER": lambda s: re.match("^[A-Z0-9-]+$", s),
    "USER": lambda s: re.match(
        "^[a-zA-Z0-9_\.\@][a-zA-Z0-9_\-\.\@]*\$?$", s),  # IEEE Std 1003.1-2001
    "FILE": lambda s: os.path.isfile(s),
    "DIR": lambda s: os.path.isdir(os.path.split(s)[0]),
    "UNIXSOCKET": lambda s: s == "unix",
    "UID": lambda s: re.match("^[a-zA-Z0-9]+$", s)
}


DEBUG = os.getenv("DEBUG")
TEST = os.getenv("TEST")

if DEBUG:
    import logging
    formatter = logging.Formatter(
        "%(asctime)s - %(name)s - %(levelname)s - %(message)s")
    ch = logging.StreamHandler()
    ch.setLevel(logging.DEBUG)
    ch.setFormatter(formatter)
    logger = logging.getLogger(__name__)
    logger.setLevel(logging.DEBUG)
    logger.addHandler(ch)

syslog.openlog(SCRIPT)

#
# UTILITY
#


def is_valid_address(value):
    """
    Validate that the passed ip is a valid IP address.

    :param value: the value to be validated
    :type value: str
    :rtype: bool
    """
    try:
        socket.inet_aton(value)
        return True
    except Exception:
        log("%s: ERROR: MALFORMED IP: %s!" % (SCRIPT, value))
        return False


def split_list(_list, regex):
    """
    Split a list based on a regex:
    e.g. split_list(["xx", "yy", "x1", "zz"], "^x") => [["xx", "yy"], ["x1",
    "zz"]]

    :param _list: the list to be split.
    :type _list: list
    :param regex: the regex expression to filter with.
    :type regex: str

    :rtype: list
    """
    if not hasattr(regex, "match"):
        regex = re.compile(regex)
    result = []
    i = 0
    if not _list:
        return result
    while True:
        if regex.match(_list[i]):
            result.append([])
            while True:
                result[-1].append(_list[i])
                i += 1
                if i >= len(_list) or regex.match(_list[i]):
                    break
        else:
            i += 1
        if i >= len(_list):
            break
    return result


def get_process_list():
    """
    Get a process list by reading `/proc` filesystem.

    :return: a list of tuples, each containing pid and command string.
    :rtype: tuple if lists
    """
    res = []
    pids = [pid for pid in os.listdir('/proc') if pid.isdigit()]

    for pid in pids:
        try:
            res.append((pid, open(
                os.path.join(
                    '/proc', pid, 'cmdline'), 'rb').read()))
        except IOError:  # proc has already terminated
            continue
    return filter(None, res)


def run(command, *args, **options):
    """
    Run an external command.

    Options:

      `check`: If True, check the command's output. bail if non-zero. (the
               default is true unless detach or input is true)
      `exitcode`: like `check`, but return exitcode instead of bailing.
      `detach`: If True, run in detached process.
      `input`: If True, open command for writing stream to, returning the Popen
               object.
      `throw`: If True, raise an exception if there is an error instead
               of bailing.
    """
    parts = [command]
    parts.extend(args)
    debug("%s run: %s " % (SCRIPT, " ".join(parts)))

    _check = options.get("check", True)
    _detach = options.get("detach", False)
    _input = options.get("input", False)
    _exitcode = options.get("exitcode", False)
    _throw = options.get("throw", False)

    if not (_check or _throw) or _detach or _input:
        if _input:
            return subprocess.Popen(parts, stdin=subprocess.PIPE)
        else:
            subprocess.Popen(parts)
            return None
    else:
        try:
            devnull = open('/dev/null', 'w')
            subprocess.check_call(parts, stdout=devnull, stderr=devnull)
            return 0
        except subprocess.CalledProcessError as exc:
            if _exitcode:
                if exc.returncode != 1:
                    # 0 or 1 is to be expected, but anything else
                    # should be logged.
                    debug("ERROR: Could not run %s: %s" %
                          (exc.cmd, exc.output), exception=exc)
                return exc.returncode
            elif _throw:
                raise exc
            else:
                bail("ERROR: Could not run %s: %s" % (exc.cmd, exc.output),
                     exception=exc)


def log(msg=None, exception=None, priority=syslog.LOG_INFO):
    """
    print and log warning message or exception.

    :param msg: optional error message.
    :type msg: str
    :param msg: optional exception.
    :type msg: Exception
    :param msg: syslog level
    :type msg: one of LOG_EMERG, LOG_ALERT, LOG_CRIT, LOG_ERR,
               LOG_WARNING, LOG_NOTICE, LOG_INFO, LOG_DEBUG
    """
    if msg is not None:
        print("%s: %s" % (SCRIPT, msg))
        syslog.syslog(priority, msg)
    if exception is not None:
        if TEST or DEBUG:
            traceback.print_exc()
        syslog.syslog(priority, traceback.format_exc())


def debug(msg=None, exception=None):
    """
    Just like log, but is skipped unless DEBUG. Use syslog.LOG_INFO
    even for debug messages (we don't want to miss them).
    """
    if TEST or DEBUG:
        log(msg, exception)


def bail(msg=None, exception=None):
    """
    abnormal exit. like log(), but exits with error status code.
    """
    log(msg, exception)
    exit(1)

#
# OPENVPN
#


def get_openvpn_bin():
    """
    Return the path for either the system openvpn or the one the
    bundle has put there.
    """
    if os.path.isfile(OPENVPN_SYSTEM_BIN):
        return OPENVPN_SYSTEM_BIN

    # the bundle option should be removed from the debian package.
    if os.path.isfile(OPENVPN_LEAP_BIN):
        return OPENVPN_LEAP_BIN


def parse_openvpn_flags(args):
    """
    Take argument list from the command line and parse it, only allowing some
    configuration flags.

    :type args: list
    """
    result = []
    try:
        for flag in split_list(args, "^--"):
            flag_name = flag[0]
            if flag_name in ALLOWED_FLAGS:
                result.append(flag_name)
                required_params = ALLOWED_FLAGS[flag_name]
                if required_params:
                    flag_params = flag[1:]
                    if len(flag_params) != len(required_params):
                        log("%s: ERROR: not enough params for %s" %
                            (SCRIPT, flag_name))
                        return None
                    for param, param_type in zip(flag_params, required_params):
                        if PARAM_FORMATS[param_type](param):
                            result.append(param)
                        else:
                            log("%s: ERROR: Bad argument %s" %
                                (SCRIPT, param))
                            return None
            else:
                log("WARNING: unrecognized openvpn flag %s" % flag_name)
        return result
    except Exception as exc:
        log("%s: ERROR PARSING FLAGS: %s" % (SCRIPT, exc))
        if DEBUG:
            logger.exception(exc)
        return None


def openvpn_start(args):
    """
    Launch openvpn, sanitizing input, and replacing the current process with
    the openvpn process.

    :param args: arguments to be passed to openvpn
    :type args: list
    """
    openvpn_flags = parse_openvpn_flags(args)
    if openvpn_flags:
        OPENVPN = get_openvpn_bin()
        flags = [OPENVPN] + FIXED_FLAGS + openvpn_flags
        if DEBUG:
            log("%s: running openvpn with flags:" % (SCRIPT,))
            log(flags)
        # note: first argument to command is ignored, but customarily set to
        # the command.
        os.execv(OPENVPN, flags)
    else:
        bail('ERROR: could not parse openvpn options')


def openvpn_stop(args):
    """
    Stop the openvpn that has likely been launched by bitmask.

    :param args: arguments to openvpn
    :type args: list
    """
    plist = get_process_list()
    OPENVPN_BIN = get_openvpn_bin()
    found_leap_openvpn = filter(
        lambda (p, s): s.startswith(OPENVPN_BIN) and LEAPOPENVPN in s,
        plist)

    if found_leap_openvpn:
        pid = found_leap_openvpn[0][0]
        os.kill(int(pid), signal.SIGTERM)

#
# FIREWALL
#


def get_gateways(gateways):
    """
    Filter a passed sequence of gateways, returning only the valid ones.

    :param gateways: a sequence of gateways to filter.
    :type gateways: iterable
    :rtype: iterable
    """
    result = filter(is_valid_address, gateways)
    if not result:
        bail("ERROR: No valid gateways specified")
    else:
        return result


def get_default_device():
    """
    Retrieve the current default network device.

    :rtype: str
    """
    routes = subprocess.check_output([IP, "route", "show"])
    match = re.search("^default .*dev ([^\s]*) .*$", routes, flags=re.M)
    if match and match.groups():
        return match.group(1)
    else:
        bail("Could not find default device")


def get_local_network_ipv4(device):
    """
    Get the local ipv4 addres for a given device.

    :param device:
    :type device: str
    """
    addresses = cmdcheck([IP, "-o", "address", "show", "dev", device])
    match = re.search("^.*inet ([^ ]*) .*$", addresses, flags=re.M)
    if match and match.groups():
        return match.group(1)
    else:
        return None


def get_local_network_ipv6(device):
    """
    Get the local ipv6 addres for a given device.

    :param device:
    :type device: str
    """
    addresses = cmdcheck([IP, "-o", "address", "show", "dev", device])
    match = re.search("^.*inet6 ([^ ]*) .*$", addresses, flags=re.M)
    if match and match.groups():
        return match.group(1)
    else:
        return None


def run_iptable_with_check(cmd, *args, **options):
    """
    Run an iptables command checking to see if it should:
      for --append: run only if rule does not already exist.
      for --insert: run only if rule does not already exist.
      for --delete: run only if rule does exist.
    other commands are run normally.
    """
    if "--insert" in args:
        check_args = [arg.replace("--insert", "--check") for arg in args]
        check_code = run(cmd, *check_args, exitcode=True)
        if check_code != 0:
            run(cmd, *args, **options)
    elif "--append" in args:
        check_args = [arg.replace("--append", "--check") for arg in args]
        check_code = run(cmd, *check_args, exitcode=True)
        if check_code != 0:
            run(cmd, *args, **options)
    elif "--delete" in args:
        check_args = [arg.replace("--delete", "--check") for arg in args]
        check_code = run(cmd, *check_args, exitcode=True)
        if check_code == 0:
            run(cmd, *args, **options)
    else:
        run(cmd, *args, **options)


def iptables(*args, **options):
    """
    Run iptables4 and iptables6.
    """
    ip4tables(*args, **options)
    ip6tables(*args, **options)


def ip4tables(*args, **options):
    """
    Run iptables4 with checks.
    """
    run_iptable_with_check(IPTABLES, *args, **options)


def ip6tables(*args, **options):
    """
    Run iptables6 with checks.
    """
    run_iptable_with_check(IP6TABLES, *args, **options)

#
# NOTE: these tests to see if a chain exists might incorrectly return false.
# This happens when there is an error in calling `iptables --list bitmask`.
#
# For this reason, when stopping the firewall, we do not trust the
# output of ipvx_chain_exists() but instead always attempt to delete
# the chain.
#


def ipv4_chain_exists(chain, table=None):
    """
    Check if a given chain exists. Only returns true if it actually exists,
    but might return false if it exists and iptables failed to run.

    :param chain: the chain to check against
    :type chain: str
    :rtype: bool
    """
    if table is not None:
        code = run(IPTABLES, "-t", table,
                   "--list", chain, "--numeric", exitcode=True)
    else:
        code = run(IPTABLES, "--list", chain, "--numeric", exitcode=True)
    if code == 0:
        return True
    elif code == 1:
        return False
    else:
        log("ERROR: Could not determine state of iptable chain")
        return False


def ipv6_chain_exists(chain):
    """
    see ipv4_chain_exists()

    :param chain: the chain to check against
    :type chain: str
    :rtype: bool
    """
    code = run(IP6TABLES, "--list", chain, "--numeric", exitcode=True)
    if code == 0:
        return True
    elif code == 1:
        return False
    else:
        log("ERROR: Could not determine state of iptable chain")
        return False


def enable_ip_forwarding():
    """
    ip_fowarding must be enabled for the firewall to work.
    """
    with open('/proc/sys/net/ipv4/ip_forward', 'w') as f:
        f.write('1\n')


def firewall_start(args):
    """
    Bring up the firewall.

    :param args: list of gateways, to be sanitized.
    :type args: list
    """
    default_device = get_default_device()
    local_network_ipv4 = get_local_network_ipv4(default_device)
    local_network_ipv6 = get_local_network_ipv6(default_device)
    gateways = get_gateways(args)

    # add custom chain "bitmask" to front of OUTPUT chain for both
    # the 'filter' and the 'nat' tables.
    if not ipv4_chain_exists(BITMASK_CHAIN):
        ip4tables("--new-chain", BITMASK_CHAIN)
    if not ipv4_chain_exists(BITMASK_CHAIN_NAT_OUT, 'nat'):
        ip4tables("--table", "nat", "--new-chain", BITMASK_CHAIN_NAT_OUT)
    if not ipv4_chain_exists(BITMASK_CHAIN_NAT_POST, 'nat'):
        ip4tables("--table", "nat", "--new-chain", BITMASK_CHAIN_NAT_POST)
    if not ipv6_chain_exists(BITMASK_CHAIN):
        ip6tables("--new-chain", BITMASK_CHAIN)
    ip4tables("--table", "nat", "--insert", "OUTPUT",
              "--jump", BITMASK_CHAIN_NAT_OUT)
    ip4tables("--table", "nat", "--insert", "POSTROUTING",
              "--jump", BITMASK_CHAIN_NAT_POST)
    iptables("--insert", "OUTPUT", "--jump", BITMASK_CHAIN)

    # route all ipv4 DNS over VPN
    # (note: NAT does not work with ipv6 until kernel 3.7)
    enable_ip_forwarding()
    # allow dns to localhost
    ip4tables("-t", "nat", "--append", BITMASK_CHAIN, "--protocol", "udp",
              "--dest", "127.0.1.1,127.0.0.1", "--dport", "53",
              "--jump", "ACCEPT")
    # rewrite all outgoing packets to use VPN DNS server
    # (DNS does sometimes use TCP!)
    ip4tables("-t", "nat", "--append", BITMASK_CHAIN_NAT_OUT, "-p", "udp",
              "--dport", "53", "--jump", "DNAT", "--to", NAMESERVER+":53")
    ip4tables("-t", "nat", "--append", BITMASK_CHAIN_NAT_OUT, "-p", "tcp",
              "--dport", "53", "--jump", "DNAT", "--to", NAMESERVER+":53")
    # enable masquerading, so that DNS packets rewritten by DNAT will
    # have the correct source IPs
    ip4tables("-t", "nat", "--append", BITMASK_CHAIN_NAT_POST,
              "--protocol", "udp", "--dport", "53", "--jump", "MASQUERADE")
    ip4tables("-t", "nat", "--append", BITMASK_CHAIN_NAT_POST,
              "--protocol", "tcp", "--dport", "53", "--jump", "MASQUERADE")

    # allow local network traffic
    if local_network_ipv4:
        # allow local network destinations
        ip4tables("--append", BITMASK_CHAIN,
                  "--destination", local_network_ipv4, "-o", default_device,
                  "--jump", "ACCEPT")
        # allow local network sources for DNS
        # (required to allow local network DNS that gets rewritten by NAT
        #  to get passed through so that MASQUERADE can set correct source IP)
        ip4tables("--append", BITMASK_CHAIN,
                  "--source", local_network_ipv4, "-o", default_device,
                  "-p", "udp", "--dport", "53", "--jump", "ACCEPT")
        ip4tables("--append", BITMASK_CHAIN,
                  "--source", local_network_ipv4, "-o", default_device,
                  "-p", "tcp", "--dport", "53", "--jump", "ACCEPT")
        # allow multicast Simple Service Discovery Protocol
        ip4tables("--append", BITMASK_CHAIN,
                  "--protocol", "udp",
                  "--destination", "239.255.255.250", "--dport", "1900",
                  "-o", default_device, "--jump", "RETURN")
        # allow multicast Bonjour/mDNS
        ip4tables("--append", BITMASK_CHAIN,
                  "--protocol", "udp",
                  "--destination", "224.0.0.251", "--dport", "5353",
                  "-o", default_device, "--jump", "RETURN")
    if local_network_ipv6:
        ip6tables("--append", BITMASK_CHAIN,
                  "--destination", local_network_ipv6, "-o", default_device,
                  "--jump", "ACCEPT")
        # allow multicast Simple Service Discovery Protocol
        ip6tables("--append", BITMASK_CHAIN,
                  "--protocol", "udp",
                  "--destination", "FF05::C", "--dport", "1900",
                  "-o", default_device, "--jump", "RETURN")
        # allow multicast Bonjour/mDNS
        ip6tables("--append", BITMASK_CHAIN,
                  "--protocol", "udp",
                  "--destination", "FF02::FB", "--dport", "5353",
                  "-o", default_device, "--jump", "RETURN")

    # allow ipv4 traffic to gateways
    for gateway in gateways:
        ip4tables("--append", BITMASK_CHAIN, "--destination", gateway,
                  "-o", default_device, "--jump", "ACCEPT")

    # log rejected packets to syslog
    if DEBUG:
        iptables("--append", BITMASK_CHAIN, "-o", default_device,
                 "--jump", "LOG", "--log-prefix", "iptables denied: ",
                 "--log-level", "7")

    # for now, ensure all other ipv6 packets get rejected (regardless of
    # device). not sure why, but "-p any" doesn't work.
    ip6tables("--append", BITMASK_CHAIN, "-p", "tcp", "--jump", "REJECT")
    ip6tables("--append", BITMASK_CHAIN, "-p", "udp", "--jump", "REJECT")

    # reject all other ipv4 sent over the default device
    ip4tables("--append", BITMASK_CHAIN, "-o",
              default_device, "--jump", "REJECT")


def firewall_stop():
    """
    Stop the firewall. Because we really really always want the firewall to
    be stopped if at all possible, this function is cautious and contains a
    lot of trys and excepts.

    If there were any problems, we raise an exception at the end. This allows
    the calling code to retry stopping the firewall. Stopping the firewall
    can fail if iptables is being run by another process (only one iptables
    command can be run at a time).
    """
    ok = True

    # -t filter -D OUTPUT -j bitmask
    try:
        iptables("--delete", "OUTPUT", "--jump", BITMASK_CHAIN, throw=True)
    except subprocess.CalledProcessError as exc:
        debug("INFO: not able to remove bitmask firewall from OUTPUT chain "
              "(maybe it is already removed?)", exc)
        ok = False

    # -t nat -D OUTPUT -j bitmask
    try:
        ip4tables("-t", "nat", "--delete", "OUTPUT",
                  "--jump", BITMASK_CHAIN_NAT_OUT, throw=True)
    except subprocess.CalledProcessError as exc:
        debug("INFO: not able to remove bitmask firewall from OUTPUT chain "
              "in 'nat' table (maybe it is already removed?)", exc)
        ok = False

    # -t nat -D POSTROUTING -j bitmask_postrouting
    try:
        ip4tables("-t", "nat", "--delete", "POSTROUTING",
                  "--jump", BITMASK_CHAIN_NAT_POST, throw=True)
    except subprocess.CalledProcessError as exc:
        debug("INFO: not able to remove bitmask firewall from POSTROUTING "
              "chain in 'nat' table (maybe it is already removed?)", exc)
        ok = False

    # -t filter --delete-chain bitmask
    try:
        ip4tables("--flush", BITMASK_CHAIN, throw=True)
        ip4tables("--delete-chain", BITMASK_CHAIN, throw=True)
    except subprocess.CalledProcessError as exc:
        debug("INFO: not able to flush and delete bitmask ipv4 firewall "
              "chain (maybe it is already destroyed?)", exc)
        ok = False

    # -t nat --delete-chain bitmask
    try:
        ip4tables("-t", "nat", "--flush", BITMASK_CHAIN_NAT_OUT, throw=True)
        ip4tables("-t", "nat", "--delete-chain",
                  BITMASK_CHAIN_NAT_OUT, throw=True)
    except subprocess.CalledProcessError as exc:
        debug("INFO: not able to flush and delete bitmask ipv4 firewall "
              "chain in 'nat' table (maybe it is already destroyed?)", exc)
        ok = False

    # -t nat --delete-chain bitmask_postrouting
    try:
        ip4tables("-t", "nat", "--flush", BITMASK_CHAIN_NAT_POST, throw=True)
        ip4tables("-t", "nat", "--delete-chain",
                  BITMASK_CHAIN_NAT_POST, throw=True)
    except subprocess.CalledProcessError as exc:
        debug("INFO: not able to flush and delete bitmask ipv4 firewall "
              "chain in 'nat' table (maybe it is already destroyed?)", exc)
        ok = False

    # -t filter --delete-chain bitmask (ipv6)
    try:
        ip6tables("--flush", BITMASK_CHAIN, throw=True)
        ip6tables("--delete-chain", BITMASK_CHAIN, throw=True)
    except subprocess.CalledProcessError as exc:
        debug("INFO: not able to flush and delete bitmask ipv6 firewall "
              "chain (maybe it is already destroyed?)", exc)
        ok = False

    if not (ok or ipv4_chain_exists or ipv6_chain_exists):
        raise Exception("firewall might still be left up. "
                        "Please try `firewall stop` again.")


def fw_email_start(args):
    """
    Bring up the email firewall.

    :param args: the user uid of the bitmask process
    :type args: list
    """
    # add custom chain "bitmask_email" to front of INPUT chain
    if not ipv4_chain_exists(BITMASK_CHAIN_EMAIL):
        ip4tables("--new-chain", BITMASK_CHAIN_EMAIL)
    if not ipv6_chain_exists(BITMASK_CHAIN_EMAIL):
        ip6tables("--new-chain", BITMASK_CHAIN_EMAIL)
    iptables("--insert", "INPUT", "--jump", BITMASK_CHAIN_EMAIL)

    # add custom chain "bitmask_email_output" to front of OUTPUT chain
    if not ipv4_chain_exists(BITMASK_CHAIN_EMAIL_OUT):
        ip4tables("--new-chain", BITMASK_CHAIN_EMAIL_OUT)
    if not ipv6_chain_exists(BITMASK_CHAIN_EMAIL_OUT):
        ip6tables("--new-chain", BITMASK_CHAIN_EMAIL_OUT)
    iptables("--insert", "OUTPUT", "--jump", BITMASK_CHAIN_EMAIL_OUT)

    # Disable the access to imap and smtp from outside
    iptables("--append", BITMASK_CHAIN_EMAIL,
             "--in-interface", LOCAL_INTERFACE, "--protocol", "tcp",
             "--dport", IMAP_PORT, "--jump", "ACCEPT")
    iptables("--append", BITMASK_CHAIN_EMAIL,
             "--in-interface", LOCAL_INTERFACE, "--protocol", "tcp",
             "--dport", SMTP_PORT, "--jump", "ACCEPT")
    iptables("--append", BITMASK_CHAIN_EMAIL,
             "--protocol", "tcp", "--dport", IMAP_PORT, "--jump", "REJECT")
    iptables("--append", BITMASK_CHAIN_EMAIL,
             "--protocol", "tcp", "--dport", SMTP_PORT, "--jump", "REJECT")

    if not args or not PARAM_FORMATS["UID"](args[0]):
        raise Exception("No uid given")
    uid = args[0]

    # Only the unix 'uid' have access to the email imap and smtp ports
    iptables("--append", BITMASK_CHAIN_EMAIL_OUT,
             "--out-interface", LOCAL_INTERFACE,
             "--match", "owner", "--uid-owner", uid, "--protocol", "tcp",
             "--dport", IMAP_PORT, "--jump", "ACCEPT")
    iptables("--append", BITMASK_CHAIN_EMAIL_OUT,
             "--out-interface", LOCAL_INTERFACE,
             "--match", "owner", "--uid-owner", uid, "--protocol", "tcp",
             "--dport", SMTP_PORT, "--jump", "ACCEPT")
    iptables("--append", BITMASK_CHAIN_EMAIL_OUT,
             "--out-interface", LOCAL_INTERFACE,
             "--protocol", "tcp", "--dport", IMAP_PORT, "--jump", "REJECT")
    iptables("--append", BITMASK_CHAIN_EMAIL_OUT,
             "--out-interface", LOCAL_INTERFACE,
             "--protocol", "tcp", "--dport", SMTP_PORT, "--jump", "REJECT")


def fw_email_stop():
    """
    Stop the email firewall.
    """
    ok = True

    try:
        iptables("--delete", "INPUT", "--jump", BITMASK_CHAIN_EMAIL,
                 throw=True)
    except subprocess.CalledProcessError as exc:
        debug("INFO: not able to remove bitmask email firewall from INPUT "
              "chain (maybe it is already removed?)", exc)
        ok = False

    try:
        iptables("--delete", "OUTPUT", "--jump", BITMASK_CHAIN_EMAIL_OUT,
                 throw=True)
    except subprocess.CalledProcessError as exc:
        debug("INFO: not able to remove bitmask email firewall from OUTPUT "
              "chain (maybe it is already removed?)", exc)
        ok = False

    try:
        ip4tables("--flush", BITMASK_CHAIN_EMAIL, throw=True)
        ip4tables("--delete-chain", BITMASK_CHAIN_EMAIL, throw=True)
    except subprocess.CalledProcessError as exc:
        debug("INFO: not able to flush and delete bitmask ipv4 email firewall "
              "chain (maybe it is already destroyed?)", exc)
        ok = False

    try:
        ip6tables("--flush", BITMASK_CHAIN_EMAIL, throw=True)
        ip6tables("--delete-chain", BITMASK_CHAIN_EMAIL, throw=True)
    except subprocess.CalledProcessError as exc:
        debug("INFO: not able to flush and delete bitmask ipv6 email firewall "
              "chain (maybe it is already destroyed?)", exc)
        ok = False

    try:
        ip4tables("--flush", BITMASK_CHAIN_EMAIL_OUT, throw=True)
        ip4tables("--delete-chain", BITMASK_CHAIN_EMAIL_OUT, throw=True)
    except subprocess.CalledProcessError as exc:
        debug("INFO: not able to flush and delete bitmask ipv4 email firewall "
              "chain (maybe it is already destroyed?)", exc)
        ok = False

    try:
        ip6tables("--flush", BITMASK_CHAIN_EMAIL_OUT, throw=True)
        ip6tables("--delete-chain", BITMASK_CHAIN_EMAIL_OUT, throw=True)
    except subprocess.CalledProcessError as exc:
        debug("INFO: not able to flush and delete bitmask ipv6 email firewall "
              "chain (maybe it is already destroyed?)", exc)
        ok = False

    if not (ok or ipv4_chain_exists or ipv6_chain_exists):
        raise Exception("email firewall might still be left up. "
                        "Please try `fw-email stop` again.")


#
# MAIN
#


def main():
    """
    Entry point for cmdline execution.
    """
    # TODO use argparse instead.

    if len(sys.argv) >= 2:
        command = "_".join(sys.argv[1:3])
        args = sys.argv[3:]

        is_restart = False
        if args and args[0] == "restart":
            is_restart = True
            args.remove('restart')

        if command == "version":
            print(VERSION)
            exit(0)

        if os.getuid() != 0:
            bail("ERROR: must be run as root")

        if command == "openvpn_start":
            openvpn_start(args)

        elif command == "openvpn_stop":
            openvpn_stop(args)

        elif command == "firewall_start":
            try:
                firewall_start(args)
            except Exception as ex:
                if not is_restart:
                    firewall_stop()
                bail("ERROR: could not start firewall", ex)

        elif command == "firewall_stop":
            try:
                firewall_stop()
            except Exception as ex:
                bail("ERROR: could not stop firewall", ex)

        elif command == "firewall_isup":
            if ipv4_chain_exists(BITMASK_CHAIN):
                log("%s: INFO: bitmask firewall is up" % (SCRIPT,))
            else:
                bail("INFO: bitmask firewall is down")

        elif command == "fw-email_start":
            try:
                fw_email_start(args)
            except Exception as ex:
                if not is_restart:
                    fw_email_stop()
                bail("ERROR: could not start email firewall", ex)

        elif command == "fw-email_stop":
            try:
                fw_email_stop()
            except Exception as ex:
                bail("ERROR: could not stop email firewall", ex)

        elif command == "fw-email_isup":
            if ipv4_chain_exists(BITMASK_CHAIN_EMAIL):
                log("%s: INFO: bitmask email firewall is up" % (SCRIPT,))
            else:
                bail("INFO: bitmask email firewall is down")

        else:
            bail("ERROR: No such command")
    else:
        bail("ERROR: No such command")

if __name__ == "__main__":
    debug(" ".join(sys.argv))
    main()
    log("done")
    exit(0)