#!/usr/bin/env python

# Copyright (c) 2014, NORDUnet A/S.
# See LICENSE for licensing information.

import argparse
import urllib2
import urllib
import json
import base64
import sys
import struct
import hashlib
import itertools
from certtools import *
from precerttools import *
import os
import signal
import select
import zipfile

parser = argparse.ArgumentParser(description='')
parser.add_argument('baseurl', help="Base URL for CT server")
parser.add_argument('--store', default=None, metavar="dir", help='Get certificates from directory dir')
parser.add_argument('--sct-file', default=None, metavar="file", help='Store SCT:s in file')
parser.add_argument('--parallel', type=int, default=16, metavar="n", help="Number of parallel submits")
parser.add_argument('--check-sct', action='store_true', help="Check SCT signature")
parser.add_argument('--pre-warm', action='store_true', help="Wait 3 seconds after first submit")
parser.add_argument('--publickey', default=None, metavar="file", help='Public key for the CT log')
parser.add_argument('--cafile', default=None, metavar="file", help='File containing the CA cert')
args = parser.parse_args()


from multiprocessing import Pool

baseurl = args.baseurl
certfilepath = args.store

logpublickey = get_public_key_from_file(args.publickey) if args.publickey else None

lookup_in_log = False

if certfilepath is None:
    certfiles = None
elif certfilepath[-1] == "/":
    certfiles = [certfilepath + filename for filename in sorted(os.listdir(certfilepath)) if os.path.isfile(certfilepath + filename)]
    certfiles = [certfilepath]

sth = get_sth(baseurl)

def submitcert((certfile, blob)):
    timing = timing_point()
    timing_point(timing, "readcerts")

        signed_entry = pack_cert(blob)
        issuer_key_hash = None
        result = add_chain(baseurl, {"blob":base64.b64encode(blob)})
    except SystemExit:
        print "EXIT:", certfile
        select.select([], [], [], 1.0)
        return (None, None)

    timing_point(timing, "addchain")

    if result == None:
        print "ERROR for certfile", certfile
        return (None, timing["deltatimes"])

        if args.check_sct:
            check_sct_signature(baseurl, signed_entry, result, publickey=logpublickey)
            timing_point(timing, "checksig")
    except AssertionError, e:
        print "ERROR:", certfile, e
        return (None, None)
    except urllib2.HTTPError, e:
        print "ERROR:", certfile, e
        return (None, None)
    except ecdsa.keys.BadSignatureError, e:
        print "ERROR: bad signature", certfile
        return (None, None)

    if lookup_in_log:

        merkle_tree_leaf = pack_mtl(result["timestamp"], blob)

        leaf_hash = get_leaf_hash(merkle_tree_leaf)

        proof = get_proof_by_hash(baseurl, leaf_hash, sth["tree_size"])

        leaf_index = proof["leaf_index"]

        entries = get_entries(baseurl, leaf_index, leaf_index)

        fetched_entry = entries["entries"][0]

        print "does the leaf_input of the fetched entry match what we calculated:", \
          base64.decodestring(fetched_entry["leaf_input"]) == merkle_tree_leaf

        extra_data = fetched_entry["extra_data"]

        certchain = decode_certificate_chain(base64.decodestring(extra_data))

        submittedcertchain = certchain[1:]

        for (submittedcert, fetchedcert, i) in zip(submittedcertchain,
                                                   certchain, itertools.count(1)):
            print "cert", i, "in chain is the same:", submittedcert == fetchedcert

        if len(certchain) == len(submittedcertchain) + 1:
            last_issuer = get_cert_info(certchain[-1])["issuer"]
            root_subject = get_cert_info(certchain[-1])["subject"]
            print "issuer of last cert in submitted chain and " \
                "subject of last cert in fetched chain is the same:", \
                last_issuer == root_subject
        elif len(certchain) == len(submittedcertchain):
            print "cert chains are the same length"
            print "ERROR: fetched cert chain has length", len(certchain),
            print "and submitted chain has length", len(submittedcertchain)

    timing_point(timing, "lookup")
    return ((blob, issuer_key_hash, result), timing["deltatimes"])

def get_ncerts(certfiles):
    n = 0
    for certfile in certfiles:
        if certfile.endswith(".zip"):
            zf = zipfile.ZipFile(certfile)
            n += len(zf.namelist())
            n += 1
    return n

def get_all_certificates(certfiles):
    for certfile in certfiles:
        if certfile.endswith(".zip"):
            zf = zipfile.ZipFile(certfile)
            for name in zf.namelist():
                yield (name, zf.read(name))
            yield (certfile, open(certfile).read())

def save_sct(sct, sth, leafcert, issuer_key_hash):
    sctlog = open(args.sct_file, "a")
    sctentry = {"leafcert": base64.b64encode(leafcert), "sct": sct, "sth": sth}
    if issuer_key_hash:
        sctentry["issuer_key_hash"] = base64.b64encode(issuer_key_hash)
    json.dump(sctentry, sctlog)

p = Pool(args.parallel, lambda: signal.signal(signal.SIGINT, signal.SIG_IGN))

nsubmitted = 0
lastprinted = 0

def get_cert_from_stdin():
    yield (('<stdin>', base64.b64encode(sys.stdin.read())))

print "listing certs"
if certfiles is not None:
    ncerts = get_ncerts(certfiles)
    print ncerts, "certs"
    certs = get_all_certificates(certfiles)
    ncerts = 1
    certs = get_cert_from_stdin()

(result, timing) = submitcert(certs.next())
if result != None:
    nsubmitted += 1
    (leafcert, issuer_key_hash, sct) = result
    save_sct(sct, sth, leafcert, issuer_key_hash)

if args.pre_warm:
    select.select([], [], [], 3.0)

starttime = datetime.datetime.now()

    for result, timing in p.imap_unordered(submitcert, certs):
        if timing == None:
            print "error"
            print "submitted", nsubmitted
        if result != None:
            nsubmitted += 1
            (leafcert, issuer_key_hash, sct) = result
            save_sct(sct, sth, leafcert, issuer_key_hash)
        deltatime = datetime.datetime.now() - starttime
        deltatime_f = deltatime.seconds + deltatime.microseconds / 1000000.0
        rate = nsubmitted / deltatime_f
        if nsubmitted > lastprinted + ncerts / 10:
            print nsubmitted, "rate %.1f" % rate
            lastprinted = nsubmitted
        #print timing, "rate %.1f" % rate
    print "submitted", nsubmitted
except KeyboardInterrupt: