diff --git a/command.py b/command.py index fa48264b..2a2ce138 100644 --- a/command.py +++ b/command.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import contextlib import multiprocessing import optparse import os @@ -70,6 +71,14 @@ class Command: # migrated subcommands can set it to False. MULTI_MANIFEST_SUPPORT = True + # Shared data across parallel execution workers. + _parallel_context = None + + @classmethod + def get_parallel_context(cls): + assert cls._parallel_context is not None + return cls._parallel_context + def __init__( self, repodir=None, @@ -242,9 +251,36 @@ class Command: """Perform the action, after option parsing is complete.""" raise NotImplementedError - @staticmethod + @classmethod + @contextlib.contextmanager + def ParallelContext(cls): + """Obtains the context, which is shared to ExecuteInParallel workers. + + Callers can store data in the context dict before invocation of + ExecuteInParallel. The dict will then be shared to child workers of + ExecuteInParallel. + """ + assert cls._parallel_context is None + cls._parallel_context = {} + try: + yield + finally: + cls._parallel_context = None + + @classmethod + def _SetParallelContext(cls, context): + cls._parallel_context = context + + @classmethod def ExecuteInParallel( - jobs, func, inputs, callback, output=None, ordered=False + cls, + jobs, + func, + inputs, + callback, + output=None, + ordered=False, + chunksize=WORKER_BATCH_SIZE, ): """Helper for managing parallel execution boiler plate. @@ -269,6 +305,8 @@ class Command: output: An output manager. May be progress.Progess or color.Coloring. ordered: Whether the jobs should be processed in order. + chunksize: The number of jobs processed in batch by parallel + workers. Returns: The |callback| function's results are returned. @@ -278,12 +316,16 @@ class Command: if len(inputs) == 1 or jobs == 1: return callback(None, output, (func(x) for x in inputs)) else: - with multiprocessing.Pool(jobs) as pool: + with multiprocessing.Pool( + jobs, + initializer=cls._SetParallelContext, + initargs=(cls._parallel_context,), + ) as pool: submit = pool.imap if ordered else pool.imap_unordered return callback( pool, output, - submit(func, inputs, chunksize=WORKER_BATCH_SIZE), + submit(func, inputs, chunksize=chunksize), ) finally: if isinstance(output, progress.Progress): diff --git a/subcmds/sync.py b/subcmds/sync.py index bebe18b9..00fee776 100644 --- a/subcmds/sync.py +++ b/subcmds/sync.py @@ -141,7 +141,7 @@ class _FetchOneResult(NamedTuple): Attributes: success (bool): True if successful. - project (Project): The fetched project. + project_idx (int): The fetched project index. start (float): The starting time.time(). finish (float): The ending time.time(). remote_fetched (bool): True if the remote was actually queried. @@ -149,7 +149,7 @@ class _FetchOneResult(NamedTuple): success: bool errors: List[Exception] - project: Project + project_idx: int start: float finish: float remote_fetched: bool @@ -182,14 +182,14 @@ class _CheckoutOneResult(NamedTuple): Attributes: success (bool): True if successful. - project (Project): The project. + project_idx (int): The project index. start (float): The starting time.time(). finish (float): The ending time.time(). """ success: bool errors: List[Exception] - project: Project + project_idx: int start: float finish: float @@ -592,7 +592,8 @@ later is required to fix a server side protocol bug. branch = branch[len(R_HEADS) :] return branch - def _GetCurrentBranchOnly(self, opt, manifest): + @classmethod + def _GetCurrentBranchOnly(cls, opt, manifest): """Returns whether current-branch or use-superproject options are enabled. @@ -710,7 +711,8 @@ later is required to fix a server side protocol bug. if need_unload: m.outer_client.manifest.Unload() - def _FetchProjectList(self, opt, projects): + @classmethod + def _FetchProjectList(cls, opt, projects): """Main function of the fetch worker. The projects we're given share the same underlying git object store, so @@ -722,21 +724,23 @@ later is required to fix a server side protocol bug. opt: Program options returned from optparse. See _Options(). projects: Projects to fetch. """ - return [self._FetchOne(opt, x) for x in projects] + return [cls._FetchOne(opt, x) for x in projects] - def _FetchOne(self, opt, project): + @classmethod + def _FetchOne(cls, opt, project_idx): """Fetch git objects for a single project. Args: opt: Program options returned from optparse. See _Options(). - project: Project object for the project to fetch. + project_idx: Project index for the project to fetch. Returns: Whether the fetch was successful. """ + project = cls.get_parallel_context()["projects"][project_idx] start = time.time() k = f"{project.name} @ {project.relpath}" - self._sync_dict[k] = start + cls.get_parallel_context()["sync_dict"][k] = start success = False remote_fetched = False errors = [] @@ -746,7 +750,7 @@ later is required to fix a server side protocol bug. quiet=opt.quiet, verbose=opt.verbose, output_redir=buf, - current_branch_only=self._GetCurrentBranchOnly( + current_branch_only=cls._GetCurrentBranchOnly( opt, project.manifest ), force_sync=opt.force_sync, @@ -756,7 +760,7 @@ later is required to fix a server side protocol bug. optimized_fetch=opt.optimized_fetch, retry_fetches=opt.retry_fetches, prune=opt.prune, - ssh_proxy=self.ssh_proxy, + ssh_proxy=cls.get_parallel_context()["ssh_proxy"], clone_filter=project.manifest.CloneFilter, partial_clone_exclude=project.manifest.PartialCloneExclude, clone_filter_for_depth=project.manifest.CloneFilterForDepth, @@ -788,24 +792,20 @@ later is required to fix a server side protocol bug. type(e).__name__, e, ) - del self._sync_dict[k] errors.append(e) raise + finally: + del cls.get_parallel_context()["sync_dict"][k] finish = time.time() - del self._sync_dict[k] return _FetchOneResult( - success, errors, project, start, finish, remote_fetched + success, errors, project_idx, start, finish, remote_fetched ) - @classmethod - def _FetchInitChild(cls, ssh_proxy): - cls.ssh_proxy = ssh_proxy - def _GetSyncProgressMessage(self): earliest_time = float("inf") earliest_proj = None - items = self._sync_dict.items() + items = self.get_parallel_context()["sync_dict"].items() for project, t in items: if t < earliest_time: earliest_time = t @@ -813,7 +813,7 @@ later is required to fix a server side protocol bug. if not earliest_proj: # This function is called when sync is still running but in some - # cases (by chance), _sync_dict can contain no entries. Return some + # cases (by chance), sync_dict can contain no entries. Return some # text to indicate that sync is still working. return "..working.." @@ -835,7 +835,6 @@ later is required to fix a server side protocol bug. elide=True, ) - self._sync_dict = multiprocessing.Manager().dict() sync_event = _threading.Event() def _MonitorSyncLoop(): @@ -846,21 +845,13 @@ later is required to fix a server side protocol bug. sync_progress_thread = _threading.Thread(target=_MonitorSyncLoop) sync_progress_thread.daemon = True - sync_progress_thread.start() - objdir_project_map = dict() - for project in projects: - objdir_project_map.setdefault(project.objdir, []).append(project) - projects_list = list(objdir_project_map.values()) - - jobs = min(opt.jobs_network, len(projects_list)) - - def _ProcessResults(results_sets): + def _ProcessResults(pool, pm, results_sets): ret = True for results in results_sets: for result in results: success = result.success - project = result.project + project = projects[result.project_idx] start = result.start finish = result.finish self._fetch_times.Set(project, finish - start) @@ -884,45 +875,49 @@ later is required to fix a server side protocol bug. fetched.add(project.gitdir) pm.update() if not ret and opt.fail_fast: + if pool: + pool.close() break return ret - # We pass the ssh proxy settings via the class. This allows - # multiprocessing to pickle it up when spawning children. We can't pass - # it as an argument to _FetchProjectList below as multiprocessing is - # unable to pickle those. - Sync.ssh_proxy = None + with self.ParallelContext(): + self.get_parallel_context()["projects"] = projects + self.get_parallel_context()[ + "sync_dict" + ] = multiprocessing.Manager().dict() - # NB: Multiprocessing is heavy, so don't spin it up for one job. - if jobs == 1: - self._FetchInitChild(ssh_proxy) - if not _ProcessResults( - self._FetchProjectList(opt, x) for x in projects_list - ): - ret = False - else: + objdir_project_map = dict() + for index, project in enumerate(projects): + objdir_project_map.setdefault(project.objdir, []).append(index) + projects_list = list(objdir_project_map.values()) + + jobs = min(opt.jobs_network, len(projects_list)) + + # We pass the ssh proxy settings via the class. This allows + # multiprocessing to pickle it up when spawning children. We can't + # pass it as an argument to _FetchProjectList below as + # multiprocessing is unable to pickle those. + self.get_parallel_context()["ssh_proxy"] = ssh_proxy + + sync_progress_thread.start() if not opt.quiet: pm.update(inc=0, msg="warming up") - with multiprocessing.Pool( - jobs, initializer=self._FetchInitChild, initargs=(ssh_proxy,) - ) as pool: - results = pool.imap_unordered( + try: + ret = self.ExecuteInParallel( + jobs, functools.partial(self._FetchProjectList, opt), projects_list, - chunksize=_chunksize(len(projects_list), jobs), + callback=_ProcessResults, + output=pm, + # Use chunksize=1 to avoid the chance that some workers are + # idle while other workers still have more than one job in + # their chunk queue. + chunksize=1, ) - if not _ProcessResults(results): - ret = False - pool.close() + finally: + sync_event.set() + sync_progress_thread.join() - # Cleanup the reference now that we're done with it, and we're going to - # release any resources it points to. If we don't, later - # multiprocessing usage (e.g. checkouts) will try to pickle and then - # crash. - del Sync.ssh_proxy - - sync_event.set() - pm.end() self._fetch_times.Save() self._local_sync_state.Save() @@ -1008,14 +1003,15 @@ later is required to fix a server side protocol bug. return _FetchMainResult(all_projects) + @classmethod def _CheckoutOne( - self, + cls, detach_head, force_sync, force_checkout, force_rebase, verbose, - project, + project_idx, ): """Checkout work tree for one project @@ -1027,11 +1023,12 @@ later is required to fix a server side protocol bug. force_checkout: Force checking out of the repo content. force_rebase: Force rebase. verbose: Whether to show verbose messages. - project: Project object for the project to checkout. + project_idx: Project index for the project to checkout. Returns: Whether the fetch was successful. """ + project = cls.get_parallel_context()["projects"][project_idx] start = time.time() syncbuf = SyncBuffer( project.manifest.manifestProject.config, detach_head=detach_head @@ -1065,7 +1062,7 @@ later is required to fix a server side protocol bug. if not success: logger.error("error: Cannot checkout %s", project.name) finish = time.time() - return _CheckoutOneResult(success, errors, project, start, finish) + return _CheckoutOneResult(success, errors, project_idx, start, finish) def _Checkout(self, all_projects, opt, err_results, checkout_errors): """Checkout projects listed in all_projects @@ -1083,7 +1080,9 @@ later is required to fix a server side protocol bug. ret = True for result in results: success = result.success - project = result.project + project = self.get_parallel_context()["projects"][ + result.project_idx + ] start = result.start finish = result.finish self.event_log.AddSync( @@ -1110,22 +1109,28 @@ later is required to fix a server side protocol bug. return ret for projects in _SafeCheckoutOrder(all_projects): - proc_res = self.ExecuteInParallel( - opt.jobs_checkout, - functools.partial( - self._CheckoutOne, - opt.detach_head, - opt.force_sync, - opt.force_checkout, - opt.rebase, - opt.verbose, - ), - projects, - callback=_ProcessResults, - output=Progress( - "Checking out", len(all_projects), quiet=opt.quiet - ), - ) + with self.ParallelContext(): + self.get_parallel_context()["projects"] = projects + proc_res = self.ExecuteInParallel( + opt.jobs_checkout, + functools.partial( + self._CheckoutOne, + opt.detach_head, + opt.force_sync, + opt.force_checkout, + opt.rebase, + opt.verbose, + ), + range(len(projects)), + callback=_ProcessResults, + output=Progress( + "Checking out", len(all_projects), quiet=opt.quiet + ), + # Use chunksize=1 to avoid the chance that some workers are + # idle while other workers still have more than one job in + # their chunk queue. + chunksize=1, + ) self._local_sync_state.Save() return proc_res and not err_results