command: add a helper for the parallel execution boilerplate

Now that we have a bunch of subcommands doing parallel execution, a
common pattern arises that we can factor out for most of them.  We
leave forall alone as it's a bit too complicated atm to cut over.

Change-Id: I3617a4f7c66142bcd1ab030cb4cca698a65010ac
Reviewed-on: https://gerrit-review.googlesource.com/c/git-repo/+/301942
Tested-by: Mike Frysinger <vapier@google.com>
Reviewed-by: Chris Mcdonald <cjmcdonald@google.com>
This commit is contained in:
Mike Frysinger 2021-03-01 00:56:38 -05:00
parent b8bf291ddb
commit b5d075d04f
10 changed files with 145 additions and 143 deletions

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import multiprocessing
import os import os
import optparse import optparse
import platform import platform
@ -21,6 +22,7 @@ import sys
from event_log import EventLog from event_log import EventLog
from error import NoSuchProjectError from error import NoSuchProjectError
from error import InvalidProjectGroupsError from error import InvalidProjectGroupsError
import progress
# Number of projects to submit to a single worker process at a time. # Number of projects to submit to a single worker process at a time.
@ -156,6 +158,44 @@ class Command(object):
""" """
raise NotImplementedError raise NotImplementedError
@staticmethod
def ExecuteInParallel(jobs, func, inputs, callback, output=None, ordered=False):
"""Helper for managing parallel execution boiler plate.
For subcommands that can easily split their work up.
Args:
jobs: How many parallel processes to use.
func: The function to apply to each of the |inputs|. Usually a
functools.partial for wrapping additional arguments. It will be run
in a separate process, so it must be pickalable, so nested functions
won't work. Methods on the subcommand Command class should work.
inputs: The list of items to process. Must be a list.
callback: The function to pass the results to for processing. It will be
executed in the main thread and process the results of |func| as they
become available. Thus it may be a local nested function. Its return
value is passed back directly. It takes three arguments:
- The processing pool (or None with one job).
- The |output| argument.
- An iterator for the results.
output: An output manager. May be progress.Progess or color.Coloring.
ordered: Whether the jobs should be processed in order.
Returns:
The |callback| function's results are returned.
"""
try:
# NB: Multiprocessing is heavy, so don't spin it up for one job.
if len(inputs) == 1 or jobs == 1:
return callback(None, output, (func(x) for x in inputs))
else:
with multiprocessing.Pool(jobs) as pool:
submit = pool.imap if ordered else pool.imap_unordered
return callback(pool, output, submit(func, inputs, chunksize=WORKER_BATCH_SIZE))
finally:
if isinstance(output, progress.Progress):
output.end()
def _ResetPathToProjectMap(self, projects): def _ResetPathToProjectMap(self, projects):
self._by_path = dict((p.worktree, p) for p in projects) self._by_path = dict((p.worktree, p) for p in projects)

View File

@ -15,10 +15,9 @@
from collections import defaultdict from collections import defaultdict
import functools import functools
import itertools import itertools
import multiprocessing
import sys import sys
from command import Command, DEFAULT_LOCAL_JOBS, WORKER_BATCH_SIZE from command import Command, DEFAULT_LOCAL_JOBS
from git_command import git from git_command import git
from progress import Progress from progress import Progress
@ -52,9 +51,9 @@ It is equivalent to "git branch -D <branchname>".
else: else:
args.insert(0, "'All local branches'") args.insert(0, "'All local branches'")
def _ExecuteOne(self, opt, nb, project): def _ExecuteOne(self, all_branches, nb, project):
"""Abandon one project.""" """Abandon one project."""
if opt.all: if all_branches:
branches = project.GetBranches() branches = project.GetBranches()
else: else:
branches = [nb] branches = [nb]
@ -72,7 +71,7 @@ It is equivalent to "git branch -D <branchname>".
success = defaultdict(list) success = defaultdict(list)
all_projects = self.GetProjects(args[1:]) all_projects = self.GetProjects(args[1:])
def _ProcessResults(states): def _ProcessResults(_pool, pm, states):
for (results, project) in states: for (results, project) in states:
for branch, status in results.items(): for branch, status in results.items():
if status: if status:
@ -81,17 +80,12 @@ It is equivalent to "git branch -D <branchname>".
err[branch].append(project) err[branch].append(project)
pm.update() pm.update()
pm = Progress('Abandon %s' % nb, len(all_projects), quiet=opt.quiet) self.ExecuteInParallel(
# NB: Multiprocessing is heavy, so don't spin it up for one job. opt.jobs,
if len(all_projects) == 1 or opt.jobs == 1: functools.partial(self._ExecuteOne, opt.all, nb),
_ProcessResults(self._ExecuteOne(opt, nb, x) for x in all_projects) all_projects,
else: callback=_ProcessResults,
with multiprocessing.Pool(opt.jobs) as pool: output=Progress('Abandon %s' % (nb,), len(all_projects), quiet=opt.quiet))
states = pool.imap_unordered(
functools.partial(self._ExecuteOne, opt, nb), all_projects,
chunksize=WORKER_BATCH_SIZE)
_ProcessResults(states)
pm.end()
width = max(itertools.chain( width = max(itertools.chain(
[25], (len(x) for x in itertools.chain(success, err)))) [25], (len(x) for x in itertools.chain(success, err))))

View File

@ -13,10 +13,10 @@
# limitations under the License. # limitations under the License.
import itertools import itertools
import multiprocessing
import sys import sys
from color import Coloring from color import Coloring
from command import Command, DEFAULT_LOCAL_JOBS, WORKER_BATCH_SIZE from command import Command, DEFAULT_LOCAL_JOBS
class BranchColoring(Coloring): class BranchColoring(Coloring):
@ -102,15 +102,19 @@ is shown, then the branch appears in all projects.
out = BranchColoring(self.manifest.manifestProject.config) out = BranchColoring(self.manifest.manifestProject.config)
all_branches = {} all_branches = {}
project_cnt = len(projects) project_cnt = len(projects)
with multiprocessing.Pool(processes=opt.jobs) as pool:
project_branches = pool.imap_unordered(
expand_project_to_branches, projects, chunksize=WORKER_BATCH_SIZE)
for name, b in itertools.chain.from_iterable(project_branches): def _ProcessResults(_pool, _output, results):
for name, b in itertools.chain.from_iterable(results):
if name not in all_branches: if name not in all_branches:
all_branches[name] = BranchInfo(name) all_branches[name] = BranchInfo(name)
all_branches[name].add(b) all_branches[name].add(b)
self.ExecuteInParallel(
opt.jobs,
expand_project_to_branches,
projects,
callback=_ProcessResults)
names = sorted(all_branches) names = sorted(all_branches)
if not names: if not names:

View File

@ -13,10 +13,9 @@
# limitations under the License. # limitations under the License.
import functools import functools
import multiprocessing
import sys import sys
from command import Command, DEFAULT_LOCAL_JOBS, WORKER_BATCH_SIZE from command import Command, DEFAULT_LOCAL_JOBS
from progress import Progress from progress import Progress
@ -50,7 +49,7 @@ The command is equivalent to:
success = [] success = []
all_projects = self.GetProjects(args[1:]) all_projects = self.GetProjects(args[1:])
def _ProcessResults(results): def _ProcessResults(_pool, pm, results):
for status, project in results: for status, project in results:
if status is not None: if status is not None:
if status: if status:
@ -59,17 +58,12 @@ The command is equivalent to:
err.append(project) err.append(project)
pm.update() pm.update()
pm = Progress('Checkout %s' % nb, len(all_projects), quiet=opt.quiet) self.ExecuteInParallel(
# NB: Multiprocessing is heavy, so don't spin it up for one job. opt.jobs,
if len(all_projects) == 1 or opt.jobs == 1: functools.partial(self._ExecuteOne, nb),
_ProcessResults(self._ExecuteOne(nb, x) for x in all_projects) all_projects,
else: callback=_ProcessResults,
with multiprocessing.Pool(opt.jobs) as pool: output=Progress('Checkout %s' % (nb,), len(all_projects), quiet=opt.quiet))
results = pool.imap_unordered(
functools.partial(self._ExecuteOne, nb), all_projects,
chunksize=WORKER_BATCH_SIZE)
_ProcessResults(results)
pm.end()
if err: if err:
for p in err: for p in err:

View File

@ -14,9 +14,8 @@
import functools import functools
import io import io
import multiprocessing
from command import DEFAULT_LOCAL_JOBS, PagedCommand, WORKER_BATCH_SIZE from command import DEFAULT_LOCAL_JOBS, PagedCommand
class Diff(PagedCommand): class Diff(PagedCommand):
@ -36,7 +35,7 @@ to the Unix 'patch' command.
dest='absolute', action='store_true', dest='absolute', action='store_true',
help='Paths are relative to the repository root') help='Paths are relative to the repository root')
def _DiffHelper(self, absolute, project): def _ExecuteOne(self, absolute, project):
"""Obtains the diff for a specific project. """Obtains the diff for a specific project.
Args: Args:
@ -51,22 +50,20 @@ to the Unix 'patch' command.
return (ret, buf.getvalue()) return (ret, buf.getvalue())
def Execute(self, opt, args): def Execute(self, opt, args):
ret = 0
all_projects = self.GetProjects(args) all_projects = self.GetProjects(args)
# NB: Multiprocessing is heavy, so don't spin it up for one job. def _ProcessResults(_pool, _output, results):
if len(all_projects) == 1 or opt.jobs == 1: ret = 0
for project in all_projects: for (state, output) in results:
if not project.PrintWorkTreeDiff(opt.absolute): if output:
print(output, end='')
if not state:
ret = 1 ret = 1
else: return ret
with multiprocessing.Pool(opt.jobs) as pool:
states = pool.imap(functools.partial(self._DiffHelper, opt.absolute),
all_projects, WORKER_BATCH_SIZE)
for (state, output) in states:
if output:
print(output, end='')
if not state:
ret = 1
return ret return self.ExecuteInParallel(
opt.jobs,
functools.partial(self._ExecuteOne, opt.absolute),
all_projects,
callback=_ProcessResults,
ordered=True)

View File

@ -13,11 +13,10 @@
# limitations under the License. # limitations under the License.
import functools import functools
import multiprocessing
import sys import sys
from color import Coloring from color import Coloring
from command import DEFAULT_LOCAL_JOBS, PagedCommand, WORKER_BATCH_SIZE from command import DEFAULT_LOCAL_JOBS, PagedCommand
from error import GitError from error import GitError
from git_command import GitCommand from git_command import GitCommand
@ -173,7 +172,7 @@ contain a line that matches both expressions:
return (project, p.Wait(), p.stdout, p.stderr) return (project, p.Wait(), p.stdout, p.stderr)
@staticmethod @staticmethod
def _ProcessResults(out, full_name, have_rev, results): def _ProcessResults(full_name, have_rev, _pool, out, results):
git_failed = False git_failed = False
bad_rev = False bad_rev = False
have_match = False have_match = False
@ -256,18 +255,13 @@ contain a line that matches both expressions:
cmd_argv.extend(opt.revision) cmd_argv.extend(opt.revision)
cmd_argv.append('--') cmd_argv.append('--')
process_results = functools.partial( git_failed, bad_rev, have_match = self.ExecuteInParallel(
self._ProcessResults, out, full_name, have_rev) opt.jobs,
# NB: Multiprocessing is heavy, so don't spin it up for one job. functools.partial(self._ExecuteOne, cmd_argv),
if len(projects) == 1 or opt.jobs == 1: projects,
git_failed, bad_rev, have_match = process_results( callback=functools.partial(self._ProcessResults, full_name, have_rev),
self._ExecuteOne(cmd_argv, x) for x in projects) output=out,
else: ordered=True)
with multiprocessing.Pool(opt.jobs) as pool:
results = pool.imap(
functools.partial(self._ExecuteOne, cmd_argv), projects,
chunksize=WORKER_BATCH_SIZE)
git_failed, bad_rev, have_match = process_results(results)
if git_failed: if git_failed:
sys.exit(1) sys.exit(1)

View File

@ -13,10 +13,9 @@
# limitations under the License. # limitations under the License.
import itertools import itertools
import multiprocessing
from color import Coloring from color import Coloring
from command import DEFAULT_LOCAL_JOBS, PagedCommand, WORKER_BATCH_SIZE from command import DEFAULT_LOCAL_JOBS, PagedCommand
class Prune(PagedCommand): class Prune(PagedCommand):
@ -36,18 +35,15 @@ class Prune(PagedCommand):
# NB: Should be able to refactor this module to display summary as results # NB: Should be able to refactor this module to display summary as results
# come back from children. # come back from children.
def _ProcessResults(results): def _ProcessResults(_pool, _output, results):
return list(itertools.chain.from_iterable(results)) return list(itertools.chain.from_iterable(results))
# NB: Multiprocessing is heavy, so don't spin it up for one job. all_branches = self.ExecuteInParallel(
if len(projects) == 1 or opt.jobs == 1: opt.jobs,
all_branches = _ProcessResults(self._ExecuteOne(x) for x in projects) self._ExecuteOne,
else: projects,
with multiprocessing.Pool(opt.jobs) as pool: callback=_ProcessResults,
results = pool.imap( ordered=True)
self._ExecuteOne, projects,
chunksize=WORKER_BATCH_SIZE)
all_branches = _ProcessResults(results)
if not all_branches: if not all_branches:
return return

View File

@ -13,11 +13,10 @@
# limitations under the License. # limitations under the License.
import functools import functools
import multiprocessing
import os import os
import sys import sys
from command import Command, DEFAULT_LOCAL_JOBS, WORKER_BATCH_SIZE from command import Command, DEFAULT_LOCAL_JOBS
from git_config import IsImmutable from git_config import IsImmutable
from git_command import git from git_command import git
import gitc_utils import gitc_utils
@ -55,7 +54,7 @@ revision specified in the manifest.
if not git.check_ref_format('heads/%s' % nb): if not git.check_ref_format('heads/%s' % nb):
self.OptionParser.error("'%s' is not a valid name" % nb) self.OptionParser.error("'%s' is not a valid name" % nb)
def _ExecuteOne(self, opt, nb, project): def _ExecuteOne(self, revision, nb, project):
"""Start one project.""" """Start one project."""
# If the current revision is immutable, such as a SHA1, a tag or # If the current revision is immutable, such as a SHA1, a tag or
# a change, then we can't push back to it. Substitute with # a change, then we can't push back to it. Substitute with
@ -69,7 +68,7 @@ revision specified in the manifest.
try: try:
ret = project.StartBranch( ret = project.StartBranch(
nb, branch_merge=branch_merge, revision=opt.revision) nb, branch_merge=branch_merge, revision=revision)
except Exception as e: except Exception as e:
print('error: unable to checkout %s: %s' % (project.name, e), file=sys.stderr) print('error: unable to checkout %s: %s' % (project.name, e), file=sys.stderr)
ret = False ret = False
@ -123,23 +122,18 @@ revision specified in the manifest.
pm.update() pm.update()
pm.end() pm.end()
def _ProcessResults(results): def _ProcessResults(_pool, pm, results):
for (result, project) in results: for (result, project) in results:
if not result: if not result:
err.append(project) err.append(project)
pm.update() pm.update()
pm = Progress('Starting %s' % nb, len(all_projects), quiet=opt.quiet) self.ExecuteInParallel(
# NB: Multiprocessing is heavy, so don't spin it up for one job. opt.jobs,
if len(all_projects) == 1 or opt.jobs == 1: functools.partial(self._ExecuteOne, opt.revision, nb),
_ProcessResults(self._ExecuteOne(opt, nb, x) for x in all_projects) all_projects,
else: callback=_ProcessResults,
with multiprocessing.Pool(opt.jobs) as pool: output=Progress('Starting %s' % (nb,), len(all_projects), quiet=opt.quiet))
results = pool.imap_unordered(
functools.partial(self._ExecuteOne, opt, nb), all_projects,
chunksize=WORKER_BATCH_SIZE)
_ProcessResults(results)
pm.end()
if err: if err:
for p in err: for p in err:

View File

@ -15,10 +15,9 @@
import functools import functools
import glob import glob
import io import io
import multiprocessing
import os import os
from command import DEFAULT_LOCAL_JOBS, PagedCommand, WORKER_BATCH_SIZE from command import DEFAULT_LOCAL_JOBS, PagedCommand
from color import Coloring from color import Coloring
import platform_utils import platform_utils
@ -119,22 +118,23 @@ the following meanings:
def Execute(self, opt, args): def Execute(self, opt, args):
all_projects = self.GetProjects(args) all_projects = self.GetProjects(args)
counter = 0
if opt.jobs == 1: def _ProcessResults(_pool, _output, results):
for project in all_projects: ret = 0
state = project.PrintWorkTreeStatus(quiet=opt.quiet) for (state, output) in results:
if output:
print(output, end='')
if state == 'CLEAN': if state == 'CLEAN':
counter += 1 ret += 1
else: return ret
with multiprocessing.Pool(opt.jobs) as pool:
states = pool.imap(functools.partial(self._StatusHelper, opt.quiet), counter = self.ExecuteInParallel(
all_projects, chunksize=WORKER_BATCH_SIZE) opt.jobs,
for (state, output) in states: functools.partial(self._StatusHelper, opt.quiet),
if output: all_projects,
print(output, end='') callback=_ProcessResults,
if state == 'CLEAN': ordered=True)
counter += 1
if not opt.quiet and len(all_projects) == counter: if not opt.quiet and len(all_projects) == counter:
print('nothing to commit (working directory clean)') print('nothing to commit (working directory clean)')

View File

@ -51,7 +51,7 @@ import git_superproject
import gitc_utils import gitc_utils
from project import Project from project import Project
from project import RemoteSpec from project import RemoteSpec
from command import Command, MirrorSafeCommand, WORKER_BATCH_SIZE from command import Command, MirrorSafeCommand
from error import RepoChangedException, GitError, ManifestParseError from error import RepoChangedException, GitError, ManifestParseError
import platform_utils import platform_utils
from project import SyncBuffer from project import SyncBuffer
@ -428,11 +428,12 @@ later is required to fix a server side protocol bug.
return (ret, fetched) return (ret, fetched)
def _CheckoutOne(self, opt, project): def _CheckoutOne(self, detach_head, force_sync, project):
"""Checkout work tree for one project """Checkout work tree for one project
Args: Args:
opt: Program options returned from optparse. See _Options(). detach_head: Whether to leave a detached HEAD.
force_sync: Force checking out of the repo.
project: Project object for the project to checkout. project: Project object for the project to checkout.
Returns: Returns:
@ -440,10 +441,10 @@ later is required to fix a server side protocol bug.
""" """
start = time.time() start = time.time()
syncbuf = SyncBuffer(self.manifest.manifestProject.config, syncbuf = SyncBuffer(self.manifest.manifestProject.config,
detach_head=opt.detach_head) detach_head=detach_head)
success = False success = False
try: try:
project.Sync_LocalHalf(syncbuf, force_sync=opt.force_sync) project.Sync_LocalHalf(syncbuf, force_sync=force_sync)
success = syncbuf.Finish() success = syncbuf.Finish()
except Exception as e: except Exception as e:
print('error: Cannot checkout %s: %s: %s' % print('error: Cannot checkout %s: %s: %s' %
@ -464,44 +465,32 @@ later is required to fix a server side protocol bug.
opt: Program options returned from optparse. See _Options(). opt: Program options returned from optparse. See _Options().
err_results: A list of strings, paths to git repos where checkout failed. err_results: A list of strings, paths to git repos where checkout failed.
""" """
ret = True
jobs = opt.jobs_checkout if opt.jobs_checkout else self.jobs
# Only checkout projects with worktrees. # Only checkout projects with worktrees.
all_projects = [x for x in all_projects if x.worktree] all_projects = [x for x in all_projects if x.worktree]
pm = Progress('Checking out', len(all_projects), quiet=opt.quiet) def _ProcessResults(pool, pm, results):
ret = True
def _ProcessResults(results):
for (success, project, start, finish) in results: for (success, project, start, finish) in results:
self.event_log.AddSync(project, event_log.TASK_SYNC_LOCAL, self.event_log.AddSync(project, event_log.TASK_SYNC_LOCAL,
start, finish, success) start, finish, success)
# Check for any errors before running any more tasks. # Check for any errors before running any more tasks.
# ...we'll let existing jobs finish, though. # ...we'll let existing jobs finish, though.
if not success: if not success:
ret = False
err_results.append(project.relpath) err_results.append(project.relpath)
if opt.fail_fast: if opt.fail_fast:
return False if pool:
pool.close()
return ret
pm.update(msg=project.name) pm.update(msg=project.name)
return True return ret
# NB: Multiprocessing is heavy, so don't spin it up for one job. return self.ExecuteInParallel(
if len(all_projects) == 1 or jobs == 1: opt.jobs_checkout if opt.jobs_checkout else self.jobs,
if not _ProcessResults(self._CheckoutOne(opt, x) for x in all_projects): functools.partial(self._CheckoutOne, opt.detach_head, opt.force_sync),
ret = False all_projects,
else: callback=_ProcessResults,
with multiprocessing.Pool(jobs) as pool: output=Progress('Checking out', len(all_projects), quiet=opt.quiet)) and not err_results
results = pool.imap_unordered(
functools.partial(self._CheckoutOne, opt),
all_projects,
chunksize=WORKER_BATCH_SIZE)
if not _ProcessResults(results):
ret = False
pool.close()
pm.end()
return ret and not err_results
def _GCProjects(self, projects, opt, err_event): def _GCProjects(self, projects, opt, err_event):
pm = Progress('Garbage collecting', len(projects), delay=False, quiet=opt.quiet) pm = Progress('Garbage collecting', len(projects), delay=False, quiet=opt.quiet)