#!/usr/bin/env python

# Copyright (c) 2015, NORDUnet A/S.
# See LICENSE for licensing information.

import argparse
import urllib2
import urllib
import json
import base64
import sys
import struct
import hashlib
import itertools
from certtools import *
try:
    from precerttools import *
except ImportError:
    pass
import os
import signal
import select
import zipfile
import traceback

parser = argparse.ArgumentParser(description='')
parser.add_argument('--store', default=None, metavar="dir", help='Get certificates from directory dir')
parser.add_argument('--parallel', type=int, default=1, metavar="n", help="Number of parallel workers")
args = parser.parse_args()

from multiprocessing import Pool

certfilepath = args.store

if certfilepath[-1] == "/":
    certfiles = [certfilepath + filename for filename in sorted(os.listdir(certfilepath)) if os.path.isfile(certfilepath + filename)]
else:
    certfiles = [certfilepath]

def submitcert((certfile, cert)):
    try:
        certchain = get_certs_from_string(cert)
        if len(certchain) == 0:
            return True
        precerts = get_precerts_from_string(cert)
        hash = get_hash_from_certfile(cert)
        timestamp = get_timestamp_from_certfile(cert)
        assert len(precerts) == 0 or len(precerts) == 1
        precert = precerts[0] if precerts else None
        if precert:
            if ext_key_usage_precert_signing_cert in get_ext_key_usage(certchain[0]):
                issuer_key_hash = get_cert_key_hash(certchain[1])
                issuer = certchain[1]
            else:
                issuer_key_hash = get_cert_key_hash(certchain[0])
                issuer = None
            cleanedcert = cleanprecert(precert, issuer=issuer)
            mtl = pack_mtl_precert(timestamp, cleanedcert, issuer_key_hash)
            leaf_hash = get_leaf_hash(mtl)
        else:
            mtl = pack_mtl(timestamp, certchain[0])
            leaf_hash = get_leaf_hash(mtl)
        if leaf_hash == hash:
            return True
        else:
            print certfile, repr(leaf_hash), repr(hash), precert != None
            return None
    except Exception, e:
        print certfile
        traceback.print_exc()
        raise e

def get_all_certificates(certfiles):
    for certfile in certfiles:
        if certfile.endswith(".zip"):
            zf = zipfile.ZipFile(certfile)
            for name in zf.namelist():
                yield (name, zf.read(name))
            zf.close()
        else:
            yield (certfile, open(certfile).read())

p = Pool(args.parallel, lambda: signal.signal(signal.SIGINT, signal.SIG_IGN))

certs = get_all_certificates(certfiles)

try:
    for result in p.imap_unordered(submitcert, certs):
        if result == None:
            print "error"
            p.terminate()
            p.join()
            sys.exit(1)
except KeyboardInterrupt:
    p.terminate()
    p.join()