sync: Preserve errors on KeyboardInterrupt

If a KeyboardInterrupt is encountered before an error is aggregated then
the context surrounding the interrupt is lost. This change aggregates
errors as soon as possible for the sync command

Bug: b/293344017
Change-Id: Iac14f9d59723cc9dedbb960f14fdc1fa5b348ea3
Reviewed-on: https://gerrit-review.googlesource.com/c/git-repo/+/384974
Tested-by: Jason Chang <jasonnc@google.com>
Commit-Queue: Jason Chang <jasonnc@google.com>
Reviewed-by: Gavin Mak <gavinmak@google.com>
This commit is contained in:
Jason Chang 2023-08-31 17:06:36 -07:00 committed by LUCI
parent b861511db9
commit daf2ad38eb

View File

@ -22,9 +22,10 @@ import netrc
import optparse
import os
import socket
import sys
import tempfile
import time
from typing import List, NamedTuple, Set
from typing import List, NamedTuple, Set, Union
import urllib.error
import urllib.parse
import urllib.request
@ -55,6 +56,7 @@ from command import MirrorSafeCommand
from command import WORKER_BATCH_SIZE
from error import GitError
from error import RepoChangedException
from error import RepoError
from error import RepoExitError
from error import RepoUnhandledExceptionError
from error import SyncError
@ -120,7 +122,6 @@ class _FetchResult(NamedTuple):
success: bool
projects: Set[str]
errors: List[Exception]
class _FetchMainResult(NamedTuple):
@ -131,7 +132,6 @@ class _FetchMainResult(NamedTuple):
"""
all_projects: List[Project]
errors: List[Exception]
class _CheckoutOneResult(NamedTuple):
@ -163,6 +163,34 @@ class SmartSyncError(SyncError):
"""Smart sync exit error."""
class ManifestInterruptError(RepoError):
"""Aggregate Error to be logged when a user interrupts a manifest update."""
def __init__(self, output, **kwargs):
super().__init__(output, **kwargs)
self.output = output
def __str__(self):
error_type = type(self).__name__
return f"{error_type}:{self.output}"
class TeeStringIO(io.StringIO):
"""StringIO class that can write to an additional destination."""
def __init__(
self, io: Union[io.TextIOWrapper, None], *args, **kwargs
) -> None:
super().__init__(*args, **kwargs)
self.io = io
def write(self, s: str) -> int:
"""Write to additional destination."""
super().write(s)
if self.io is not None:
self.io.write(s)
class Sync(Command, MirrorSafeCommand):
COMMON = True
MULTI_MANIFEST_SUPPORT = True
@ -648,7 +676,7 @@ later is required to fix a server side protocol bug.
success = False
remote_fetched = False
errors = []
buf = io.StringIO()
buf = TeeStringIO(sys.stdout if opt.verbose else None)
try:
sync_result = project.Sync_NetworkHalf(
quiet=opt.quiet,
@ -675,7 +703,7 @@ later is required to fix a server side protocol bug.
errors.append(sync_result.error)
output = buf.getvalue()
if (opt.verbose or not success) and output:
if output and buf.io is None and not success:
print("\n" + output.rstrip())
if not success:
@ -729,13 +757,12 @@ later is required to fix a server side protocol bug.
jobs = jobs_str(len(items))
return f"{jobs} | {elapsed_str(elapsed)} {earliest_proj}"
def _Fetch(self, projects, opt, err_event, ssh_proxy):
def _Fetch(self, projects, opt, err_event, ssh_proxy, errors):
ret = True
jobs = opt.jobs_network
fetched = set()
remote_fetched = set()
errors = []
pm = Progress(
"Fetching",
len(projects),
@ -850,10 +877,10 @@ later is required to fix a server side protocol bug.
if not self.outer_client.manifest.IsArchive:
self._GCProjects(projects, opt, err_event)
return _FetchResult(ret, fetched, errors)
return _FetchResult(ret, fetched)
def _FetchMain(
self, opt, args, all_projects, err_event, ssh_proxy, manifest
self, opt, args, all_projects, err_event, ssh_proxy, manifest, errors
):
"""The main network fetch loop.
@ -869,7 +896,6 @@ later is required to fix a server side protocol bug.
List of all projects that should be checked out.
"""
rp = manifest.repoProject
errors = []
to_fetch = []
now = time.time()
@ -878,11 +904,9 @@ later is required to fix a server side protocol bug.
to_fetch.extend(all_projects)
to_fetch.sort(key=self._fetch_times.Get, reverse=True)
result = self._Fetch(to_fetch, opt, err_event, ssh_proxy)
result = self._Fetch(to_fetch, opt, err_event, ssh_proxy, errors)
success = result.success
fetched = result.projects
if result.errors:
errors.extend(result.errors)
if not success:
err_event.set()
@ -898,7 +922,7 @@ later is required to fix a server side protocol bug.
logger.error(e)
raise e
return _FetchMainResult([], errors)
return _FetchMainResult([])
# Iteratively fetch missing and/or nested unregistered submodules.
previously_missing_set = set()
@ -923,16 +947,14 @@ later is required to fix a server side protocol bug.
if previously_missing_set == missing_set:
break
previously_missing_set = missing_set
result = self._Fetch(missing, opt, err_event, ssh_proxy)
result = self._Fetch(missing, opt, err_event, ssh_proxy, errors)
success = result.success
new_fetched = result.projects
if result.errors:
errors.extend(result.errors)
if not success:
err_event.set()
fetched.update(new_fetched)
return _FetchMainResult(all_projects, errors)
return _FetchMainResult(all_projects)
def _CheckoutOne(self, detach_head, force_sync, project):
"""Checkout work tree for one project
@ -1440,7 +1462,7 @@ later is required to fix a server side protocol bug.
return manifest_name
def _UpdateAllManifestProjects(self, opt, mp, manifest_name):
def _UpdateAllManifestProjects(self, opt, mp, manifest_name, errors):
"""Fetch & update the local manifest project.
After syncing the manifest project, if the manifest has any sub
@ -1452,7 +1474,7 @@ later is required to fix a server side protocol bug.
manifest_name: Manifest file to be reloaded.
"""
if not mp.standalone_manifest_url:
self._UpdateManifestProject(opt, mp, manifest_name)
self._UpdateManifestProject(opt, mp, manifest_name, errors)
if mp.manifest.submanifests:
for submanifest in mp.manifest.submanifests.values():
@ -1465,10 +1487,10 @@ later is required to fix a server side protocol bug.
git_event_log=self.git_event_log,
)
self._UpdateAllManifestProjects(
opt, child.manifestProject, None
opt, child.manifestProject, None, errors
)
def _UpdateManifestProject(self, opt, mp, manifest_name):
def _UpdateManifestProject(self, opt, mp, manifest_name, errors):
"""Fetch & update the local manifest project.
Args:
@ -1478,21 +1500,32 @@ later is required to fix a server side protocol bug.
"""
if not opt.local_only:
start = time.time()
result = mp.Sync_NetworkHalf(
quiet=opt.quiet,
verbose=opt.verbose,
current_branch_only=self._GetCurrentBranchOnly(
opt, mp.manifest
),
force_sync=opt.force_sync,
tags=opt.tags,
optimized_fetch=opt.optimized_fetch,
retry_fetches=opt.retry_fetches,
submodules=mp.manifest.HasSubmodules,
clone_filter=mp.manifest.CloneFilter,
partial_clone_exclude=mp.manifest.PartialCloneExclude,
clone_filter_for_depth=mp.manifest.CloneFilterForDepth,
)
buf = TeeStringIO(sys.stdout)
try:
result = mp.Sync_NetworkHalf(
quiet=opt.quiet,
output_redir=buf,
verbose=opt.verbose,
current_branch_only=self._GetCurrentBranchOnly(
opt, mp.manifest
),
force_sync=opt.force_sync,
tags=opt.tags,
optimized_fetch=opt.optimized_fetch,
retry_fetches=opt.retry_fetches,
submodules=mp.manifest.HasSubmodules,
clone_filter=mp.manifest.CloneFilter,
partial_clone_exclude=mp.manifest.PartialCloneExclude,
clone_filter_for_depth=mp.manifest.CloneFilterForDepth,
)
if result.error:
errors.append(result.error)
except KeyboardInterrupt:
errors.append(
ManifestInterruptError(buf.getvalue(), project=mp.name)
)
raise
finish = time.time()
self.event_log.AddSync(
mp, event_log.TASK_SYNC_NETWORK, start, finish, result.success
@ -1664,7 +1697,7 @@ later is required to fix a server side protocol bug.
mp.ConfigureCloneFilterForDepth("blob:none")
if opt.mp_update:
self._UpdateAllManifestProjects(opt, mp, manifest_name)
self._UpdateAllManifestProjects(opt, mp, manifest_name, errors)
else:
logger.info("Skipping update of local manifest project.")
@ -1704,10 +1737,14 @@ later is required to fix a server side protocol bug.
# Initialize the socket dir once in the parent.
ssh_proxy.sock()
result = self._FetchMain(
opt, args, all_projects, err_event, ssh_proxy, manifest
opt,
args,
all_projects,
err_event,
ssh_proxy,
manifest,
errors,
)
if result.errors:
errors.extend(result.errors)
all_projects = result.all_projects
if opt.network_only: