# Copyright (c) 2015, NORDUnet A/S.
# See LICENSE for licensing information.

import os
import base64
import hashlib
import sys
import struct
import json
import yaml
import argparse
import requests
try:
    import permdb
except ImportError:
    pass
from certtools import get_leaf_hash, http_request, get_leaf_hash

def parselogrow(row):
    return base64.b16decode(row, casefold=True)

def get_logorder(filename, items=-1):
    logorder = []
    n = 0
    for row in open(filename, "r"):
        if n == items:
            break
        logorder.append(parselogrow(row.rstrip()))
        n += 1
    return logorder

def get_nfetched(currentsizefile, logorderfile):
    try:
        limit = json.loads(open(currentsizefile).read())
    except (IOError, ValueError):
        return -1
    if limit['index'] >= 0:
        with open(logorderfile, 'r') as f:
            f.seek(limit['index']*65)
            assert f.read(64).lower() == limit['hash']
    return limit['index'] + 1

def get_sth(filename):
    try:
        sth = json.loads(open(filename, 'r').read())
    except (IOError, ValueError):
        sth = {'tree_size': -1,
               'timestamp': 0,
               'sha256_root_hash': '',
               'tree_head_signature': ''}
    return sth

def read_chain_open(chainsdir, filename):
    path = chainsdir + "/" + \
      filename[0:2] + "/" + filename[2:4] + "/" + filename[4:6]
    f = open(path + "/" + filename, "r")
    return f

def read_chain(chainsdir, key):
    filename = base64.b16encode(key).upper()
    try:
        f = read_chain_open(chainsdir, filename)
    except IOError:
        f = read_chain_open(chainsdir, filename.lower())
    value = f.read()
    f.close()
    return value

def tlv_decode(data):
    (length,) = struct.unpack(">I", data[0:4])
    dtype = data[4:8]
    value = data[8:length]
    rest = data[length:]
    return (dtype, value, rest)

def tlv_encode(dtype, value):
    assert len(dtype) == 4
    data = struct.pack(">I", len(value) + 8) + dtype + value
    return data

def tlv_decodelist(data):
    l = []
    while len(data):
        (dtype, value, rest) = tlv_decode(data)
        l.append((dtype, value))
        data = rest
    return l

def tlv_encodelist(l):
    data = ""
    for (dtype, value) in l:
        data += tlv_encode(dtype, value)
    return data

def unwrap_entry(entry):
    ploplevel = tlv_decodelist(entry)
    assert len(ploplevel) == 2
    (ploptype, plopdata) = ploplevel[0]
    (plopchecksumtype, plopchecksum) = ploplevel[1]
    assert ploptype == "PLOP"
    assert plopchecksumtype == "S256"
    computedchecksum = hashlib.sha256(plopdata).digest()
    assert computedchecksum == plopchecksum
    return plopdata

def wrap_entry(entry):
    return tlv_encodelist([("PLOP", entry),
                           ("S256", hashlib.sha256(entry).digest())])

def verify_entry(verifycert, entry, ehash):
    packed = unwrap_entry(entry)
    unpacked = tlv_decodelist(packed)
    (mtltype, mtl) = unpacked[0]
    assert ehash == get_leaf_hash(mtl)
    assert mtltype == "MTL1"
    s = struct.pack(">I", len(packed)) + packed
    try:
        verifycert.stdin.write(s)
    except IOError:
        sys.stderr.write("merge: unable to write to verifycert process: ")
        while 1:
            line = verifycert.stdout.readline()
            if line:
                sys.stderr.write(line)
            else:
                sys.exit(1)
    result_length_packed = verifycert.stdout.read(4)
    (result_length,) = struct.unpack(">I", result_length_packed)
    result = verifycert.stdout.read(result_length)
    assert len(result) == result_length
    (error_code,) = struct.unpack("B", result[0:1])
    if error_code != 0:
        print >>sys.stderr, result[1:]
        sys.exit(1)

def hexencode(key):
    return base64.b16encode(key).lower()

def hexdecode(s):
    return base64.b16decode(s.upper())

def write_chain(key, value, chainsdir, hashed_dir=True):
    filename = hexencode(key)
    if hashed_dir:
        path = chainsdir + "/" \
          + filename[0:2] + "/" + filename[2:4] + "/" + filename[4:6]
        try:
            os.makedirs(path)
        except Exception:
            pass
    else:
        path = chainsdir
    f = open(path + "/" + filename, "w")
    f.write(value)
    f.close()

def add_to_logorder(logorderfile, key):
    f = open(logorderfile, "a")
    f.write(hexencode(key) + "\n")
    f.close()

def fsync_logorder(logorderfile):
    f = open(logorderfile, "a")
    os.fsync(f.fileno())
    f.close()

def get_new_entries(node, baseurl, own_key, paths):
    try:
        result = http_request(baseurl + "plop/v1/storage/fetchnewentries",
                              key=own_key, verifynode=node,
                              publickeydir=paths["publickeys"])
        parsed_result = json.loads(result)
        if parsed_result.get(u"result") == u"ok":
            return [base64.b64decode(entry) for \
                    entry in parsed_result[u"entries"]]
        print >>sys.stderr, "ERROR: fetchnewentries", parsed_result
        sys.exit(1)
    except requests.exceptions.HTTPError, e:
        print >>sys.stderr, "ERROR: fetchnewentries", e.response
        sys.exit(1)

def get_entries(node, baseurl, own_key, paths, hashes, session=None):
    try:
        params = {"hash":[base64.b64encode(ehash) for ehash in hashes]}
        result = http_request(baseurl + "plop/v1/storage/getentry",
                              params=params,
                              key=own_key, verifynode=node,
                              publickeydir=paths["publickeys"], session=session)
        parsed_result = json.loads(result)
        if parsed_result.get(u"result") == u"ok":
            entries = dict([(base64.b64decode(entry["hash"]),
                             base64.b64decode(entry["entry"])) for \
                             entry in parsed_result[u"entries"]])
            assert len(entries) == len(hashes)
            assert set(entries.keys()) == set(hashes)
            return entries
        print >>sys.stderr, "ERROR: getentry", parsed_result
        sys.exit(1)
    except requests.exceptions.HTTPError, e:
        print >>sys.stderr, "ERROR: getentry", e.request.url, e.response
        sys.exit(1)

def get_curpos(node, baseurl, own_key, paths):
    try:
        result = http_request(baseurl + "plop/v1/frontend/currentposition",
                              key=own_key, verifynode=node,
                              publickeydir=paths["publickeys"])
        parsed_result = json.loads(result)
        if parsed_result.get(u"result") == u"ok":
            return parsed_result[u"position"]
        print >>sys.stderr, "ERROR: currentposition", parsed_result
        sys.exit(1)
    except requests.exceptions.HTTPError, e:
        print >>sys.stderr, "ERROR: currentposition", e.response
        sys.exit(1)

def get_frontend_verifiedsize(node, baseurl, own_key, paths):
    return frontend_verify_entries(node, baseurl, own_key, paths, 0)

def frontend_verify_entries(node, baseurl, own_key, paths, size):
    try:
        arguments = {"verify_to": size}
        result = http_request(baseurl + "plop/v1/frontend/verify-entries",
                              json.dumps(arguments),
                              key=own_key, verifynode=node,
                              publickeydir=paths["publickeys"])
        parsed_result = json.loads(result)
        if parsed_result.get(u"result") == u"ok":
            return parsed_result[u"verified"]
        print >>sys.stderr, "ERROR: verify-entries", parsed_result
        sys.exit(1)
    except requests.exceptions.HTTPError, e:
        print >>sys.stderr, "ERROR: verify-entries", e.response
        sys.exit(1)

def get_verifiedsize(node, baseurl, own_key, paths):
    try:
        result = http_request(baseurl + "plop/v1/merge/verifiedsize",
                              key=own_key, verifynode=node,
                              publickeydir=paths["publickeys"])
        parsed_result = json.loads(result)
        if parsed_result.get(u"result") == u"ok":
            return parsed_result[u"size"]
        print >>sys.stderr, "ERROR: verifiedsize", parsed_result
        sys.exit(1)
    except requests.exceptions.HTTPError, e:
        print >>sys.stderr, "ERROR: verifiedsize", e.response
        sys.exit(1)


def sendlog(node, baseurl, own_key, paths, submission):
    try:
        result = http_request(baseurl + "plop/v1/frontend/sendlog",
                              json.dumps(submission), key=own_key,
                              verifynode=node, publickeydir=paths["publickeys"])
        return json.loads(result)
    except requests.exceptions.HTTPError, e:
        print >>sys.stderr, "ERROR: sendlog", e.response
        sys.stderr.flush()
        return None
    except ValueError, e:
        print >>sys.stderr, "==== FAILED REQUEST ===="
        print >>sys.stderr, submission
        print >>sys.stderr, "======= RESPONSE ======="
        print >>sys.stderr, result
        print >>sys.stderr, "========================"
        sys.stderr.flush()
        raise e

def backup_sendlog(node, baseurl, own_key, paths, submission):
    try:
        result = http_request(baseurl + "plop/v1/merge/sendlog",
                              json.dumps(submission), key=own_key,
                              verifynode=node, publickeydir=paths["publickeys"])
        return json.loads(result)
    except requests.exceptions.HTTPError, e:
        print >>sys.stderr, "ERROR: backup_sendlog", e.response
        sys.stderr.flush()
        return None
    except requests.packages.urllib3.exceptions.NewConnectionError, e:
        print >>sys.stderr, "ERROR: backup_sendlog new connection error"
        sys.stderr.flush()
        return None
    except ValueError, e:
        print >>sys.stderr, "==== FAILED REQUEST ===="
        print >>sys.stderr, submission
        print >>sys.stderr, "======= RESPONSE ======="
        print >>sys.stderr, result
        print >>sys.stderr, "========================"
        sys.stderr.flush()
        raise e

def sendentries(node, baseurl, own_key, paths, entries, session=None):
    try:
        json_entries = [{"entry":base64.b64encode(entry), "treeleafhash":base64.b64encode(hash)} for hash, entry in entries]
        result = http_request(
            baseurl + "plop/v1/frontend/sendentry",
            json.dumps(json_entries),
            key=own_key, verifynode=node, publickeydir=paths["publickeys"],
            session=session)
        return json.loads(result)
    except requests.exceptions.HTTPError, e:
        print >>sys.stderr, "ERROR: sendentries", e.response
        sys.exit(1)
    except ValueError, e:
        print >>sys.stderr, "==== FAILED REQUEST ===="
        print >>sys.stderr, ehash
        print >>sys.stderr, "======= RESPONSE ======="
        print >>sys.stderr, result
        print >>sys.stderr, "========================"
        sys.stderr.flush()
        raise e
    except requests.exceptions.ConnectionError, e:
        print >>sys.stderr, "ERROR: sendentries", baseurl, e.request, e.response
        sys.exit(1)

def sendentries_merge(node, baseurl, own_key, paths, entries, session=None):
    try:
        json_entries = [{"entry":base64.b64encode(entry), "treeleafhash":base64.b64encode(hash)} for hash, entry in entries]
        result = http_request(
            baseurl + "plop/v1/merge/sendentry",
            json.dumps(json_entries),
            key=own_key, verifynode=node, publickeydir=paths["publickeys"],
            session=session)
        return json.loads(result)
    except requests.exceptions.HTTPError, e:
        print >>sys.stderr, "ERROR: sendentries_merge", e.response
        sys.exit(1)
    except ValueError, e:
        print >>sys.stderr, "==== FAILED REQUEST ===="
        print >>sys.stderr, ehash
        print >>sys.stderr, "======= RESPONSE ======="
        print >>sys.stderr, result
        print >>sys.stderr, "========================"
        sys.stderr.flush()
        raise e
    except requests.exceptions.ConnectionError, e:
        print >>sys.stderr, "ERROR: sendentries_merge", baseurl, e.request, e.response
        sys.exit(1)

def publish_sth(node, baseurl, own_key, paths, submission):
    try:
        result = http_request(baseurl + "plop/v1/frontend/publish-sth",
                              json.dumps(submission), key=own_key,
                              verifynode=node, publickeydir=paths["publickeys"])
        return json.loads(result)
    except requests.exceptions.HTTPError, e:
        print >>sys.stderr, "ERROR: publish-sth", e.response
        sys.exit(1)
    except ValueError, e:
        print >>sys.stderr, "==== FAILED REQUEST ===="
        print >>sys.stderr, submission
        print >>sys.stderr, "======= RESPONSE ======="
        print >>sys.stderr, result
        print >>sys.stderr, "========================"
        sys.stderr.flush()
        raise e

def verifyroot(node, baseurl, own_key, paths, treesize):
    try:
        result = http_request(baseurl + "plop/v1/merge/verifyroot",
                              json.dumps({"tree_size":treesize}), key=own_key,
                              verifynode=node, publickeydir=paths["publickeys"])
        return json.loads(result)
    except requests.exceptions.HTTPError, e:
        print >>sys.stderr, "ERROR: verifyroot", e.response
        sys.exit(1)
    except ValueError, e:
        print >>sys.stderr, "==== FAILED REQUEST ===="
        print >>sys.stderr, treesize
        print >>sys.stderr, "======= RESPONSE ======="
        print >>sys.stderr, result
        print >>sys.stderr, "========================"
        sys.stderr.flush()
        raise e

def setverifiedsize(node, baseurl, own_key, paths, treesize):
    try:
        result = http_request(baseurl + "plop/v1/merge/setverifiedsize",
                              json.dumps({"size":treesize}), key=own_key,
                              verifynode=node, publickeydir=paths["publickeys"])
        return json.loads(result)
    except requests.exceptions.HTTPError, e:
        print >>sys.stderr, "ERROR: setverifiedsize", e.response
        sys.exit(1)
    except ValueError, e:
        print >>sys.stderr, "==== FAILED REQUEST ===="
        print >>sys.stderr, treesize
        print >>sys.stderr, "======= RESPONSE ======="
        print >>sys.stderr, result
        print >>sys.stderr, "========================"
        sys.stderr.flush()
        raise e

def get_missingentries(node, baseurl, own_key, paths):
    try:
        result = http_request(baseurl + "plop/v1/frontend/missingentries",
                              key=own_key, verifynode=node,
                              publickeydir=paths["publickeys"])
        parsed_result = json.loads(result)
        if parsed_result.get(u"result") == u"ok":
            return parsed_result[u"entries"]
        print >>sys.stderr, "ERROR: missingentries", parsed_result
        sys.exit(1)
    except requests.exceptions.HTTPError, e:
        print >>sys.stderr, "ERROR: missingentries", e.response
        sys.exit(1)

def get_missingentriesforbackup(node, baseurl, own_key, paths):
    try:
        result = http_request(baseurl + "plop/v1/merge/missingentries",
                              key=own_key, verifynode=node,
                              publickeydir=paths["publickeys"])
        parsed_result = json.loads(result)
        if parsed_result.get(u"result") == u"ok":
            return parsed_result[u"entries"]
        print >>sys.stderr, "ERROR: missingentriesforbackup", parsed_result
        sys.exit(1)
    except requests.exceptions.HTTPError, e:
        print >>sys.stderr, "ERROR: missingentriesforbackup", e.response
        sys.exit(1)

def chunks(l, n):
    return [l[i:i+n] for i in range(0, len(l), n)]

def parse_args():
    parser = argparse.ArgumentParser(description="")
    parser.add_argument('node', nargs='*', help="Node to operate on")
    parser.add_argument('--config', help="System configuration",
                        required=True)
    parser.add_argument('--localconfig', help="Local configuration",
                        required=True)
    parser.add_argument('--interval', type=int, metavar="n",
                        help="Repeate every N seconds")
    parser.add_argument("--timing", action='store_true',
                        help="Print timing information")
    args = parser.parse_args()

    config = yaml.load(open(args.config))
    localconfig = yaml.load(open(args.localconfig))

    return (args, config, localconfig)

def perm(dbtype, path):
    if dbtype == "filedb":
        return FileDB(path)
    elif dbtype == "permdb":
        return PermDB(path)
    assert False

class FileDB:
    def __init__(self, path):
        self.path = path
    def get(self, key):
        return read_chain(self.path, key)
    def add(self, key, value):
        return write_chain(key, value, self.path)
    def commit(self):
        pass

class PermDB:
    def __init__(self, path):
        self.permdbobj = permdb.alloc(path)
    def get(self, key):
        return permdb.getvalue(self.permdbobj, key)
    def add(self, key, value):
        return permdb.addvalue(self.permdbobj, key, value)
    def commit(self):
        permdb.committree(self.permdbobj)