#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (c) 2014-2015, NORDUnet A/S.
# See LICENSE for licensing information.
#
# Generate a new 'sth' file.
# See catlfish/doc/merge.txt for more about the merge process.
#
import sys
import json
import time
import requests
import logging
from base64 import b64encode
from datetime import datetime, timedelta
from mergetools import parse_args, get_nfetched, hexencode, hexdecode, \
     get_logorder, get_sth, flock_ex_or_fail
from certtools import create_ssl_context, get_public_key_from_file, \
     timing_point, create_sth_signature, write_file, check_sth_signature, \
     build_merkle_tree

def merge_sth(args, config, localconfig):
    paths = localconfig["paths"]
    own_key = (localconfig["nodename"],
               "%s/%s-private.pem" % (paths["privatekeys"],
                                      localconfig["nodename"]))
    ctbaseurl = config["baseurl"]
    signingnodes = config["signingnodes"]
    mergenodes = config.get("mergenodes", [])
    mergedb = paths["mergedb"]
    sthfile = mergedb + "/sth"
    logorderfile = mergedb + "/logorder"
    currentsizefile = mergedb + "/fetched"
    logpublickey = get_public_key_from_file(paths["logpublickey"])
    backupquorum = config.get("backup-quorum-size", 0)
    assert backupquorum <= len(mergenodes) - 1
    create_ssl_context(cafile=paths["https_cacertfile"])
    timing = timing_point()

    trees = [{'tree_size': get_nfetched(currentsizefile, logorderfile),
              'sha256_root_hash': ''}]
    logging.debug("starting point, trees: %s", trees)
    for mergenode in mergenodes:
        if mergenode["name"] == config["primarymergenode"]:
            continue
        verifiedfile = mergedb + "/verified." + mergenode["name"]
        try:
            tree = json.loads(open(verifiedfile, "r").read())
        except (IOError, ValueError):
            tree = {'tree_size': 0, "sha256_root_hash": ''}
        trees.append(tree)
    trees.sort(key=lambda e: e['tree_size'], reverse=True)
    logging.debug("trees: %s", trees)

    if backupquorum > len(trees) - 1:
        logging.error("backup quorum > number of secondaries: %d > %d",
                      backupquorum, len(trees) - 1)
        return -1
    tree_size = trees[backupquorum]['tree_size']
    root_hash = hexdecode(trees[backupquorum]['sha256_root_hash'])
    logging.debug("tree size candidate at backupquorum %d: %d", backupquorum,
                  tree_size)

    cur_sth = get_sth(sthfile)
    if tree_size < cur_sth['tree_size']:
        logging.info("candidate tree < current tree: %d < %d",
                     tree_size, cur_sth['tree_size'])
        return 0

    assert tree_size >= 0         # Don't read logorder without limit.
    logorder = get_logorder(logorderfile, tree_size)
    timing_point(timing, "get logorder")
    if tree_size == -1:
        tree_size = len(logorder)
    logging.info("new tree size will be %d", tree_size)

    root_hash_calc = build_merkle_tree(logorder)[-1][0]
    assert root_hash == '' or root_hash == root_hash_calc
    root_hash = root_hash_calc
    timestamp = int(time.time() * 1000)

    tree_head_signature = None
    for signingnode in signingnodes:
        try:
            tree_head_signature = \
              create_sth_signature(tree_size, timestamp,
                                   root_hash,
                                   "https://%s/" % signingnode["address"],
                                   key=own_key)
            break
        except requests.exceptions.HTTPError, e:
            logging.warning("create_sth_signature error: %s", e.response)
    if tree_head_signature == None:
        logging.error("Could not contact any signing nodes")
        return 0

    sth = {"tree_size": tree_size, "timestamp": timestamp,
           "sha256_root_hash": b64encode(root_hash),
           "tree_head_signature": b64encode(tree_head_signature)}

    check_sth_signature(ctbaseurl, sth, publickey=logpublickey)
    timing_point(timing, "build sth")

    logging.info("new root: %s %d %d", hexencode(root_hash), timestamp, tree_size)

    write_file(sthfile, sth)

    if args.timing:
        logging.debug("timing: merge_sth: %s", timing["deltatimes"])

    return 0

def main():
    """
    Read 'sth' to get the current tree size, assuming zero if file not
    found.

    Read tree sizes from the backup.<secondary> files, put them in a
    list and sort the list. Let new tree size be list[backup-quorum]. If
    the new tree size is smaller than the currently published tree size,
    stop here.

    Decide on a timestamp, build an STH and write it to 'sth'.

    Sleep some and start over, or exit if there's no `--mergeinterval'.
    """
    args, config, localconfig = parse_args()
    paths = localconfig["paths"]
    mergedb = paths["mergedb"]
    lockfile = mergedb + "/.merge_sth.lock"

    loglevel = getattr(logging, args.loglevel.upper())
    if args.mergeinterval is None:
        logging.basicConfig(level=loglevel)
    else:
        logging.basicConfig(filename=args.logdir + "/merge_sth.log",
                            level=loglevel)

    if not flock_ex_or_fail(lockfile):
        logging.critical("unable to take lock %s", lockfile)
        return 1

    while True:
        merge_start_time = datetime.now()
        ret = merge_sth(args, config, localconfig)
        if ret < 0:
            return 1
        if args.mergeinterval is None:
            break
        sleep = (merge_start_time + timedelta(seconds=args.mergeinterval) -
                 datetime.now()).seconds
        if sleep > 0:
            logging.debug("sleeping %d seconds", sleep)
            time.sleep(sleep)

    return 0

if __name__ == '__main__':
    sys.exit(main())