#!/usr/bin/env python
# -*- coding: utf-8 -*-

# Copyright (c) 2014-2016, NORDUnet A/S.
# See LICENSE for licensing information.

import urllib2
import urllib
import json
import base64
import sys
import struct
import hashlib
import itertools
import os.path
from time import sleep
from certtools import *

baseurls = [sys.argv[1]]
logpublickeyfile = sys.argv[2]
cacertfile = sys.argv[3]
toolsdir = os.path.dirname(sys.argv[0])
testdir = sys.argv[4]

certfiles = [toolsdir + ("/testcerts/cert%d.txt" % e) for e in range(1, 6)]

cc1 = get_certs_from_file(certfiles[0])
cc2 = get_certs_from_file(certfiles[1])
cc3 = get_certs_from_file(certfiles[2])
cc4 = get_certs_from_file(certfiles[3])
cc5 = get_certs_from_file(certfiles[4])

create_ssl_context(cafile=cacertfile)

failures = 0
indentation = ""

logpublickey = get_public_key_from_file(logpublickeyfile)

def testgroup(name):
    global indentation
    print "testgroup " + name + ":"
    indentation = "    "

def print_error(message, *args):
    global failures, indentation
    print indentation + "ERROR:", message % args
    failures += 1

def print_success(message, *args):
    print indentation + message % args

def assert_equal(actual, expected, name, quiet=False, nodata=False, fatal=False):
    global failures
    if actual != expected:
        if nodata:
            print_error("%s differs", name)
        else:
            print_error("%s expected %s got %s", name, expected, actual)
        if fatal:
            sys.exit(1)
    elif not quiet:
        print_success("%s was correct", name)

def print_and_check_tree_size(expected, baseurl):
    global failures
    sth = get_sth(baseurl)
    try:
        check_sth_signature(baseurl, sth, publickey=logpublickey)
    except AssertionError, e:
        print_error("%s", e)
    except ecdsa.keys.BadSignatureError, e:
        print_error("bad STH signature")
    tree_size = sth["tree_size"]
    assert_equal(tree_size, expected, "tree size", quiet=True)

def do_add_chain(chain, baseurl):
    global failures
    try:
        result = add_chain(baseurl, {"chain":map(base64.b64encode, chain)})
    except ValueError, e:
        print_error("%s", e)
    try:
        signed_entry = pack_cert(chain[0])
        check_sct_signature(baseurl, signed_entry, result, publickey=logpublickey)
        print_success("signature check succeeded")
    except AssertionError, e:
        print_error("%s", e)
    except ecdsa.keys.BadSignatureError, e:
        print e
        print_error("bad SCT signature")
    return result

def get_and_validate_proof(timestamp, chain, leaf_index, nentries, baseurl):
    cert = chain[0]
    merkle_tree_leaf = pack_mtl(timestamp, cert)
    leaf_hash = get_leaf_hash(merkle_tree_leaf)
    sth = get_sth(baseurl)
    proof = get_proof_by_hash(baseurl, leaf_hash, sth["tree_size"])
    leaf_index = proof["leaf_index"]
    inclusion_proof = [base64.b64decode(e) for e in proof["audit_path"]]
    assert_equal(leaf_index, leaf_index, "leaf_index", quiet=True)
    assert_equal(len(inclusion_proof), nentries, "audit_path length", quiet=True)

    calc_root_hash = verify_inclusion_proof(inclusion_proof, leaf_index, sth["tree_size"], leaf_hash)
    root_hash = base64.b64decode(sth["sha256_root_hash"])

    assert_equal(root_hash, calc_root_hash, "verified root hash", nodata=True, quiet=True)
    get_and_check_entry(timestamp, chain, leaf_index, baseurl)

def get_and_validate_consistency_proof(sth1, sth2, size1, size2, baseurl):
    consistency_proof = [base64.decodestring(entry) for entry in get_consistency_proof(baseurl, size1, size2)]
    (old_treehead, new_treehead) = verify_consistency_proof(consistency_proof, size1, size2, sth1)
    #print repr(sth1), repr(old_treehead)
    #print repr(sth2), repr(new_treehead)
    assert_equal(old_treehead, sth1, "sth1", nodata=True, quiet=True)
    assert_equal(new_treehead, sth2, "sth2", nodata=True, quiet=True)


def get_and_check_entry(timestamp, chain, leaf_index, baseurl):
    entries = get_entries(baseurl, leaf_index, leaf_index)
    assert_equal(len(entries), 1, "get_entries", quiet=True)
    fetched_entry = entries["entries"][0]
    merkle_tree_leaf = pack_mtl(timestamp, chain[0])
    leaf_input =  base64.decodestring(fetched_entry["leaf_input"])
    assert_equal(leaf_input, merkle_tree_leaf, "entry", nodata=True, quiet=True)
    extra_data = base64.decodestring(fetched_entry["extra_data"])
    certchain = decode_certificate_chain(extra_data)

    submittedcertchain = chain[1:]

    for (submittedcert, fetchedcert, i) in zip(submittedcertchain,
                                               certchain, itertools.count(1)):
        assert_equal(fetchedcert, submittedcert, "cert %d in chain" % (i,), quiet=True)

    if len(certchain) == len(submittedcertchain) + 1:
        last_issuer = get_cert_info(submittedcertchain[-1])["issuer"]
        root_subject = get_cert_info(certchain[-1])["subject"]
        if last_issuer == root_subject:
            print_success("fetched chain has an appended root cert")
        else:
            print_error("fetched chain has an extra entry")
    elif len(certchain) == len(submittedcertchain):
        print_success("cert chains are the same length")
    else:
        print_error("cert chain length %d expected %d or %d",
                    len(certchain),
                    len(submittedcertchain),
                    len(submittedcertchain))

def correct_tree_size(expected):
    for baseurl in baseurls:
        sth = get_sth(baseurl)
        tree_size = sth["tree_size"]
        if tree_size != expected:
            return False
    return True

def merge(expected=None, wait=0):
    for i in range(10):
        rv = subprocess.call([toolsdir + "/merge", "--config", testdir + "/catlfish-test.cfg",
                                "--localconfig", testdir + "/catlfish-test-local-merge.cfg"])
        if rv:
            return rv
        if i < wait:
            sleep(1)
            continue
        if correct_tree_size(expected):
            return 0
    return 0

mergeresult = merge(expected=0, wait=3)
assert_equal(mergeresult, 0, "merge", quiet=True, fatal=True)

for baseurl in baseurls:
    print_and_check_tree_size(0, baseurl)

testgroup("cert1")

result1 = do_add_chain(cc1, baseurls[0])

mergeresult = merge(expected=1)
assert_equal(mergeresult, 0, "merge", quiet=True, fatal=True)

size_sth = {}

for baseurl in baseurls:
    print_and_check_tree_size(1, baseurl)
size_sth[1] = base64.b64decode(get_sth(baseurls[0])["sha256_root_hash"])

result2 = do_add_chain(cc1, baseurls[0])

assert_equal(result2["timestamp"], result1["timestamp"], "timestamp")

mergeresult = merge(expected=1, wait=3)
assert_equal(mergeresult, 0, "merge", quiet=True, fatal=True)

for baseurl in baseurls:
    print_and_check_tree_size(1, baseurl)
size1_v2_sth = base64.b64decode(get_sth(baseurls[0])["sha256_root_hash"])

assert_equal(size_sth[1], size1_v2_sth, "sth", nodata=True)

# TODO: add invalid cert and check that it generates an error
#       and that treesize still is 1

get_and_validate_proof(result1["timestamp"], cc1, 0, 0, baseurls[0])

testgroup("cert2")

result3 = do_add_chain(cc2, baseurls[0])

mergeresult = merge(expected=2)
assert_equal(mergeresult, 0, "merge", quiet=True, fatal=True)

for baseurl in baseurls:
    print_and_check_tree_size(2, baseurl)
size_sth[2] = base64.b64decode(get_sth(baseurls[0])["sha256_root_hash"])

get_and_validate_proof(result1["timestamp"], cc1, 0, 1, baseurls[0])
get_and_validate_proof(result3["timestamp"], cc2, 1, 1, baseurls[0])

testgroup("cert3")

result4 = do_add_chain(cc3, baseurls[0])

mergeresult = merge(expected=3)
assert_equal(mergeresult, 0, "merge", quiet=True, fatal=True)

for baseurl in baseurls:
    print_and_check_tree_size(3, baseurl)
size_sth[3] = base64.b64decode(get_sth(baseurls[0])["sha256_root_hash"])

get_and_validate_proof(result1["timestamp"], cc1, 0, 2, baseurls[0])
get_and_validate_proof(result3["timestamp"], cc2, 1, 2, baseurls[0])
get_and_validate_proof(result4["timestamp"], cc3, 2, 1, baseurls[0])

testgroup("cert4")

result5 = do_add_chain(cc4, baseurls[0])

mergeresult = merge(expected=4)
assert_equal(mergeresult, 0, "merge", quiet=True, fatal=True)

for baseurl in baseurls:
    print_and_check_tree_size(4, baseurl)
size_sth[4] = base64.b64decode(get_sth(baseurls[0])["sha256_root_hash"])

get_and_validate_proof(result1["timestamp"], cc1, 0, 2, baseurls[0])
get_and_validate_proof(result3["timestamp"], cc2, 1, 2, baseurls[0])
get_and_validate_proof(result4["timestamp"], cc3, 2, 2, baseurls[0])
get_and_validate_proof(result5["timestamp"], cc4, 3, 2, baseurls[0])

testgroup("cert5")

result6 = do_add_chain(cc5, baseurls[0])

mergeresult = merge(expected=5)
assert_equal(mergeresult, 0, "merge", quiet=True, fatal=True)

for baseurl in baseurls:
    print_and_check_tree_size(5, baseurl)
size_sth[5] = base64.b64decode(get_sth(baseurls[0])["sha256_root_hash"])

get_and_validate_proof(result1["timestamp"], cc1, 0, 3, baseurls[0])
get_and_validate_proof(result3["timestamp"], cc2, 1, 3, baseurls[0])
get_and_validate_proof(result4["timestamp"], cc3, 2, 3, baseurls[0])
get_and_validate_proof(result5["timestamp"], cc4, 3, 3, baseurls[0])
get_and_validate_proof(result6["timestamp"], cc5, 4, 1, baseurls[0])

mergeresult = merge(expected=5, wait=3)
assert_equal(mergeresult, 0, "merge", quiet=True, fatal=True)

for first_size in range(1, 5):
    for second_size in range(first_size + 1, 6):
        get_and_validate_consistency_proof(size_sth[first_size], size_sth[second_size], first_size, second_size, baseurls[0])

print "-------"
if failures:
    print failures, "failed tests" if failures != 1 else "failed test"
    sys.exit(1)
else:
    print "all tests succeeded"