From 39ffd9977e2f6cb1ca1757e59173fc93e0eab72c Mon Sep 17 00:00:00 2001 From: Kuang-che Wu Date: Fri, 18 Oct 2024 23:32:08 +0800 Subject: [PATCH] sync: reduce multiprocessing serialization overhead Background: - Manifest object is large (for projects like Android) in terms of serialization cost and size (more than 1mb). - Lots of Project objects usually share only a few manifest objects. Before this CL, Project objects were passed to workers via function parameters. Function parameters are pickled separately (in chunk). In other words, manifests are serialized again and again. The major serialization overhead of repo sync was O(manifest_size * projects / chunksize) This CL uses following tricks to reduce serialization overhead. - All projects are pickled in one invocation. Because Project objects share manifests, pickle library remembers which objects are already seen and avoid the serialization cost. - Pass the Project objects to workers at worker intialization time. And pass project index as function parameters instead. The number of workers is much smaller than the number of projects. - Worker init state are shared on Linux (fork based). So it requires zero serialization for Project objects. On Linux (fork based), the serialization overhead is O(projects) --- one int per project On Windows (spawn based), the serialization overhead is O(manifest_size * min(workers, projects)) Moreover, use chunksize=1 to avoid the chance that some workers are idle while other workers still have more than one job in their chunk queue. Using 2.7k projects as the baseline, originally "repo sync" no-op sync takes 31s for fetch and 25s for checkout on my Linux workstation. With this CL, it takes 12s for fetch and 1s for checkout. Bug: b/371638995 Change-Id: Ifa22072ea54eacb4a5c525c050d84de371e87caa Reviewed-on: https://gerrit-review.googlesource.com/c/git-repo/+/439921 Tested-by: Kuang-che Wu Reviewed-by: Josip Sokcevic Commit-Queue: Kuang-che Wu --- command.py | 50 ++++++++++++-- subcmds/sync.py | 169 +++++++++++++++++++++++++----------------------- 2 files changed, 133 insertions(+), 86 deletions(-) diff --git a/command.py b/command.py index fa48264b..2a2ce138 100644 --- a/command.py +++ b/command.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import contextlib import multiprocessing import optparse import os @@ -70,6 +71,14 @@ class Command: # migrated subcommands can set it to False. MULTI_MANIFEST_SUPPORT = True + # Shared data across parallel execution workers. + _parallel_context = None + + @classmethod + def get_parallel_context(cls): + assert cls._parallel_context is not None + return cls._parallel_context + def __init__( self, repodir=None, @@ -242,9 +251,36 @@ class Command: """Perform the action, after option parsing is complete.""" raise NotImplementedError - @staticmethod + @classmethod + @contextlib.contextmanager + def ParallelContext(cls): + """Obtains the context, which is shared to ExecuteInParallel workers. + + Callers can store data in the context dict before invocation of + ExecuteInParallel. The dict will then be shared to child workers of + ExecuteInParallel. + """ + assert cls._parallel_context is None + cls._parallel_context = {} + try: + yield + finally: + cls._parallel_context = None + + @classmethod + def _SetParallelContext(cls, context): + cls._parallel_context = context + + @classmethod def ExecuteInParallel( - jobs, func, inputs, callback, output=None, ordered=False + cls, + jobs, + func, + inputs, + callback, + output=None, + ordered=False, + chunksize=WORKER_BATCH_SIZE, ): """Helper for managing parallel execution boiler plate. @@ -269,6 +305,8 @@ class Command: output: An output manager. May be progress.Progess or color.Coloring. ordered: Whether the jobs should be processed in order. + chunksize: The number of jobs processed in batch by parallel + workers. Returns: The |callback| function's results are returned. @@ -278,12 +316,16 @@ class Command: if len(inputs) == 1 or jobs == 1: return callback(None, output, (func(x) for x in inputs)) else: - with multiprocessing.Pool(jobs) as pool: + with multiprocessing.Pool( + jobs, + initializer=cls._SetParallelContext, + initargs=(cls._parallel_context,), + ) as pool: submit = pool.imap if ordered else pool.imap_unordered return callback( pool, output, - submit(func, inputs, chunksize=WORKER_BATCH_SIZE), + submit(func, inputs, chunksize=chunksize), ) finally: if isinstance(output, progress.Progress): diff --git a/subcmds/sync.py b/subcmds/sync.py index bebe18b9..00fee776 100644 --- a/subcmds/sync.py +++ b/subcmds/sync.py @@ -141,7 +141,7 @@ class _FetchOneResult(NamedTuple): Attributes: success (bool): True if successful. - project (Project): The fetched project. + project_idx (int): The fetched project index. start (float): The starting time.time(). finish (float): The ending time.time(). remote_fetched (bool): True if the remote was actually queried. @@ -149,7 +149,7 @@ class _FetchOneResult(NamedTuple): success: bool errors: List[Exception] - project: Project + project_idx: int start: float finish: float remote_fetched: bool @@ -182,14 +182,14 @@ class _CheckoutOneResult(NamedTuple): Attributes: success (bool): True if successful. - project (Project): The project. + project_idx (int): The project index. start (float): The starting time.time(). finish (float): The ending time.time(). """ success: bool errors: List[Exception] - project: Project + project_idx: int start: float finish: float @@ -592,7 +592,8 @@ later is required to fix a server side protocol bug. branch = branch[len(R_HEADS) :] return branch - def _GetCurrentBranchOnly(self, opt, manifest): + @classmethod + def _GetCurrentBranchOnly(cls, opt, manifest): """Returns whether current-branch or use-superproject options are enabled. @@ -710,7 +711,8 @@ later is required to fix a server side protocol bug. if need_unload: m.outer_client.manifest.Unload() - def _FetchProjectList(self, opt, projects): + @classmethod + def _FetchProjectList(cls, opt, projects): """Main function of the fetch worker. The projects we're given share the same underlying git object store, so @@ -722,21 +724,23 @@ later is required to fix a server side protocol bug. opt: Program options returned from optparse. See _Options(). projects: Projects to fetch. """ - return [self._FetchOne(opt, x) for x in projects] + return [cls._FetchOne(opt, x) for x in projects] - def _FetchOne(self, opt, project): + @classmethod + def _FetchOne(cls, opt, project_idx): """Fetch git objects for a single project. Args: opt: Program options returned from optparse. See _Options(). - project: Project object for the project to fetch. + project_idx: Project index for the project to fetch. Returns: Whether the fetch was successful. """ + project = cls.get_parallel_context()["projects"][project_idx] start = time.time() k = f"{project.name} @ {project.relpath}" - self._sync_dict[k] = start + cls.get_parallel_context()["sync_dict"][k] = start success = False remote_fetched = False errors = [] @@ -746,7 +750,7 @@ later is required to fix a server side protocol bug. quiet=opt.quiet, verbose=opt.verbose, output_redir=buf, - current_branch_only=self._GetCurrentBranchOnly( + current_branch_only=cls._GetCurrentBranchOnly( opt, project.manifest ), force_sync=opt.force_sync, @@ -756,7 +760,7 @@ later is required to fix a server side protocol bug. optimized_fetch=opt.optimized_fetch, retry_fetches=opt.retry_fetches, prune=opt.prune, - ssh_proxy=self.ssh_proxy, + ssh_proxy=cls.get_parallel_context()["ssh_proxy"], clone_filter=project.manifest.CloneFilter, partial_clone_exclude=project.manifest.PartialCloneExclude, clone_filter_for_depth=project.manifest.CloneFilterForDepth, @@ -788,24 +792,20 @@ later is required to fix a server side protocol bug. type(e).__name__, e, ) - del self._sync_dict[k] errors.append(e) raise + finally: + del cls.get_parallel_context()["sync_dict"][k] finish = time.time() - del self._sync_dict[k] return _FetchOneResult( - success, errors, project, start, finish, remote_fetched + success, errors, project_idx, start, finish, remote_fetched ) - @classmethod - def _FetchInitChild(cls, ssh_proxy): - cls.ssh_proxy = ssh_proxy - def _GetSyncProgressMessage(self): earliest_time = float("inf") earliest_proj = None - items = self._sync_dict.items() + items = self.get_parallel_context()["sync_dict"].items() for project, t in items: if t < earliest_time: earliest_time = t @@ -813,7 +813,7 @@ later is required to fix a server side protocol bug. if not earliest_proj: # This function is called when sync is still running but in some - # cases (by chance), _sync_dict can contain no entries. Return some + # cases (by chance), sync_dict can contain no entries. Return some # text to indicate that sync is still working. return "..working.." @@ -835,7 +835,6 @@ later is required to fix a server side protocol bug. elide=True, ) - self._sync_dict = multiprocessing.Manager().dict() sync_event = _threading.Event() def _MonitorSyncLoop(): @@ -846,21 +845,13 @@ later is required to fix a server side protocol bug. sync_progress_thread = _threading.Thread(target=_MonitorSyncLoop) sync_progress_thread.daemon = True - sync_progress_thread.start() - objdir_project_map = dict() - for project in projects: - objdir_project_map.setdefault(project.objdir, []).append(project) - projects_list = list(objdir_project_map.values()) - - jobs = min(opt.jobs_network, len(projects_list)) - - def _ProcessResults(results_sets): + def _ProcessResults(pool, pm, results_sets): ret = True for results in results_sets: for result in results: success = result.success - project = result.project + project = projects[result.project_idx] start = result.start finish = result.finish self._fetch_times.Set(project, finish - start) @@ -884,45 +875,49 @@ later is required to fix a server side protocol bug. fetched.add(project.gitdir) pm.update() if not ret and opt.fail_fast: + if pool: + pool.close() break return ret - # We pass the ssh proxy settings via the class. This allows - # multiprocessing to pickle it up when spawning children. We can't pass - # it as an argument to _FetchProjectList below as multiprocessing is - # unable to pickle those. - Sync.ssh_proxy = None + with self.ParallelContext(): + self.get_parallel_context()["projects"] = projects + self.get_parallel_context()[ + "sync_dict" + ] = multiprocessing.Manager().dict() - # NB: Multiprocessing is heavy, so don't spin it up for one job. - if jobs == 1: - self._FetchInitChild(ssh_proxy) - if not _ProcessResults( - self._FetchProjectList(opt, x) for x in projects_list - ): - ret = False - else: + objdir_project_map = dict() + for index, project in enumerate(projects): + objdir_project_map.setdefault(project.objdir, []).append(index) + projects_list = list(objdir_project_map.values()) + + jobs = min(opt.jobs_network, len(projects_list)) + + # We pass the ssh proxy settings via the class. This allows + # multiprocessing to pickle it up when spawning children. We can't + # pass it as an argument to _FetchProjectList below as + # multiprocessing is unable to pickle those. + self.get_parallel_context()["ssh_proxy"] = ssh_proxy + + sync_progress_thread.start() if not opt.quiet: pm.update(inc=0, msg="warming up") - with multiprocessing.Pool( - jobs, initializer=self._FetchInitChild, initargs=(ssh_proxy,) - ) as pool: - results = pool.imap_unordered( + try: + ret = self.ExecuteInParallel( + jobs, functools.partial(self._FetchProjectList, opt), projects_list, - chunksize=_chunksize(len(projects_list), jobs), + callback=_ProcessResults, + output=pm, + # Use chunksize=1 to avoid the chance that some workers are + # idle while other workers still have more than one job in + # their chunk queue. + chunksize=1, ) - if not _ProcessResults(results): - ret = False - pool.close() + finally: + sync_event.set() + sync_progress_thread.join() - # Cleanup the reference now that we're done with it, and we're going to - # release any resources it points to. If we don't, later - # multiprocessing usage (e.g. checkouts) will try to pickle and then - # crash. - del Sync.ssh_proxy - - sync_event.set() - pm.end() self._fetch_times.Save() self._local_sync_state.Save() @@ -1008,14 +1003,15 @@ later is required to fix a server side protocol bug. return _FetchMainResult(all_projects) + @classmethod def _CheckoutOne( - self, + cls, detach_head, force_sync, force_checkout, force_rebase, verbose, - project, + project_idx, ): """Checkout work tree for one project @@ -1027,11 +1023,12 @@ later is required to fix a server side protocol bug. force_checkout: Force checking out of the repo content. force_rebase: Force rebase. verbose: Whether to show verbose messages. - project: Project object for the project to checkout. + project_idx: Project index for the project to checkout. Returns: Whether the fetch was successful. """ + project = cls.get_parallel_context()["projects"][project_idx] start = time.time() syncbuf = SyncBuffer( project.manifest.manifestProject.config, detach_head=detach_head @@ -1065,7 +1062,7 @@ later is required to fix a server side protocol bug. if not success: logger.error("error: Cannot checkout %s", project.name) finish = time.time() - return _CheckoutOneResult(success, errors, project, start, finish) + return _CheckoutOneResult(success, errors, project_idx, start, finish) def _Checkout(self, all_projects, opt, err_results, checkout_errors): """Checkout projects listed in all_projects @@ -1083,7 +1080,9 @@ later is required to fix a server side protocol bug. ret = True for result in results: success = result.success - project = result.project + project = self.get_parallel_context()["projects"][ + result.project_idx + ] start = result.start finish = result.finish self.event_log.AddSync( @@ -1110,22 +1109,28 @@ later is required to fix a server side protocol bug. return ret for projects in _SafeCheckoutOrder(all_projects): - proc_res = self.ExecuteInParallel( - opt.jobs_checkout, - functools.partial( - self._CheckoutOne, - opt.detach_head, - opt.force_sync, - opt.force_checkout, - opt.rebase, - opt.verbose, - ), - projects, - callback=_ProcessResults, - output=Progress( - "Checking out", len(all_projects), quiet=opt.quiet - ), - ) + with self.ParallelContext(): + self.get_parallel_context()["projects"] = projects + proc_res = self.ExecuteInParallel( + opt.jobs_checkout, + functools.partial( + self._CheckoutOne, + opt.detach_head, + opt.force_sync, + opt.force_checkout, + opt.rebase, + opt.verbose, + ), + range(len(projects)), + callback=_ProcessResults, + output=Progress( + "Checking out", len(all_projects), quiet=opt.quiet + ), + # Use chunksize=1 to avoid the chance that some workers are + # idle while other workers still have more than one job in + # their chunk queue. + chunksize=1, + ) self._local_sync_state.Save() return proc_res and not err_results