#!/usr/bin/env python
import urllib2
import urllib
import json
import base64
import sys
import struct
import hashlib
import itertools
from certtools import *

baseurl = "https://127.0.0.1:8080/"
certfiles = ["testcerts/cert1.txt", "testcerts/cert2.txt",
             "testcerts/cert3.txt", "testcerts/cert4.txt",
             "testcerts/cert5.txt"]

cc1 = get_certs_from_file(certfiles[0])
cc2 = get_certs_from_file(certfiles[1])
cc3 = get_certs_from_file(certfiles[2])
cc4 = get_certs_from_file(certfiles[3])
cc5 = get_certs_from_file(certfiles[4])

failures = 0

def assert_equal(actual, expected, name):
    global failures
    if actual != expected:
        print "ERROR:", name, "expected", expected, "got", actual
        failures += 1
    else:
        print name, "was correct"

def print_and_check_tree_size(expected):
    global failures
    sth = get_sth(baseurl)
    try:
        check_sth_signature(baseurl, sth)
    except AssertionError, e:
        print "ERROR:", e
        failures += 1
    except ecdsa.keys.BadSignatureError, e:
        print "ERROR: bad STH signature"
        failures += 1
    tree_size = sth["tree_size"]
    if tree_size == expected:
        print "tree size", tree_size
    else:
        print "ERROR: tree size", tree_size, "expected", expected
        failures += 1

def do_add_chain(chain):
    global failures
    try:
        result = add_chain(baseurl, {"chain":map(base64.b64encode, chain)})
    except ValueError, e:
        print "ERROR:", e
        failures += 1
    try:
        check_sct_signature(baseurl, chain[0], result)
    except AssertionError, e:
        print "ERROR:", e
        failures += 1
    except ecdsa.keys.BadSignatureError, e:
        print "ERROR: bad SCT signature"
        failures += 1
    print "signature check succeeded"
    return result

def get_and_validate_proof(timestamp, cert, leaf_index, nentries):
    merkle_tree_leaf = pack_mtl(timestamp, cert)
    leaf_hash = get_leaf_hash(merkle_tree_leaf)
    sth = get_sth(baseurl)
    proof = get_proof_by_hash(baseurl, leaf_hash, sth["tree_size"])
    assert_equal(proof["leaf_index"], leaf_index, "leaf_index")
    assert_equal(len(proof["audit_path"]), nentries, "audit_path length")

print_and_check_tree_size(0)

result1 = do_add_chain(cc1)

print_and_check_tree_size(1)

result2 = do_add_chain(cc1)

assert_equal(result2["timestamp"], result1["timestamp"], "timestamp")

print_and_check_tree_size(1)

# TODO: add invalid cert and check that it generates an error
#       and that treesize still is 1

get_and_validate_proof(result1["timestamp"], cc1[0], 0, 0)

result3 = do_add_chain(cc2)

print_and_check_tree_size(2)

get_and_validate_proof(result1["timestamp"], cc1[0], 0, 1)
get_and_validate_proof(result3["timestamp"], cc2[0], 1, 1)

result4 = do_add_chain(cc3)

print_and_check_tree_size(3)

get_and_validate_proof(result1["timestamp"], cc1[0], 0, 2)
get_and_validate_proof(result3["timestamp"], cc2[0], 1, 2)
get_and_validate_proof(result4["timestamp"], cc3[0], 2, 1)

result5 = do_add_chain(cc4)

print_and_check_tree_size(4)

get_and_validate_proof(result1["timestamp"], cc1[0], 0, 2)
get_and_validate_proof(result3["timestamp"], cc2[0], 1, 2)
get_and_validate_proof(result4["timestamp"], cc3[0], 2, 2)
get_and_validate_proof(result5["timestamp"], cc4[0], 3, 2)

result6 = do_add_chain(cc5)

print_and_check_tree_size(5)

get_and_validate_proof(result1["timestamp"], cc1[0], 0, 3)
get_and_validate_proof(result3["timestamp"], cc2[0], 1, 3)
get_and_validate_proof(result4["timestamp"], cc3[0], 2, 3)
get_and_validate_proof(result5["timestamp"], cc4[0], 3, 3)
get_and_validate_proof(result6["timestamp"], cc5[0], 4, 1)

print "-------"
if failures:
    print failures, "failed tests" if failures != 1 else "failed test"
    sys.exit(1)
else:
    print "all tests succeeded"