#!/usr/bin/env python # Copyright (c) 2014, NORDUnet A/S. # See LICENSE for licensing information. import urllib2 import urllib import json import base64 import sys import struct import hashlib import itertools from certtools import * import os import signal import select import zipfile from multiprocessing import Pool baseurl = sys.argv[1] certfilepath = sys.argv[2] lookup_in_log = False check_sig = False if certfilepath[-1] == "/": certfiles = [certfilepath + filename for filename in sorted(os.listdir(certfilepath))] else: certfiles = [certfilepath] def submitcert((certfile, cert)): timing = timing_point() certchain = get_certs_from_string(cert) timing_point(timing, "readcerts") try: result = add_chain(baseurl, {"chain":map(base64.b64encode, certchain)}) except SystemExit: print "EXIT:", certfile select.select([], [], [], 1.0) return None timing_point(timing, "addchain") if result == None: print "ERROR for certfile", certfile return timing["deltatimes"] try: if check_sig: check_sct_signature(baseurl, certchain[0], result) timing_point(timing, "checksig") except AssertionError, e: print "ERROR:", certfile, e return None except urllib2.HTTPError, e: print "ERROR:", certfile, e return None except ecdsa.keys.BadSignatureError, e: print "ERROR: bad signature", certfile return None if lookup_in_log: merkle_tree_leaf = pack_mtl(result["timestamp"], certchain[0]) leaf_hash = get_leaf_hash(merkle_tree_leaf) sth = get_sth(baseurl) proof = get_proof_by_hash(baseurl, leaf_hash, sth["tree_size"]) leaf_index = proof["leaf_index"] entries = get_entries(baseurl, leaf_index, leaf_index) fetched_entry = entries["entries"][0] print "does the leaf_input of the fetched entry match what we calculated:", \ base64.decodestring(fetched_entry["leaf_input"]) == merkle_tree_leaf extra_data = fetched_entry["extra_data"] certchain = decode_certificate_chain(base64.decodestring(extra_data)) submittedcertchain = certchain[1:] for (submittedcert, fetchedcert, i) in zip(submittedcertchain, certchain, itertools.count(1)): print "cert", i, "in chain is the same:", submittedcert == fetchedcert if len(certchain) == len(submittedcertchain) + 1: last_issuer = get_cert_info(certchain[-1])["issuer"] root_subject = get_cert_info(certchain[-1])["subject"] print "issuer of last cert in submitted chain and " \ "subject of last cert in fetched chain is the same:", \ last_issuer == root_subject elif len(certchain) == len(submittedcertchain): print "cert chains are the same length" else: print "ERROR: fetched cert chain has length", len(certchain), print "and submitted chain has length", len(submittedcertchain) timing_point(timing, "lookup") return timing["deltatimes"] def get_ncerts(certfiles): n = 0 for certfile in certfiles: if certfile.endswith(".zip"): zf = zipfile.ZipFile(certfile) n += len(zf.namelist()) zf.close() else: n += 1 return n def get_all_certificates(certfiles): for certfile in certfiles: if certfile.endswith(".zip"): zf = zipfile.ZipFile(certfile) for name in zf.namelist(): yield (name, zf.read(name)) zf.close() else: yield (certfile, open(certfile).read()) p = Pool(16, lambda: signal.signal(signal.SIGINT, signal.SIG_IGN)) nsubmitted = 0 lastprinted = 0 ncerts = get_ncerts(certfiles) print ncerts, "certs" certs = get_all_certificates(certfiles) submitcert(certs.next()) nsubmitted += 1 select.select([], [], [], 3.0) starttime = datetime.datetime.now() try: for timing in p.imap_unordered(submitcert, certs): if timing == None: print "error" print "submitted", nsubmitted p.terminate() p.join() sys.exit(1) nsubmitted += 1 deltatime = datetime.datetime.now() - starttime deltatime_f = deltatime.seconds + deltatime.microseconds / 1000000.0 rate = nsubmitted / deltatime_f if nsubmitted > lastprinted + ncerts / 10: print nsubmitted, "rate %.1f" % rate lastprinted = nsubmitted #print timing, "rate %.1f" % rate print "submitted", nsubmitted except KeyboardInterrupt: p.terminate() p.join()