#!/usr/bin/env python

# Copyright (c) 2014, NORDUnet A/S.
# See LICENSE for licensing information.

import urllib2
import urllib
import json
import base64
import sys
import struct
import hashlib
import itertools
from certtools import *
import os

from multiprocessing import Pool

baseurl = sys.argv[1]
certfilepath = sys.argv[2]

lookup_in_log = False
check_sig = False

if certfilepath[-1] == "/":
    certfiles = [certfilepath + filename for filename in sorted(os.listdir(certfilepath))]
else:
    certfiles = [certfilepath]

def submitcert(certfile):
    timing = timing_point()
    certs = get_certs_from_file(certfile)
    timing_point(timing, "readcerts")

    result = add_chain(baseurl, {"chain":map(base64.b64encode, certs)})

    timing_point(timing, "addchain")

    try:
        if check_sig:
            check_sct_signature(baseurl, certs[0], result)
            timing_point(timing, "checksig")
    except AssertionError, e:
        print "ERROR:", e
        sys.exit(1)
    except ecdsa.keys.BadSignatureError, e:
        print "ERROR: bad signature"
        sys.exit(1)

    if lookup_in_log:

        merkle_tree_leaf = pack_mtl(result["timestamp"], certs[0])

        leaf_hash = get_leaf_hash(merkle_tree_leaf)

        sth = get_sth(baseurl)

        proof = get_proof_by_hash(baseurl, leaf_hash, sth["tree_size"])

        leaf_index = proof["leaf_index"]

        entries = get_entries(baseurl, leaf_index, leaf_index)

        fetched_entry = entries["entries"][0]

        print "does the leaf_input of the fetched entry match what we calculated:", \
          base64.decodestring(fetched_entry["leaf_input"]) == merkle_tree_leaf

        extra_data = fetched_entry["extra_data"]

        certchain = decode_certificate_chain(base64.decodestring(extra_data))

        submittedcertchain = certs[1:]

        for (submittedcert, fetchedcert, i) in zip(submittedcertchain,
                                                   certchain, itertools.count(1)):
            print "cert", i, "in chain is the same:", submittedcert == fetchedcert

        if len(certchain) == len(submittedcertchain) + 1:
            last_issuer = get_cert_info(certs[-1])["issuer"]
            root_subject = get_cert_info(certchain[-1])["subject"]
            print "issuer of last cert in submitted chain and " \
                "subject of last cert in fetched chain is the same:", \
                last_issuer == root_subject
        elif len(certchain) == len(submittedcertchain):
            print "cert chains are the same length"
        else:
            print "ERROR: fetched cert chain has length", len(certchain),
            print "and submitted chain has length", len(submittedcertchain)

    timing_point(timing, "lookup")
    return timing["deltatimes"]

p = Pool(1)

for timing in p.imap_unordered(submitcert, certfiles):
    print timing