diff options
Diffstat (limited to 'tools/merge_fetch.py')
-rwxr-xr-x | tools/merge_fetch.py | 247 |
1 files changed, 219 insertions, 28 deletions
diff --git a/tools/merge_fetch.py b/tools/merge_fetch.py index 8c3a997..7e0dfd8 100755 --- a/tools/merge_fetch.py +++ b/tools/merge_fetch.py @@ -11,17 +11,24 @@ import sys import struct import subprocess import requests +import signal +import logging from time import sleep +from multiprocessing import Process, Pipe +from random import Random from mergetools import get_logorder, verify_entry, get_new_entries, \ chunks, fsync_logorder, get_entries, add_to_logorder, \ - hexencode, parse_args, perm + hexencode, hexdecode, parse_args, perm, flock_ex_or_fail, Status, \ + terminate_child_procs from certtools import timing_point, write_file, create_ssl_context -def merge_fetch(args, config, localconfig): +def merge_fetch_sequenced(args, config, localconfig): paths = localconfig["paths"] storagenodes = config["storagenodes"] mergedb = paths["mergedb"] logorderfile = mergedb + "/logorder" + statusfile = mergedb + "/merge_fetch.status" + s = Status(statusfile) chainsdb = perm(localconfig.get("dbbackend", "filedb"), mergedb + "/chains") own_key = (localconfig["nodename"], "%s/%s-private.pem" % (paths["privatekeys"], @@ -38,8 +45,7 @@ def merge_fetch(args, config, localconfig): entries_to_fetch = {} for storagenode in storagenodes: - print >>sys.stderr, "getting new entries from", storagenode["name"] - sys.stderr.flush() + logging.info("getting new entries from %s", storagenode["name"]) new_entries_per_node[storagenode["name"]] = \ set(get_new_entries(storagenode["name"], "https://%s/" % storagenode["address"], @@ -49,8 +55,7 @@ def merge_fetch(args, config, localconfig): timing_point(timing, "get new entries") new_entries -= certsinlog - print >>sys.stderr, "adding", len(new_entries), "entries" - sys.stderr.flush() + logging.info("adding %d entries", len(new_entries)) for ehash in new_entries: for storagenode in storagenodes: @@ -64,9 +69,8 @@ def merge_fetch(args, config, localconfig): added_entries = 0 for storagenode in storagenodes: - print >>sys.stderr, "getting %d entries from %s:" % \ - (len(entries_to_fetch[storagenode["name"]]), storagenode["name"]), - sys.stderr.flush() + nentries = len(entries_to_fetch[storagenode["name"]]) + logging.info("getting %d entries from %s", nentries, storagenode["name"]) with requests.sessions.Session() as session: for chunk in chunks(entries_to_fetch[storagenode["name"]], 100): entries = get_entries(storagenode["name"], @@ -80,21 +84,17 @@ def merge_fetch(args, config, localconfig): logorder.append(ehash) certsinlog.add(ehash) added_entries += 1 - print >>sys.stderr, added_entries, - sys.stderr.flush() - print >>sys.stderr - sys.stderr.flush() + s.status("PROG getting %d entries from %s: %d" % + (nentries, storagenode["name"], added_entries)) chainsdb.commit() fsync_logorder(logorderfile) timing_point(timing, "add entries") - print >>sys.stderr, "added", added_entries, "entries" - sys.stderr.flush() + logging.info("added %d entries", added_entries) verifycert.communicate(struct.pack("I", 0)) if args.timing: - print >>sys.stderr, "timing: merge_fetch:", timing["deltatimes"] - sys.stderr.flush() + logging.debug("timing: merge_fetch: %s", timing["deltatimes"]) tree_size = len(logorder) if tree_size == 0: @@ -102,30 +102,221 @@ def merge_fetch(args, config, localconfig): else: return (tree_size, logorder[tree_size-1]) +def merge_fetch_worker(args, localconfig, storagenode, pipe): + paths = localconfig["paths"] + mergedb = paths["mergedb"] + chainsdb = perm(localconfig.get("dbbackend", "filedb"), mergedb + "/chains") + own_key = (localconfig["nodename"], + "%s/%s-private.pem" % (paths["privatekeys"], + localconfig["nodename"])) + to_fetch = set() + timeout = max(3, args.mergeinterval / 10) + while True: + if pipe.poll(timeout): + msg = pipe.recv().split() + if len(msg) < 2: + continue + cmd = msg[0] + ehash = msg[1] + if cmd == 'FETCH': + to_fetch.add(hexdecode(ehash)) + else: + logging.warning("%s: unknown command from parent: %s", + storagenode["name"], msg) + + if len(to_fetch) > 0: + logging.info("%s: fetching %d entries", storagenode["name"], + len(to_fetch)) + # TODO: Consider running the verifycert process longer. + verifycert = subprocess.Popen( + [paths["verifycert_bin"], paths["known_roots"]], + stdin=subprocess.PIPE, stdout=subprocess.PIPE) + # Chunking for letting other workers take the chainsdb lock. + for chunk in chunks(list(to_fetch), 100): + chainsdb.lock_ex() + with requests.sessions.Session() as session: + entries = get_entries(storagenode["name"], + "https://%s/" % storagenode["address"], + own_key, paths, chunk, session=session) + for ehash in chunk: + entry = entries[ehash] + verify_entry(verifycert, entry, ehash) + chainsdb.add(ehash, entry) + chainsdb.commit() + chainsdb.release_lock() + for ehash in chunk: + pipe.send('FETCHED %s' % hexencode(ehash)) + to_fetch.remove(ehash) + verifycert.communicate(struct.pack("I", 0)) + + new_entries = get_new_entries(storagenode["name"], + "https://%s/" % storagenode["address"], + own_key, paths) + if len(new_entries) > 0: + logging.info("%s: got %d new entries", storagenode["name"], + len(new_entries)) + for ehash in new_entries: + pipe.send('NEWENTRY %s' % hexencode(ehash)) + +def term(signal, arg): + terminate_child_procs() + sys.exit(1) + +def newworker(name, args): + my_conn, child_conn = Pipe() + p = Process(target=merge_fetch_worker, + args=tuple(args + [child_conn]), + name='merge_fetch_%s' % name) + p.daemon = True + p.start() + logging.debug("%s started, pid %d", name, p.pid) + return (name, my_conn, p) + +def merge_fetch_parallel(args, config, localconfig): + paths = localconfig["paths"] + storagenodes = config["storagenodes"] + mergedb = paths["mergedb"] + logorderfile = mergedb + "/logorder" + currentsizefile = mergedb + "/fetched" + + rand = Random() + signal.signal(signal.SIGTERM, term) + + procs = {} + for storagenode in storagenodes: + name = storagenode['name'] + procs[name] = newworker(name, [args, localconfig, storagenode]) + + logorder = get_logorder(logorderfile) # List of entries in log. + entries_in_log = set(logorder) # Set of entries in log. + entries_to_fetch = set() # Set of entries to fetch. + fetch = {} # Dict with entries to fetch. + while procs: + assert(not entries_to_fetch) + # Poll worker processes. + for name, pipe, p in procs.values(): + if not p.is_alive(): + logging.warning("%s is gone, restarting", name) + procs[name] = newworker(name, [args, localconfig, + storagenodes[name]]) + continue + logging.info("polling %s", name) + if pipe.poll(1): + msg = pipe.recv().split() + if len(msg) < 2: + logging.warning("unknown command from %s: %s", name, msg) + continue + cmd = msg[0] + ehash = msg[1] + if cmd == 'NEWENTRY': + logging.info("NEWENTRY at %s: %s", name, ehash) + entries_to_fetch.add(ehash) + logging.debug("entries_to_fetch: %s", entries_to_fetch) + elif cmd == 'FETCHED': + logging.info("FETCHED from %s: %s", name, ehash) + logorder.append(ehash) + add_to_logorder(logorderfile, hexdecode(ehash)) + fsync_logorder(logorderfile) + entries_in_log.add(ehash) + if ehash in entries_to_fetch: + entries_to_fetch.remove(ehash) + del fetch[ehash] + else: + logging.warning("unknown command from %s: %s", name, msg) + + # Ask workers to fetch entries. + logging.debug("nof entries to fetch including entries in log: %d", + len(entries_to_fetch)) + entries_to_fetch -= entries_in_log + logging.info("entries to fetch: %d", len(entries_to_fetch)) + # Add entries in entries_to_fetch as keys in dictionary fetch, + # values being a list of storage nodes, in randomised order. + for e in entries_to_fetch: + if not e in fetch: + l = procs.values() + rand.shuffle(l) + fetch[e] = l + # For each entry to fetch, treat its list of nodes as a + # circular list and ask the one in the front to fetch the + # entry. + while entries_to_fetch: + ehash = entries_to_fetch.pop() + nodes = fetch[ehash] + node = nodes.pop(0) + fetch[ehash] = nodes.append(node) + name, pipe, p = node + logging.info("asking %s to FETCH %s", name, ehash) + pipe.send("FETCH %s" % ehash) + + # Update the 'fetched' file. + logsize = len(logorder) + if logsize == 0: + last_hash = '' + else: + last_hash = logorder[logsize - 1] + logging.info("updating 'fetched' file: %d %s", logsize-1, last_hash) + currentsize = {"index": logsize - 1, "hash": last_hash} + logging.debug("writing to %s: %s", currentsizefile, currentsize) + write_file(currentsizefile, currentsize) + + return 0 + def main(): """ - Fetch new entries from all storage nodes. + If no `--mergeinterval': + Fetch new entries from all storage nodes, in sequence, updating + the 'logorder' file and the 'chains' database. + + Write 'fetched' to reflect how far in 'logorder' we've succesfully + fetched and verified. + + If `--mergeinterval': + Start one process per storage node, read their stdout for learning + about two things: (i) new entries ready for fetching ("NEWENTRY") and + (ii) new entries being succesfully fetched ("FETCHED"). - Indicate current position by writing the index in the logorder file - (0-based) to the 'fetched' file. + Write to their stdin ("FETCH") when they should fetch another entry. + Update 'logorder' and the 'chains' database as we see new FETCHED + messages. - Sleep some and start over. + Write 'fetched' to reflect how far in 'logorder' we've succesfully + fetched and verified. + + Keep doing this forever. + + NOTE: The point of having 'fetched' is that it can be atomically + written while 'logorder' cannot (unless we're fine with rewriting it + for each and every update, which we're not). + + TODO: Deduplicate some code. """ args, config, localconfig = parse_args() paths = localconfig["paths"] mergedb = paths["mergedb"] currentsizefile = mergedb + "/fetched" + lockfile = mergedb + "/.merge_fetch.lock" + + loglevel = getattr(logging, args.loglevel.upper()) + if args.mergeinterval is None: + logging.basicConfig(level=loglevel) + else: + logging.basicConfig(filename=args.logdir + "/merge_fetch.log", + level=loglevel) + + if not flock_ex_or_fail(lockfile): + logging.critical("unable to take lock %s", lockfile) + return 1 + create_ssl_context(cafile=paths["https_cacertfile"]) - while True: - logsize, last_hash = merge_fetch(args, config, localconfig) + if args.mergeinterval: + return merge_fetch_parallel(args, config, localconfig) + else: + logsize, last_hash = merge_fetch_sequenced(args, config, localconfig) currentsize = {"index": logsize - 1, "hash": hexencode(last_hash)} - #print >>sys.stderr, "DEBUG: writing to", currentsizefile, ":", currentsize + logging.debug("writing to %s: %s", currentsizefile, currentsize) write_file(currentsizefile, currentsize) - if args.interval is None: - break - print >>sys.stderr, "sleeping", args.interval, "seconds" - sleep(args.interval) + return 0 if __name__ == '__main__': sys.exit(main()) |