#!/usr/bin/env python # Copyright (c) 2014, NORDUnet A/S. # See LICENSE for licensing information. import argparse import urllib2 import urllib import json import base64 import sys import struct import hashlib import itertools from certtools import * import os import signal import select import zipfile parser = argparse.ArgumentParser(description='') parser.add_argument('baseurl', help="Base URL for CT server") parser.add_argument('--store', default=None, metavar="dir", help='Get certificates from directory dir') parser.add_argument('--sct-file', default=None, metavar="file", help='Store SCT:s in file') parser.add_argument('--parallel', type=int, default=16, metavar="n", help="Number of parallel submits") parser.add_argument('--check-sct', action='store_true', help="Check SCT signature") parser.add_argument('--pre-warm', action='store_true', help="Wait 3 seconds after first submit") args = parser.parse_args() from multiprocessing import Pool baseurl = args.baseurl certfilepath = args.store lookup_in_log = False if certfilepath[-1] == "/": certfiles = [certfilepath + filename for filename in sorted(os.listdir(certfilepath)) if os.path.isfile(certfilepath + filename)] else: certfiles = [certfilepath] sth = get_sth(baseurl) def submitcert((certfile, cert)): timing = timing_point() certchain = get_certs_from_string(cert) timing_point(timing, "readcerts") try: result = add_chain(baseurl, {"chain":map(base64.b64encode, certchain)}) except SystemExit: print "EXIT:", certfile select.select([], [], [], 1.0) return (None, None) timing_point(timing, "addchain") if result == None: print "ERROR for certfile", certfile return (None, timing["deltatimes"]) try: if args.check_sct: check_sct_signature(baseurl, certchain[0], result) timing_point(timing, "checksig") except AssertionError, e: print "ERROR:", certfile, e return (None, None) except urllib2.HTTPError, e: print "ERROR:", certfile, e return (None, None) except ecdsa.keys.BadSignatureError, e: print "ERROR: bad signature", certfile return (None, None) if lookup_in_log: merkle_tree_leaf = pack_mtl(result["timestamp"], certchain[0]) leaf_hash = get_leaf_hash(merkle_tree_leaf) proof = get_proof_by_hash(baseurl, leaf_hash, sth["tree_size"]) leaf_index = proof["leaf_index"] entries = get_entries(baseurl, leaf_index, leaf_index) fetched_entry = entries["entries"][0] print "does the leaf_input of the fetched entry match what we calculated:", \ base64.decodestring(fetched_entry["leaf_input"]) == merkle_tree_leaf extra_data = fetched_entry["extra_data"] certchain = decode_certificate_chain(base64.decodestring(extra_data)) submittedcertchain = certchain[1:] for (submittedcert, fetchedcert, i) in zip(submittedcertchain, certchain, itertools.count(1)): print "cert", i, "in chain is the same:", submittedcert == fetchedcert if len(certchain) == len(submittedcertchain) + 1: last_issuer = get_cert_info(certchain[-1])["issuer"] root_subject = get_cert_info(certchain[-1])["subject"] print "issuer of last cert in submitted chain and " \ "subject of last cert in fetched chain is the same:", \ last_issuer == root_subject elif len(certchain) == len(submittedcertchain): print "cert chains are the same length" else: print "ERROR: fetched cert chain has length", len(certchain), print "and submitted chain has length", len(submittedcertchain) timing_point(timing, "lookup") return ((certchain[0], result), timing["deltatimes"]) def get_ncerts(certfiles): n = 0 for certfile in certfiles: if certfile.endswith(".zip"): zf = zipfile.ZipFile(certfile) n += len(zf.namelist()) zf.close() else: n += 1 return n def get_all_certificates(certfiles): for certfile in certfiles: if certfile.endswith(".zip"): zf = zipfile.ZipFile(certfile) for name in zf.namelist(): yield (name, zf.read(name)) zf.close() else: yield (certfile, open(certfile).read()) def save_sct(sct, sth): sctlog = open(args.sct_file, "a") json.dump({"leafcert": base64.b64encode(leafcert), "sct": sct, "sth": sth}, sctlog) sctlog.write("\n") sctlog.close() p = Pool(args.parallel, lambda: signal.signal(signal.SIGINT, signal.SIG_IGN)) nsubmitted = 0 lastprinted = 0 print "listing certs" ncerts = get_ncerts(certfiles) print ncerts, "certs" certs = get_all_certificates(certfiles) (result, timing) = submitcert(certs.next()) if result != None: nsubmitted += 1 (leafcert, sct) = result save_sct(sct, sth) if args.pre_warm: select.select([], [], [], 3.0) starttime = datetime.datetime.now() try: for result, timing in p.imap_unordered(submitcert, certs): if timing == None: print "error" print "submitted", nsubmitted p.terminate() p.join() sys.exit(1) if result != None: nsubmitted += 1 (leafcert, sct) = result save_sct(sct, sth) deltatime = datetime.datetime.now() - starttime deltatime_f = deltatime.seconds + deltatime.microseconds / 1000000.0 rate = nsubmitted / deltatime_f if nsubmitted > lastprinted + ncerts / 10: print nsubmitted, "rate %.1f" % rate lastprinted = nsubmitted #print timing, "rate %.1f" % rate print "submitted", nsubmitted except KeyboardInterrupt: p.terminate() p.join()