diff --git a/command.py b/command.py index be2d6a6e..9b1220dc 100644 --- a/command.py +++ b/command.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import multiprocessing import os import optparse import platform @@ -21,6 +22,7 @@ import sys from event_log import EventLog from error import NoSuchProjectError from error import InvalidProjectGroupsError +import progress # Number of projects to submit to a single worker process at a time. @@ -156,6 +158,44 @@ class Command(object): """ raise NotImplementedError + @staticmethod + def ExecuteInParallel(jobs, func, inputs, callback, output=None, ordered=False): + """Helper for managing parallel execution boiler plate. + + For subcommands that can easily split their work up. + + Args: + jobs: How many parallel processes to use. + func: The function to apply to each of the |inputs|. Usually a + functools.partial for wrapping additional arguments. It will be run + in a separate process, so it must be pickalable, so nested functions + won't work. Methods on the subcommand Command class should work. + inputs: The list of items to process. Must be a list. + callback: The function to pass the results to for processing. It will be + executed in the main thread and process the results of |func| as they + become available. Thus it may be a local nested function. Its return + value is passed back directly. It takes three arguments: + - The processing pool (or None with one job). + - The |output| argument. + - An iterator for the results. + output: An output manager. May be progress.Progess or color.Coloring. + ordered: Whether the jobs should be processed in order. + + Returns: + The |callback| function's results are returned. + """ + try: + # NB: Multiprocessing is heavy, so don't spin it up for one job. + if len(inputs) == 1 or jobs == 1: + return callback(None, output, (func(x) for x in inputs)) + else: + with multiprocessing.Pool(jobs) as pool: + submit = pool.imap if ordered else pool.imap_unordered + return callback(pool, output, submit(func, inputs, chunksize=WORKER_BATCH_SIZE)) + finally: + if isinstance(output, progress.Progress): + output.end() + def _ResetPathToProjectMap(self, projects): self._by_path = dict((p.worktree, p) for p in projects) diff --git a/subcmds/abandon.py b/subcmds/abandon.py index 1d22917e..c7c127d6 100644 --- a/subcmds/abandon.py +++ b/subcmds/abandon.py @@ -15,10 +15,9 @@ from collections import defaultdict import functools import itertools -import multiprocessing import sys -from command import Command, DEFAULT_LOCAL_JOBS, WORKER_BATCH_SIZE +from command import Command, DEFAULT_LOCAL_JOBS from git_command import git from progress import Progress @@ -52,9 +51,9 @@ It is equivalent to "git branch -D ". else: args.insert(0, "'All local branches'") - def _ExecuteOne(self, opt, nb, project): + def _ExecuteOne(self, all_branches, nb, project): """Abandon one project.""" - if opt.all: + if all_branches: branches = project.GetBranches() else: branches = [nb] @@ -72,7 +71,7 @@ It is equivalent to "git branch -D ". success = defaultdict(list) all_projects = self.GetProjects(args[1:]) - def _ProcessResults(states): + def _ProcessResults(_pool, pm, states): for (results, project) in states: for branch, status in results.items(): if status: @@ -81,17 +80,12 @@ It is equivalent to "git branch -D ". err[branch].append(project) pm.update() - pm = Progress('Abandon %s' % nb, len(all_projects), quiet=opt.quiet) - # NB: Multiprocessing is heavy, so don't spin it up for one job. - if len(all_projects) == 1 or opt.jobs == 1: - _ProcessResults(self._ExecuteOne(opt, nb, x) for x in all_projects) - else: - with multiprocessing.Pool(opt.jobs) as pool: - states = pool.imap_unordered( - functools.partial(self._ExecuteOne, opt, nb), all_projects, - chunksize=WORKER_BATCH_SIZE) - _ProcessResults(states) - pm.end() + self.ExecuteInParallel( + opt.jobs, + functools.partial(self._ExecuteOne, opt.all, nb), + all_projects, + callback=_ProcessResults, + output=Progress('Abandon %s' % (nb,), len(all_projects), quiet=opt.quiet)) width = max(itertools.chain( [25], (len(x) for x in itertools.chain(success, err)))) diff --git a/subcmds/branches.py b/subcmds/branches.py index d5ea580c..2dc102bb 100644 --- a/subcmds/branches.py +++ b/subcmds/branches.py @@ -13,10 +13,10 @@ # limitations under the License. import itertools -import multiprocessing import sys + from color import Coloring -from command import Command, DEFAULT_LOCAL_JOBS, WORKER_BATCH_SIZE +from command import Command, DEFAULT_LOCAL_JOBS class BranchColoring(Coloring): @@ -102,15 +102,19 @@ is shown, then the branch appears in all projects. out = BranchColoring(self.manifest.manifestProject.config) all_branches = {} project_cnt = len(projects) - with multiprocessing.Pool(processes=opt.jobs) as pool: - project_branches = pool.imap_unordered( - expand_project_to_branches, projects, chunksize=WORKER_BATCH_SIZE) - for name, b in itertools.chain.from_iterable(project_branches): + def _ProcessResults(_pool, _output, results): + for name, b in itertools.chain.from_iterable(results): if name not in all_branches: all_branches[name] = BranchInfo(name) all_branches[name].add(b) + self.ExecuteInParallel( + opt.jobs, + expand_project_to_branches, + projects, + callback=_ProcessResults) + names = sorted(all_branches) if not names: diff --git a/subcmds/checkout.py b/subcmds/checkout.py index 6b71a8fa..4d8009b1 100644 --- a/subcmds/checkout.py +++ b/subcmds/checkout.py @@ -13,10 +13,9 @@ # limitations under the License. import functools -import multiprocessing import sys -from command import Command, DEFAULT_LOCAL_JOBS, WORKER_BATCH_SIZE +from command import Command, DEFAULT_LOCAL_JOBS from progress import Progress @@ -50,7 +49,7 @@ The command is equivalent to: success = [] all_projects = self.GetProjects(args[1:]) - def _ProcessResults(results): + def _ProcessResults(_pool, pm, results): for status, project in results: if status is not None: if status: @@ -59,17 +58,12 @@ The command is equivalent to: err.append(project) pm.update() - pm = Progress('Checkout %s' % nb, len(all_projects), quiet=opt.quiet) - # NB: Multiprocessing is heavy, so don't spin it up for one job. - if len(all_projects) == 1 or opt.jobs == 1: - _ProcessResults(self._ExecuteOne(nb, x) for x in all_projects) - else: - with multiprocessing.Pool(opt.jobs) as pool: - results = pool.imap_unordered( - functools.partial(self._ExecuteOne, nb), all_projects, - chunksize=WORKER_BATCH_SIZE) - _ProcessResults(results) - pm.end() + self.ExecuteInParallel( + opt.jobs, + functools.partial(self._ExecuteOne, nb), + all_projects, + callback=_ProcessResults, + output=Progress('Checkout %s' % (nb,), len(all_projects), quiet=opt.quiet)) if err: for p in err: diff --git a/subcmds/diff.py b/subcmds/diff.py index cdc262e6..4966bb1a 100644 --- a/subcmds/diff.py +++ b/subcmds/diff.py @@ -14,9 +14,8 @@ import functools import io -import multiprocessing -from command import DEFAULT_LOCAL_JOBS, PagedCommand, WORKER_BATCH_SIZE +from command import DEFAULT_LOCAL_JOBS, PagedCommand class Diff(PagedCommand): @@ -36,7 +35,7 @@ to the Unix 'patch' command. dest='absolute', action='store_true', help='Paths are relative to the repository root') - def _DiffHelper(self, absolute, project): + def _ExecuteOne(self, absolute, project): """Obtains the diff for a specific project. Args: @@ -51,22 +50,20 @@ to the Unix 'patch' command. return (ret, buf.getvalue()) def Execute(self, opt, args): - ret = 0 all_projects = self.GetProjects(args) - # NB: Multiprocessing is heavy, so don't spin it up for one job. - if len(all_projects) == 1 or opt.jobs == 1: - for project in all_projects: - if not project.PrintWorkTreeDiff(opt.absolute): + def _ProcessResults(_pool, _output, results): + ret = 0 + for (state, output) in results: + if output: + print(output, end='') + if not state: ret = 1 - else: - with multiprocessing.Pool(opt.jobs) as pool: - states = pool.imap(functools.partial(self._DiffHelper, opt.absolute), - all_projects, WORKER_BATCH_SIZE) - for (state, output) in states: - if output: - print(output, end='') - if not state: - ret = 1 + return ret - return ret + return self.ExecuteInParallel( + opt.jobs, + functools.partial(self._ExecuteOne, opt.absolute), + all_projects, + callback=_ProcessResults, + ordered=True) diff --git a/subcmds/grep.py b/subcmds/grep.py index 9a4a8a36..6cb1445a 100644 --- a/subcmds/grep.py +++ b/subcmds/grep.py @@ -13,11 +13,10 @@ # limitations under the License. import functools -import multiprocessing import sys from color import Coloring -from command import DEFAULT_LOCAL_JOBS, PagedCommand, WORKER_BATCH_SIZE +from command import DEFAULT_LOCAL_JOBS, PagedCommand from error import GitError from git_command import GitCommand @@ -173,7 +172,7 @@ contain a line that matches both expressions: return (project, p.Wait(), p.stdout, p.stderr) @staticmethod - def _ProcessResults(out, full_name, have_rev, results): + def _ProcessResults(full_name, have_rev, _pool, out, results): git_failed = False bad_rev = False have_match = False @@ -256,18 +255,13 @@ contain a line that matches both expressions: cmd_argv.extend(opt.revision) cmd_argv.append('--') - process_results = functools.partial( - self._ProcessResults, out, full_name, have_rev) - # NB: Multiprocessing is heavy, so don't spin it up for one job. - if len(projects) == 1 or opt.jobs == 1: - git_failed, bad_rev, have_match = process_results( - self._ExecuteOne(cmd_argv, x) for x in projects) - else: - with multiprocessing.Pool(opt.jobs) as pool: - results = pool.imap( - functools.partial(self._ExecuteOne, cmd_argv), projects, - chunksize=WORKER_BATCH_SIZE) - git_failed, bad_rev, have_match = process_results(results) + git_failed, bad_rev, have_match = self.ExecuteInParallel( + opt.jobs, + functools.partial(self._ExecuteOne, cmd_argv), + projects, + callback=functools.partial(self._ProcessResults, full_name, have_rev), + output=out, + ordered=True) if git_failed: sys.exit(1) diff --git a/subcmds/prune.py b/subcmds/prune.py index 4084c8b6..236b647f 100644 --- a/subcmds/prune.py +++ b/subcmds/prune.py @@ -13,10 +13,9 @@ # limitations under the License. import itertools -import multiprocessing from color import Coloring -from command import DEFAULT_LOCAL_JOBS, PagedCommand, WORKER_BATCH_SIZE +from command import DEFAULT_LOCAL_JOBS, PagedCommand class Prune(PagedCommand): @@ -36,18 +35,15 @@ class Prune(PagedCommand): # NB: Should be able to refactor this module to display summary as results # come back from children. - def _ProcessResults(results): + def _ProcessResults(_pool, _output, results): return list(itertools.chain.from_iterable(results)) - # NB: Multiprocessing is heavy, so don't spin it up for one job. - if len(projects) == 1 or opt.jobs == 1: - all_branches = _ProcessResults(self._ExecuteOne(x) for x in projects) - else: - with multiprocessing.Pool(opt.jobs) as pool: - results = pool.imap( - self._ExecuteOne, projects, - chunksize=WORKER_BATCH_SIZE) - all_branches = _ProcessResults(results) + all_branches = self.ExecuteInParallel( + opt.jobs, + self._ExecuteOne, + projects, + callback=_ProcessResults, + ordered=True) if not all_branches: return diff --git a/subcmds/start.py b/subcmds/start.py index aa2f915a..ff2bae56 100644 --- a/subcmds/start.py +++ b/subcmds/start.py @@ -13,11 +13,10 @@ # limitations under the License. import functools -import multiprocessing import os import sys -from command import Command, DEFAULT_LOCAL_JOBS, WORKER_BATCH_SIZE +from command import Command, DEFAULT_LOCAL_JOBS from git_config import IsImmutable from git_command import git import gitc_utils @@ -55,7 +54,7 @@ revision specified in the manifest. if not git.check_ref_format('heads/%s' % nb): self.OptionParser.error("'%s' is not a valid name" % nb) - def _ExecuteOne(self, opt, nb, project): + def _ExecuteOne(self, revision, nb, project): """Start one project.""" # If the current revision is immutable, such as a SHA1, a tag or # a change, then we can't push back to it. Substitute with @@ -69,7 +68,7 @@ revision specified in the manifest. try: ret = project.StartBranch( - nb, branch_merge=branch_merge, revision=opt.revision) + nb, branch_merge=branch_merge, revision=revision) except Exception as e: print('error: unable to checkout %s: %s' % (project.name, e), file=sys.stderr) ret = False @@ -123,23 +122,18 @@ revision specified in the manifest. pm.update() pm.end() - def _ProcessResults(results): + def _ProcessResults(_pool, pm, results): for (result, project) in results: if not result: err.append(project) pm.update() - pm = Progress('Starting %s' % nb, len(all_projects), quiet=opt.quiet) - # NB: Multiprocessing is heavy, so don't spin it up for one job. - if len(all_projects) == 1 or opt.jobs == 1: - _ProcessResults(self._ExecuteOne(opt, nb, x) for x in all_projects) - else: - with multiprocessing.Pool(opt.jobs) as pool: - results = pool.imap_unordered( - functools.partial(self._ExecuteOne, opt, nb), all_projects, - chunksize=WORKER_BATCH_SIZE) - _ProcessResults(results) - pm.end() + self.ExecuteInParallel( + opt.jobs, + functools.partial(self._ExecuteOne, opt.revision, nb), + all_projects, + callback=_ProcessResults, + output=Progress('Starting %s' % (nb,), len(all_projects), quiet=opt.quiet)) if err: for p in err: diff --git a/subcmds/status.py b/subcmds/status.py index dc223a00..1b48dcea 100644 --- a/subcmds/status.py +++ b/subcmds/status.py @@ -15,10 +15,9 @@ import functools import glob import io -import multiprocessing import os -from command import DEFAULT_LOCAL_JOBS, PagedCommand, WORKER_BATCH_SIZE +from command import DEFAULT_LOCAL_JOBS, PagedCommand from color import Coloring import platform_utils @@ -119,22 +118,23 @@ the following meanings: def Execute(self, opt, args): all_projects = self.GetProjects(args) - counter = 0 - if opt.jobs == 1: - for project in all_projects: - state = project.PrintWorkTreeStatus(quiet=opt.quiet) + def _ProcessResults(_pool, _output, results): + ret = 0 + for (state, output) in results: + if output: + print(output, end='') if state == 'CLEAN': - counter += 1 - else: - with multiprocessing.Pool(opt.jobs) as pool: - states = pool.imap(functools.partial(self._StatusHelper, opt.quiet), - all_projects, chunksize=WORKER_BATCH_SIZE) - for (state, output) in states: - if output: - print(output, end='') - if state == 'CLEAN': - counter += 1 + ret += 1 + return ret + + counter = self.ExecuteInParallel( + opt.jobs, + functools.partial(self._StatusHelper, opt.quiet), + all_projects, + callback=_ProcessResults, + ordered=True) + if not opt.quiet and len(all_projects) == counter: print('nothing to commit (working directory clean)') diff --git a/subcmds/sync.py b/subcmds/sync.py index 21166af5..4763fadc 100644 --- a/subcmds/sync.py +++ b/subcmds/sync.py @@ -51,7 +51,7 @@ import git_superproject import gitc_utils from project import Project from project import RemoteSpec -from command import Command, MirrorSafeCommand, WORKER_BATCH_SIZE +from command import Command, MirrorSafeCommand from error import RepoChangedException, GitError, ManifestParseError import platform_utils from project import SyncBuffer @@ -428,11 +428,12 @@ later is required to fix a server side protocol bug. return (ret, fetched) - def _CheckoutOne(self, opt, project): + def _CheckoutOne(self, detach_head, force_sync, project): """Checkout work tree for one project Args: - opt: Program options returned from optparse. See _Options(). + detach_head: Whether to leave a detached HEAD. + force_sync: Force checking out of the repo. project: Project object for the project to checkout. Returns: @@ -440,10 +441,10 @@ later is required to fix a server side protocol bug. """ start = time.time() syncbuf = SyncBuffer(self.manifest.manifestProject.config, - detach_head=opt.detach_head) + detach_head=detach_head) success = False try: - project.Sync_LocalHalf(syncbuf, force_sync=opt.force_sync) + project.Sync_LocalHalf(syncbuf, force_sync=force_sync) success = syncbuf.Finish() except Exception as e: print('error: Cannot checkout %s: %s: %s' % @@ -464,44 +465,32 @@ later is required to fix a server side protocol bug. opt: Program options returned from optparse. See _Options(). err_results: A list of strings, paths to git repos where checkout failed. """ - ret = True - jobs = opt.jobs_checkout if opt.jobs_checkout else self.jobs - # Only checkout projects with worktrees. all_projects = [x for x in all_projects if x.worktree] - pm = Progress('Checking out', len(all_projects), quiet=opt.quiet) - - def _ProcessResults(results): + def _ProcessResults(pool, pm, results): + ret = True for (success, project, start, finish) in results: self.event_log.AddSync(project, event_log.TASK_SYNC_LOCAL, start, finish, success) # Check for any errors before running any more tasks. # ...we'll let existing jobs finish, though. if not success: + ret = False err_results.append(project.relpath) if opt.fail_fast: - return False + if pool: + pool.close() + return ret pm.update(msg=project.name) - return True + return ret - # NB: Multiprocessing is heavy, so don't spin it up for one job. - if len(all_projects) == 1 or jobs == 1: - if not _ProcessResults(self._CheckoutOne(opt, x) for x in all_projects): - ret = False - else: - with multiprocessing.Pool(jobs) as pool: - results = pool.imap_unordered( - functools.partial(self._CheckoutOne, opt), - all_projects, - chunksize=WORKER_BATCH_SIZE) - if not _ProcessResults(results): - ret = False - pool.close() - - pm.end() - - return ret and not err_results + return self.ExecuteInParallel( + opt.jobs_checkout if opt.jobs_checkout else self.jobs, + functools.partial(self._CheckoutOne, opt.detach_head, opt.force_sync), + all_projects, + callback=_ProcessResults, + output=Progress('Checking out', len(all_projects), quiet=opt.quiet)) and not err_results def _GCProjects(self, projects, opt, err_event): pm = Progress('Garbage collecting', len(projects), delay=False, quiet=opt.quiet)