ssh: rewrite proxy management for multiprocessing usage

We changed sync to use multiprocessing for parallel work.  This broke
the ssh proxy code as it's all based on threads.  Rewrite the logic to
be multiprocessing safe.

Now instead of the module acting as a stateful object, callers have to
instantiate a new ProxyManager class that holds all the state, an pass
that down to any users.

Bug: https://crbug.com/gerrit/12389
Change-Id: I4b1af116f7306b91e825d3c56fb4274c9b033562
Reviewed-on: https://gerrit-review.googlesource.com/c/git-repo/+/305486
Tested-by: Mike Frysinger <vapier@google.com>
Reviewed-by: Chris Mcdonald <cjmcdonald@google.com>
This commit is contained in:
Mike Frysinger
2021-05-06 00:44:42 -04:00
parent 19e409c818
commit 339f2df1dd
6 changed files with 221 additions and 154 deletions

View File

@ -358,7 +358,7 @@ later is required to fix a server side protocol bug.
optimized_fetch=opt.optimized_fetch,
retry_fetches=opt.retry_fetches,
prune=opt.prune,
ssh_proxy=True,
ssh_proxy=self.ssh_proxy,
clone_filter=self.manifest.CloneFilter,
partial_clone_exclude=self.manifest.PartialCloneExclude)
@ -380,7 +380,11 @@ later is required to fix a server side protocol bug.
finish = time.time()
return (success, project, start, finish)
def _Fetch(self, projects, opt, err_event):
@classmethod
def _FetchInitChild(cls, ssh_proxy):
cls.ssh_proxy = ssh_proxy
def _Fetch(self, projects, opt, err_event, ssh_proxy):
ret = True
jobs = opt.jobs_network if opt.jobs_network else self.jobs
@ -410,8 +414,14 @@ later is required to fix a server side protocol bug.
break
return ret
# We pass the ssh proxy settings via the class. This allows multiprocessing
# to pickle it up when spawning children. We can't pass it as an argument
# to _FetchProjectList below as multiprocessing is unable to pickle those.
Sync.ssh_proxy = None
# NB: Multiprocessing is heavy, so don't spin it up for one job.
if len(projects_list) == 1 or jobs == 1:
self._FetchInitChild(ssh_proxy)
if not _ProcessResults(self._FetchProjectList(opt, x) for x in projects_list):
ret = False
else:
@ -429,7 +439,8 @@ later is required to fix a server side protocol bug.
else:
pm.update(inc=0, msg='warming up')
chunksize = 4
with multiprocessing.Pool(jobs) as pool:
with multiprocessing.Pool(
jobs, initializer=self._FetchInitChild, initargs=(ssh_proxy,)) as pool:
results = pool.imap_unordered(
functools.partial(self._FetchProjectList, opt),
projects_list,
@ -438,6 +449,11 @@ later is required to fix a server side protocol bug.
ret = False
pool.close()
# Cleanup the reference now that we're done with it, and we're going to
# release any resources it points to. If we don't, later multiprocessing
# usage (e.g. checkouts) will try to pickle and then crash.
del Sync.ssh_proxy
pm.end()
self._fetch_times.Save()
@ -447,7 +463,7 @@ later is required to fix a server side protocol bug.
return (ret, fetched)
def _FetchMain(self, opt, args, all_projects, err_event, manifest_name,
load_local_manifests):
load_local_manifests, ssh_proxy):
"""The main network fetch loop.
Args:
@ -457,6 +473,7 @@ later is required to fix a server side protocol bug.
err_event: Whether an error was hit while processing.
manifest_name: Manifest file to be reloaded.
load_local_manifests: Whether to load local manifests.
ssh_proxy: SSH manager for clients & masters.
"""
rp = self.manifest.repoProject
@ -467,7 +484,7 @@ later is required to fix a server side protocol bug.
to_fetch.extend(all_projects)
to_fetch.sort(key=self._fetch_times.Get, reverse=True)
success, fetched = self._Fetch(to_fetch, opt, err_event)
success, fetched = self._Fetch(to_fetch, opt, err_event, ssh_proxy)
if not success:
err_event.set()
@ -498,7 +515,7 @@ later is required to fix a server side protocol bug.
if previously_missing_set == missing_set:
break
previously_missing_set = missing_set
success, new_fetched = self._Fetch(missing, opt, err_event)
success, new_fetched = self._Fetch(missing, opt, err_event, ssh_proxy)
if not success:
err_event.set()
fetched.update(new_fetched)
@ -985,12 +1002,15 @@ later is required to fix a server side protocol bug.
self._fetch_times = _FetchTimes(self.manifest)
if not opt.local_only:
try:
ssh.init()
self._FetchMain(opt, args, all_projects, err_event, manifest_name,
load_local_manifests)
finally:
ssh.close()
with multiprocessing.Manager() as manager:
with ssh.ProxyManager(manager) as ssh_proxy:
# Initialize the socket dir once in the parent.
ssh_proxy.sock()
self._FetchMain(opt, args, all_projects, err_event, manifest_name,
load_local_manifests, ssh_proxy)
if opt.network_only:
return
# If we saw an error, exit with code 1 so that other scripts can check.
if err_event.is_set():