summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLinus Nordberg <linus@nordu.net>2016-11-30 16:47:23 +0100
committerLinus Nordberg <linus@nordu.net>2016-11-30 16:47:23 +0100
commitbff5d58fcce0534cf4774df386ff448261b28c20 (patch)
tree076adf9062bb24b210f8712b4aeb303b3023643d
parent720473257b4b7ab9916826ae87e617d1df138260 (diff)
Parallelise merge_dist.
Also deduplicate some code.
-rwxr-xr-xtools/merge_backup.py16
-rwxr-xr-xtools/merge_dist.py131
-rw-r--r--tools/mergetools.py10
3 files changed, 97 insertions, 60 deletions
diff --git a/tools/merge_backup.py b/tools/merge_backup.py
index 41b1014..cadcec7 100755
--- a/tools/merge_backup.py
+++ b/tools/merge_backup.py
@@ -22,7 +22,7 @@ from mergetools import chunks, backup_sendlog, get_logorder, \
get_verifiedsize, get_missingentriesforbackup, \
hexencode, setverifiedsize, sendentries_merge, verifyroot, \
get_nfetched, parse_args, perm, waitforfile, flock_ex_or_fail, \
- Status, loginit
+ Status, loginit, start_worker
def backup_loop(nodename, nodeaddress, own_key, paths, verifiedsize, chunk):
for trynumber in range(5, 0, -1):
@@ -166,12 +166,11 @@ def merge_backup(args, config, localconfig, secondaries):
backupargs = (secondary, localconfig, chainsdb, logorder, s, timing)
if args.mergeinterval:
- pipe_mine, pipe_theirs = Pipe()
- p = Process(target=lambda pipe, argv: pipe.send(do_send(argv)),
- args=(pipe_theirs, backupargs),
- name='backup_%s' % nodename)
- p.start()
- procs[p] = (nodename, pipe_mine)
+ name = 'backup_%s' % nodename
+ p, pipe = start_worker(name,
+ lambda cpipe, argv: cpipe.send(do_send(argv)),
+ backupargs)
+ procs[p] = (nodename, pipe)
else:
root_hash = do_send(backupargs)
update_backupfile(mergedb, nodename, tree_size, root_hash)
@@ -233,7 +232,6 @@ def main():
create_ssl_context(cafile=paths["https_cacertfile"])
fetched_statinfo = waitforfile(fetched_path)
- retval = 0
while True:
failures = merge_backup(args, config, localconfig, nodes)
if not args.mergeinterval:
@@ -245,7 +243,7 @@ def main():
break
fetched_statinfo = stat(fetched_path)
- return retval
+ return 0
if __name__ == '__main__':
sys.exit(main())
diff --git a/tools/merge_dist.py b/tools/merge_dist.py
index d612600..bc9c676 100755
--- a/tools/merge_dist.py
+++ b/tools/merge_dist.py
@@ -14,11 +14,12 @@ import logging
from time import sleep
from base64 import b64encode, b64decode
from os import stat
+from multiprocessing import Process, Pipe
from certtools import timing_point, create_ssl_context
from mergetools import get_curpos, get_logorder, chunks, get_missingentries, \
publish_sth, sendlog, sendentries, parse_args, perm, \
get_frontend_verifiedsize, frontend_verify_entries, \
- waitforfile, flock_ex_or_fail, Status, loginit
+ waitforfile, flock_ex_or_fail, Status, loginit, start_worker
def sendlog_helper(entries, curpos, nodename, nodeaddress, own_key, paths,
statusupdates):
@@ -70,12 +71,51 @@ def fill_in_missing_entries(nodename, nodeaddress, own_key, paths, chainsdb,
own_key, paths)
timing_point(timing, "get missing")
-def merge_dist(args, localconfig, frontendnodes, timestamp):
- maxwindow = localconfig.get("maxwindow", 1000)
+def do_send(args, localconfig, frontendnode, logorder, sth, chainsdb, s):
+ timing = timing_point()
paths = localconfig["paths"]
own_key = (localconfig["nodename"],
"%s/%s-private.pem" % (paths["privatekeys"],
localconfig["nodename"]))
+ maxwindow = localconfig.get("maxwindow", 1000)
+ nodename = frontendnode["name"]
+ nodeaddress = "https://%s/" % frontendnode["address"]
+
+ logging.info("distributing for node %s", nodename)
+ curpos = get_curpos(nodename, nodeaddress, own_key, paths)
+ timing_point(timing, "get curpos")
+ logging.info("current position %d", curpos)
+
+ verifiedsize = \
+ get_frontend_verifiedsize(nodename, nodeaddress, own_key, paths)
+ timing_point(timing, "get verified size")
+ logging.info("verified size %d", verifiedsize)
+
+ assert verifiedsize >= curpos
+
+ while verifiedsize < len(logorder):
+ uptopos = min(verifiedsize + maxwindow, len(logorder))
+
+ entries = [b64encode(entry) for entry in logorder[verifiedsize:uptopos]]
+ sendlog_helper(entries, verifiedsize, nodename, nodeaddress, own_key, paths, s)
+ timing_point(timing, "sendlog")
+
+ fill_in_missing_entries(nodename, nodeaddress, own_key, paths, chainsdb, timing, s)
+
+ verifiedsize = frontend_verify_entries(nodename, nodeaddress, own_key, paths, uptopos)
+
+ logging.info("sending sth to node %s", nodename)
+ publishsthresult = publish_sth(nodename, nodeaddress, own_key, paths, sth)
+ if publishsthresult["result"] != "ok":
+ logging.info("publishsth: %s", publishsthresult)
+ sys.exit(1)
+ timing_point(timing, "send sth")
+
+ if args.timing:
+ logging.debug("timing: merge_dist: %s", timing["deltatimes"])
+
+def merge_dist(args, localconfig, frontendnodes, timestamp):
+ paths = localconfig["paths"]
mergedb = paths["mergedb"]
chainsdb = perm(localconfig.get("dbbackend", "filedb"), mergedb + "/chains")
logorderfile = mergedb + "/logorder"
@@ -89,56 +129,49 @@ def merge_dist(args, localconfig, frontendnodes, timestamp):
sth = json.loads(open(sthfile, 'r').read())
except (IOError, ValueError):
logging.warning("No valid STH file found in %s", sthfile)
- return timestamp
+ return timestamp, 0
if sth['timestamp'] < timestamp:
logging.warning("New STH file older than the previous one: %d < %d",
- sth['timestamp'], timestamp)
- return timestamp
+ sth['timestamp'], timestamp)
+ return timestamp, 0
if sth['timestamp'] == timestamp:
- return timestamp
+ return timestamp, 0
timestamp = sth['timestamp']
logorder = get_logorder(logorderfile, sth['tree_size'])
timing_point(timing, "get logorder")
+ procs = {}
for frontendnode in frontendnodes:
- nodeaddress = "https://%s/" % frontendnode["address"]
nodename = frontendnode["name"]
- timing = timing_point()
-
- logging.info("distributing for node %s", nodename)
- curpos = get_curpos(nodename, nodeaddress, own_key, paths)
- timing_point(timing, "get curpos")
- logging.info("current position %d", curpos)
-
- verifiedsize = get_frontend_verifiedsize(nodename, nodeaddress, own_key, paths)
- timing_point(timing, "get verified size")
- logging.info("verified size %d", verifiedsize)
-
- assert verifiedsize >= curpos
- while verifiedsize < len(logorder):
- uptopos = min(verifiedsize + maxwindow, len(logorder))
-
- entries = [b64encode(entry) for entry in logorder[verifiedsize:uptopos]]
- sendlog_helper(entries, verifiedsize, nodename, nodeaddress, own_key, paths, s)
- timing_point(timing, "sendlog")
+ if args.mergeinterval:
+ name = 'dist_%s' % nodename
+ p, pipe = start_worker(name,
+ lambda _, argv: do_send(*(argv)),
+ (args, localconfig, frontendnode, logorder, sth, chainsdb, s))
+ procs[p] = (nodename, pipe)
+ else:
+ do_send(args, localconfig, frontendnode, logorder, sth, chainsdb, s)
- fill_in_missing_entries(nodename, nodeaddress, own_key, paths, chainsdb, timing, s)
+ if not args.mergeinterval:
+ return timestamp, 0
- verifiedsize = frontend_verify_entries(nodename, nodeaddress, own_key, paths, uptopos)
+ failures = 0
+ while True:
+ for p in list(procs):
+ if not p.is_alive():
+ p.join()
+ nodename, _ = procs[p]
+ if p.exitcode != 0:
+ logging.warning("%s failure: %d", nodename, p.exitcode)
+ failures += 1
+ del procs[p]
+ if not procs:
+ break
+ sleep(1)
- logging.info("sending sth to node %s", nodename)
- publishsthresult = publish_sth(nodename, nodeaddress, own_key, paths, sth)
- if publishsthresult["result"] != "ok":
- logging.info("publishsth: %s", publishsthresult)
- sys.exit(1)
- timing_point(timing, "send sth")
-
- if args.timing:
- logging.debug("timing: merge_dist: %s", timing["deltatimes"])
-
- return timestamp
+ return timestamp, failures
def main():
"""
@@ -146,12 +179,12 @@ def main():
Distribute missing entries and the STH to all frontend nodes.
- If `--mergeinterval', wait until 'sth' is updated and read it and
- start distributing again.
+ If `--mergeinterval', start over again.
"""
args, config, localconfig = parse_args()
paths = localconfig["paths"]
mergedb = paths["mergedb"]
+ sth_path = localconfig["paths"]["mergedb"] + "/sth"
lockfile = mergedb + "/.merge_dist.lock"
timestamp = 0
@@ -166,20 +199,18 @@ def main():
else:
nodes = [n for n in config["frontendnodes"] if n["name"] in args.node]
- if args.mergeinterval is None:
- if merge_dist(args, localconfig, nodes, timestamp) < 0:
- return 1
- return 0
-
- sth_path = localconfig["paths"]["mergedb"] + "/sth"
sth_statinfo = waitforfile(sth_path)
while True:
- if merge_dist(args, localconfig, nodes, timestamp) < 0:
- return 1
+ timestamp, failures = merge_dist(args, localconfig, nodes, timestamp)
+ if not args.mergeinterval:
+ break
sth_statinfo_old = sth_statinfo
while sth_statinfo == sth_statinfo_old:
- sleep(args.mergeinterval / 30)
+ sleep(max(3, args.mergeinterval / 10))
+ if failures > 0:
+ break
sth_statinfo = stat(sth_path)
+
return 0
if __name__ == '__main__':
diff --git a/tools/mergetools.py b/tools/mergetools.py
index 109e9d4..beb41bf 100644
--- a/tools/mergetools.py
+++ b/tools/mergetools.py
@@ -484,13 +484,21 @@ def flock_ex_or_fail(path):
return False
return True
+def start_worker(name, fun, args):
+ pipe_mine, pipe_theirs = multiprocessing.Pipe()
+ p = multiprocessing.Process(target=fun,
+ args=(pipe_theirs, args),
+ name=name)
+ p.start()
+ return (p, pipe_mine)
+
def terminate_child_procs():
for p in multiprocessing.active_children():
#print >>sys.stderr, "DEBUG: terminating pid", p.pid
p.terminate()
def loginit(args, fname):
- logfmt = '%(asctime)s %(message)s'
+ logfmt = '%(asctime)s %(name)s %(levelname)s %(message)s'
loglevel = getattr(logging, args.loglevel.upper())
if args.logdir is None:
logging.basicConfig(format=logfmt, level=loglevel)