#!/usr/bin/env python

# Copyright (c) 2014, NORDUnet A/S.
# See LICENSE for licensing information.

import argparse
import urllib2
import urllib
import json
import base64
import sys
import struct
import hashlib
import itertools
from certtools import *
import os
import signal
import select
import zipfile

parser = argparse.ArgumentParser(description='')
parser.add_argument('baseurl', help="Base URL for CT server")
parser.add_argument('--store', default=None, metavar="dir", help='Get certificates from directory dir')
parser.add_argument('--sct-file', default=None, metavar="file", help='Store SCT:s in file')
parser.add_argument('--parallel', type=int, default=16, metavar="n", help="Number of parallel submits")
parser.add_argument('--check-sct', action='store_true', help="Check SCT signature")
parser.add_argument('--pre-warm', action='store_true', help="Wait 3 seconds after first submit")
args = parser.parse_args()

from multiprocessing import Pool

baseurl = args.baseurl
certfilepath = args.store

lookup_in_log = False

if certfilepath[-1] == "/":
    certfiles = [certfilepath + filename for filename in sorted(os.listdir(certfilepath)) if os.path.isfile(certfilepath + filename)]
else:
    certfiles = [certfilepath]

sth = get_sth(baseurl)

def submitcert((certfile, cert)):
    timing = timing_point()
    certchain = get_certs_from_string(cert)
    timing_point(timing, "readcerts")

    try:
        result = add_chain(baseurl, {"chain":map(base64.b64encode, certchain)})
    except SystemExit:
        print "EXIT:", certfile
        select.select([], [], [], 1.0)
        return (None, None)

    timing_point(timing, "addchain")

    if result == None:
        print "ERROR for certfile", certfile
        return (None, timing["deltatimes"])

    try:
        if args.check_sct:
            check_sct_signature(baseurl, certchain[0], result)
            timing_point(timing, "checksig")
    except AssertionError, e:
        print "ERROR:", certfile, e
        return (None, None)
    except urllib2.HTTPError, e:
        print "ERROR:", certfile, e
        return (None, None)
    except ecdsa.keys.BadSignatureError, e:
        print "ERROR: bad signature", certfile
        return (None, None)

    if lookup_in_log:

        merkle_tree_leaf = pack_mtl(result["timestamp"], certchain[0])

        leaf_hash = get_leaf_hash(merkle_tree_leaf)

        proof = get_proof_by_hash(baseurl, leaf_hash, sth["tree_size"])

        leaf_index = proof["leaf_index"]

        entries = get_entries(baseurl, leaf_index, leaf_index)

        fetched_entry = entries["entries"][0]

        print "does the leaf_input of the fetched entry match what we calculated:", \
          base64.decodestring(fetched_entry["leaf_input"]) == merkle_tree_leaf

        extra_data = fetched_entry["extra_data"]

        certchain = decode_certificate_chain(base64.decodestring(extra_data))

        submittedcertchain = certchain[1:]

        for (submittedcert, fetchedcert, i) in zip(submittedcertchain,
                                                   certchain, itertools.count(1)):
            print "cert", i, "in chain is the same:", submittedcert == fetchedcert

        if len(certchain) == len(submittedcertchain) + 1:
            last_issuer = get_cert_info(certchain[-1])["issuer"]
            root_subject = get_cert_info(certchain[-1])["subject"]
            print "issuer of last cert in submitted chain and " \
                "subject of last cert in fetched chain is the same:", \
                last_issuer == root_subject
        elif len(certchain) == len(submittedcertchain):
            print "cert chains are the same length"
        else:
            print "ERROR: fetched cert chain has length", len(certchain),
            print "and submitted chain has length", len(submittedcertchain)

    timing_point(timing, "lookup")
    return ((certchain[0], result), timing["deltatimes"])

def get_ncerts(certfiles):
    n = 0
    for certfile in certfiles:
        if certfile.endswith(".zip"):
            zf = zipfile.ZipFile(certfile)
            n += len(zf.namelist())
            zf.close()
        else:
            n += 1
    return n

def get_all_certificates(certfiles):
    for certfile in certfiles:
        if certfile.endswith(".zip"):
            zf = zipfile.ZipFile(certfile)
            for name in zf.namelist():
                yield (name, zf.read(name))
            zf.close()
        else:
            yield (certfile, open(certfile).read())

def save_sct(sct, sth):
    sctlog = open(args.sct_file, "a")
    json.dump({"leafcert": base64.b64encode(leafcert), "sct": sct, "sth": sth}, sctlog)
    sctlog.write("\n")
    sctlog.close()

p = Pool(args.parallel, lambda: signal.signal(signal.SIGINT, signal.SIG_IGN))

nsubmitted = 0
lastprinted = 0

print "listing certs"
ncerts = get_ncerts(certfiles)

print ncerts, "certs"

certs = get_all_certificates(certfiles)

(result, timing) = submitcert(certs.next())
if result != None:
    nsubmitted += 1
    (leafcert, sct) = result
    save_sct(sct, sth)

if args.pre_warm:
    select.select([], [], [], 3.0)

starttime = datetime.datetime.now()

try:
    for result, timing in p.imap_unordered(submitcert, certs):
        if timing == None:
            print "error"
            print "submitted", nsubmitted
            p.terminate()
            p.join()
            sys.exit(1)
        if result != None:
            nsubmitted += 1
            (leafcert, sct) = result
            save_sct(sct, sth)
        deltatime = datetime.datetime.now() - starttime
        deltatime_f = deltatime.seconds + deltatime.microseconds / 1000000.0
        rate = nsubmitted / deltatime_f
        if nsubmitted > lastprinted + ncerts / 10:
            print nsubmitted, "rate %.1f" % rate
            lastprinted = nsubmitted
        #print timing, "rate %.1f" % rate
    print "submitted", nsubmitted
except KeyboardInterrupt:
    p.terminate()
    p.join()