#!/usr/libexec/platform-python

import argparse
import errno
import os
import select
import socket
import sys
import subprocess
import termios
import threading
import traceback

# disable output buffering on stdout
import functools
print = functools.partial(print, flush=True)

def e2str(msg=None):
    etype, e, tb = sys.exc_info()
    line = tb.tb_lineno
    fname = os.path.basename(tb.tb_frame.f_code.co_filename)
    ename = (etype.__module__ + '.' + etype.__name__) \
        if etype.__module__ else etype.__name__
    return fname + ':' + str(line) + ': ' + (msg + ', ' if msg else '') + \
        ename + ', ' + str(e)

# b None: return current echo state and change nothing
# b True: switch echo on
# b False: switch echo off
# @return original echo state
def stdin_tty_echo(b=None):
    ori = None
    if not sys.stdin.isatty():
        return ori
    # best effort
    try:
        attrs = termios.tcgetattr(0)
        ori = bool(attrs[3] & termios.ECHO)
        if b is None:
            return ori
        if b:
            attrs[3] |= termios.ECHO
        else:
            attrs[3] &= ~termios.ECHO
        termios.tcsetattr(0, termios.TCSADRAIN, attrs)
    except:
        pass
    return ori

# CLIENT

def run_client_mode(sockname, args):
    argc = len(args)

    if argc == 1 and args[0] == 'unlock':
        while True:
            sys.stdout.write('Passphrase:\n')
            sys.stdout.flush()
            echo_ori = stdin_tty_echo(False)
            try: p = sys.stdin.readline()
            except:
                print('Failed to read passphrase')
                return False
            finally: stdin_tty_echo(echo_ori)
            if len(p) > 1:
                break
        args.append(p[:-1])
        argc += 1

    try:
        srvsock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
        srvsock.connect(sockname)
    except:
        print(e2str('Could not connect'))
        return False

    # easier for line-oriented
    srvio = srvsock.makefile(mode='rwb', buffering=0, newline='\n')

    forward = { sys.stdin: srvio, srvio: sys.stdout }
    dead = { sys.stdin: True, srvio: False }
    wait_server = False
    rlist = [ srvio ]
    retval = True

    # best effort
    try:
        if argc > 0:
            data = ' '.join(args)
            data = bytes(data + '\n', 'utf-8')
            srvio.write(data)
            srvio.flush()
            wait_server = True

        if not wait_server: # interactive
            dead[sys.stdin] = False
            rlist.append(sys.stdin)

        while True:
            # use select to detect server disconnect or stop
            fdr, fdw, fdex = select.select(rlist, [], [])
            for s in fdr:
                # read
                data = s.readline()
                if s == srvio:
                    data = data.decode('utf-8')
                else:
                    data = bytes(data, 'utf-8')
                # check EOF on stdin or disconnect from server
                if len(data) == 0:
                    dead[s] = True
                    continue
                # forward
                forward[s].write(data)
                forward[s].flush()
                if s == srvio:
                    if data[:3] == 'ERR':
                        retval = False
                    wait_server = False
                else:
                    wait_server = True
            if dead[srvio]:
                if wait_server:
                    retval = False
                    print('Server left while waiting for reply')
                break
            elif not wait_server and dead[sys.stdin]:
                break
    except KeyboardInterrupt:
        pass
    except:
        retval = False
        print(e2str())
    # best effort
    try: srvio.close()
    except: pass
    try: srvsock.shutdown(socket.SHUT_RDWR)
    except: pass
    try: srvsock.close()
    except: pass
    return retval

# VERIFY

SHELL_CAT = '''
function cat_file() {
    case "${1##*.}" in
        xz) xzcat "$1" ;;
        *) cat "$1" ;;
    esac
}
'''

SHELL_VERIFY = '''
keyfile=$1; datfile=$2; sigfile=$3
[[ -z $keyfile ]] && { echo 'Empty key file argument'; exit 1; }
[[ -z $datfile ]] && { echo 'Empty data file argument'; exit 1; }
[[ -z $sigfile ]] && { echo 'Empty sign file argument'; exit 1; }
cat_file "$datfile" |
    nice -n 19 openssl dgst -verify "$keyfile" \\
        -signature <(base64 -d < "$sigfile")
exit $?
'''

def run_verify_mode(pubkey, datfile, sigfile):
    cmd = [ SHELL_CAT + SHELL_VERIFY, '--', pubkey, datfile, sigfile ]
    p = subprocess.run(cmd, shell=True, executable='/bin/bash')
    return p.returncode == 0

# SERVER

# v: bool
# v: tuple (bool, message)
def format_result(v):
    if isinstance(v, bool):
        return 'OK' if v else 'ERR'
    elif isinstance(v, tuple) and len(v) > 1:
        return ('OK' if v[0] else 'ERR') + ', ' + str(v[1])
    else:
        return 'ERR, Cannot parse result'

class ClientThread(threading.Thread):
    def __init__(self, args):
        global CLIENT_TIMEOUT
        global HANDLERS
        global ARGS

        threading.Thread.__init__(self)
        self.sock = args[0]
        ARGS.quiet or self.log('New connection')

        self.sock.settimeout(CLIENT_TIMEOUT)
        self.io = self.sock.makefile(mode='rwb', buffering=0, newline='\n')
        self.client_loop_running = True
        self.icmd = 0
        self.cmd = None

    def log(self, *args):
        print(self.name + '>', *args)

    def read_cmd(self):
        self.icmd = self.icmd + 1
        self.cmd = None # reset

        try:
            cmd = self.io.readline()
            cmd = cmd.decode('utf-8')
            if cmd == '':
                # \n not present, assume client has left
                return False

        except socket.timeout:
            self.reply((False, 'Client timeout'))
            return False
        except OSError as e:
            # server is stopping, client is disconnecting, don't log
            if e.errno != errno.EBADF and e.errno != errno.ECONNRESET:
                self.log(e2str('Read failed'))
            return False
        except Exception as e:
            self.log(e2str('Read failed'))
            self.reply((False, 'Input read error'))
            return False

        if cmd[-1:] != '\n': # okay with length zero
            self.reply((False, 'Invalid input'))
            return False
        if len(cmd) < 2:
            self.reply((False, 'Empty command'))
            return False

        # using shlex is better but no need so far
        self.cmd = cmd[:-1].split()
        return True

    def close(self):
        self.client_loop_running = False
        # best effort
        try: self.io.close()
        except: pass
        try: self.sock.shutdown(socket.SHUT_RDWR)
        except: pass
        try: self.sock.close()
        except: pass

    def reply(self, result, handler=None):
        try:
            result = format_result(result)
            if self.cmd:
                logcmd = ' '.join(
                    [ e if i == 0 else '<hidden>' for i, e in enumerate(self.cmd) ]
                    if handler and handler['args'].get('hide')
                    else self.cmd
                )
            else:
                logcmd = '<none>'
            self.log(f'Command[{self.icmd}]: {logcmd} => {result}')
            reply = bytes(result + '\n', 'utf-8')
            self.io.write(reply)
            self.io.flush()
        except:
            self.log(e2str('Reply failed'))
            return False
        return True

    def run(self):
        while self.client_loop_running:
            hdl = None

            # read command
            if not self.read_cmd():
                break

            # run command
            try:
                if self.cmd[0] in HANDLERS:
                    hdl = HANDLERS[self.cmd[0]] # shortcut
                    nargs = len(self.cmd) - 1
                    if hdl['args']['min'] <= nargs and nargs <= hdl['args']['max']:
                        result = hdl['func'](self, *self.cmd[1:])
                    else:
                        result = (False, 'Invalid number of argument')
                else:
                    result = (False, 'Invalid command')
            except:
                self.log(e2str('Command failed'))
                self.reply((False, 'Internal error'), hdl)
                break

            # reply command result and continue
            if not self.reply(result, hdl):
                break

        # done, close connection
        self.close()
        ARGS.quiet or self.log('Disconnected')

def cmd_handler__echo(ct, msg):
    return True, msg

def cmd_handler__exit(ct):
    ct.client_loop_running = False
    return True

def cmd_handler__lock(ct):
    global LOCK, PASSPHRASE
    with LOCK:
        PASSPHRASE = None
    return True

SHELL_SIGN = '''
set -f
keyfile=$1; datfile=$2; sigfile=$3
[[ -z $keyfile ]] && { echo 'Empty key file argument'; exit 1; }
[[ -z $datfile ]] && { echo 'Empty data file argument'; exit 1; }
[[ -z $sigfile ]] && { echo 'Empty sign file argument'; exit 1; }
[[ -r $datfile ]] || { echo 'Cannot read data file'; exit 1; }
[[ -e $sigfile ]] && { echo 'Sign file already exists, remove first'; exit 1; }
out=( $(cat_file "$datfile" |
            nice -n 19 openssl dgst -sign "$keyfile" -passin env:P |
            base64 -w 0) )
[[ -n ${out[O]} ]] || { echo 'Cannot sign file, check status'; exit 1; }
rand=$RANDOM
echo "${out[0]}" > "$sigfile.new.$rand" || { echo 'Cannot build sign file 1/2'; exit 1; }
mv "$sigfile.new.$rand" "$sigfile" || { echo 'Cannot build sign file 2/2'; exit 1; }
exit 0
'''

def cmd_handler__sign(ct, datfile, sigfile):
    global LOCK, PASSPHRASE
    with LOCK:
        if PASSPHRASE is None:
            return False, 'Passphrase not set, need unlock'
        cmd = [ SHELL_CAT + SHELL_SIGN, '--', ARGS.keyfile, datfile, sigfile ]
        # setenv with preexec_fn avoids leaking P=... strings
        p = subprocess.run(cmd, shell=True, executable='/bin/bash',
                stdout=subprocess.PIPE, universal_newlines=True,
                preexec_fn=lambda: os.putenv('P', PASSPHRASE))
        if p.returncode == 0:
            return True
        return False, p.stdout.split('\n')[0]

def cmd_handler__shutdown(ct):
    global LOCK, SERVER_LOOP_RUNNING
    with LOCK:
        SERVER_LOOP_RUNNING = False
    ct.client_loop_running = False
    return True

def cmd_handler__status(ct):
    global LOCK, PASSPHRASE
    with LOCK:
        if PASSPHRASE is not None:
            # check the passphrase
            cmd = [ 'openssl', 'rsa', '-in', ARGS.keyfile, '-passin', 'env:P', '-noout' ]
            # setenv with preexec_fn avoids leaking P=... strings
            p = subprocess.run(cmd, preexec_fn=lambda: os.putenv('P', PASSPHRASE))
            if p.returncode == 0:
                return True
        return False, 'Passphrase not set, need unlock'

def cmd_handler__unlock(ct, v=None):
    global LOCK, PASSPHRASE
    if v is None:
        # best effort
        try:
            ct.io.write(b'Passphrase:\n')
            ct.io.flush()
            v = ct.io.readline()
            v = v.decode('utf-8')[:-1]
        except socket.timeout:
            ct.client_loop_running = False
            return False, 'Client timeout'
        except:
            ct.client_loop_running = False
            return False, 'Failed to read passphrase'
    if len(v) == 0:
        return False, 'Empty passphrase'
    with LOCK:
        PASSPHRASE = v
        return cmd_handler__status(ct)

# MAIN

# parse command line arguments
ap = argparse.ArgumentParser()

ap.add_argument('-c', '--client', help='run as client',
    dest='mode', action='store_const', const='client', default='client')
ap.add_argument('-s', '--server', help='run as server',
    dest='mode', action='store_const', const='server')
ap.add_argument('-v', '--verify', help='run verify mode',
    dest='mode', action='store_const', const='verify')
ap.add_argument('-k', '--key', help='key filename, public (verify) or private (server)',
    dest='keyfile', default=os.getenv('ZLC_SIGFILED_KEYFILE'))
ap.add_argument('-S', '--socket', help='server socket filename', dest='sockname',
    default=os.getenv('ZLC_SIGFILED_SOCKNAME'))
ap.add_argument('-q', '--quiet', help='server emit less output',
    dest='quiet', action='store_true', default=False)
ap.add_argument('rest', help='command and arguments (client), data file and sign file (verify)', metavar='arg', nargs='*')

ARGS = ap.parse_args()

if (ARGS.mode == 'client' or ARGS.mode == 'server') and ARGS.sockname is None:
    print('Server socket filename not set, check program usage');
    sys.exit(1)

if (ARGS.mode == 'server' or ARGS.mode == 'verify') and ARGS.keyfile is None:
    print('Key filename not set, check program usage');
    sys.exit(1)

if (ARGS.mode == 'verify' and len(ARGS.rest) != 2):
    print('Verify mode need two arguments, check program usage')
    sys.exit(1)

# run client mode
if ARGS.mode == 'client':
    retval = run_client_mode(ARGS.sockname, ARGS.rest)
    sys.exit(0 if retval else 2)
# run verify mode
elif ARGS.mode == 'verify':
    retval = run_verify_mode(ARGS.keyfile, ARGS.rest[0], ARGS.rest[1])
    sys.exit(0 if retval else 2)

# run server mode
CLIENT_TIMEOUT = 10
SERVER_LOOP_RUNNING = True
PASSPHRASE = None
LOCK = threading.RLock()
HANDLERS = {
    'echo': { 'func': cmd_handler__echo, 'args': { 'min': 1, 'max': 1 } },
    'exit': { 'func': cmd_handler__exit, 'args': { 'min': 0, 'max': 0 } },
    'lock': { 'func': cmd_handler__lock, 'args': { 'min': 0, 'max': 0 } },
    'shutdown': { 'func': cmd_handler__shutdown, 'args': { 'min': 0, 'max': 0 } },
    'sign': { 'func': cmd_handler__sign, 'args': { 'min': 2, 'max': 2 } },
    'status': { 'func': cmd_handler__status, 'args': { 'min': 0, 'max': 0 } },
    'unlock': { 'func': cmd_handler__unlock, 'args': { 'min': 0, 'max': 1, 'hide': True } },
}

# create signature files with mode 400
os.umask(0o0277)

try:
    server = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
    server.settimeout(1) # easier to handle shutdown
    server.bind(ARGS.sockname)
    os.chmod(ARGS.sockname, 0o0600)
    server.listen()
except:
    print(e2str('Cannot start server'))
    sys.exit(2)

print('Server started')

while SERVER_LOOP_RUNNING:
    try:
        csock = server.accept()
        cth = ClientThread(csock)
        cth.start()
    except socket.timeout:
        pass
    except KeyboardInterrupt:
        break

print('Server shutdown')

# best effort
try: server.shutdown(socket.SHUT_RDWR)
except: pass
try: server.close()
except: pass
try: os.unlink(ARGS.sockname)
except: pass

th_main = threading.currentThread()
for th in threading.enumerate():
    if th is th_main or not isinstance(th, ClientThread):
        continue
    th.log('Disconnect pending client')
    th.close()
    th.join()
