#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Author: Kali Kaneko
# Copyright (C) 2015-2017 LEAP Encryption Access Project
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

"""
This is a privileged helper script for safely running certain commands as root
under OSX.

It should be run by launchd, and it exposes a Unix Domain Socket to where
the following commmands can be written by the Bitmask application:

  firewall_start [restart] GATEWAY1 GATEWAY2 ...
  firewall_stop
  openvpn_start CONFIG1 CONFIG1 ...
  openvpn_stop
  fw_email_start uid
  fw_email_stop

To load it manually:

  sudo launchctl load /Library/LaunchDaemons/se.leap.bitmask-helper

To see the loaded rules:

  sudo pfctl -s rules -a bitmask

To test the commands, you can write directly to the unix socket. Remember to
terminate the command properly:

 echo 'firewall_stop/CMD' | socat - UNIX-CONNECT:/tmp/bitmask-helper.socket

"""
import os
import socket
import signal
import subprocess
import syslog
import threading

from commands import getoutput as exec_cmd
from functools import partial

import daemon

VERSION = "1"
SCRIPT = "bitmask-helper"
NAMESERVER = "10.42.0.1"
BITMASK_ANCHOR = "com.apple/250.BitmaskFirewall"
BITMASK_ANCHOR_EMAIL = "bitmask_email"

OPENVPN_USER = 'nobody'
OPENVPN_GROUP = 'nogroup'
LEAPOPENVPN = 'LEAPOPENVPN'
APP_PATH = '/Applications/Bitmask.app/'
RESOURCES_PATH = APP_PATH + 'Contents/Resources/'
OPENVPN_LEAP_BIN = RESOURCES_PATH + 'openvpn.leap'

FIXED_FLAGS = [
    "--setenv", "LEAPOPENVPN", "1",
    "--nobind",
    "--client",
    "--tls-client",
    "--remote-cert-tls", "server",
    "--management-signal",
    "--script-security", "1",
    "--user", "nobody",
    "--remap-usr1", "SIGTERM",
    "--group", OPENVPN_GROUP,
]

ALLOWED_FLAGS = {
    "--remote": ["IP", "NUMBER", "PROTO"],
    "--tls-cipher": ["CIPHER"],
    "--cipher": ["CIPHER"],
    "--auth": ["CIPHER"],
    "--management": ["DIR", "UNIXSOCKET"],
    "--management-client-user": ["USER"],
    "--cert": ["FILE"],
    "--key": ["FILE"],
    "--ca": ["FILE"],
    "--fragment": ["NUMBER"]
}

PARAM_FORMATS = {
    "NUMBER": lambda s: re.match("^\d+$", s),
    "PROTO": lambda s: re.match("^(tcp|udp)$", s),
    "IP": lambda s: is_valid_address(s),
    "CIPHER": lambda s: re.match("^[A-Z0-9-]+$", s),
    "USER": lambda s: re.match(
        "^[a-zA-Z0-9_\.\@][a-zA-Z0-9_\-\.\@]*\$?$", s),  # IEEE Std 1003.1-2001
    "FILE": lambda s: os.path.isfile(s),
    "DIR": lambda s: os.path.isdir(os.path.split(s)[0]),
    "UNIXSOCKET": lambda s: s == "unix",
    "UID": lambda s: re.match("^[a-zA-Z0-9]+$", s)
}

#
# paths (must use absolute paths, since this script is run as root)
#

PFCTL = '/sbin/pfctl'
ROUTE = '/sbin/route'
AWK = '/usr/bin/awk'
GREP = '/usr/bin/grep'
CAT = '/bin/cat'

UID = os.getuid()
SERVER_ADDRESS = '/tmp/bitmask-helper.socket'


#
# COMMAND DISPATCH
#

def serve_forever():
    try:
        os.unlink(SERVER_ADDRESS)
    except OSError:
        if os.path.exists(SERVER_ADDRESS):
            raise

    syslog.syslog(syslog.LOG_WARNING, "serving forever")
    # XXX should check permissions on the socket file
    sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
    sock.bind(SERVER_ADDRESS)
    sock.listen(1)
    syslog.syslog(syslog.LOG_WARNING, "Binded to %s" % SERVER_ADDRESS)

    while True:
        connection, client_address = sock.accept()
        thread = threading.Thread(target=handle_command, args=[connection])
        thread.daemon = True
        thread.start()
    
def recv_until_marker(sock):
    end = '/CMD'
    total_data=[]
    data=''
    while True:
        data=sock.recv(8192)
        if end in data:
            total_data.append(data[:data.find(end)])
            break
        total_data.append(data)
        if len(total_data)>1:
            #check if end_of_data was split
            last_pair=total_data[-2]+total_data[-1]
            if end in last_pair:
                total_data[-2] = last_pair[:last_pair.find(end)]
                total_data.pop()
                break
    return ''.join(total_data)


def handle_command(sock):
    syslog.syslog(syslog.LOG_WARNING, "handle")

    received = recv_until_marker(sock)
    syslog.syslog(syslog.LOG_WARNING, "GOT -----> %s" % received)
    line = received.replace('\n', '').split(' ')

    command, args = line[0], line[1:]
    syslog.syslog(syslog.LOG_WARNING, 'command %s' % (command))

    cmd_dict = {
        'firewall_start': (firewall_start, args),
        'firewall_stop': (firewall_stop, []),
        'firewall_isup': (firewall_isup, []),
        'openvpn_start': (openvpn_start, args),
        'openvpn_stop': (openvpn_stop, []),
        'openvpn_force_stop': (openvpn_stop, ['KILL']),
        'openvpn_set_watcher': (openvpn_set_watcher, args)
    }

    cmd_call = cmd_dict.get(command, None)
    syslog.syslog(syslog.LOG_WARNING, 'call: %s' % (str(cmd_call)))
    try:
        if cmd_call:
            syslog.syslog(
                syslog.LOG_WARNING, 'GOT "%s"' % (command))
            cmd, args = cmd_call
            if args:
                cmd = partial(cmd, *args)

            # TODO Use a MUTEX in here
            result = cmd()
            syslog.syslog(syslog.LOG_WARNING, "Executed")
            syslog.syslog(syslog.LOG_WARNING, "Result: %s" % (str(result)))
            if result == 'YES':
                sock.sendall("%s: YES\n" % command)
            elif result == 'NO':
                sock.sendall("%s: NO\n" % command)
            else:
                sock.sendall("%s: OK\n" % command)

        else:
            syslog.syslog(syslog.LOG_WARNING, 'invalid command: %s' % (command,))
            sock.sendall("%s: ERROR\n" % command)
    except Exception as exc:
        syslog.syslog(syslog.LOG_WARNING, "error executing function %r" % (exc))
    finally:
        sock.close()



#
# OPENVPN
#


openvpn_proc = None
openvpn_watcher_pid = None


def openvpn_start(*args):
    """
    Sanitize input and run openvpn as a subprocess of this long-running daemon.
    Keeps a reference to the subprocess Popen class instance.

    :param args: arguments to be passed to openvpn
    :type args: list
    """
    syslog.syslog(syslog.LOG_WARNING, "OPENVPN START")
    opts = list(args[1:])

    opts += ['--dhcp-option', 'DNS', '10.42.0.1',
             '--up', RESOURCES_PATH + 'client.up.sh',
             '--down', RESOURCES_PATH + 'client.down.sh']
    opts += ["--dev", "tun"]
    binary = [RESOURCES_PATH + 'openvpn.leap']
    cmd = binary + opts
    #syslog.syslog(syslog.LOG_WARNING, 'LAUNCHING VPN: ' + ' '.join(cmd))

    # TODO sanitize options
    global openvpn_proc
    openvpn_proc = subprocess.Popen(cmd, shell=False, bufsize=-1)
    #try:
    #    result = subprocess.check_output(cmd, shell=False, stderr=subprocess.STDOUT)
    #except Exception as exc:
    #    syslog.syslog(syslog.LOG_WARNING, exc.output)
    #syslog.syslog(syslog.LOG_WARNING, "OpenVPN PID: %s" % str(openvpn_proc.pid))


def openvpn_stop(sig='TERM'):
    """
    Stop the openvpn that has been launched by this privileged helper.

    :param args: arguments to openvpn
    :type args: list
    """
    global openvpn_proc
    
    if openvpn_proc:
        syslog.syslog(syslog.LOG_WARNING, "OVPN PROC: %s" % str(openvpn_proc.pid))

        if sig == 'KILL':
            stop_signal = signal.SIGKILL
            openvpn_proc.kill()
        elif sig == 'TERM':
            stop_signal = signal.SIGTERM
            openvpn_proc.terminate()
    
        returncode = openvpn_proc.wait()
        syslog.syslog(syslog.LOG_WARNING, "openvpn return code: %s" % str(returncode))
        syslog.syslog(syslog.LOG_WARNING, "openvpn_watcher_pid: %s" % str(openvpn_watcher_pid))
        if openvpn_watcher_pid:
            os.kill(openvpn_watcher_pid, stop_signal)


def openvpn_set_watcher(pid, *args):
    global openvpn_watcher_pid
    openvpn_watcher_pid = int(pid)
    syslog.syslog(syslog.LOG_WARNING, "Watcher PID: %s" % pid)


#
# FIREWALL
#


def firewall_start(*gateways):
    """
    Bring up the firewall.

    :param gws: list of gateways, to be sanitized.
    :type gws: list
    """

    gateways = get_gateways(gateways)

    if not gateways:
        return False

    _enable_pf()
    _reset_bitmask_gateways_table(gateways)

    default_device = _get_default_device()
    _load_bitmask_anchor(default_device)


def firewall_stop():
    """
    Flush everything from anchor bitmask
    """
    cmd = '{pfctl} -a {anchor} -F all'.format(
        pfctl=PFCTL, anchor=BITMASK_ANCHOR)
    return exec_cmd(cmd)


def firewall_isup():
    """
    Return YES if anchor bitmask is loaded with rules
    """
    syslog.syslog(syslog.LOG_WARNING, 'PID---->%s' % os.getpid())
    cmd = '{pfctl} -s rules -a {anchor} | wc -l'.format(
        pfctl=PFCTL, anchor=BITMASK_ANCHOR)
    output = exec_cmd(cmd)
    rules = output[-1]
    if int(rules) > 0:
        return 'YES'
    else:
        return 'NO'


def _enable_pf():
    exec_cmd('{pfctl} -e'.format(pfctl=PFCTL))


def _reset_bitmask_gateways_table(gateways):
    cmd = '{pfctl} -a {anchor} -t bitmask_gateways -T delete'.format(
        pfctl=PFCTL, anchor=BITMASK_ANCHOR)
    output = exec_cmd(cmd)

    for gateway in gateways:
        cmd = '{pfctl} -a {anchor} -t bitmask_gateways -T add {gw}'.format(
            pfctl=PFCTL, anchor=BITMASK_ANCHOR, gw=gateway)
        output = exec_cmd(cmd)
        syslog.syslog(syslog.LOG_WARNING, "adding gw %s" % gateway)

    #cmd = '{pfctl} -a {anchor} -t bitmask_nameservers -T delete'.format(
    #    pfctl=PFCTL, anchor=BITMASK_ANCHOR)
    #output = exec_cmd(cmd)

    cmd = '{pfctl} -a {anchor} -t bitmask_gateways -T add {ns}'.format(
        pfctl=PFCTL, anchor=BITMASK_ANCHOR, ns=NAMESERVER)
    output = exec_cmd(cmd)
    syslog.syslog(syslog.LOG_WARNING, "adding ns %s" % NAMESERVER)

def _load_bitmask_anchor(default_device):
    cmd = ('{pfctl} -D default_device={defaultdevice} '
          '-a {anchor} -f {rulefile}').format(
        pfctl=PFCTL, defaultdevice=default_device,
        anchor=BITMASK_ANCHOR,
        rulefile=RESOURCES_PATH + 'bitmask-helper/bitmask.pf.conf')
    syslog.syslog(syslog.LOG_WARNING, "LOADING CMD: %s" % cmd)
    return exec_cmd(cmd)


def _get_default_device():
    """
    Retrieve the current default network device.

    :rtype: str
    """
    cmd_def_device = (
        '{route} -n get -net default | '
        '{grep} interface | {awk} "{{print $2}}"').format(
            route=ROUTE, grep=GREP, awk=AWK)
    iface = exec_cmd(cmd_def_device)
    iface = iface.replace("interface: ", "").strip()
    syslog.syslog(syslog.LOG_WARNING, "default device %s" % iface)
    return iface



#
# UTILITY
#


def is_valid_address(value):
    """
    Validate that the passed ip is a valid IP address.

    :param value: the value to be validated
    :type value: str
    :rtype: bool
    """
    try:
        socket.inet_aton(value)
        return True
    except Exception:
        syslog.syslog(syslog.LOG_WARNING, 'MALFORMED IP: %s!' % (value))
        return False


#
# FIREWALL
#


def get_gateways(gateways):
    """
    Filter a passed sequence of gateways, returning only the valid ones.

    :param gateways: a sequence of gateways to filter.
    :type gateways: iterable
    :rtype: iterable
    """
    syslog.syslog(syslog.LOG_WARNING, 'Filtering %s' % str(gateways))
    result = filter(is_valid_address, gateways)
    if not result:
        syslog.syslog(syslog.LOG_ERR, 'No valid gateways specified')
        return False
    else:
        return result



if __name__ == "__main__":
    with daemon.DaemonContext():
        syslog.syslog(syslog.LOG_WARNING, "Serving...") 
        serve_forever()