#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (c) 2014-2015, NORDUnet A/S.
# See LICENSE for licensing information.

import sys
import json
import urllib2
import time
from base64 import b64encode, b64decode
from mergetools import parse_args, get_nfetched, hexencode, hexdecode, \
     get_logorder, get_sth
from certtools import create_ssl_context, get_public_key_from_file, \
     timing_point, create_sth_signature, write_file, check_sth_signature, \
     build_merkle_tree

def merge_sth(args, config, localconfig):
    paths = localconfig["paths"]
    own_key = (localconfig["nodename"],
               "%s/%s-private.pem" % (paths["privatekeys"],
                                      localconfig["nodename"]))
    ctbaseurl = config["baseurl"]
    signingnodes = config["signingnodes"]
    mergenodes = config.get("mergenodes", [])
    mergedb = paths["mergedb"]
    sthfile = mergedb + "/sth"
    logorderfile = mergedb + "/logorder"
    logpublickey = get_public_key_from_file(paths["logpublickey"])
    backupquorum = localconfig.get("backupquorum", 0)
    assert backupquorum <= len(mergenodes) - 1
    create_ssl_context(cafile=paths["https_cacertfile"])
    timing = timing_point()

    trees = [{'tree_size': 0, 'sha256_root_hash': ''}]
    for mergenode in mergenodes:
        if mergenode["name"] == config["primarymergenode"]:
            continue
        verifiedfile = mergedb + "/verified." + mergenode["name"]
        try:
            tree = json.loads(open(verifiedfile, "r").read())
        except (IOError, ValueError):
            tree = {'tree_size': 0, "sha256_root_hash": ''}
        trees.append(tree)
    trees.sort(key=lambda e: e['tree_size'], reverse=True)
    print >>sys.stderr, "DEBUG: trees:", trees
    tree_size = trees[backupquorum]['tree_size']
    root_hash = hexdecode(trees[backupquorum]['sha256_root_hash'])
    print >>sys.stderr, "DEBUG: tree size candidate at backupquorum", \
      backupquorum, ":", tree_size

    cur_sth = get_sth(sthfile)
    if tree_size < cur_sth['tree_size']:
        print >>sys.stderr, "candidate tree < current tree:", \
          tree_size, "<", cur_sth['tree_size']
        return

    assert tree_size >= 0         # Don't read logorder without limit.
    logorder = get_logorder(logorderfile, tree_size)
    timing_point(timing, "get logorder")
    if tree_size == -1:
        tree_size = len(logorder)
    print >>sys.stderr, "new tree size will be", tree_size

    root_hash_calc = build_merkle_tree(logorder)[-1][0]
    assert root_hash == '' or root_hash == root_hash_calc
    root_hash = root_hash_calc
    timestamp = int(time.time() * 1000)

    tree_head_signature = None
    for signingnode in signingnodes:
        try:
            tree_head_signature = \
              create_sth_signature(tree_size, timestamp,
                                   root_hash,
                                   "https://%s/" % signingnode["address"],
                                   key=own_key)
            break
        except urllib2.URLError, err:
            print >>sys.stderr, err
            sys.stderr.flush()
    if tree_head_signature == None:
        print >>sys.stderr, "Could not contact any signing nodes"
        sys.exit(1)

    sth = {"tree_size": tree_size, "timestamp": timestamp,
           "sha256_root_hash": b64encode(root_hash),
           "tree_head_signature": b64encode(tree_head_signature)}

    check_sth_signature(ctbaseurl, sth, publickey=logpublickey)
    timing_point(timing, "build sth")

    print hexencode(root_hash), timestamp, tree_size
    sys.stdout.flush()

    write_file(sthfile, sth)

    if args.timing:
        print >>sys.stderr, timing["deltatimes"]
        sys.stderr.flush()

def main():
    """
    Read file 'sth' to get current tree size, assuming zero if file not
    found.

    Read tree sizes from the backup.<secondary> files, put them in a
    list and sort it. Let new tree size equal list[backup-quorum]. Barf
    on a new tree size smaller than the currently published tree size.

    Decide on a timestamp, build an STH and write it to file 'sth'.
    """
    args, config, localconfig = parse_args()

    while True:
        merge_sth(args, config, localconfig)
        if args.interval is None:
            break
        print >>sys.stderr, "sleeping", args.interval, "seconds"
        time.sleep(args.interval)

if __name__ == '__main__':
    sys.exit(main())