Format codebase with black and check formatting in CQ

Apply rules set by https://gerrit-review.googlesource.com/c/git-repo/+/362954/ across the codebase and fix any lingering errors caught
by flake8. Also check black formatting in run_tests (and CQ).

Bug: b/267675342
Change-Id: I972d77649dac351150dcfeb1cd1ad0ea2efc1956
Reviewed-on: https://gerrit-review.googlesource.com/c/git-repo/+/363474
Reviewed-by: Mike Frysinger <vapier@google.com>
Tested-by: Gavin Mak <gavinmak@google.com>
Commit-Queue: Gavin Mak <gavinmak@google.com>
This commit is contained in:
Gavin Mak 2023-03-11 06:46:20 +00:00 committed by LUCI
parent 1604cf255f
commit ea2e330e43
79 changed files with 19698 additions and 16679 deletions

View File

@ -1,5 +1,8 @@
[flake8] [flake8]
max-line-length = 80 max-line-length = 80
per-file-ignores =
# E501: line too long
tests/test_git_superproject.py: E501
extend-ignore = extend-ignore =
# E203: Whitespace before ':' # E203: Whitespace before ':'
# See https://github.com/PyCQA/pycodestyle/issues/373 # See https://github.com/PyCQA/pycodestyle/issues/373

View File

@ -17,23 +17,20 @@ import sys
import pager import pager
COLORS = {None: -1, COLORS = {
'normal': -1, None: -1,
'black': 0, "normal": -1,
'red': 1, "black": 0,
'green': 2, "red": 1,
'yellow': 3, "green": 2,
'blue': 4, "yellow": 3,
'magenta': 5, "blue": 4,
'cyan': 6, "magenta": 5,
'white': 7} "cyan": 6,
"white": 7,
}
ATTRS = {None: -1, ATTRS = {None: -1, "bold": 1, "dim": 2, "ul": 4, "blink": 5, "reverse": 7}
'bold': 1,
'dim': 2,
'ul': 4,
'blink': 5,
'reverse': 7}
RESET = "\033[m" RESET = "\033[m"
@ -56,30 +53,30 @@ def _Color(fg=None, bg=None, attr=None):
code = "\033[" code = "\033["
if attr >= 0: if attr >= 0:
code += chr(ord('0') + attr) code += chr(ord("0") + attr)
need_sep = True need_sep = True
if fg >= 0: if fg >= 0:
if need_sep: if need_sep:
code += ';' code += ";"
need_sep = True need_sep = True
if fg < 8: if fg < 8:
code += '3%c' % (ord('0') + fg) code += "3%c" % (ord("0") + fg)
else: else:
code += '38;5;%d' % fg code += "38;5;%d" % fg
if bg >= 0: if bg >= 0:
if need_sep: if need_sep:
code += ';' code += ";"
if bg < 8: if bg < 8:
code += '4%c' % (ord('0') + bg) code += "4%c" % (ord("0") + bg)
else: else:
code += '48;5;%d' % bg code += "48;5;%d" % bg
code += 'm' code += "m"
else: else:
code = '' code = ""
return code return code
@ -97,17 +94,17 @@ def SetDefaultColoring(state):
global DEFAULT global DEFAULT
state = state.lower() state = state.lower()
if state in ('auto',): if state in ("auto",):
DEFAULT = state DEFAULT = state
elif state in ('always', 'yes', 'true', True): elif state in ("always", "yes", "true", True):
DEFAULT = 'always' DEFAULT = "always"
elif state in ('never', 'no', 'false', False): elif state in ("never", "no", "false", False):
DEFAULT = 'never' DEFAULT = "never"
class Coloring(object): class Coloring(object):
def __init__(self, config, section_type): def __init__(self, config, section_type):
self._section = 'color.%s' % section_type self._section = "color.%s" % section_type
self._config = config self._config = config
self._out = sys.stdout self._out = sys.stdout
@ -115,14 +112,14 @@ class Coloring(object):
if on is None: if on is None:
on = self._config.GetString(self._section) on = self._config.GetString(self._section)
if on is None: if on is None:
on = self._config.GetString('color.ui') on = self._config.GetString("color.ui")
if on == 'auto': if on == "auto":
if pager.active or os.isatty(1): if pager.active or os.isatty(1):
self._on = True self._on = True
else: else:
self._on = False self._on = False
elif on in ('true', 'always'): elif on in ("true", "always"):
self._on = True self._on = True
else: else:
self._on = False self._on = False
@ -141,7 +138,7 @@ class Coloring(object):
self._out.flush() self._out.flush()
def nl(self): def nl(self):
self._out.write('\n') self._out.write("\n")
def printer(self, opt=None, fg=None, bg=None, attr=None): def printer(self, opt=None, fg=None, bg=None, attr=None):
s = self s = self
@ -149,6 +146,7 @@ class Coloring(object):
def f(fmt, *args): def f(fmt, *args):
s._out.write(c(fmt, *args)) s._out.write(c(fmt, *args))
return f return f
def nofmt_printer(self, opt=None, fg=None, bg=None, attr=None): def nofmt_printer(self, opt=None, fg=None, bg=None, attr=None):
@ -157,6 +155,7 @@ class Coloring(object):
def f(fmt): def f(fmt):
s._out.write(c(fmt)) s._out.write(c(fmt))
return f return f
def colorer(self, opt=None, fg=None, bg=None, attr=None): def colorer(self, opt=None, fg=None, bg=None, attr=None):
@ -165,12 +164,14 @@ class Coloring(object):
def f(fmt, *args): def f(fmt, *args):
output = fmt % args output = fmt % args
return ''.join([c, output, RESET]) return "".join([c, output, RESET])
return f return f
else: else:
def f(fmt, *args): def f(fmt, *args):
return fmt % args return fmt % args
return f return f
def nofmt_colorer(self, opt=None, fg=None, bg=None, attr=None): def nofmt_colorer(self, opt=None, fg=None, bg=None, attr=None):
@ -178,29 +179,32 @@ class Coloring(object):
c = self._parse(opt, fg, bg, attr) c = self._parse(opt, fg, bg, attr)
def f(fmt): def f(fmt):
return ''.join([c, fmt, RESET]) return "".join([c, fmt, RESET])
return f return f
else: else:
def f(fmt): def f(fmt):
return fmt return fmt
return f return f
def _parse(self, opt, fg, bg, attr): def _parse(self, opt, fg, bg, attr):
if not opt: if not opt:
return _Color(fg, bg, attr) return _Color(fg, bg, attr)
v = self._config.GetString('%s.%s' % (self._section, opt)) v = self._config.GetString("%s.%s" % (self._section, opt))
if v is None: if v is None:
return _Color(fg, bg, attr) return _Color(fg, bg, attr)
v = v.strip().lower() v = v.strip().lower()
if v == "reset": if v == "reset":
return RESET return RESET
elif v == '': elif v == "":
return _Color(fg, bg, attr) return _Color(fg, bg, attr)
have_fg = False have_fg = False
for a in v.split(' '): for a in v.split(" "):
if is_color(a): if is_color(a):
if have_fg: if have_fg:
bg = a bg = a

View File

@ -25,7 +25,7 @@ import progress
# Are we generating man-pages? # Are we generating man-pages?
GENERATE_MANPAGES = os.environ.get('_REPO_GENERATE_MANPAGES_') == ' indeed! ' GENERATE_MANPAGES = os.environ.get("_REPO_GENERATE_MANPAGES_") == " indeed! "
# Number of projects to submit to a single worker process at a time. # Number of projects to submit to a single worker process at a time.
@ -43,8 +43,7 @@ DEFAULT_LOCAL_JOBS = min(os.cpu_count(), 8)
class Command(object): class Command(object):
"""Base class for any command line action in repo. """Base class for any command line action in repo."""
"""
# Singleton for all commands to track overall repo command execution and # Singleton for all commands to track overall repo command execution and
# provide event summary to callers. Only used by sync subcommand currently. # provide event summary to callers. Only used by sync subcommand currently.
@ -52,9 +51,9 @@ class Command(object):
# NB: This is being replaced by git trace2 events. See git_trace2_event_log. # NB: This is being replaced by git trace2 events. See git_trace2_event_log.
event_log = EventLog() event_log = EventLog()
# Whether this command is a "common" one, i.e. whether the user would commonly # Whether this command is a "common" one, i.e. whether the user would
# use it or it's a more uncommon command. This is used by the help command to # commonly use it or it's a more uncommon command. This is used by the help
# show short-vs-full summaries. # command to show short-vs-full summaries.
COMMON = False COMMON = False
# Whether this command supports running in parallel. If greater than 0, # Whether this command supports running in parallel. If greater than 0,
@ -67,8 +66,16 @@ class Command(object):
# migrated subcommands can set it to False. # migrated subcommands can set it to False.
MULTI_MANIFEST_SUPPORT = True MULTI_MANIFEST_SUPPORT = True
def __init__(self, repodir=None, client=None, manifest=None, gitc_manifest=None, def __init__(
git_event_log=None, outer_client=None, outer_manifest=None): self,
repodir=None,
client=None,
manifest=None,
gitc_manifest=None,
git_event_log=None,
outer_client=None,
outer_manifest=None,
):
self.repodir = repodir self.repodir = repodir
self.client = client self.client = client
self.outer_client = outer_client or client self.outer_client = outer_client or client
@ -84,7 +91,7 @@ class Command(object):
return False return False
def ReadEnvironmentOptions(self, opts): def ReadEnvironmentOptions(self, opts):
""" Set options from environment variables. """ """Set options from environment variables."""
env_options = self._RegisteredEnvironmentOptions() env_options = self._RegisteredEnvironmentOptions()
@ -93,8 +100,8 @@ class Command(object):
opt_value = getattr(opts, opt_key) opt_value = getattr(opts, opt_key)
# If the value is set, it means the user has passed it as a command # If the value is set, it means the user has passed it as a command
# line option, and we should use that. Otherwise we can try to set it # line option, and we should use that. Otherwise we can try to set
# with the value from the corresponding environment variable. # it with the value from the corresponding environment variable.
if opt_value is not None: if opt_value is not None:
continue continue
@ -108,11 +115,13 @@ class Command(object):
def OptionParser(self): def OptionParser(self):
if self._optparse is None: if self._optparse is None:
try: try:
me = 'repo %s' % self.NAME me = "repo %s" % self.NAME
usage = self.helpUsage.strip().replace('%prog', me) usage = self.helpUsage.strip().replace("%prog", me)
except AttributeError: except AttributeError:
usage = 'repo %s' % self.NAME usage = "repo %s" % self.NAME
epilog = 'Run `repo help %s` to view the detailed manual.' % self.NAME epilog = (
"Run `repo help %s` to view the detailed manual." % self.NAME
)
self._optparse = optparse.OptionParser(usage=usage, epilog=epilog) self._optparse = optparse.OptionParser(usage=usage, epilog=epilog)
self._CommonOptions(self._optparse) self._CommonOptions(self._optparse)
self._Options(self._optparse) self._Options(self._optparse)
@ -124,35 +133,63 @@ class Command(object):
These will show up for *all* subcommands, so use sparingly. These will show up for *all* subcommands, so use sparingly.
NB: Keep in sync with repo:InitParser(). NB: Keep in sync with repo:InitParser().
""" """
g = p.add_option_group('Logging options') g = p.add_option_group("Logging options")
opts = ['-v'] if opt_v else [] opts = ["-v"] if opt_v else []
g.add_option(*opts, '--verbose', g.add_option(
dest='output_mode', action='store_true', *opts,
help='show all output') "--verbose",
g.add_option('-q', '--quiet', dest="output_mode",
dest='output_mode', action='store_false', action="store_true",
help='only show errors') help="show all output",
)
g.add_option(
"-q",
"--quiet",
dest="output_mode",
action="store_false",
help="only show errors",
)
if self.PARALLEL_JOBS is not None: if self.PARALLEL_JOBS is not None:
default = 'based on number of CPU cores' default = "based on number of CPU cores"
if not GENERATE_MANPAGES: if not GENERATE_MANPAGES:
# Only include active cpu count if we aren't generating man pages. # Only include active cpu count if we aren't generating man
default = f'%default; {default}' # pages.
default = f"%default; {default}"
p.add_option( p.add_option(
'-j', '--jobs', "-j",
type=int, default=self.PARALLEL_JOBS, "--jobs",
help=f'number of jobs to run in parallel (default: {default})') type=int,
default=self.PARALLEL_JOBS,
help=f"number of jobs to run in parallel (default: {default})",
)
m = p.add_option_group('Multi-manifest options') m = p.add_option_group("Multi-manifest options")
m.add_option('--outer-manifest', action='store_true', default=None, m.add_option(
help='operate starting at the outermost manifest') "--outer-manifest",
m.add_option('--no-outer-manifest', dest='outer_manifest', action="store_true",
action='store_false', help='do not operate on outer manifests') default=None,
m.add_option('--this-manifest-only', action='store_true', default=None, help="operate starting at the outermost manifest",
help='only operate on this (sub)manifest') )
m.add_option('--no-this-manifest-only', '--all-manifests', m.add_option(
dest='this_manifest_only', action='store_false', "--no-outer-manifest",
help='operate on this manifest and its submanifests') dest="outer_manifest",
action="store_false",
help="do not operate on outer manifests",
)
m.add_option(
"--this-manifest-only",
action="store_true",
default=None,
help="only operate on this (sub)manifest",
)
m.add_option(
"--no-this-manifest-only",
"--all-manifests",
dest="this_manifest_only",
action="store_false",
help="operate on this manifest and its submanifests",
)
def _Options(self, p): def _Options(self, p):
"""Initialize the option parser with subcommand-specific options.""" """Initialize the option parser with subcommand-specific options."""
@ -176,8 +213,7 @@ class Command(object):
return {} return {}
def Usage(self): def Usage(self):
"""Display usage and terminate. """Display usage and terminate."""
"""
self.OptionParser.print_usage() self.OptionParser.print_usage()
sys.exit(1) sys.exit(1)
@ -186,8 +222,8 @@ class Command(object):
opt.quiet = opt.output_mode is False opt.quiet = opt.output_mode is False
opt.verbose = opt.output_mode is True opt.verbose = opt.output_mode is True
if opt.outer_manifest is None: if opt.outer_manifest is None:
# By default, treat multi-manifest instances as a single manifest from # By default, treat multi-manifest instances as a single manifest
# the user's perspective. # from the user's perspective.
opt.outer_manifest = True opt.outer_manifest = True
def ValidateOptions(self, opt, args): def ValidateOptions(self, opt, args):
@ -201,12 +237,13 @@ class Command(object):
""" """
def Execute(self, opt, args): def Execute(self, opt, args):
"""Perform the action, after option parsing is complete. """Perform the action, after option parsing is complete."""
"""
raise NotImplementedError raise NotImplementedError
@staticmethod @staticmethod
def ExecuteInParallel(jobs, func, inputs, callback, output=None, ordered=False): def ExecuteInParallel(
jobs, func, inputs, callback, output=None, ordered=False
):
"""Helper for managing parallel execution boiler plate. """Helper for managing parallel execution boiler plate.
For subcommands that can easily split their work up. For subcommands that can easily split their work up.
@ -214,18 +251,21 @@ class Command(object):
Args: Args:
jobs: How many parallel processes to use. jobs: How many parallel processes to use.
func: The function to apply to each of the |inputs|. Usually a func: The function to apply to each of the |inputs|. Usually a
functools.partial for wrapping additional arguments. It will be run functools.partial for wrapping additional arguments. It will be
in a separate process, so it must be pickalable, so nested functions run in a separate process, so it must be pickalable, so nested
won't work. Methods on the subcommand Command class should work. functions won't work. Methods on the subcommand Command class
should work.
inputs: The list of items to process. Must be a list. inputs: The list of items to process. Must be a list.
callback: The function to pass the results to for processing. It will be callback: The function to pass the results to for processing. It
executed in the main thread and process the results of |func| as they will be executed in the main thread and process the results of
become available. Thus it may be a local nested function. Its return |func| as they become available. Thus it may be a local nested
value is passed back directly. It takes three arguments: function. Its return value is passed back directly. It takes
three arguments:
- The processing pool (or None with one job). - The processing pool (or None with one job).
- The |output| argument. - The |output| argument.
- An iterator for the results. - An iterator for the results.
output: An output manager. May be progress.Progess or color.Coloring. output: An output manager. May be progress.Progess or
color.Coloring.
ordered: Whether the jobs should be processed in order. ordered: Whether the jobs should be processed in order.
Returns: Returns:
@ -238,7 +278,11 @@ class Command(object):
else: else:
with multiprocessing.Pool(jobs) as pool: with multiprocessing.Pool(jobs) as pool:
submit = pool.imap if ordered else pool.imap_unordered submit = pool.imap if ordered else pool.imap_unordered
return callback(pool, output, submit(func, inputs, chunksize=WORKER_BATCH_SIZE)) return callback(
pool,
output,
submit(func, inputs, chunksize=WORKER_BATCH_SIZE),
)
finally: finally:
if isinstance(output, progress.Progress): if isinstance(output, progress.Progress):
output.end() output.end()
@ -253,9 +297,7 @@ class Command(object):
project = None project = None
if os.path.exists(path): if os.path.exists(path):
oldpath = None oldpath = None
while (path and while path and path != oldpath and path != manifest.topdir:
path != oldpath and
path != manifest.topdir):
try: try:
project = self._by_path[path] project = self._by_path[path]
break break
@ -274,8 +316,15 @@ class Command(object):
pass pass
return project return project
def GetProjects(self, args, manifest=None, groups='', missing_ok=False, def GetProjects(
submodules_ok=False, all_manifests=False): self,
args,
manifest=None,
groups="",
missing_ok=False,
submodules_ok=False,
all_manifests=False,
):
"""A list of projects that match the arguments. """A list of projects that match the arguments.
Args: Args:
@ -284,8 +333,9 @@ class Command(object):
groups: a string, the manifest groups in use. groups: a string, the manifest groups in use.
missing_ok: a boolean, whether to allow missing projects. missing_ok: a boolean, whether to allow missing projects.
submodules_ok: a boolean, whether to allow submodules. submodules_ok: a boolean, whether to allow submodules.
all_manifests: a boolean, if True then all manifests and submanifests are all_manifests: a boolean, if True then all manifests and
used. If False, then only the local (sub)manifest is used. submanifests are used. If False, then only the local
(sub)manifest is used.
Returns: Returns:
A list of matching Project instances. A list of matching Project instances.
@ -302,31 +352,38 @@ class Command(object):
if not groups: if not groups:
groups = manifest.GetGroupsStr() groups = manifest.GetGroupsStr()
groups = [x for x in re.split(r'[,\s]+', groups) if x] groups = [x for x in re.split(r"[,\s]+", groups) if x]
if not args: if not args:
derived_projects = {} derived_projects = {}
for project in all_projects_list: for project in all_projects_list:
if submodules_ok or project.sync_s: if submodules_ok or project.sync_s:
derived_projects.update((p.name, p) derived_projects.update(
for p in project.GetDerivedSubprojects()) (p.name, p) for p in project.GetDerivedSubprojects()
)
all_projects_list.extend(derived_projects.values()) all_projects_list.extend(derived_projects.values())
for project in all_projects_list: for project in all_projects_list:
if (missing_ok or project.Exists) and project.MatchesGroups(groups): if (missing_ok or project.Exists) and project.MatchesGroups(
groups
):
result.append(project) result.append(project)
else: else:
self._ResetPathToProjectMap(all_projects_list) self._ResetPathToProjectMap(all_projects_list)
for arg in args: for arg in args:
# We have to filter by manifest groups in case the requested project is # We have to filter by manifest groups in case the requested
# checked out multiple times or differently based on them. # project is checked out multiple times or differently based on
projects = [project # them.
projects = [
project
for project in manifest.GetProjectsWithName( for project in manifest.GetProjectsWithName(
arg, all_manifests=all_manifests) arg, all_manifests=all_manifests
if project.MatchesGroups(groups)] )
if project.MatchesGroups(groups)
]
if not projects: if not projects:
path = os.path.abspath(arg).replace('\\', '/') path = os.path.abspath(arg).replace("\\", "/")
tree = manifest tree = manifest
if all_manifests: if all_manifests:
# Look for the deepest matching submanifest. # Look for the deepest matching submanifest.
@ -335,16 +392,23 @@ class Command(object):
break break
project = self._GetProjectByPath(tree, path) project = self._GetProjectByPath(tree, path)
# If it's not a derived project, update path->project mapping and # If it's not a derived project, update path->project
# search again, as arg might actually point to a derived subproject. # mapping and search again, as arg might actually point to
if (project and not project.Derived and (submodules_ok or # a derived subproject.
project.sync_s)): if (
project
and not project.Derived
and (submodules_ok or project.sync_s)
):
search_again = False search_again = False
for subproject in project.GetDerivedSubprojects(): for subproject in project.GetDerivedSubprojects():
self._UpdatePathToProjectMap(subproject) self._UpdatePathToProjectMap(subproject)
search_again = True search_again = True
if search_again: if search_again:
project = self._GetProjectByPath(manifest, path) or project project = (
self._GetProjectByPath(manifest, path)
or project
)
if project: if project:
projects = [project] projects = [project]
@ -354,8 +418,10 @@ class Command(object):
for project in projects: for project in projects:
if not missing_ok and not project.Exists: if not missing_ok and not project.Exists:
raise NoSuchProjectError('%s (%s)' % ( raise NoSuchProjectError(
arg, project.RelPath(local=not all_manifests))) "%s (%s)"
% (arg, project.RelPath(local=not all_manifests))
)
if not project.MatchesGroups(groups): if not project.MatchesGroups(groups):
raise InvalidProjectGroupsError(arg) raise InvalidProjectGroupsError(arg)
@ -363,6 +429,7 @@ class Command(object):
def _getpath(x): def _getpath(x):
return x.relpath return x.relpath
result.sort(key=_getpath) result.sort(key=_getpath)
return result return result
@ -371,14 +438,15 @@ class Command(object):
Args: Args:
args: a list of (case-insensitive) strings, projects to search for. args: a list of (case-insensitive) strings, projects to search for.
inverse: a boolean, if True, then projects not matching any |args| are inverse: a boolean, if True, then projects not matching any |args|
returned. are returned.
all_manifests: a boolean, if True then all manifests and submanifests are all_manifests: a boolean, if True then all manifests and
used. If False, then only the local (sub)manifest is used. submanifests are used. If False, then only the local
(sub)manifest is used.
""" """
result = [] result = []
patterns = [re.compile(r'%s' % a, re.IGNORECASE) for a in args] patterns = [re.compile(r"%s" % a, re.IGNORECASE) for a in args]
for project in self.GetProjects('', all_manifests=all_manifests): for project in self.GetProjects("", all_manifests=all_manifests):
paths = [project.name, project.RelPath(local=not all_manifests)] paths = [project.name, project.RelPath(local=not all_manifests)]
for pattern in patterns: for pattern in patterns:
match = any(pattern.search(x) for x in paths) match = any(pattern.search(x) for x in paths)
@ -390,8 +458,9 @@ class Command(object):
else: else:
if inverse: if inverse:
result.append(project) result.append(project)
result.sort(key=lambda project: (project.manifest.path_prefix, result.sort(
project.relpath)) key=lambda project: (project.manifest.path_prefix, project.relpath)
)
return result return result
def ManifestList(self, opt): def ManifestList(self, opt):
@ -410,8 +479,8 @@ class Command(object):
class InteractiveCommand(Command): class InteractiveCommand(Command):
"""Command which requires user interaction on the tty and """Command which requires user interaction on the tty and must not run
must not run within a pager, even if the user asks to. within a pager, even if the user asks to.
""" """
def WantPager(self, _opt): def WantPager(self, _opt):
@ -419,8 +488,8 @@ class InteractiveCommand(Command):
class PagedCommand(Command): class PagedCommand(Command):
"""Command which defaults to output in a pager, as its """Command which defaults to output in a pager, as its display tends to be
display tends to be larger than one screen full. larger than one screen full.
""" """
def WantPager(self, _opt): def WantPager(self, _opt):
@ -428,18 +497,16 @@ class PagedCommand(Command):
class MirrorSafeCommand(object): class MirrorSafeCommand(object):
"""Command permits itself to run within a mirror, """Command permits itself to run within a mirror, and does not require a
and does not require a working directory. working directory.
""" """
class GitcAvailableCommand(object): class GitcAvailableCommand(object):
"""Command that requires GITC to be available, but does """Command that requires GITC to be available, but does not require the
not require the local client to be a GITC client. local client to be a GITC client.
""" """
class GitcClientCommand(object): class GitcClientCommand(object):
"""Command that requires the local client to be a GITC """Command that requires the local client to be a GITC client."""
client.
"""

View File

@ -36,31 +36,33 @@ class Editor(object):
@classmethod @classmethod
def _SelectEditor(cls): def _SelectEditor(cls):
e = os.getenv('GIT_EDITOR') e = os.getenv("GIT_EDITOR")
if e: if e:
return e return e
if cls.globalConfig: if cls.globalConfig:
e = cls.globalConfig.GetString('core.editor') e = cls.globalConfig.GetString("core.editor")
if e: if e:
return e return e
e = os.getenv('VISUAL') e = os.getenv("VISUAL")
if e: if e:
return e return e
e = os.getenv('EDITOR') e = os.getenv("EDITOR")
if e: if e:
return e return e
if os.getenv('TERM') == 'dumb': if os.getenv("TERM") == "dumb":
print( print(
"""No editor specified in GIT_EDITOR, core.editor, VISUAL or EDITOR. """No editor specified in GIT_EDITOR, core.editor, VISUAL or EDITOR.
Tried to fall back to vi but terminal is dumb. Please configure at Tried to fall back to vi but terminal is dumb. Please configure at
least one of these before using this command.""", file=sys.stderr) least one of these before using this command.""", # noqa: E501
file=sys.stderr,
)
sys.exit(1) sys.exit(1)
return 'vi' return "vi"
@classmethod @classmethod
def EditString(cls, data): def EditString(cls, data):
@ -76,22 +78,23 @@ least one of these before using this command.""", file=sys.stderr)
EditorError: The editor failed to run. EditorError: The editor failed to run.
""" """
editor = cls._GetEditor() editor = cls._GetEditor()
if editor == ':': if editor == ":":
return data return data
fd, path = tempfile.mkstemp() fd, path = tempfile.mkstemp()
try: try:
os.write(fd, data.encode('utf-8')) os.write(fd, data.encode("utf-8"))
os.close(fd) os.close(fd)
fd = None fd = None
if platform_utils.isWindows(): if platform_utils.isWindows():
# Split on spaces, respecting quoted strings # Split on spaces, respecting quoted strings
import shlex import shlex
args = shlex.split(editor) args = shlex.split(editor)
shell = False shell = False
elif re.compile("^.*[$ \t'].*$").match(editor): elif re.compile("^.*[$ \t'].*$").match(editor):
args = [editor + ' "$@"', 'sh'] args = [editor + ' "$@"', "sh"]
shell = True shell = True
else: else:
args = [editor] args = [editor]
@ -101,14 +104,17 @@ least one of these before using this command.""", file=sys.stderr)
try: try:
rc = subprocess.Popen(args, shell=shell).wait() rc = subprocess.Popen(args, shell=shell).wait()
except OSError as e: except OSError as e:
raise EditorError('editor failed, %s: %s %s' raise EditorError(
% (str(e), editor, path)) "editor failed, %s: %s %s" % (str(e), editor, path)
)
if rc != 0: if rc != 0:
raise EditorError('editor failed with exit status %d: %s %s' raise EditorError(
% (rc, editor, path)) "editor failed with exit status %d: %s %s"
% (rc, editor, path)
)
with open(path, mode='rb') as fd2: with open(path, mode="rb") as fd2:
return fd2.read().decode('utf-8') return fd2.read().decode("utf-8")
finally: finally:
if fd: if fd:
os.close(fd) os.close(fd)

View File

@ -14,23 +14,19 @@
class ManifestParseError(Exception): class ManifestParseError(Exception):
"""Failed to parse the manifest file. """Failed to parse the manifest file."""
"""
class ManifestInvalidRevisionError(ManifestParseError): class ManifestInvalidRevisionError(ManifestParseError):
"""The revision value in a project is incorrect. """The revision value in a project is incorrect."""
"""
class ManifestInvalidPathError(ManifestParseError): class ManifestInvalidPathError(ManifestParseError):
"""A path used in <copyfile> or <linkfile> is incorrect. """A path used in <copyfile> or <linkfile> is incorrect."""
"""
class NoManifestException(Exception): class NoManifestException(Exception):
"""The required manifest does not exist. """The required manifest does not exist."""
"""
def __init__(self, path, reason): def __init__(self, path, reason):
super().__init__(path, reason) super().__init__(path, reason)
@ -42,8 +38,7 @@ class NoManifestException(Exception):
class EditorError(Exception): class EditorError(Exception):
"""Unspecified error from the user's text editor. """Unspecified error from the user's text editor."""
"""
def __init__(self, reason): def __init__(self, reason):
super().__init__(reason) super().__init__(reason)
@ -54,8 +49,7 @@ class EditorError(Exception):
class GitError(Exception): class GitError(Exception):
"""Unspecified internal error from git. """Unspecified internal error from git."""
"""
def __init__(self, command): def __init__(self, command):
super().__init__(command) super().__init__(command)
@ -66,8 +60,7 @@ class GitError(Exception):
class UploadError(Exception): class UploadError(Exception):
"""A bundle upload to Gerrit did not succeed. """A bundle upload to Gerrit did not succeed."""
"""
def __init__(self, reason): def __init__(self, reason):
super().__init__(reason) super().__init__(reason)
@ -78,8 +71,7 @@ class UploadError(Exception):
class DownloadError(Exception): class DownloadError(Exception):
"""Cannot download a repository. """Cannot download a repository."""
"""
def __init__(self, reason): def __init__(self, reason):
super().__init__(reason) super().__init__(reason)
@ -90,8 +82,7 @@ class DownloadError(Exception):
class NoSuchProjectError(Exception): class NoSuchProjectError(Exception):
"""A specified project does not exist in the work tree. """A specified project does not exist in the work tree."""
"""
def __init__(self, name=None): def __init__(self, name=None):
super().__init__(name) super().__init__(name)
@ -99,13 +90,12 @@ class NoSuchProjectError(Exception):
def __str__(self): def __str__(self):
if self.name is None: if self.name is None:
return 'in current directory' return "in current directory"
return self.name return self.name
class InvalidProjectGroupsError(Exception): class InvalidProjectGroupsError(Exception):
"""A specified project is not suitable for the specified groups """A specified project is not suitable for the specified groups"""
"""
def __init__(self, name=None): def __init__(self, name=None):
super().__init__(name) super().__init__(name)
@ -113,7 +103,7 @@ class InvalidProjectGroupsError(Exception):
def __str__(self): def __str__(self):
if self.name is None: if self.name is None:
return 'in current directory' return "in current directory"
return self.name return self.name

View File

@ -15,9 +15,9 @@
import json import json
import multiprocessing import multiprocessing
TASK_COMMAND = 'command' TASK_COMMAND = "command"
TASK_SYNC_NETWORK = 'sync-network' TASK_SYNC_NETWORK = "sync-network"
TASK_SYNC_LOCAL = 'sync-local' TASK_SYNC_LOCAL = "sync-local"
class EventLog(object): class EventLog(object):
@ -51,8 +51,16 @@ class EventLog(object):
self._log = [] self._log = []
self._parent = None self._parent = None
def Add(self, name, task_name, start, finish=None, success=None, def Add(
try_count=1, kind='RepoOp'): self,
name,
task_name,
start,
finish=None,
success=None,
try_count=1,
kind="RepoOp",
):
"""Add an event to the log. """Add an event to the log.
Args: Args:
@ -68,15 +76,15 @@ class EventLog(object):
A dictionary of the event added to the log. A dictionary of the event added to the log.
""" """
event = { event = {
'id': (kind, _NextEventId()), "id": (kind, _NextEventId()),
'name': name, "name": name,
'task_name': task_name, "task_name": task_name,
'start_time': start, "start_time": start,
'try': try_count, "try": try_count,
} }
if self._parent: if self._parent:
event['parent'] = self._parent['id'] event["parent"] = self._parent["id"]
if success is not None or finish is not None: if success is not None or finish is not None:
self.FinishEvent(event, finish, success) self.FinishEvent(event, finish, success)
@ -100,15 +108,15 @@ class EventLog(object):
""" """
event = self.Add(project.relpath, task_name, start, finish, success) event = self.Add(project.relpath, task_name, start, finish, success)
if event is not None: if event is not None:
event['project'] = project.name event["project"] = project.name
if project.revisionExpr: if project.revisionExpr:
event['revision'] = project.revisionExpr event["revision"] = project.revisionExpr
if project.remote.url: if project.remote.url:
event['project_url'] = project.remote.url event["project_url"] = project.remote.url
if project.remote.fetchUrl: if project.remote.fetchUrl:
event['remote_url'] = project.remote.fetchUrl event["remote_url"] = project.remote.fetchUrl
try: try:
event['git_hash'] = project.GetCommitRevisionId() event["git_hash"] = project.GetCommitRevisionId()
except Exception: except Exception:
pass pass
return event return event
@ -122,7 +130,7 @@ class EventLog(object):
Returns: Returns:
status string. status string.
""" """
return 'pass' if success else 'fail' return "pass" if success else "fail"
def FinishEvent(self, event, finish, success): def FinishEvent(self, event, finish, success):
"""Finishes an incomplete event. """Finishes an incomplete event.
@ -135,8 +143,8 @@ class EventLog(object):
Returns: Returns:
A dictionary of the event added to the log. A dictionary of the event added to the log.
""" """
event['status'] = self.GetStatusString(success) event["status"] = self.GetStatusString(success)
event['finish_time'] = finish event["finish_time"] = finish
return event return event
def SetParent(self, event): def SetParent(self, event):
@ -153,14 +161,14 @@ class EventLog(object):
Args: Args:
filename: The file to write the log to. filename: The file to write the log to.
""" """
with open(filename, 'w+') as f: with open(filename, "w+") as f:
for e in self._log: for e in self._log:
json.dump(e, f, sort_keys=True) json.dump(e, f, sort_keys=True)
f.write('\n') f.write("\n")
# An integer id that is unique across this invocation of the program. # An integer id that is unique across this invocation of the program.
_EVENT_ID = multiprocessing.Value('i', 1) _EVENT_ID = multiprocessing.Value("i", 1)
def _NextEventId(): def _NextEventId():

View File

@ -27,19 +27,23 @@ def fetch_file(url, verbose=False):
The contents of the file as bytes. The contents of the file as bytes.
""" """
scheme = urlparse(url).scheme scheme = urlparse(url).scheme
if scheme == 'gs': if scheme == "gs":
cmd = ['gsutil', 'cat', url] cmd = ["gsutil", "cat", url]
try: try:
result = subprocess.run( result = subprocess.run(
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True
check=True) )
if result.stderr and verbose: if result.stderr and verbose:
print('warning: non-fatal error running "gsutil": %s' % result.stderr, print(
file=sys.stderr) 'warning: non-fatal error running "gsutil": %s'
% result.stderr,
file=sys.stderr,
)
return result.stdout return result.stdout
except subprocess.CalledProcessError as e: except subprocess.CalledProcessError as e:
print('fatal: error running "gsutil": %s' % e.stderr, print(
file=sys.stderr) 'fatal: error running "gsutil": %s' % e.stderr, file=sys.stderr
)
sys.exit(1) sys.exit(1)
with urlopen(url) as f: with urlopen(url) as f:
return f.read() return f.read()

View File

@ -24,7 +24,7 @@ import platform_utils
from repo_trace import REPO_TRACE, IsTrace, Trace from repo_trace import REPO_TRACE, IsTrace, Trace
from wrapper import Wrapper from wrapper import Wrapper
GIT = 'git' GIT = "git"
# NB: These do not need to be kept in sync with the repo launcher script. # NB: These do not need to be kept in sync with the repo launcher script.
# These may be much newer as it allows the repo launcher to roll between # These may be much newer as it allows the repo launcher to roll between
# different repo releases while source versions might require a newer git. # different repo releases while source versions might require a newer git.
@ -36,7 +36,7 @@ GIT = 'git'
# git-1.7 is in (EOL) Ubuntu Precise. git-1.9 is in Ubuntu Trusty. # git-1.7 is in (EOL) Ubuntu Precise. git-1.9 is in Ubuntu Trusty.
MIN_GIT_VERSION_SOFT = (1, 9, 1) MIN_GIT_VERSION_SOFT = (1, 9, 1)
MIN_GIT_VERSION_HARD = (1, 7, 2) MIN_GIT_VERSION_HARD = (1, 7, 2)
GIT_DIR = 'GIT_DIR' GIT_DIR = "GIT_DIR"
LAST_GITDIR = None LAST_GITDIR = None
LAST_CWD = None LAST_CWD = None
@ -47,17 +47,18 @@ class _GitCall(object):
def version_tuple(self): def version_tuple(self):
ret = Wrapper().ParseGitVersion() ret = Wrapper().ParseGitVersion()
if ret is None: if ret is None:
print('fatal: unable to detect git version', file=sys.stderr) print("fatal: unable to detect git version", file=sys.stderr)
sys.exit(1) sys.exit(1)
return ret return ret
def __getattr__(self, name): def __getattr__(self, name):
name = name.replace('_', '-') name = name.replace("_", "-")
def fun(*cmdv): def fun(*cmdv):
command = [name] command = [name]
command.extend(cmdv) command.extend(cmdv)
return GitCommand(None, command).Wait() == 0 return GitCommand(None, command).Wait() == 0
return fun return fun
@ -66,7 +67,7 @@ git = _GitCall()
def RepoSourceVersion(): def RepoSourceVersion():
"""Return the version of the repo.git tree.""" """Return the version of the repo.git tree."""
ver = getattr(RepoSourceVersion, 'version', None) ver = getattr(RepoSourceVersion, "version", None)
# We avoid GitCommand so we don't run into circular deps -- GitCommand needs # We avoid GitCommand so we don't run into circular deps -- GitCommand needs
# to initialize version info we provide. # to initialize version info we provide.
@ -74,17 +75,22 @@ def RepoSourceVersion():
env = GitCommand._GetBasicEnv() env = GitCommand._GetBasicEnv()
proj = os.path.dirname(os.path.abspath(__file__)) proj = os.path.dirname(os.path.abspath(__file__))
env[GIT_DIR] = os.path.join(proj, '.git') env[GIT_DIR] = os.path.join(proj, ".git")
result = subprocess.run([GIT, 'describe', HEAD], stdout=subprocess.PIPE, result = subprocess.run(
stderr=subprocess.DEVNULL, encoding='utf-8', [GIT, "describe", HEAD],
env=env, check=False) stdout=subprocess.PIPE,
stderr=subprocess.DEVNULL,
encoding="utf-8",
env=env,
check=False,
)
if result.returncode == 0: if result.returncode == 0:
ver = result.stdout.strip() ver = result.stdout.strip()
if ver.startswith('v'): if ver.startswith("v"):
ver = ver[1:] ver = ver[1:]
else: else:
ver = 'unknown' ver = "unknown"
setattr(RepoSourceVersion, 'version', ver) setattr(RepoSourceVersion, "version", ver)
return ver return ver
@ -105,14 +111,14 @@ class UserAgent(object):
"""The operating system name.""" """The operating system name."""
if self._os is None: if self._os is None:
os_name = sys.platform os_name = sys.platform
if os_name.lower().startswith('linux'): if os_name.lower().startswith("linux"):
os_name = 'Linux' os_name = "Linux"
elif os_name == 'win32': elif os_name == "win32":
os_name = 'Win32' os_name = "Win32"
elif os_name == 'cygwin': elif os_name == "cygwin":
os_name = 'Cygwin' os_name = "Cygwin"
elif os_name == 'darwin': elif os_name == "darwin":
os_name = 'Darwin' os_name = "Darwin"
self._os = os_name self._os = os_name
return self._os return self._os
@ -122,11 +128,14 @@ class UserAgent(object):
"""The UA when connecting directly from repo.""" """The UA when connecting directly from repo."""
if self._repo_ua is None: if self._repo_ua is None:
py_version = sys.version_info py_version = sys.version_info
self._repo_ua = 'git-repo/%s (%s) git/%s Python/%d.%d.%d' % ( self._repo_ua = "git-repo/%s (%s) git/%s Python/%d.%d.%d" % (
RepoSourceVersion(), RepoSourceVersion(),
self.os, self.os,
git.version_tuple().full, git.version_tuple().full,
py_version.major, py_version.minor, py_version.micro) py_version.major,
py_version.minor,
py_version.micro,
)
return self._repo_ua return self._repo_ua
@ -134,10 +143,11 @@ class UserAgent(object):
def git(self): def git(self):
"""The UA when running git.""" """The UA when running git."""
if self._git_ua is None: if self._git_ua is None:
self._git_ua = 'git/%s (%s) git-repo/%s' % ( self._git_ua = "git/%s (%s) git-repo/%s" % (
git.version_tuple().full, git.version_tuple().full,
self.os, self.os,
RepoSourceVersion()) RepoSourceVersion(),
)
return self._git_ua return self._git_ua
@ -145,15 +155,17 @@ class UserAgent(object):
user_agent = UserAgent() user_agent = UserAgent()
def git_require(min_version, fail=False, msg=''): def git_require(min_version, fail=False, msg=""):
git_version = git.version_tuple() git_version = git.version_tuple()
if min_version <= git_version: if min_version <= git_version:
return True return True
if fail: if fail:
need = '.'.join(map(str, min_version)) need = ".".join(map(str, min_version))
if msg: if msg:
msg = ' for ' + msg msg = " for " + msg
print('fatal: git %s or later required%s' % (need, msg), file=sys.stderr) print(
"fatal: git %s or later required%s" % (need, msg), file=sys.stderr
)
sys.exit(1) sys.exit(1)
return False return False
@ -164,40 +176,44 @@ def _build_env(
disable_editor: Optional[bool] = False, disable_editor: Optional[bool] = False,
ssh_proxy: Optional[Any] = None, ssh_proxy: Optional[Any] = None,
gitdir: Optional[str] = None, gitdir: Optional[str] = None,
objdir: Optional[str] = None objdir: Optional[str] = None,
): ):
"""Constucts an env dict for command execution.""" """Constucts an env dict for command execution."""
assert _kwargs_only == (), '_build_env only accepts keyword arguments.' assert _kwargs_only == (), "_build_env only accepts keyword arguments."
env = GitCommand._GetBasicEnv() env = GitCommand._GetBasicEnv()
if disable_editor: if disable_editor:
env['GIT_EDITOR'] = ':' env["GIT_EDITOR"] = ":"
if ssh_proxy: if ssh_proxy:
env['REPO_SSH_SOCK'] = ssh_proxy.sock() env["REPO_SSH_SOCK"] = ssh_proxy.sock()
env['GIT_SSH'] = ssh_proxy.proxy env["GIT_SSH"] = ssh_proxy.proxy
env['GIT_SSH_VARIANT'] = 'ssh' env["GIT_SSH_VARIANT"] = "ssh"
if 'http_proxy' in env and 'darwin' == sys.platform: if "http_proxy" in env and "darwin" == sys.platform:
s = "'http.proxy=%s'" % (env['http_proxy'],) s = "'http.proxy=%s'" % (env["http_proxy"],)
p = env.get('GIT_CONFIG_PARAMETERS') p = env.get("GIT_CONFIG_PARAMETERS")
if p is not None: if p is not None:
s = p + ' ' + s s = p + " " + s
env['GIT_CONFIG_PARAMETERS'] = s env["GIT_CONFIG_PARAMETERS"] = s
if 'GIT_ALLOW_PROTOCOL' not in env: if "GIT_ALLOW_PROTOCOL" not in env:
env['GIT_ALLOW_PROTOCOL'] = ( env[
'file:git:http:https:ssh:persistent-http:persistent-https:sso:rpc') "GIT_ALLOW_PROTOCOL"
env['GIT_HTTP_USER_AGENT'] = user_agent.git ] = "file:git:http:https:ssh:persistent-http:persistent-https:sso:rpc"
env["GIT_HTTP_USER_AGENT"] = user_agent.git
if objdir: if objdir:
# Set to the place we want to save the objects. # Set to the place we want to save the objects.
env['GIT_OBJECT_DIRECTORY'] = objdir env["GIT_OBJECT_DIRECTORY"] = objdir
alt_objects = os.path.join(gitdir, 'objects') if gitdir else None alt_objects = os.path.join(gitdir, "objects") if gitdir else None
if alt_objects and os.path.realpath(alt_objects) != os.path.realpath(objdir): if alt_objects and os.path.realpath(alt_objects) != os.path.realpath(
# Allow git to search the original place in case of local or unique refs objdir
# that git will attempt to resolve even if we aren't fetching them. ):
env['GIT_ALTERNATE_OBJECT_DIRECTORIES'] = alt_objects # Allow git to search the original place in case of local or unique
# refs that git will attempt to resolve even if we aren't fetching
# them.
env["GIT_ALTERNATE_OBJECT_DIRECTORIES"] = alt_objects
if bare and gitdir is not None: if bare and gitdir is not None:
env[GIT_DIR] = gitdir env[GIT_DIR] = gitdir
@ -207,7 +223,8 @@ def _build_env(
class GitCommand(object): class GitCommand(object):
"""Wrapper around a single git invocation.""" """Wrapper around a single git invocation."""
def __init__(self, def __init__(
self,
project, project,
cmdv, cmdv,
bare=False, bare=False,
@ -219,8 +236,8 @@ class GitCommand(object):
ssh_proxy=None, ssh_proxy=None,
cwd=None, cwd=None,
gitdir=None, gitdir=None,
objdir=None): objdir=None,
):
if project: if project:
if not cwd: if not cwd:
cwd = project.worktree cwd = project.worktree
@ -230,9 +247,9 @@ class GitCommand(object):
# Git on Windows wants its paths only using / for reliability. # Git on Windows wants its paths only using / for reliability.
if platform_utils.isWindows(): if platform_utils.isWindows():
if objdir: if objdir:
objdir = objdir.replace('\\', '/') objdir = objdir.replace("\\", "/")
if gitdir: if gitdir:
gitdir = gitdir.replace('\\', '/') gitdir = gitdir.replace("\\", "/")
env = _build_env( env = _build_env(
disable_editor=disable_editor, disable_editor=disable_editor,
@ -247,63 +264,75 @@ class GitCommand(object):
cwd = None cwd = None
command.append(cmdv[0]) command.append(cmdv[0])
# Need to use the --progress flag for fetch/clone so output will be # Need to use the --progress flag for fetch/clone so output will be
# displayed as by default git only does progress output if stderr is a TTY. # displayed as by default git only does progress output if stderr is a
if sys.stderr.isatty() and cmdv[0] in ('fetch', 'clone'): # TTY.
if '--progress' not in cmdv and '--quiet' not in cmdv: if sys.stderr.isatty() and cmdv[0] in ("fetch", "clone"):
command.append('--progress') if "--progress" not in cmdv and "--quiet" not in cmdv:
command.append("--progress")
command.extend(cmdv[1:]) command.extend(cmdv[1:])
stdin = subprocess.PIPE if input else None stdin = subprocess.PIPE if input else None
stdout = subprocess.PIPE if capture_stdout else None stdout = subprocess.PIPE if capture_stdout else None
stderr = (subprocess.STDOUT if merge_output else stderr = (
(subprocess.PIPE if capture_stderr else None)) subprocess.STDOUT
if merge_output
else (subprocess.PIPE if capture_stderr else None)
)
dbg = '' dbg = ""
if IsTrace(): if IsTrace():
global LAST_CWD global LAST_CWD
global LAST_GITDIR global LAST_GITDIR
if cwd and LAST_CWD != cwd: if cwd and LAST_CWD != cwd:
if LAST_GITDIR or LAST_CWD: if LAST_GITDIR or LAST_CWD:
dbg += '\n' dbg += "\n"
dbg += ': cd %s\n' % cwd dbg += ": cd %s\n" % cwd
LAST_CWD = cwd LAST_CWD = cwd
if GIT_DIR in env and LAST_GITDIR != env[GIT_DIR]: if GIT_DIR in env and LAST_GITDIR != env[GIT_DIR]:
if LAST_GITDIR or LAST_CWD: if LAST_GITDIR or LAST_CWD:
dbg += '\n' dbg += "\n"
dbg += ': export GIT_DIR=%s\n' % env[GIT_DIR] dbg += ": export GIT_DIR=%s\n" % env[GIT_DIR]
LAST_GITDIR = env[GIT_DIR] LAST_GITDIR = env[GIT_DIR]
if 'GIT_OBJECT_DIRECTORY' in env: if "GIT_OBJECT_DIRECTORY" in env:
dbg += ': export GIT_OBJECT_DIRECTORY=%s\n' % env['GIT_OBJECT_DIRECTORY'] dbg += (
if 'GIT_ALTERNATE_OBJECT_DIRECTORIES' in env: ": export GIT_OBJECT_DIRECTORY=%s\n"
dbg += ': export GIT_ALTERNATE_OBJECT_DIRECTORIES=%s\n' % ( % env["GIT_OBJECT_DIRECTORY"]
env['GIT_ALTERNATE_OBJECT_DIRECTORIES']) )
if "GIT_ALTERNATE_OBJECT_DIRECTORIES" in env:
dbg += ": export GIT_ALTERNATE_OBJECT_DIRECTORIES=%s\n" % (
env["GIT_ALTERNATE_OBJECT_DIRECTORIES"]
)
dbg += ': ' dbg += ": "
dbg += ' '.join(command) dbg += " ".join(command)
if stdin == subprocess.PIPE: if stdin == subprocess.PIPE:
dbg += ' 0<|' dbg += " 0<|"
if stdout == subprocess.PIPE: if stdout == subprocess.PIPE:
dbg += ' 1>|' dbg += " 1>|"
if stderr == subprocess.PIPE: if stderr == subprocess.PIPE:
dbg += ' 2>|' dbg += " 2>|"
elif stderr == subprocess.STDOUT: elif stderr == subprocess.STDOUT:
dbg += ' 2>&1' dbg += " 2>&1"
with Trace('git command %s %s with debug: %s', LAST_GITDIR, command, dbg): with Trace(
"git command %s %s with debug: %s", LAST_GITDIR, command, dbg
):
try: try:
p = subprocess.Popen(command, p = subprocess.Popen(
command,
cwd=cwd, cwd=cwd,
env=env, env=env,
encoding='utf-8', encoding="utf-8",
errors='backslashreplace', errors="backslashreplace",
stdin=stdin, stdin=stdin,
stdout=stdout, stdout=stdout,
stderr=stderr) stderr=stderr,
)
except Exception as e: except Exception as e:
raise GitError('%s: %s' % (command[1], e)) raise GitError("%s: %s" % (command[1], e))
if ssh_proxy: if ssh_proxy:
ssh_proxy.add_client(p) ssh_proxy.add_client(p)
@ -324,13 +353,15 @@ class GitCommand(object):
This is guaranteed to be side-effect free. This is guaranteed to be side-effect free.
""" """
env = os.environ.copy() env = os.environ.copy()
for key in (REPO_TRACE, for key in (
REPO_TRACE,
GIT_DIR, GIT_DIR,
'GIT_ALTERNATE_OBJECT_DIRECTORIES', "GIT_ALTERNATE_OBJECT_DIRECTORIES",
'GIT_OBJECT_DIRECTORY', "GIT_OBJECT_DIRECTORY",
'GIT_WORK_TREE', "GIT_WORK_TREE",
'GIT_GRAFT_FILE', "GIT_GRAFT_FILE",
'GIT_INDEX_FILE'): "GIT_INDEX_FILE",
):
env.pop(key, None) env.pop(key, None)
return env return env

View File

@ -34,9 +34,9 @@ from git_refs import R_CHANGES, R_HEADS, R_TAGS
# Prefix that is prepended to all the keys of SyncAnalysisState's data # Prefix that is prepended to all the keys of SyncAnalysisState's data
# that is saved in the config. # that is saved in the config.
SYNC_STATE_PREFIX = 'repo.syncstate.' SYNC_STATE_PREFIX = "repo.syncstate."
ID_RE = re.compile(r'^[0-9a-f]{40}$') ID_RE = re.compile(r"^[0-9a-f]{40}$")
REVIEW_CACHE = dict() REVIEW_CACHE = dict()
@ -58,19 +58,19 @@ def IsImmutable(rev):
def _key(name): def _key(name):
parts = name.split('.') parts = name.split(".")
if len(parts) < 2: if len(parts) < 2:
return name.lower() return name.lower()
parts[0] = parts[0].lower() parts[0] = parts[0].lower()
parts[-1] = parts[-1].lower() parts[-1] = parts[-1].lower()
return '.'.join(parts) return ".".join(parts)
class GitConfig(object): class GitConfig(object):
_ForUser = None _ForUser = None
_ForSystem = None _ForSystem = None
_SYSTEM_CONFIG = '/etc/gitconfig' _SYSTEM_CONFIG = "/etc/gitconfig"
@classmethod @classmethod
def ForSystem(cls): def ForSystem(cls):
@ -86,12 +86,11 @@ class GitConfig(object):
@staticmethod @staticmethod
def _getUserConfig(): def _getUserConfig():
return os.path.expanduser('~/.gitconfig') return os.path.expanduser("~/.gitconfig")
@classmethod @classmethod
def ForRepository(cls, gitdir, defaults=None): def ForRepository(cls, gitdir, defaults=None):
return cls(configfile=os.path.join(gitdir, 'config'), return cls(configfile=os.path.join(gitdir, "config"), defaults=defaults)
defaults=defaults)
def __init__(self, configfile, defaults=None, jsonFile=None): def __init__(self, configfile, defaults=None, jsonFile=None):
self.file = configfile self.file = configfile
@ -105,15 +104,15 @@ class GitConfig(object):
if self._json is None: if self._json is None:
self._json = os.path.join( self._json = os.path.join(
os.path.dirname(self.file), os.path.dirname(self.file),
'.repo_' + os.path.basename(self.file) + '.json') ".repo_" + os.path.basename(self.file) + ".json",
)
def ClearCache(self): def ClearCache(self):
"""Clear the in-memory cache of config.""" """Clear the in-memory cache of config."""
self._cache_dict = None self._cache_dict = None
def Has(self, name, include_defaults=True): def Has(self, name, include_defaults=True):
"""Return true if this configuration file has the key. """Return true if this configuration file has the key."""
"""
if _key(name) in self._cache: if _key(name) in self._cache:
return True return True
if include_defaults and self.defaults: if include_defaults and self.defaults:
@ -138,26 +137,28 @@ class GitConfig(object):
v = v.strip() v = v.strip()
mult = 1 mult = 1
if v.endswith('k'): if v.endswith("k"):
v = v[:-1] v = v[:-1]
mult = 1024 mult = 1024
elif v.endswith('m'): elif v.endswith("m"):
v = v[:-1] v = v[:-1]
mult = 1024 * 1024 mult = 1024 * 1024
elif v.endswith('g'): elif v.endswith("g"):
v = v[:-1] v = v[:-1]
mult = 1024 * 1024 * 1024 mult = 1024 * 1024 * 1024
base = 10 base = 10
if v.startswith('0x'): if v.startswith("0x"):
base = 16 base = 16
try: try:
return int(v, base=base) * mult return int(v, base=base) * mult
except ValueError: except ValueError:
print( print(
f"warning: expected {name} to represent an integer, got {v} instead", f"warning: expected {name} to represent an integer, got {v} "
file=sys.stderr) "instead",
file=sys.stderr,
)
return None return None
def DumpConfigDict(self): def DumpConfigDict(self):
@ -177,26 +178,30 @@ class GitConfig(object):
def GetBoolean(self, name: str) -> Union[str, None]: def GetBoolean(self, name: str) -> Union[str, None]:
"""Returns a boolean from the configuration file. """Returns a boolean from the configuration file.
None : The value was not defined, or is not a boolean.
True : The value was set to true or yes. Returns:
None: The value was not defined, or is not a boolean.
True: The value was set to true or yes.
False: The value was set to false or no. False: The value was set to false or no.
""" """
v = self.GetString(name) v = self.GetString(name)
if v is None: if v is None:
return None return None
v = v.lower() v = v.lower()
if v in ('true', 'yes'): if v in ("true", "yes"):
return True return True
if v in ('false', 'no'): if v in ("false", "no"):
return False return False
print(f"warning: expected {name} to represent a boolean, got {v} instead", print(
file=sys.stderr) f"warning: expected {name} to represent a boolean, got {v} instead",
file=sys.stderr,
)
return None return None
def SetBoolean(self, name, value): def SetBoolean(self, name, value):
"""Set the truthy value for a key.""" """Set the truthy value for a key."""
if value is not None: if value is not None:
value = 'true' if value else 'false' value = "true" if value else "false"
self.SetString(name, value) self.SetString(name, value)
def GetString(self, name: str, all_keys: bool = False) -> Union[str, None]: def GetString(self, name: str, all_keys: bool = False) -> Union[str, None]:
@ -240,7 +245,7 @@ class GitConfig(object):
if value is None: if value is None:
if old: if old:
del self._cache[key] del self._cache[key]
self._do('--unset-all', name) self._do("--unset-all", name)
elif isinstance(value, list): elif isinstance(value, list):
if len(value) == 0: if len(value) == 0:
@ -251,17 +256,16 @@ class GitConfig(object):
elif old != value: elif old != value:
self._cache[key] = list(value) self._cache[key] = list(value)
self._do('--replace-all', name, value[0]) self._do("--replace-all", name, value[0])
for i in range(1, len(value)): for i in range(1, len(value)):
self._do('--add', name, value[i]) self._do("--add", name, value[i])
elif len(old) != 1 or old[0] != value: elif len(old) != 1 or old[0] != value:
self._cache[key] = [value] self._cache[key] = [value]
self._do('--replace-all', name, value) self._do("--replace-all", name, value)
def GetRemote(self, name): def GetRemote(self, name):
"""Get the remote.$name.* configuration values as an object. """Get the remote.$name.* configuration values as an object."""
"""
try: try:
r = self._remotes[name] r = self._remotes[name]
except KeyError: except KeyError:
@ -270,8 +274,7 @@ class GitConfig(object):
return r return r
def GetBranch(self, name): def GetBranch(self, name):
"""Get the branch.$name.* configuration values as an object. """Get the branch.$name.* configuration values as an object."""
"""
try: try:
b = self._branches[name] b = self._branches[name]
except KeyError: except KeyError:
@ -281,14 +284,20 @@ class GitConfig(object):
def GetSyncAnalysisStateData(self): def GetSyncAnalysisStateData(self):
"""Returns data to be logged for the analysis of sync performance.""" """Returns data to be logged for the analysis of sync performance."""
return {k: v for k, v in self.DumpConfigDict().items() if k.startswith(SYNC_STATE_PREFIX)} return {
k: v
for k, v in self.DumpConfigDict().items()
if k.startswith(SYNC_STATE_PREFIX)
}
def UpdateSyncAnalysisState(self, options, superproject_logging_data): def UpdateSyncAnalysisState(self, options, superproject_logging_data):
"""Update Config's SYNC_STATE_PREFIX* data with the latest sync data. """Update Config's SYNC_STATE_PREFIX* data with the latest sync data.
Args: Args:
options: Options passed to sync returned from optparse. See _Options(). options: Options passed to sync returned from optparse. See
superproject_logging_data: A dictionary of superproject data that is to be logged. _Options().
superproject_logging_data: A dictionary of superproject data that is
to be logged.
Returns: Returns:
SyncAnalysisState object. SyncAnalysisState object.
@ -296,25 +305,22 @@ class GitConfig(object):
return SyncAnalysisState(self, options, superproject_logging_data) return SyncAnalysisState(self, options, superproject_logging_data)
def GetSubSections(self, section): def GetSubSections(self, section):
"""List all subsection names matching $section.*.* """List all subsection names matching $section.*.*"""
"""
return self._sections.get(section, set()) return self._sections.get(section, set())
def HasSection(self, section, subsection=''): def HasSection(self, section, subsection=""):
"""Does at least one key in section.subsection exist? """Does at least one key in section.subsection exist?"""
"""
try: try:
return subsection in self._sections[section] return subsection in self._sections[section]
except KeyError: except KeyError:
return False return False
def UrlInsteadOf(self, url): def UrlInsteadOf(self, url):
"""Resolve any url.*.insteadof references. """Resolve any url.*.insteadof references."""
""" for new_url in self.GetSubSections("url"):
for new_url in self.GetSubSections('url'): for old_url in self.GetString("url.%s.insteadof" % new_url, True):
for old_url in self.GetString('url.%s.insteadof' % new_url, True):
if old_url is not None and url.startswith(old_url): if old_url is not None and url.startswith(old_url):
return new_url + url[len(old_url):] return new_url + url[len(old_url) :]
return url return url
@property @property
@ -323,13 +329,13 @@ class GitConfig(object):
if d is None: if d is None:
d = {} d = {}
for name in self._cache.keys(): for name in self._cache.keys():
p = name.split('.') p = name.split(".")
if 2 == len(p): if 2 == len(p):
section = p[0] section = p[0]
subsect = '' subsect = ""
else: else:
section = p[0] section = p[0]
subsect = '.'.join(p[1:-1]) subsect = ".".join(p[1:-1])
if section not in d: if section not in d:
d[section] = set() d[section] = set()
d[section].add(subsect) d[section].add(subsect)
@ -357,7 +363,7 @@ class GitConfig(object):
except OSError: except OSError:
return None return None
try: try:
with Trace(': parsing %s', self.file): with Trace(": parsing %s", self.file):
with open(self._json) as fd: with open(self._json) as fd:
return json.load(fd) return json.load(fd)
except (IOError, ValueError): except (IOError, ValueError):
@ -366,7 +372,7 @@ class GitConfig(object):
def _SaveJson(self, cache): def _SaveJson(self, cache):
try: try:
with open(self._json, 'w') as fd: with open(self._json, "w") as fd:
json.dump(cache, fd, indent=2) json.dump(cache, fd, indent=2)
except (IOError, TypeError): except (IOError, TypeError):
platform_utils.remove(self._json, missing_ok=True) platform_utils.remove(self._json, missing_ok=True)
@ -382,10 +388,10 @@ class GitConfig(object):
if not os.path.exists(self.file): if not os.path.exists(self.file):
return c return c
d = self._do('--null', '--list') d = self._do("--null", "--list")
for line in d.rstrip('\0').split('\0'): for line in d.rstrip("\0").split("\0"):
if '\n' in line: if "\n" in line:
key, val = line.split('\n', 1) key, val = line.split("\n", 1)
else: else:
key = line key = line
val = None val = None
@ -399,19 +405,16 @@ class GitConfig(object):
def _do(self, *args): def _do(self, *args):
if self.file == self._SYSTEM_CONFIG: if self.file == self._SYSTEM_CONFIG:
command = ['config', '--system', '--includes'] command = ["config", "--system", "--includes"]
else: else:
command = ['config', '--file', self.file, '--includes'] command = ["config", "--file", self.file, "--includes"]
command.extend(args) command.extend(args)
p = GitCommand(None, p = GitCommand(None, command, capture_stdout=True, capture_stderr=True)
command,
capture_stdout=True,
capture_stderr=True)
if p.Wait() == 0: if p.Wait() == 0:
return p.stdout return p.stdout
else: else:
raise GitError('git config %s: %s' % (str(args), p.stderr)) raise GitError("git config %s: %s" % (str(args), p.stderr))
class RepoConfig(GitConfig): class RepoConfig(GitConfig):
@ -419,8 +422,8 @@ class RepoConfig(GitConfig):
@staticmethod @staticmethod
def _getUserConfig(): def _getUserConfig():
repo_config_dir = os.getenv('REPO_CONFIG_DIR', os.path.expanduser('~')) repo_config_dir = os.getenv("REPO_CONFIG_DIR", os.path.expanduser("~"))
return os.path.join(repo_config_dir, '.repoconfig/config') return os.path.join(repo_config_dir, ".repoconfig/config")
class RefSpec(object): class RefSpec(object):
@ -433,8 +436,8 @@ class RefSpec(object):
@classmethod @classmethod
def FromString(cls, rs): def FromString(cls, rs):
lhs, rhs = rs.split(':', 2) lhs, rhs = rs.split(":", 2)
if lhs.startswith('+'): if lhs.startswith("+"):
lhs = lhs[1:] lhs = lhs[1:]
forced = True forced = True
else: else:
@ -450,7 +453,7 @@ class RefSpec(object):
if self.src: if self.src:
if rev == self.src: if rev == self.src:
return True return True
if self.src.endswith('/*') and rev.startswith(self.src[:-1]): if self.src.endswith("/*") and rev.startswith(self.src[:-1]):
return True return True
return False return False
@ -458,28 +461,28 @@ class RefSpec(object):
if self.dst: if self.dst:
if ref == self.dst: if ref == self.dst:
return True return True
if self.dst.endswith('/*') and ref.startswith(self.dst[:-1]): if self.dst.endswith("/*") and ref.startswith(self.dst[:-1]):
return True return True
return False return False
def MapSource(self, rev): def MapSource(self, rev):
if self.src.endswith('/*'): if self.src.endswith("/*"):
return self.dst[:-1] + rev[len(self.src) - 1:] return self.dst[:-1] + rev[len(self.src) - 1 :]
return self.dst return self.dst
def __str__(self): def __str__(self):
s = '' s = ""
if self.forced: if self.forced:
s += '+' s += "+"
if self.src: if self.src:
s += self.src s += self.src
if self.dst: if self.dst:
s += ':' s += ":"
s += self.dst s += self.dst
return s return s
URI_ALL = re.compile(r'^([a-z][a-z+-]*)://([^@/]*@?[^/]*)/') URI_ALL = re.compile(r"^([a-z][a-z+-]*)://([^@/]*@?[^/]*)/")
def GetSchemeFromUrl(url): def GetSchemeFromUrl(url):
@ -491,23 +494,27 @@ def GetSchemeFromUrl(url):
@contextlib.contextmanager @contextlib.contextmanager
def GetUrlCookieFile(url, quiet): def GetUrlCookieFile(url, quiet):
if url.startswith('persistent-'): if url.startswith("persistent-"):
try: try:
p = subprocess.Popen( p = subprocess.Popen(
['git-remote-persistent-https', '-print_config', url], ["git-remote-persistent-https", "-print_config", url],
stdin=subprocess.PIPE, stdout=subprocess.PIPE, stdin=subprocess.PIPE,
stderr=subprocess.PIPE) stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
try: try:
cookieprefix = 'http.cookiefile=' cookieprefix = "http.cookiefile="
proxyprefix = 'http.proxy=' proxyprefix = "http.proxy="
cookiefile = None cookiefile = None
proxy = None proxy = None
for line in p.stdout: for line in p.stdout:
line = line.strip().decode('utf-8') line = line.strip().decode("utf-8")
if line.startswith(cookieprefix): if line.startswith(cookieprefix):
cookiefile = os.path.expanduser(line[len(cookieprefix):]) cookiefile = os.path.expanduser(
line[len(cookieprefix) :]
)
if line.startswith(proxyprefix): if line.startswith(proxyprefix):
proxy = line[len(proxyprefix):] proxy = line[len(proxyprefix) :]
# Leave subprocess open, as cookie file may be transient. # Leave subprocess open, as cookie file may be transient.
if cookiefile or proxy: if cookiefile or proxy:
yield cookiefile, proxy yield cookiefile, proxy
@ -515,8 +522,8 @@ def GetUrlCookieFile(url, quiet):
finally: finally:
p.stdin.close() p.stdin.close()
if p.wait(): if p.wait():
err_msg = p.stderr.read().decode('utf-8') err_msg = p.stderr.read().decode("utf-8")
if ' -print_config' in err_msg: if " -print_config" in err_msg:
pass # Persistent proxy doesn't support -print_config. pass # Persistent proxy doesn't support -print_config.
elif not quiet: elif not quiet:
print(err_msg, file=sys.stderr) print(err_msg, file=sys.stderr)
@ -524,30 +531,30 @@ def GetUrlCookieFile(url, quiet):
if e.errno == errno.ENOENT: if e.errno == errno.ENOENT:
pass # No persistent proxy. pass # No persistent proxy.
raise raise
cookiefile = GitConfig.ForUser().GetString('http.cookiefile') cookiefile = GitConfig.ForUser().GetString("http.cookiefile")
if cookiefile: if cookiefile:
cookiefile = os.path.expanduser(cookiefile) cookiefile = os.path.expanduser(cookiefile)
yield cookiefile, None yield cookiefile, None
class Remote(object): class Remote(object):
"""Configuration options related to a remote. """Configuration options related to a remote."""
"""
def __init__(self, config, name): def __init__(self, config, name):
self._config = config self._config = config
self.name = name self.name = name
self.url = self._Get('url') self.url = self._Get("url")
self.pushUrl = self._Get('pushurl') self.pushUrl = self._Get("pushurl")
self.review = self._Get('review') self.review = self._Get("review")
self.projectname = self._Get('projectname') self.projectname = self._Get("projectname")
self.fetch = list(map(RefSpec.FromString, self.fetch = list(
self._Get('fetch', all_keys=True))) map(RefSpec.FromString, self._Get("fetch", all_keys=True))
)
self._review_url = None self._review_url = None
def _InsteadOf(self): def _InsteadOf(self):
globCfg = GitConfig.ForUser() globCfg = GitConfig.ForUser()
urlList = globCfg.GetSubSections('url') urlList = globCfg.GetSubSections("url")
longest = "" longest = ""
longestUrl = "" longestUrl = ""
@ -556,8 +563,9 @@ class Remote(object):
insteadOfList = globCfg.GetString(key, all_keys=True) insteadOfList = globCfg.GetString(key, all_keys=True)
for insteadOf in insteadOfList: for insteadOf in insteadOfList:
if (self.url.startswith(insteadOf) if self.url.startswith(insteadOf) and len(insteadOf) > len(
and len(insteadOf) > len(longest)): longest
):
longest = insteadOf longest = insteadOf
longestUrl = url longestUrl = url
@ -590,71 +598,78 @@ class Remote(object):
return None return None
u = self.review u = self.review
if u.startswith('persistent-'): if u.startswith("persistent-"):
u = u[len('persistent-'):] u = u[len("persistent-") :]
if u.split(':')[0] not in ('http', 'https', 'sso', 'ssh'): if u.split(":")[0] not in ("http", "https", "sso", "ssh"):
u = 'http://%s' % u u = "http://%s" % u
if u.endswith('/Gerrit'): if u.endswith("/Gerrit"):
u = u[:len(u) - len('/Gerrit')] u = u[: len(u) - len("/Gerrit")]
if u.endswith('/ssh_info'): if u.endswith("/ssh_info"):
u = u[:len(u) - len('/ssh_info')] u = u[: len(u) - len("/ssh_info")]
if not u.endswith('/'): if not u.endswith("/"):
u += '/' u += "/"
http_url = u http_url = u
if u in REVIEW_CACHE: if u in REVIEW_CACHE:
self._review_url = REVIEW_CACHE[u] self._review_url = REVIEW_CACHE[u]
elif 'REPO_HOST_PORT_INFO' in os.environ: elif "REPO_HOST_PORT_INFO" in os.environ:
host, port = os.environ['REPO_HOST_PORT_INFO'].split() host, port = os.environ["REPO_HOST_PORT_INFO"].split()
self._review_url = self._SshReviewUrl(userEmail, host, port) self._review_url = self._SshReviewUrl(userEmail, host, port)
REVIEW_CACHE[u] = self._review_url REVIEW_CACHE[u] = self._review_url
elif u.startswith('sso:') or u.startswith('ssh:'): elif u.startswith("sso:") or u.startswith("ssh:"):
self._review_url = u # Assume it's right self._review_url = u # Assume it's right
REVIEW_CACHE[u] = self._review_url REVIEW_CACHE[u] = self._review_url
elif 'REPO_IGNORE_SSH_INFO' in os.environ: elif "REPO_IGNORE_SSH_INFO" in os.environ:
self._review_url = http_url self._review_url = http_url
REVIEW_CACHE[u] = self._review_url REVIEW_CACHE[u] = self._review_url
else: else:
try: try:
info_url = u + 'ssh_info' info_url = u + "ssh_info"
if not validate_certs: if not validate_certs:
context = ssl._create_unverified_context() context = ssl._create_unverified_context()
info = urllib.request.urlopen(info_url, context=context).read() info = urllib.request.urlopen(
info_url, context=context
).read()
else: else:
info = urllib.request.urlopen(info_url).read() info = urllib.request.urlopen(info_url).read()
if info == b'NOT_AVAILABLE' or b'<' in info: if info == b"NOT_AVAILABLE" or b"<" in info:
# If `info` contains '<', we assume the server gave us some sort # If `info` contains '<', we assume the server gave us
# of HTML response back, like maybe a login page. # some sort of HTML response back, like maybe a login
# page.
# #
# Assume HTTP if SSH is not enabled or ssh_info doesn't look right. # Assume HTTP if SSH is not enabled or ssh_info doesn't
# look right.
self._review_url = http_url self._review_url = http_url
else: else:
info = info.decode('utf-8') info = info.decode("utf-8")
host, port = info.split() host, port = info.split()
self._review_url = self._SshReviewUrl(userEmail, host, port) self._review_url = self._SshReviewUrl(
userEmail, host, port
)
except urllib.error.HTTPError as e: except urllib.error.HTTPError as e:
raise UploadError('%s: %s' % (self.review, str(e))) raise UploadError("%s: %s" % (self.review, str(e)))
except urllib.error.URLError as e: except urllib.error.URLError as e:
raise UploadError('%s: %s' % (self.review, str(e))) raise UploadError("%s: %s" % (self.review, str(e)))
except HTTPException as e: except HTTPException as e:
raise UploadError('%s: %s' % (self.review, e.__class__.__name__)) raise UploadError(
"%s: %s" % (self.review, e.__class__.__name__)
)
REVIEW_CACHE[u] = self._review_url REVIEW_CACHE[u] = self._review_url
return self._review_url + self.projectname return self._review_url + self.projectname
def _SshReviewUrl(self, userEmail, host, port): def _SshReviewUrl(self, userEmail, host, port):
username = self._config.GetString('review.%s.username' % self.review) username = self._config.GetString("review.%s.username" % self.review)
if username is None: if username is None:
username = userEmail.split('@')[0] username = userEmail.split("@")[0]
return 'ssh://%s@%s:%s/' % (username, host, port) return "ssh://%s@%s:%s/" % (username, host, port)
def ToLocal(self, rev): def ToLocal(self, rev):
"""Convert a remote revision string to something we have locally. """Convert a remote revision string to something we have locally."""
""" if self.name == "." or IsId(rev):
if self.name == '.' or IsId(rev):
return rev return rev
if not rev.startswith('refs/'): if not rev.startswith("refs/"):
rev = R_HEADS + rev rev = R_HEADS + rev
for spec in self.fetch: for spec in self.fetch:
@ -664,57 +679,55 @@ class Remote(object):
if not rev.startswith(R_HEADS): if not rev.startswith(R_HEADS):
return rev return rev
raise GitError('%s: remote %s does not have %s' % raise GitError(
(self.projectname, self.name, rev)) "%s: remote %s does not have %s"
% (self.projectname, self.name, rev)
)
def WritesTo(self, ref): def WritesTo(self, ref):
"""True if the remote stores to the tracking ref. """True if the remote stores to the tracking ref."""
"""
for spec in self.fetch: for spec in self.fetch:
if spec.DestMatches(ref): if spec.DestMatches(ref):
return True return True
return False return False
def ResetFetch(self, mirror=False): def ResetFetch(self, mirror=False):
"""Set the fetch refspec to its default value. """Set the fetch refspec to its default value."""
"""
if mirror: if mirror:
dst = 'refs/heads/*' dst = "refs/heads/*"
else: else:
dst = 'refs/remotes/%s/*' % self.name dst = "refs/remotes/%s/*" % self.name
self.fetch = [RefSpec(True, 'refs/heads/*', dst)] self.fetch = [RefSpec(True, "refs/heads/*", dst)]
def Save(self): def Save(self):
"""Save this remote to the configuration. """Save this remote to the configuration."""
""" self._Set("url", self.url)
self._Set('url', self.url)
if self.pushUrl is not None: if self.pushUrl is not None:
self._Set('pushurl', self.pushUrl + '/' + self.projectname) self._Set("pushurl", self.pushUrl + "/" + self.projectname)
else: else:
self._Set('pushurl', self.pushUrl) self._Set("pushurl", self.pushUrl)
self._Set('review', self.review) self._Set("review", self.review)
self._Set('projectname', self.projectname) self._Set("projectname", self.projectname)
self._Set('fetch', list(map(str, self.fetch))) self._Set("fetch", list(map(str, self.fetch)))
def _Set(self, key, value): def _Set(self, key, value):
key = 'remote.%s.%s' % (self.name, key) key = "remote.%s.%s" % (self.name, key)
return self._config.SetString(key, value) return self._config.SetString(key, value)
def _Get(self, key, all_keys=False): def _Get(self, key, all_keys=False):
key = 'remote.%s.%s' % (self.name, key) key = "remote.%s.%s" % (self.name, key)
return self._config.GetString(key, all_keys=all_keys) return self._config.GetString(key, all_keys=all_keys)
class Branch(object): class Branch(object):
"""Configuration options related to a single branch. """Configuration options related to a single branch."""
"""
def __init__(self, config, name): def __init__(self, config, name):
self._config = config self._config = config
self.name = name self.name = name
self.merge = self._Get('merge') self.merge = self._Get("merge")
r = self._Get('remote') r = self._Get("remote")
if r: if r:
self.remote = self._config.GetRemote(r) self.remote = self._config.GetRemote(r)
else: else:
@ -722,36 +735,34 @@ class Branch(object):
@property @property
def LocalMerge(self): def LocalMerge(self):
"""Convert the merge spec to a local name. """Convert the merge spec to a local name."""
"""
if self.remote and self.merge: if self.remote and self.merge:
return self.remote.ToLocal(self.merge) return self.remote.ToLocal(self.merge)
return None return None
def Save(self): def Save(self):
"""Save this branch back into the configuration. """Save this branch back into the configuration."""
""" if self._config.HasSection("branch", self.name):
if self._config.HasSection('branch', self.name):
if self.remote: if self.remote:
self._Set('remote', self.remote.name) self._Set("remote", self.remote.name)
else: else:
self._Set('remote', None) self._Set("remote", None)
self._Set('merge', self.merge) self._Set("merge", self.merge)
else: else:
with open(self._config.file, 'a') as fd: with open(self._config.file, "a") as fd:
fd.write('[branch "%s"]\n' % self.name) fd.write('[branch "%s"]\n' % self.name)
if self.remote: if self.remote:
fd.write('\tremote = %s\n' % self.remote.name) fd.write("\tremote = %s\n" % self.remote.name)
if self.merge: if self.merge:
fd.write('\tmerge = %s\n' % self.merge) fd.write("\tmerge = %s\n" % self.merge)
def _Set(self, key, value): def _Set(self, key, value):
key = 'branch.%s.%s' % (self.name, key) key = "branch.%s.%s" % (self.name, key)
return self._config.SetString(key, value) return self._config.SetString(key, value)
def _Get(self, key, all_keys=False): def _Get(self, key, all_keys=False):
key = 'branch.%s.%s' % (self.name, key) key = "branch.%s.%s" % (self.name, key)
return self._config.GetString(key, all_keys=all_keys) return self._config.GetString(key, all_keys=all_keys)
@ -760,6 +771,7 @@ class SyncAnalysisState:
This object is versioned. This object is versioned.
""" """
def __init__(self, config, options, superproject_logging_data): def __init__(self, config, options, superproject_logging_data):
"""Initializes SyncAnalysisState. """Initializes SyncAnalysisState.
@ -773,23 +785,30 @@ class SyncAnalysisState:
Args: Args:
config: GitConfig object to store all options. config: GitConfig object to store all options.
options: Options passed to sync returned from optparse. See _Options(). options: Options passed to sync returned from optparse. See
superproject_logging_data: A dictionary of superproject data that is to be logged. _Options().
superproject_logging_data: A dictionary of superproject data that is
to be logged.
""" """
self._config = config self._config = config
now = datetime.datetime.utcnow() now = datetime.datetime.utcnow()
self._Set('main.synctime', now.isoformat() + 'Z') self._Set("main.synctime", now.isoformat() + "Z")
self._Set('main.version', '1') self._Set("main.version", "1")
self._Set('sys.argv', sys.argv) self._Set("sys.argv", sys.argv)
for key, value in superproject_logging_data.items(): for key, value in superproject_logging_data.items():
self._Set(f'superproject.{key}', value) self._Set(f"superproject.{key}", value)
for key, value in options.__dict__.items(): for key, value in options.__dict__.items():
self._Set(f'options.{key}', value) self._Set(f"options.{key}", value)
config_items = config.DumpConfigDict().items() config_items = config.DumpConfigDict().items()
EXTRACT_NAMESPACES = {'repo', 'branch', 'remote'} EXTRACT_NAMESPACES = {"repo", "branch", "remote"}
self._SetDictionary({k: v for k, v in config_items self._SetDictionary(
if not k.startswith(SYNC_STATE_PREFIX) and {
k.split('.', 1)[0] in EXTRACT_NAMESPACES}) k: v
for k, v in config_items
if not k.startswith(SYNC_STATE_PREFIX)
and k.split(".", 1)[0] in EXTRACT_NAMESPACES
}
)
def _SetDictionary(self, data): def _SetDictionary(self, data):
"""Save all key/value pairs of |data| dictionary. """Save all key/value pairs of |data| dictionary.
@ -807,13 +826,14 @@ class SyncAnalysisState:
Args: Args:
key: Name of the key. key: Name of the key.
value: |value| could be of any type. If it is 'bool', it will be saved value: |value| could be of any type. If it is 'bool', it will be
as a Boolean and for all other types, it will be saved as a String. saved as a Boolean and for all other types, it will be saved as
a String.
""" """
if value is None: if value is None:
return return
sync_key = f'{SYNC_STATE_PREFIX}{key}' sync_key = f"{SYNC_STATE_PREFIX}{key}"
sync_key = sync_key.replace('_', '') sync_key = sync_key.replace("_", "")
if isinstance(value, str): if isinstance(value, str):
self._config.SetString(sync_key, value) self._config.SetString(sync_key, value)
elif isinstance(value, bool): elif isinstance(value, bool):

View File

@ -16,14 +16,14 @@ import os
from repo_trace import Trace from repo_trace import Trace
import platform_utils import platform_utils
HEAD = 'HEAD' HEAD = "HEAD"
R_CHANGES = 'refs/changes/' R_CHANGES = "refs/changes/"
R_HEADS = 'refs/heads/' R_HEADS = "refs/heads/"
R_TAGS = 'refs/tags/' R_TAGS = "refs/tags/"
R_PUB = 'refs/published/' R_PUB = "refs/published/"
R_WORKTREE = 'refs/worktree/' R_WORKTREE = "refs/worktree/"
R_WORKTREE_M = R_WORKTREE + 'm/' R_WORKTREE_M = R_WORKTREE + "m/"
R_M = 'refs/remotes/m/' R_M = "refs/remotes/m/"
class GitRefs(object): class GitRefs(object):
@ -42,7 +42,7 @@ class GitRefs(object):
try: try:
return self.all[name] return self.all[name]
except KeyError: except KeyError:
return '' return ""
def deleted(self, name): def deleted(self, name):
if self._phyref is not None: if self._phyref is not None:
@ -60,31 +60,32 @@ class GitRefs(object):
self._EnsureLoaded() self._EnsureLoaded()
return self._symref[name] return self._symref[name]
except KeyError: except KeyError:
return '' return ""
def _EnsureLoaded(self): def _EnsureLoaded(self):
if self._phyref is None or self._NeedUpdate(): if self._phyref is None or self._NeedUpdate():
self._LoadAll() self._LoadAll()
def _NeedUpdate(self): def _NeedUpdate(self):
with Trace(': scan refs %s', self._gitdir): with Trace(": scan refs %s", self._gitdir):
for name, mtime in self._mtime.items(): for name, mtime in self._mtime.items():
try: try:
if mtime != os.path.getmtime(os.path.join(self._gitdir, name)): if mtime != os.path.getmtime(
os.path.join(self._gitdir, name)
):
return True return True
except OSError: except OSError:
return True return True
return False return False
def _LoadAll(self): def _LoadAll(self):
with Trace(': load refs %s', self._gitdir): with Trace(": load refs %s", self._gitdir):
self._phyref = {} self._phyref = {}
self._symref = {} self._symref = {}
self._mtime = {} self._mtime = {}
self._ReadPackedRefs() self._ReadPackedRefs()
self._ReadLoose('refs/') self._ReadLoose("refs/")
self._ReadLoose1(os.path.join(self._gitdir, HEAD), HEAD) self._ReadLoose1(os.path.join(self._gitdir, HEAD), HEAD)
scan = self._symref scan = self._symref
@ -100,9 +101,9 @@ class GitRefs(object):
attempts += 1 attempts += 1
def _ReadPackedRefs(self): def _ReadPackedRefs(self):
path = os.path.join(self._gitdir, 'packed-refs') path = os.path.join(self._gitdir, "packed-refs")
try: try:
fd = open(path, 'r') fd = open(path, "r")
mtime = os.path.getmtime(path) mtime = os.path.getmtime(path)
except IOError: except IOError:
return return
@ -111,33 +112,33 @@ class GitRefs(object):
try: try:
for line in fd: for line in fd:
line = str(line) line = str(line)
if line[0] == '#': if line[0] == "#":
continue continue
if line[0] == '^': if line[0] == "^":
continue continue
line = line[:-1] line = line[:-1]
p = line.split(' ') p = line.split(" ")
ref_id = p[0] ref_id = p[0]
name = p[1] name = p[1]
self._phyref[name] = ref_id self._phyref[name] = ref_id
finally: finally:
fd.close() fd.close()
self._mtime['packed-refs'] = mtime self._mtime["packed-refs"] = mtime
def _ReadLoose(self, prefix): def _ReadLoose(self, prefix):
base = os.path.join(self._gitdir, prefix) base = os.path.join(self._gitdir, prefix)
for name in platform_utils.listdir(base): for name in platform_utils.listdir(base):
p = os.path.join(base, name) p = os.path.join(base, name)
# We don't implement the full ref validation algorithm, just the simple # We don't implement the full ref validation algorithm, just the
# rules that would show up in local filesystems. # simple rules that would show up in local filesystems.
# https://git-scm.com/docs/git-check-ref-format # https://git-scm.com/docs/git-check-ref-format
if name.startswith('.') or name.endswith('.lock'): if name.startswith(".") or name.endswith(".lock"):
pass pass
elif platform_utils.isdir(p): elif platform_utils.isdir(p):
self._mtime[prefix] = os.path.getmtime(base) self._mtime[prefix] = os.path.getmtime(base)
self._ReadLoose(prefix + name + '/') self._ReadLoose(prefix + name + "/")
else: else:
self._ReadLoose1(p, prefix + name) self._ReadLoose1(p, prefix + name)
@ -157,7 +158,7 @@ class GitRefs(object):
return return
ref_id = ref_id[:-1] ref_id = ref_id[:-1]
if ref_id.startswith('ref: '): if ref_id.startswith("ref: "):
self._symref[name] = ref_id[5:] self._symref[name] = ref_id[5:]
else: else:
self._phyref[name] = ref_id self._phyref[name] = ref_id

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Provide functionality to get all projects and their commit ids from Superproject. """Provide functionality to get projects and their commit ids from Superproject.
For more information on superproject, check out: For more information on superproject, check out:
https://en.wikibooks.org/wiki/Git/Submodules_and_Superprojects https://en.wikibooks.org/wiki/Git/Submodules_and_Superprojects
@ -33,8 +33,8 @@ from git_command import git_require, GitCommand
from git_config import RepoConfig from git_config import RepoConfig
from git_refs import GitRefs from git_refs import GitRefs
_SUPERPROJECT_GIT_NAME = 'superproject.git' _SUPERPROJECT_GIT_NAME = "superproject.git"
_SUPERPROJECT_MANIFEST_NAME = 'superproject_override.xml' _SUPERPROJECT_MANIFEST_NAME = "superproject_override.xml"
class SyncResult(NamedTuple): class SyncResult(NamedTuple):
@ -71,8 +71,15 @@ class Superproject(object):
lookup of commit ids for all projects. It contains _project_commit_ids which lookup of commit ids for all projects. It contains _project_commit_ids which
is a dictionary with project/commit id entries. is a dictionary with project/commit id entries.
""" """
def __init__(self, manifest, name, remote, revision,
superproject_dir='exp-superproject'): def __init__(
self,
manifest,
name,
remote,
revision,
superproject_dir="exp-superproject",
):
"""Initializes superproject. """Initializes superproject.
Args: Args:
@ -90,19 +97,23 @@ class Superproject(object):
self.revision = self._branch = revision self.revision = self._branch = revision
self._repodir = manifest.repodir self._repodir = manifest.repodir
self._superproject_dir = superproject_dir self._superproject_dir = superproject_dir
self._superproject_path = manifest.SubmanifestInfoDir(manifest.path_prefix, self._superproject_path = manifest.SubmanifestInfoDir(
superproject_dir) manifest.path_prefix, superproject_dir
self._manifest_path = os.path.join(self._superproject_path, )
_SUPERPROJECT_MANIFEST_NAME) self._manifest_path = os.path.join(
git_name = hashlib.md5(remote.name.encode('utf8')).hexdigest() + '-' self._superproject_path, _SUPERPROJECT_MANIFEST_NAME
)
git_name = hashlib.md5(remote.name.encode("utf8")).hexdigest() + "-"
self._remote_url = remote.url self._remote_url = remote.url
self._work_git_name = git_name + _SUPERPROJECT_GIT_NAME self._work_git_name = git_name + _SUPERPROJECT_GIT_NAME
self._work_git = os.path.join(self._superproject_path, self._work_git_name) self._work_git = os.path.join(
self._superproject_path, self._work_git_name
)
# The following are command arguemnts, rather than superproject attributes, # The following are command arguemnts, rather than superproject
# and were included here originally. They should eventually become # attributes, and were included here originally. They should eventually
# arguments that are passed down from the public methods, instead of being # become arguments that are passed down from the public methods, instead
# treated as attributes. # of being treated as attributes.
self._git_event_log = None self._git_event_log = None
self._quiet = False self._quiet = False
self._print_messages = False self._print_messages = False
@ -123,26 +134,30 @@ class Superproject(object):
@property @property
def manifest_path(self): def manifest_path(self):
"""Returns the manifest path if the path exists or None.""" """Returns the manifest path if the path exists or None."""
return self._manifest_path if os.path.exists(self._manifest_path) else None return (
self._manifest_path if os.path.exists(self._manifest_path) else None
)
def _LogMessage(self, fmt, *inputs): def _LogMessage(self, fmt, *inputs):
"""Logs message to stderr and _git_event_log.""" """Logs message to stderr and _git_event_log."""
message = f'{self._LogMessagePrefix()} {fmt.format(*inputs)}' message = f"{self._LogMessagePrefix()} {fmt.format(*inputs)}"
if self._print_messages: if self._print_messages:
print(message, file=sys.stderr) print(message, file=sys.stderr)
self._git_event_log.ErrorEvent(message, fmt) self._git_event_log.ErrorEvent(message, fmt)
def _LogMessagePrefix(self): def _LogMessagePrefix(self):
"""Returns the prefix string to be logged in each log message""" """Returns the prefix string to be logged in each log message"""
return f'repo superproject branch: {self._branch} url: {self._remote_url}' return (
f"repo superproject branch: {self._branch} url: {self._remote_url}"
)
def _LogError(self, fmt, *inputs): def _LogError(self, fmt, *inputs):
"""Logs error message to stderr and _git_event_log.""" """Logs error message to stderr and _git_event_log."""
self._LogMessage(f'error: {fmt}', *inputs) self._LogMessage(f"error: {fmt}", *inputs)
def _LogWarning(self, fmt, *inputs): def _LogWarning(self, fmt, *inputs):
"""Logs warning message to stderr and _git_event_log.""" """Logs warning message to stderr and _git_event_log."""
self._LogMessage(f'warning: {fmt}', *inputs) self._LogMessage(f"warning: {fmt}", *inputs)
def _Init(self): def _Init(self):
"""Sets up a local Git repository to get a copy of a superproject. """Sets up a local Git repository to get a copy of a superproject.
@ -153,56 +168,84 @@ class Superproject(object):
if not os.path.exists(self._superproject_path): if not os.path.exists(self._superproject_path):
os.mkdir(self._superproject_path) os.mkdir(self._superproject_path)
if not self._quiet and not os.path.exists(self._work_git): if not self._quiet and not os.path.exists(self._work_git):
print('%s: Performing initial setup for superproject; this might take ' print(
'several minutes.' % self._work_git) "%s: Performing initial setup for superproject; this might "
cmd = ['init', '--bare', self._work_git_name] "take several minutes." % self._work_git
p = GitCommand(None, )
cmd = ["init", "--bare", self._work_git_name]
p = GitCommand(
None,
cmd, cmd,
cwd=self._superproject_path, cwd=self._superproject_path,
capture_stdout=True, capture_stdout=True,
capture_stderr=True) capture_stderr=True,
)
retval = p.Wait() retval = p.Wait()
if retval: if retval:
self._LogWarning('git init call failed, command: git {}, ' self._LogWarning(
'return code: {}, stderr: {}', cmd, retval, p.stderr) "git init call failed, command: git {}, "
"return code: {}, stderr: {}",
cmd,
retval,
p.stderr,
)
return False return False
return True return True
def _Fetch(self): def _Fetch(self):
"""Fetches a local copy of a superproject for the manifest based on |_remote_url|. """Fetches a superproject for the manifest based on |_remote_url|.
This runs git fetch which stores a local copy the superproject.
Returns: Returns:
True if fetch is successful, or False. True if fetch is successful, or False.
""" """
if not os.path.exists(self._work_git): if not os.path.exists(self._work_git):
self._LogWarning('git fetch missing directory: {}', self._work_git) self._LogWarning("git fetch missing directory: {}", self._work_git)
return False return False
if not git_require((2, 28, 0)): if not git_require((2, 28, 0)):
self._LogWarning('superproject requires a git version 2.28 or later') self._LogWarning(
"superproject requires a git version 2.28 or later"
)
return False return False
cmd = ['fetch', self._remote_url, '--depth', '1', '--force', '--no-tags', cmd = [
'--filter', 'blob:none'] "fetch",
self._remote_url,
"--depth",
"1",
"--force",
"--no-tags",
"--filter",
"blob:none",
]
# Check if there is a local ref that we can pass to --negotiation-tip. # Check if there is a local ref that we can pass to --negotiation-tip.
# If this is the first fetch, it does not exist yet. # If this is the first fetch, it does not exist yet.
# We use --negotiation-tip to speed up the fetch. Superproject branches do # We use --negotiation-tip to speed up the fetch. Superproject branches
# not share commits. So this lets git know it only needs to send commits # do not share commits. So this lets git know it only needs to send
# reachable from the specified local refs. # commits reachable from the specified local refs.
rev_commit = GitRefs(self._work_git).get(f'refs/heads/{self.revision}') rev_commit = GitRefs(self._work_git).get(f"refs/heads/{self.revision}")
if rev_commit: if rev_commit:
cmd.extend(['--negotiation-tip', rev_commit]) cmd.extend(["--negotiation-tip", rev_commit])
if self._branch: if self._branch:
cmd += [self._branch + ':' + self._branch] cmd += [self._branch + ":" + self._branch]
p = GitCommand(None, p = GitCommand(
None,
cmd, cmd,
cwd=self._work_git, cwd=self._work_git,
capture_stdout=True, capture_stdout=True,
capture_stderr=True) capture_stderr=True,
)
retval = p.Wait() retval = p.Wait()
if retval: if retval:
self._LogWarning('git fetch call failed, command: git {}, ' self._LogWarning(
'return code: {}, stderr: {}', cmd, retval, p.stderr) "git fetch call failed, command: git {}, "
"return code: {}, stderr: {}",
cmd,
retval,
p.stderr,
)
return False return False
return True return True
@ -215,23 +258,32 @@ class Superproject(object):
data: data returned from 'git ls-tree ...' instead of None. data: data returned from 'git ls-tree ...' instead of None.
""" """
if not os.path.exists(self._work_git): if not os.path.exists(self._work_git):
self._LogWarning('git ls-tree missing directory: {}', self._work_git) self._LogWarning(
"git ls-tree missing directory: {}", self._work_git
)
return None return None
data = None data = None
branch = 'HEAD' if not self._branch else self._branch branch = "HEAD" if not self._branch else self._branch
cmd = ['ls-tree', '-z', '-r', branch] cmd = ["ls-tree", "-z", "-r", branch]
p = GitCommand(None, p = GitCommand(
None,
cmd, cmd,
cwd=self._work_git, cwd=self._work_git,
capture_stdout=True, capture_stdout=True,
capture_stderr=True) capture_stderr=True,
)
retval = p.Wait() retval = p.Wait()
if retval == 0: if retval == 0:
data = p.stdout data = p.stdout
else: else:
self._LogWarning('git ls-tree call failed, command: git {}, ' self._LogWarning(
'return code: {}, stderr: {}', cmd, retval, p.stderr) "git ls-tree call failed, command: git {}, "
"return code: {}, stderr: {}",
cmd,
retval,
p.stderr,
)
return data return data
def Sync(self, git_event_log): def Sync(self, git_event_log):
@ -245,16 +297,20 @@ class Superproject(object):
""" """
self._git_event_log = git_event_log self._git_event_log = git_event_log
if not self._manifest.superproject: if not self._manifest.superproject:
self._LogWarning('superproject tag is not defined in manifest: {}', self._LogWarning(
self._manifest.manifestFile) "superproject tag is not defined in manifest: {}",
self._manifest.manifestFile,
)
return SyncResult(False, False) return SyncResult(False, False)
_PrintBetaNotice() _PrintBetaNotice()
should_exit = True should_exit = True
if not self._remote_url: if not self._remote_url:
self._LogWarning('superproject URL is not defined in manifest: {}', self._LogWarning(
self._manifest.manifestFile) "superproject URL is not defined in manifest: {}",
self._manifest.manifestFile,
)
return SyncResult(False, should_exit) return SyncResult(False, should_exit)
if not self._Init(): if not self._Init():
@ -262,11 +318,15 @@ class Superproject(object):
if not self._Fetch(): if not self._Fetch():
return SyncResult(False, should_exit) return SyncResult(False, should_exit)
if not self._quiet: if not self._quiet:
print('%s: Initial setup for superproject completed.' % self._work_git) print(
"%s: Initial setup for superproject completed." % self._work_git
)
return SyncResult(True, False) return SyncResult(True, False)
def _GetAllProjectsCommitIds(self): def _GetAllProjectsCommitIds(self):
"""Get commit ids for all projects from superproject and save them in _project_commit_ids. """Get commit ids for all projects from superproject and save them.
Commit ids are saved in _project_commit_ids.
Returns: Returns:
CommitIdsResult CommitIdsResult
@ -277,21 +337,24 @@ class Superproject(object):
data = self._LsTree() data = self._LsTree()
if not data: if not data:
self._LogWarning('git ls-tree failed to return data for manifest: {}', self._LogWarning(
self._manifest.manifestFile) "git ls-tree failed to return data for manifest: {}",
self._manifest.manifestFile,
)
return CommitIdsResult(None, True) return CommitIdsResult(None, True)
# Parse lines like the following to select lines starting with '160000' and # Parse lines like the following to select lines starting with '160000'
# build a dictionary with project path (last element) and its commit id (3rd element). # and build a dictionary with project path (last element) and its commit
# id (3rd element).
# #
# 160000 commit 2c2724cb36cd5a9cec6c852c681efc3b7c6b86ea\tart\x00 # 160000 commit 2c2724cb36cd5a9cec6c852c681efc3b7c6b86ea\tart\x00
# 120000 blob acc2cbdf438f9d2141f0ae424cec1d8fc4b5d97f\tbootstrap.bash\x00 # 120000 blob acc2cbdf438f9d2141f0ae424cec1d8fc4b5d97f\tbootstrap.bash\x00 # noqa: E501
commit_ids = {} commit_ids = {}
for line in data.split('\x00'): for line in data.split("\x00"):
ls_data = line.split(None, 3) ls_data = line.split(None, 3)
if not ls_data: if not ls_data:
break break
if ls_data[0] == '160000': if ls_data[0] == "160000":
commit_ids[ls_data[3]] = ls_data[2] commit_ids[ls_data[3]] = ls_data[2]
self._project_commit_ids = commit_ids self._project_commit_ids = commit_ids
@ -301,20 +364,23 @@ class Superproject(object):
"""Writes manifest to a file. """Writes manifest to a file.
Returns: Returns:
manifest_path: Path name of the file into which manifest is written instead of None. manifest_path: Path name of the file into which manifest is written
instead of None.
""" """
if not os.path.exists(self._superproject_path): if not os.path.exists(self._superproject_path):
self._LogWarning('missing superproject directory: {}', self._superproject_path) self._LogWarning(
"missing superproject directory: {}", self._superproject_path
)
return None return None
manifest_str = self._manifest.ToXml(groups=self._manifest.GetGroupsStr(), manifest_str = self._manifest.ToXml(
omit_local=True).toxml() groups=self._manifest.GetGroupsStr(), omit_local=True
).toxml()
manifest_path = self._manifest_path manifest_path = self._manifest_path
try: try:
with open(manifest_path, 'w', encoding='utf-8') as fp: with open(manifest_path, "w", encoding="utf-8") as fp:
fp.write(manifest_str) fp.write(manifest_str)
except IOError as e: except IOError as e:
self._LogError('cannot write manifest to : {} {}', self._LogError("cannot write manifest to : {} {}", manifest_path, e)
manifest_path, e)
return None return None
return manifest_path return manifest_path
@ -366,9 +432,12 @@ class Superproject(object):
# If superproject doesn't have a commit id for a project, then report an # If superproject doesn't have a commit id for a project, then report an
# error event and continue as if do not use superproject is specified. # error event and continue as if do not use superproject is specified.
if projects_missing_commit_ids: if projects_missing_commit_ids:
self._LogWarning('please file a bug using {} to report missing ' self._LogWarning(
'commit_ids for: {}', self._manifest.contactinfo.bugurl, "please file a bug using {} to report missing "
projects_missing_commit_ids) "commit_ids for: {}",
self._manifest.contactinfo.bugurl,
projects_missing_commit_ids,
)
return UpdateProjectsResult(None, False) return UpdateProjectsResult(None, False)
for project in projects: for project in projects:
@ -382,8 +451,11 @@ class Superproject(object):
@functools.lru_cache(maxsize=10) @functools.lru_cache(maxsize=10)
def _PrintBetaNotice(): def _PrintBetaNotice():
"""Print the notice of beta status.""" """Print the notice of beta status."""
print('NOTICE: --use-superproject is in beta; report any issues to the ' print(
'address described in `repo version`', file=sys.stderr) "NOTICE: --use-superproject is in beta; report any issues to the "
"address described in `repo version`",
file=sys.stderr,
)
@functools.lru_cache(maxsize=None) @functools.lru_cache(maxsize=None)
@ -392,25 +464,39 @@ def _UseSuperprojectFromConfiguration():
user_cfg = RepoConfig.ForUser() user_cfg = RepoConfig.ForUser()
time_now = int(time.time()) time_now = int(time.time())
user_value = user_cfg.GetBoolean('repo.superprojectChoice') user_value = user_cfg.GetBoolean("repo.superprojectChoice")
if user_value is not None: if user_value is not None:
user_expiration = user_cfg.GetInt('repo.superprojectChoiceExpire') user_expiration = user_cfg.GetInt("repo.superprojectChoiceExpire")
if user_expiration is None or user_expiration <= 0 or user_expiration >= time_now: if (
# TODO(b/190688390) - Remove prompt when we are comfortable with the new user_expiration is None
# default value. or user_expiration <= 0
or user_expiration >= time_now
):
# TODO(b/190688390) - Remove prompt when we are comfortable with the
# new default value.
if user_value: if user_value:
print(('You are currently enrolled in Git submodules experiment ' print(
'(go/android-submodules-quickstart). Use --no-use-superproject ' (
'to override.\n'), file=sys.stderr) "You are currently enrolled in Git submodules "
"experiment (go/android-submodules-quickstart). Use "
"--no-use-superproject to override.\n"
),
file=sys.stderr,
)
else: else:
print(('You are not currently enrolled in Git submodules experiment ' print(
'(go/android-submodules-quickstart). Use --use-superproject ' (
'to override.\n'), file=sys.stderr) "You are not currently enrolled in Git submodules "
"experiment (go/android-submodules-quickstart). Use "
"--use-superproject to override.\n"
),
file=sys.stderr,
)
return user_value return user_value
# We don't have an unexpired choice, ask for one. # We don't have an unexpired choice, ask for one.
system_cfg = RepoConfig.ForSystem() system_cfg = RepoConfig.ForSystem()
system_value = system_cfg.GetBoolean('repo.superprojectChoice') system_value = system_cfg.GetBoolean("repo.superprojectChoice")
if system_value: if system_value:
# The system configuration is proposing that we should enable the # The system configuration is proposing that we should enable the
# use of superproject. Treat the user as enrolled for two weeks. # use of superproject. Treat the user as enrolled for two weeks.
@ -419,11 +505,15 @@ def _UseSuperprojectFromConfiguration():
# default value. # default value.
userchoice = True userchoice = True
time_choiceexpire = time_now + (86400 * 14) time_choiceexpire = time_now + (86400 * 14)
user_cfg.SetString('repo.superprojectChoiceExpire', str(time_choiceexpire)) user_cfg.SetString(
user_cfg.SetBoolean('repo.superprojectChoice', userchoice) "repo.superprojectChoiceExpire", str(time_choiceexpire)
print('You are automatically enrolled in Git submodules experiment ' )
'(go/android-submodules-quickstart) for another two weeks.\n', user_cfg.SetBoolean("repo.superprojectChoice", userchoice)
file=sys.stderr) print(
"You are automatically enrolled in Git submodules experiment "
"(go/android-submodules-quickstart) for another two weeks.\n",
file=sys.stderr,
)
return True return True
# For all other cases, we would not use superproject by default. # For all other cases, we would not use superproject by default.

View File

@ -48,7 +48,8 @@ class EventLog(object):
Each entry contains the following common keys: Each entry contains the following common keys:
- event: The event name - event: The event name
- sid: session-id - Unique string to allow process instance to be identified. - sid: session-id - Unique string to allow process instance to be
identified.
- thread: The thread name. - thread: The thread name.
- time: is the UTC time of the event. - time: is the UTC time of the event.
@ -60,7 +61,7 @@ class EventLog(object):
"""Initializes the event log.""" """Initializes the event log."""
self._log = [] self._log = []
# Try to get session-id (sid) from environment (setup in repo launcher). # Try to get session-id (sid) from environment (setup in repo launcher).
KEY = 'GIT_TRACE2_PARENT_SID' KEY = "GIT_TRACE2_PARENT_SID"
if env is None: if env is None:
env = os.environ env = os.environ
@ -69,11 +70,14 @@ class EventLog(object):
# Save both our sid component and the complete sid. # Save both our sid component and the complete sid.
# We use our sid component (self._sid) as the unique filename prefix and # We use our sid component (self._sid) as the unique filename prefix and
# the full sid (self._full_sid) in the log itself. # the full sid (self._full_sid) in the log itself.
self._sid = 'repo-%s-P%08x' % (now.strftime('%Y%m%dT%H%M%SZ'), os.getpid()) self._sid = "repo-%s-P%08x" % (
now.strftime("%Y%m%dT%H%M%SZ"),
os.getpid(),
)
parent_sid = env.get(KEY) parent_sid = env.get(KEY)
# Append our sid component to the parent sid (if it exists). # Append our sid component to the parent sid (if it exists).
if parent_sid is not None: if parent_sid is not None:
self._full_sid = parent_sid + '/' + self._sid self._full_sid = parent_sid + "/" + self._sid
else: else:
self._full_sid = self._sid self._full_sid = self._sid
@ -93,13 +97,13 @@ class EventLog(object):
def _AddVersionEvent(self): def _AddVersionEvent(self):
"""Adds a 'version' event at the beginning of current log.""" """Adds a 'version' event at the beginning of current log."""
version_event = self._CreateEventDict('version') version_event = self._CreateEventDict("version")
version_event['evt'] = "2" version_event["evt"] = "2"
version_event['exe'] = RepoSourceVersion() version_event["exe"] = RepoSourceVersion()
self._log.insert(0, version_event) self._log.insert(0, version_event)
def _CreateEventDict(self, event_name): def _CreateEventDict(self, event_name):
"""Returns a dictionary with the common keys/values for git trace2 events. """Returns a dictionary with common keys/values for git trace2 events.
Args: Args:
event_name: The event name. event_name: The event name.
@ -108,16 +112,16 @@ class EventLog(object):
Dictionary with the common event fields populated. Dictionary with the common event fields populated.
""" """
return { return {
'event': event_name, "event": event_name,
'sid': self._full_sid, "sid": self._full_sid,
'thread': threading.current_thread().name, "thread": threading.current_thread().name,
'time': datetime.datetime.utcnow().isoformat() + 'Z', "time": datetime.datetime.utcnow().isoformat() + "Z",
} }
def StartEvent(self): def StartEvent(self):
"""Append a 'start' event to the current log.""" """Append a 'start' event to the current log."""
start_event = self._CreateEventDict('start') start_event = self._CreateEventDict("start")
start_event['argv'] = sys.argv start_event["argv"] = sys.argv
self._log.append(start_event) self._log.append(start_event)
def ExitEvent(self, result): def ExitEvent(self, result):
@ -126,12 +130,12 @@ class EventLog(object):
Args: Args:
result: Exit code of the event result: Exit code of the event
""" """
exit_event = self._CreateEventDict('exit') exit_event = self._CreateEventDict("exit")
# Consider 'None' success (consistent with event_log result handling). # Consider 'None' success (consistent with event_log result handling).
if result is None: if result is None:
result = 0 result = 0
exit_event['code'] = result exit_event["code"] = result
self._log.append(exit_event) self._log.append(exit_event)
def CommandEvent(self, name, subcommands): def CommandEvent(self, name, subcommands):
@ -141,9 +145,9 @@ class EventLog(object):
name: Name of the primary command (ex: repo, git) name: Name of the primary command (ex: repo, git)
subcommands: List of the sub-commands (ex: version, init, sync) subcommands: List of the sub-commands (ex: version, init, sync)
""" """
command_event = self._CreateEventDict('command') command_event = self._CreateEventDict("command")
command_event['name'] = name command_event["name"] = name
command_event['subcommands'] = subcommands command_event["subcommands"] = subcommands
self._log.append(command_event) self._log.append(command_event)
def LogConfigEvents(self, config, event_dict_name): def LogConfigEvents(self, config, event_dict_name):
@ -151,33 +155,36 @@ class EventLog(object):
Args: Args:
config: Configuration dictionary. config: Configuration dictionary.
event_dict_name: Name of the event dictionary for items to be logged under. event_dict_name: Name of the event dictionary for items to be logged
under.
""" """
for param, value in config.items(): for param, value in config.items():
event = self._CreateEventDict(event_dict_name) event = self._CreateEventDict(event_dict_name)
event['param'] = param event["param"] = param
event['value'] = value event["value"] = value
self._log.append(event) self._log.append(event)
def DefParamRepoEvents(self, config): def DefParamRepoEvents(self, config):
"""Append a 'def_param' event for each repo.* config key to the current log. """Append 'def_param' events for repo config keys to the current log.
This appends one event for each repo.* config key.
Args: Args:
config: Repo configuration dictionary config: Repo configuration dictionary
""" """
# Only output the repo.* config parameters. # Only output the repo.* config parameters.
repo_config = {k: v for k, v in config.items() if k.startswith('repo.')} repo_config = {k: v for k, v in config.items() if k.startswith("repo.")}
self.LogConfigEvents(repo_config, 'def_param') self.LogConfigEvents(repo_config, "def_param")
def GetDataEventName(self, value): def GetDataEventName(self, value):
"""Returns 'data-json' if the value is an array else returns 'data'.""" """Returns 'data-json' if the value is an array else returns 'data'."""
return 'data-json' if value[0] == '[' and value[-1] == ']' else 'data' return "data-json" if value[0] == "[" and value[-1] == "]" else "data"
def LogDataConfigEvents(self, config, prefix): def LogDataConfigEvents(self, config, prefix):
"""Append a 'data' event for each config key/value in |config| to the current log. """Append a 'data' event for each entry in |config| to the current log.
For each keyX and valueX of the config, "key" field of the event is '|prefix|/keyX' For each keyX and valueX of the config, "key" field of the event is
and the "value" of the "key" field is valueX. '|prefix|/keyX' and the "value" of the "key" field is valueX.
Args: Args:
config: Configuration dictionary. config: Configuration dictionary.
@ -185,15 +192,15 @@ class EventLog(object):
""" """
for key, value in config.items(): for key, value in config.items():
event = self._CreateEventDict(self.GetDataEventName(value)) event = self._CreateEventDict(self.GetDataEventName(value))
event['key'] = f'{prefix}/{key}' event["key"] = f"{prefix}/{key}"
event['value'] = value event["value"] = value
self._log.append(event) self._log.append(event)
def ErrorEvent(self, msg, fmt): def ErrorEvent(self, msg, fmt):
"""Append a 'error' event to the current log.""" """Append a 'error' event to the current log."""
error_event = self._CreateEventDict('error') error_event = self._CreateEventDict("error")
error_event['msg'] = msg error_event["msg"] = msg
error_event['fmt'] = fmt error_event["fmt"] = fmt
self._log.append(error_event) self._log.append(error_event)
def _GetEventTargetPath(self): def _GetEventTargetPath(self):
@ -203,38 +210,48 @@ class EventLog(object):
path: git config's 'trace2.eventtarget' path if it exists, or None path: git config's 'trace2.eventtarget' path if it exists, or None
""" """
path = None path = None
cmd = ['config', '--get', 'trace2.eventtarget'] cmd = ["config", "--get", "trace2.eventtarget"]
# TODO(https://crbug.com/gerrit/13706): Use GitConfig when it supports # TODO(https://crbug.com/gerrit/13706): Use GitConfig when it supports
# system git config variables. # system git config variables.
p = GitCommand(None, cmd, capture_stdout=True, capture_stderr=True, p = GitCommand(
bare=True) None, cmd, capture_stdout=True, capture_stderr=True, bare=True
)
retval = p.Wait() retval = p.Wait()
if retval == 0: if retval == 0:
# Strip trailing carriage-return in path. # Strip trailing carriage-return in path.
path = p.stdout.rstrip('\n') path = p.stdout.rstrip("\n")
elif retval != 1: elif retval != 1:
# `git config --get` is documented to produce an exit status of `1` if # `git config --get` is documented to produce an exit status of `1`
# the requested variable is not present in the configuration. Report any # if the requested variable is not present in the configuration.
# other return value as an error. # Report any other return value as an error.
print("repo: error: 'git config --get' call failed with return code: %r, stderr: %r" % ( print(
retval, p.stderr), file=sys.stderr) "repo: error: 'git config --get' call failed with return code: "
"%r, stderr: %r" % (retval, p.stderr),
file=sys.stderr,
)
return path return path
def _WriteLog(self, write_fn): def _WriteLog(self, write_fn):
"""Writes the log out using a provided writer function. """Writes the log out using a provided writer function.
Generate compact JSON output for each item in the log, and write it using Generate compact JSON output for each item in the log, and write it
write_fn. using write_fn.
Args: Args:
write_fn: A function that accepts byts and writes them to a destination. write_fn: A function that accepts byts and writes them to a
destination.
""" """
for e in self._log: for e in self._log:
# Dump in compact encoding mode. # Dump in compact encoding mode.
# See 'Compact encoding' in Python docs: # See 'Compact encoding' in Python docs:
# https://docs.python.org/3/library/json.html#module-json # https://docs.python.org/3/library/json.html#module-json
write_fn(json.dumps(e, indent=None, separators=(',', ':')).encode('utf-8') + b'\n') write_fn(
json.dumps(e, indent=None, separators=(",", ":")).encode(
"utf-8"
)
+ b"\n"
)
def Write(self, path=None): def Write(self, path=None):
"""Writes the log out to a file or socket. """Writes the log out to a file or socket.
@ -246,16 +263,19 @@ class EventLog(object):
(exclusive writable) file. (exclusive writable) file.
Args: Args:
path: Path to where logs should be written. The path may have a prefix of path: Path to where logs should be written. The path may have a
the form "af_unix:[{stream|dgram}:]", in which case the path is prefix of the form "af_unix:[{stream|dgram}:]", in which case
treated as a Unix domain socket. See the path is treated as a Unix domain socket. See
https://git-scm.com/docs/api-trace2#_enabling_a_target for details. https://git-scm.com/docs/api-trace2#_enabling_a_target for
details.
Returns: Returns:
log_path: Path to the log file or socket if log is written, otherwise None log_path: Path to the log file or socket if log is written,
otherwise None
""" """
log_path = None log_path = None
# If no logging path is specified, get the path from 'trace2.eventtarget'. # If no logging path is specified, get the path from
# 'trace2.eventtarget'.
if path is None: if path is None:
path = self._GetEventTargetPath() path = self._GetEventTargetPath()
@ -266,22 +286,22 @@ class EventLog(object):
path_is_socket = False path_is_socket = False
socket_type = None socket_type = None
if isinstance(path, str): if isinstance(path, str):
parts = path.split(':', 1) parts = path.split(":", 1)
if parts[0] == 'af_unix' and len(parts) == 2: if parts[0] == "af_unix" and len(parts) == 2:
path_is_socket = True path_is_socket = True
path = parts[1] path = parts[1]
parts = path.split(':', 1) parts = path.split(":", 1)
if parts[0] == 'stream' and len(parts) == 2: if parts[0] == "stream" and len(parts) == 2:
socket_type = socket.SOCK_STREAM socket_type = socket.SOCK_STREAM
path = parts[1] path = parts[1]
elif parts[0] == 'dgram' and len(parts) == 2: elif parts[0] == "dgram" and len(parts) == 2:
socket_type = socket.SOCK_DGRAM socket_type = socket.SOCK_DGRAM
path = parts[1] path = parts[1]
else: else:
# Get absolute path. # Get absolute path.
path = os.path.abspath(os.path.expanduser(path)) path = os.path.abspath(os.path.expanduser(path))
else: else:
raise TypeError('path: str required but got %s.' % type(path)) raise TypeError("path: str required but got %s." % type(path))
# Git trace2 requires a directory to write log to. # Git trace2 requires a directory to write log to.
@ -292,40 +312,59 @@ class EventLog(object):
if path_is_socket: if path_is_socket:
if socket_type == socket.SOCK_STREAM or socket_type is None: if socket_type == socket.SOCK_STREAM or socket_type is None:
try: try:
with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as sock: with socket.socket(
socket.AF_UNIX, socket.SOCK_STREAM
) as sock:
sock.connect(path) sock.connect(path)
self._WriteLog(sock.sendall) self._WriteLog(sock.sendall)
return f'af_unix:stream:{path}' return f"af_unix:stream:{path}"
except OSError as err: except OSError as err:
# If we tried to connect to a DGRAM socket using STREAM, ignore the # If we tried to connect to a DGRAM socket using STREAM,
# attempt and continue to DGRAM below. Otherwise, issue a warning. # ignore the attempt and continue to DGRAM below. Otherwise,
# issue a warning.
if err.errno != errno.EPROTOTYPE: if err.errno != errno.EPROTOTYPE:
print(f'repo: warning: git trace2 logging failed: {err}', file=sys.stderr) print(
f"repo: warning: git trace2 logging failed: {err}",
file=sys.stderr,
)
return None return None
if socket_type == socket.SOCK_DGRAM or socket_type is None: if socket_type == socket.SOCK_DGRAM or socket_type is None:
try: try:
with socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM) as sock: with socket.socket(
socket.AF_UNIX, socket.SOCK_DGRAM
) as sock:
self._WriteLog(lambda bs: sock.sendto(bs, path)) self._WriteLog(lambda bs: sock.sendto(bs, path))
return f'af_unix:dgram:{path}' return f"af_unix:dgram:{path}"
except OSError as err: except OSError as err:
print(f'repo: warning: git trace2 logging failed: {err}', file=sys.stderr) print(
f"repo: warning: git trace2 logging failed: {err}",
file=sys.stderr,
)
return None return None
# Tried to open a socket but couldn't connect (SOCK_STREAM) or write # Tried to open a socket but couldn't connect (SOCK_STREAM) or write
# (SOCK_DGRAM). # (SOCK_DGRAM).
print('repo: warning: git trace2 logging failed: could not write to socket', file=sys.stderr) print(
"repo: warning: git trace2 logging failed: could not write to "
"socket",
file=sys.stderr,
)
return None return None
# Path is an absolute path # Path is an absolute path
# Use NamedTemporaryFile to generate a unique filename as required by git trace2. # Use NamedTemporaryFile to generate a unique filename as required by
# git trace2.
try: try:
with tempfile.NamedTemporaryFile(mode='xb', prefix=self._sid, dir=path, with tempfile.NamedTemporaryFile(
delete=False) as f: mode="xb", prefix=self._sid, dir=path, delete=False
# TODO(https://crbug.com/gerrit/13706): Support writing events as they ) as f:
# occur. # TODO(https://crbug.com/gerrit/13706): Support writing events
# as they occur.
self._WriteLog(f.write) self._WriteLog(f.write)
log_path = f.name log_path = f.name
except FileExistsError as err: except FileExistsError as err:
print('repo: warning: git trace2 logging failed: %r' % err, print(
file=sys.stderr) "repo: warning: git trace2 logging failed: %r" % err,
file=sys.stderr,
)
return None return None
return log_path return log_path

View File

@ -39,9 +39,10 @@ def _get_project_revision(args):
"""Worker for _set_project_revisions to lookup one project remote.""" """Worker for _set_project_revisions to lookup one project remote."""
(i, url, expr) = args (i, url, expr) = args
gitcmd = git_command.GitCommand( gitcmd = git_command.GitCommand(
None, ['ls-remote', url, expr], capture_stdout=True, cwd='/tmp') None, ["ls-remote", url, expr], capture_stdout=True, cwd="/tmp"
)
rc = gitcmd.Wait() rc = gitcmd.Wait()
return (i, rc, gitcmd.stdout.split('\t', 1)[0]) return (i, rc, gitcmd.stdout.split("\t", 1)[0])
def _set_project_revisions(projects): def _set_project_revisions(projects):
@ -54,25 +55,33 @@ def _set_project_revisions(projects):
Args: Args:
projects: List of project objects to set the revionExpr for. projects: List of project objects to set the revionExpr for.
""" """
# Retrieve the commit id for each project based off of it's current # Retrieve the commit id for each project based off of its current
# revisionExpr and it is not already a commit id. # revisionExpr and it is not already a commit id.
with multiprocessing.Pool(NUM_BATCH_RETRIEVE_REVISIONID) as pool: with multiprocessing.Pool(NUM_BATCH_RETRIEVE_REVISIONID) as pool:
results_iter = pool.imap_unordered( results_iter = pool.imap_unordered(
_get_project_revision, _get_project_revision,
((i, project.remote.url, project.revisionExpr) (
(i, project.remote.url, project.revisionExpr)
for i, project in enumerate(projects) for i, project in enumerate(projects)
if not git_config.IsId(project.revisionExpr)), if not git_config.IsId(project.revisionExpr)
chunksize=8) ),
for (i, rc, revisionExpr) in results_iter: chunksize=8,
)
for i, rc, revisionExpr in results_iter:
project = projects[i] project = projects[i]
if rc: if rc:
print('FATAL: Failed to retrieve revisionExpr for %s' % project.name) print(
"FATAL: Failed to retrieve revisionExpr for %s"
% project.name
)
pool.terminate() pool.terminate()
sys.exit(1) sys.exit(1)
if not revisionExpr: if not revisionExpr:
pool.terminate() pool.terminate()
raise ManifestParseError('Invalid SHA-1 revision project %s (%s)' % raise ManifestParseError(
(project.remote.url, project.revisionExpr)) "Invalid SHA-1 revision project %s (%s)"
% (project.remote.url, project.revisionExpr)
)
project.revisionExpr = revisionExpr project.revisionExpr = revisionExpr
@ -85,12 +94,14 @@ def generate_gitc_manifest(gitc_manifest, manifest, paths=None):
paths: List of project paths we want to update. paths: List of project paths we want to update.
""" """
print('Generating GITC Manifest by fetching revision SHAs for each ' print(
'project.') "Generating GITC Manifest by fetching revision SHAs for each "
"project."
)
if paths is None: if paths is None:
paths = list(manifest.paths.keys()) paths = list(manifest.paths.keys())
groups = [x for x in re.split(r'[,\s]+', manifest.GetGroupsStr()) if x] groups = [x for x in re.split(r"[,\s]+", manifest.GetGroupsStr()) if x]
# Convert the paths to projects, and filter them to the matched groups. # Convert the paths to projects, and filter them to the matched groups.
projects = [manifest.paths[p] for p in paths] projects = [manifest.paths[p] for p in paths]
@ -105,12 +116,12 @@ def generate_gitc_manifest(gitc_manifest, manifest, paths=None):
proj.upstream = proj.revisionExpr proj.upstream = proj.revisionExpr
if path not in gitc_manifest.paths: if path not in gitc_manifest.paths:
# Any new projects need their first revision, even if we weren't asked # Any new projects need their first revision, even if we weren't
# for them. # asked for them.
projects.append(proj) projects.append(proj)
elif path not in paths: elif path not in paths:
# And copy revisions from the previous manifest if we're not updating # And copy revisions from the previous manifest if we're not
# them now. # updating them now.
gitc_proj = gitc_manifest.paths[path] gitc_proj = gitc_manifest.paths[path]
if gitc_proj.old_revision: if gitc_proj.old_revision:
proj.revisionExpr = None proj.revisionExpr = None
@ -123,8 +134,8 @@ def generate_gitc_manifest(gitc_manifest, manifest, paths=None):
if gitc_manifest is not None: if gitc_manifest is not None:
for path, proj in gitc_manifest.paths.items(): for path, proj in gitc_manifest.paths.items():
if proj.old_revision and path in paths: if proj.old_revision and path in paths:
# If we updated a project that has been started, keep the old-revision # If we updated a project that has been started, keep the
# updated. # old-revision updated.
repo_proj = manifest.paths[path] repo_proj = manifest.paths[path]
repo_proj.old_revision = repo_proj.revisionExpr repo_proj.old_revision = repo_proj.revisionExpr
repo_proj.revisionExpr = None repo_proj.revisionExpr = None
@ -147,8 +158,8 @@ def save_manifest(manifest, client_dir=None):
if not client_dir: if not client_dir:
manifest_file = manifest.manifestFile manifest_file = manifest.manifestFile
else: else:
manifest_file = os.path.join(client_dir, '.manifest') manifest_file = os.path.join(client_dir, ".manifest")
with open(manifest_file, 'w') as f: with open(manifest_file, "w") as f:
manifest.Save(f, groups=manifest.GetGroupsStr()) manifest.Save(f, groups=manifest.GetGroupsStr())
# TODO(sbasi/jorg): Come up with a solution to remove the sleep below. # TODO(sbasi/jorg): Come up with a solution to remove the sleep below.
# Give the GITC filesystem time to register the manifest changes. # Give the GITC filesystem time to register the manifest changes.

296
hooks.py
View File

@ -28,8 +28,9 @@ from git_refs import HEAD
class RepoHook(object): class RepoHook(object):
"""A RepoHook contains information about a script to run as a hook. """A RepoHook contains information about a script to run as a hook.
Hooks are used to run a python script before running an upload (for instance, Hooks are used to run a python script before running an upload (for
to run presubmit checks). Eventually, we may have hooks for other actions. instance, to run presubmit checks). Eventually, we may have hooks for other
actions.
This shouldn't be confused with files in the 'repo/hooks' directory. Those This shouldn't be confused with files in the 'repo/hooks' directory. Those
files are copied into each '.git/hooks' folder for each project. Repo-level files are copied into each '.git/hooks' folder for each project. Repo-level
@ -52,7 +53,8 @@ class RepoHook(object):
Invalid Invalid
""" """
def __init__(self, def __init__(
self,
hook_type, hook_type,
hooks_project, hooks_project,
repo_topdir, repo_topdir,
@ -60,7 +62,8 @@ class RepoHook(object):
bypass_hooks=False, bypass_hooks=False,
allow_all_hooks=False, allow_all_hooks=False,
ignore_hooks=False, ignore_hooks=False,
abort_if_user_denies=False): abort_if_user_denies=False,
):
"""RepoHook constructor. """RepoHook constructor.
Params: Params:
@ -78,8 +81,8 @@ class RepoHook(object):
bypass_hooks: If True, then 'Do not run the hook'. bypass_hooks: If True, then 'Do not run the hook'.
allow_all_hooks: If True, then 'Run the hook without prompting'. allow_all_hooks: If True, then 'Run the hook without prompting'.
ignore_hooks: If True, then 'Do not abort action if hooks fail'. ignore_hooks: If True, then 'Do not abort action if hooks fail'.
abort_if_user_denies: If True, we'll abort running the hook if the user abort_if_user_denies: If True, we'll abort running the hook if the
doesn't allow us to run the hook. user doesn't allow us to run the hook.
""" """
self._hook_type = hook_type self._hook_type = hook_type
self._hooks_project = hooks_project self._hooks_project = hooks_project
@ -92,77 +95,82 @@ class RepoHook(object):
# Store the full path to the script for convenience. # Store the full path to the script for convenience.
if self._hooks_project: if self._hooks_project:
self._script_fullpath = os.path.join(self._hooks_project.worktree, self._script_fullpath = os.path.join(
self._hook_type + '.py') self._hooks_project.worktree, self._hook_type + ".py"
)
else: else:
self._script_fullpath = None self._script_fullpath = None
def _GetHash(self): def _GetHash(self):
"""Return a hash of the contents of the hooks directory. """Return a hash of the contents of the hooks directory.
We'll just use git to do this. This hash has the property that if anything We'll just use git to do this. This hash has the property that if
changes in the directory we will return a different has. anything changes in the directory we will return a different has.
SECURITY CONSIDERATION: SECURITY CONSIDERATION:
This hash only represents the contents of files in the hook directory, not This hash only represents the contents of files in the hook
any other files imported or called by hooks. Changes to imported files directory, not any other files imported or called by hooks. Changes
can change the script behavior without affecting the hash. to imported files can change the script behavior without affecting
the hash.
Returns: Returns:
A string representing the hash. This will always be ASCII so that it can A string representing the hash. This will always be ASCII so that
be printed to the user easily. it can be printed to the user easily.
""" """
assert self._hooks_project, "Must have hooks to calculate their hash." assert self._hooks_project, "Must have hooks to calculate their hash."
# We will use the work_git object rather than just calling GetRevisionId(). # We will use the work_git object rather than just calling
# That gives us a hash of the latest checked in version of the files that # GetRevisionId(). That gives us a hash of the latest checked in version
# the user will actually be executing. Specifically, GetRevisionId() # of the files that the user will actually be executing. Specifically,
# doesn't appear to change even if a user checks out a different version # GetRevisionId() doesn't appear to change even if a user checks out a
# of the hooks repo (via git checkout) nor if a user commits their own revs. # different version of the hooks repo (via git checkout) nor if a user
# commits their own revs.
# #
# NOTE: Local (non-committed) changes will not be factored into this hash. # NOTE: Local (non-committed) changes will not be factored into this
# I think this is OK, since we're really only worried about warning the user # hash. I think this is OK, since we're really only worried about
# about upstream changes. # warning the user about upstream changes.
return self._hooks_project.work_git.rev_parse(HEAD) return self._hooks_project.work_git.rev_parse(HEAD)
def _GetMustVerb(self): def _GetMustVerb(self):
"""Return 'must' if the hook is required; 'should' if not.""" """Return 'must' if the hook is required; 'should' if not."""
if self._abort_if_user_denies: if self._abort_if_user_denies:
return 'must' return "must"
else: else:
return 'should' return "should"
def _CheckForHookApproval(self): def _CheckForHookApproval(self):
"""Check to see whether this hook has been approved. """Check to see whether this hook has been approved.
We'll accept approval of manifest URLs if they're using secure transports. We'll accept approval of manifest URLs if they're using secure
This way the user can say they trust the manifest hoster. For insecure transports. This way the user can say they trust the manifest hoster.
hosts, we fall back to checking the hash of the hooks repo. For insecure hosts, we fall back to checking the hash of the hooks repo.
Note that we ask permission for each individual hook even though we use Note that we ask permission for each individual hook even though we use
the hash of all hooks when detecting changes. We'd like the user to be the hash of all hooks when detecting changes. We'd like the user to be
able to approve / deny each hook individually. We only use the hash of all able to approve / deny each hook individually. We only use the hash of
hooks because there is no other easy way to detect changes to local imports. all hooks because there is no other easy way to detect changes to local
imports.
Returns: Returns:
True if this hook is approved to run; False otherwise. True if this hook is approved to run; False otherwise.
Raises: Raises:
HookError: Raised if the user doesn't approve and abort_if_user_denies HookError: Raised if the user doesn't approve and
was passed to the consturctor. abort_if_user_denies was passed to the consturctor.
""" """
if self._ManifestUrlHasSecureScheme(): if self._ManifestUrlHasSecureScheme():
return self._CheckForHookApprovalManifest() return self._CheckForHookApprovalManifest()
else: else:
return self._CheckForHookApprovalHash() return self._CheckForHookApprovalHash()
def _CheckForHookApprovalHelper(self, subkey, new_val, main_prompt, def _CheckForHookApprovalHelper(
changed_prompt): self, subkey, new_val, main_prompt, changed_prompt
):
"""Check for approval for a particular attribute and hook. """Check for approval for a particular attribute and hook.
Args: Args:
subkey: The git config key under [repo.hooks.<hook_type>] to store the subkey: The git config key under [repo.hooks.<hook_type>] to store
last approved string. the last approved string.
new_val: The new value to compare against the last approved one. new_val: The new value to compare against the last approved one.
main_prompt: Message to display to the user to ask for approval. main_prompt: Message to display to the user to ask for approval.
changed_prompt: Message explaining why we're re-asking for approval. changed_prompt: Message explaining why we're re-asking for approval.
@ -171,11 +179,11 @@ class RepoHook(object):
True if this hook is approved to run; False otherwise. True if this hook is approved to run; False otherwise.
Raises: Raises:
HookError: Raised if the user doesn't approve and abort_if_user_denies HookError: Raised if the user doesn't approve and
was passed to the consturctor. abort_if_user_denies was passed to the consturctor.
""" """
hooks_config = self._hooks_project.config hooks_config = self._hooks_project.config
git_approval_key = 'repo.hooks.%s.%s' % (self._hook_type, subkey) git_approval_key = "repo.hooks.%s.%s" % (self._hook_type, subkey)
# Get the last value that the user approved for this hook; may be None. # Get the last value that the user approved for this hook; may be None.
old_val = hooks_config.GetString(git_approval_key) old_val = hooks_config.GetString(git_approval_key)
@ -186,35 +194,44 @@ class RepoHook(object):
# Approval matched. We're done. # Approval matched. We're done.
return True return True
else: else:
# Give the user a reason why we're prompting, since they last told # Give the user a reason why we're prompting, since they last
# us to "never ask again". # told us to "never ask again".
prompt = 'WARNING: %s\n\n' % (changed_prompt,) prompt = "WARNING: %s\n\n" % (changed_prompt,)
else: else:
prompt = '' prompt = ""
# Prompt the user if we're not on a tty; on a tty we'll assume "no". # Prompt the user if we're not on a tty; on a tty we'll assume "no".
if sys.stdout.isatty(): if sys.stdout.isatty():
prompt += main_prompt + ' (yes/always/NO)? ' prompt += main_prompt + " (yes/always/NO)? "
response = input(prompt).lower() response = input(prompt).lower()
print() print()
# User is doing a one-time approval. # User is doing a one-time approval.
if response in ('y', 'yes'): if response in ("y", "yes"):
return True return True
elif response == 'always': elif response == "always":
hooks_config.SetString(git_approval_key, new_val) hooks_config.SetString(git_approval_key, new_val)
return True return True
# For anything else, we'll assume no approval. # For anything else, we'll assume no approval.
if self._abort_if_user_denies: if self._abort_if_user_denies:
raise HookError('You must allow the %s hook or use --no-verify.' % raise HookError(
self._hook_type) "You must allow the %s hook or use --no-verify."
% self._hook_type
)
return False return False
def _ManifestUrlHasSecureScheme(self): def _ManifestUrlHasSecureScheme(self):
"""Check if the URI for the manifest is a secure transport.""" """Check if the URI for the manifest is a secure transport."""
secure_schemes = ('file', 'https', 'ssh', 'persistent-https', 'sso', 'rpc') secure_schemes = (
"file",
"https",
"ssh",
"persistent-https",
"sso",
"rpc",
)
parse_results = urllib.parse.urlparse(self._manifest_url) parse_results = urllib.parse.urlparse(self._manifest_url)
return parse_results.scheme in secure_schemes return parse_results.scheme in secure_schemes
@ -225,10 +242,12 @@ class RepoHook(object):
True if this hook is approved to run; False otherwise. True if this hook is approved to run; False otherwise.
""" """
return self._CheckForHookApprovalHelper( return self._CheckForHookApprovalHelper(
'approvedmanifest', "approvedmanifest",
self._manifest_url, self._manifest_url,
'Run hook scripts from %s' % (self._manifest_url,), "Run hook scripts from %s" % (self._manifest_url,),
'Manifest URL has changed since %s was allowed.' % (self._hook_type,)) "Manifest URL has changed since %s was allowed."
% (self._hook_type,),
)
def _CheckForHookApprovalHash(self): def _CheckForHookApprovalHash(self):
"""Check whether the user has approved the hooks repo. """Check whether the user has approved the hooks repo.
@ -236,15 +255,18 @@ class RepoHook(object):
Returns: Returns:
True if this hook is approved to run; False otherwise. True if this hook is approved to run; False otherwise.
""" """
prompt = ('Repo %s run the script:\n' prompt = (
' %s\n' "Repo %s run the script:\n"
'\n' " %s\n"
'Do you want to allow this script to run') "\n"
"Do you want to allow this script to run"
)
return self._CheckForHookApprovalHelper( return self._CheckForHookApprovalHelper(
'approvedhash', "approvedhash",
self._GetHash(), self._GetHash(),
prompt % (self._GetMustVerb(), self._script_fullpath), prompt % (self._GetMustVerb(), self._script_fullpath),
'Scripts have changed since %s was allowed.' % (self._hook_type,)) "Scripts have changed since %s was allowed." % (self._hook_type,),
)
@staticmethod @staticmethod
def _ExtractInterpFromShebang(data): def _ExtractInterpFromShebang(data):
@ -256,8 +278,8 @@ class RepoHook(object):
data: The file content of the script. data: The file content of the script.
Returns: Returns:
The basename of the main script interpreter, or None if a shebang is not The basename of the main script interpreter, or None if a shebang is
used or could not be parsed out. not used or could not be parsed out.
""" """
firstline = data.splitlines()[:1] firstline = data.splitlines()[:1]
if not firstline: if not firstline:
@ -265,13 +287,13 @@ class RepoHook(object):
# The format here can be tricky. # The format here can be tricky.
shebang = firstline[0].strip() shebang = firstline[0].strip()
m = re.match(r'^#!\s*([^\s]+)(?:\s+([^\s]+))?', shebang) m = re.match(r"^#!\s*([^\s]+)(?:\s+([^\s]+))?", shebang)
if not m: if not m:
return None return None
# If the using `env`, find the target program. # If the using `env`, find the target program.
interp = m.group(1) interp = m.group(1)
if os.path.basename(interp) == 'env': if os.path.basename(interp) == "env":
interp = m.group(2) interp = m.group(2)
return interp return interp
@ -300,18 +322,18 @@ data = open(path).read()
exec(compile(data, path, 'exec'), context) exec(compile(data, path, 'exec'), context)
context['main'](**kwargs) context['main'](**kwargs)
""" % { """ % {
'path': self._script_fullpath, "path": self._script_fullpath,
'kwargs': json.dumps(kwargs), "kwargs": json.dumps(kwargs),
'context': json.dumps(context), "context": json.dumps(context),
} }
# We pass the script via stdin to avoid OS argv limits. It also makes # We pass the script via stdin to avoid OS argv limits. It also makes
# unhandled exception tracebacks less verbose/confusing for users. # unhandled exception tracebacks less verbose/confusing for users.
cmd = [interp, '-c', 'import sys; exec(sys.stdin.read())'] cmd = [interp, "-c", "import sys; exec(sys.stdin.read())"]
proc = subprocess.Popen(cmd, stdin=subprocess.PIPE) proc = subprocess.Popen(cmd, stdin=subprocess.PIPE)
proc.communicate(input=script.encode('utf-8')) proc.communicate(input=script.encode("utf-8"))
if proc.returncode: if proc.returncode:
raise HookError('Failed to run %s hook.' % (self._hook_type,)) raise HookError("Failed to run %s hook." % (self._hook_type,))
def _ExecuteHookViaImport(self, data, context, **kwargs): def _ExecuteHookViaImport(self, data, context, **kwargs):
"""Execute the hook code in |data| directly. """Execute the hook code in |data| directly.
@ -327,23 +349,27 @@ context['main'](**kwargs)
# Exec, storing global context in the context dict. We catch exceptions # Exec, storing global context in the context dict. We catch exceptions
# and convert to a HookError w/ just the failing traceback. # and convert to a HookError w/ just the failing traceback.
try: try:
exec(compile(data, self._script_fullpath, 'exec'), context) exec(compile(data, self._script_fullpath, "exec"), context)
except Exception: except Exception:
raise HookError('%s\nFailed to import %s hook; see traceback above.' % raise HookError(
(traceback.format_exc(), self._hook_type)) "%s\nFailed to import %s hook; see traceback above."
% (traceback.format_exc(), self._hook_type)
)
# Running the script should have defined a main() function. # Running the script should have defined a main() function.
if 'main' not in context: if "main" not in context:
raise HookError('Missing main() in: "%s"' % self._script_fullpath) raise HookError('Missing main() in: "%s"' % self._script_fullpath)
# Call the main function in the hook. If the hook should cause the # Call the main function in the hook. If the hook should cause the
# build to fail, it will raise an Exception. We'll catch that convert # build to fail, it will raise an Exception. We'll catch that convert
# to a HookError w/ just the failing traceback. # to a HookError w/ just the failing traceback.
try: try:
context['main'](**kwargs) context["main"](**kwargs)
except Exception: except Exception:
raise HookError('%s\nFailed to run main() for %s hook; see traceback ' raise HookError(
'above.' % (traceback.format_exc(), self._hook_type)) "%s\nFailed to run main() for %s hook; see traceback "
"above." % (traceback.format_exc(), self._hook_type)
)
def _ExecuteHook(self, **kwargs): def _ExecuteHook(self, **kwargs):
"""Actually execute the given hook. """Actually execute the given hook.
@ -351,9 +377,9 @@ context['main'](**kwargs)
This will run the hook's 'main' function in our python interpreter. This will run the hook's 'main' function in our python interpreter.
Args: Args:
kwargs: Keyword arguments to pass to the hook. These are often specific kwargs: Keyword arguments to pass to the hook. These are often
to the hook type. For instance, pre-upload hooks will contain specific to the hook type. For instance, pre-upload hooks will
a project_list. contain a project_list.
""" """
# Keep sys.path and CWD stashed away so that we can always restore them # Keep sys.path and CWD stashed away so that we can always restore them
# upon function exit. # upon function exit.
@ -370,17 +396,18 @@ context['main'](**kwargs)
sys.path = [os.path.dirname(self._script_fullpath)] + sys.path[1:] sys.path = [os.path.dirname(self._script_fullpath)] + sys.path[1:]
# Initial global context for the hook to run within. # Initial global context for the hook to run within.
context = {'__file__': self._script_fullpath} context = {"__file__": self._script_fullpath}
# Add 'hook_should_take_kwargs' to the arguments to be passed to main. # Add 'hook_should_take_kwargs' to the arguments to be passed to
# We don't actually want hooks to define their main with this argument-- # main. We don't actually want hooks to define their main with this
# it's there to remind them that their hook should always take **kwargs. # argument--it's there to remind them that their hook should always
# take **kwargs.
# For instance, a pre-upload hook should be defined like: # For instance, a pre-upload hook should be defined like:
# def main(project_list, **kwargs): # def main(project_list, **kwargs):
# #
# This allows us to later expand the API without breaking old hooks. # This allows us to later expand the API without breaking old hooks.
kwargs = kwargs.copy() kwargs = kwargs.copy()
kwargs['hook_should_take_kwargs'] = True kwargs["hook_should_take_kwargs"] = True
# See what version of python the hook has been written against. # See what version of python the hook has been written against.
data = open(self._script_fullpath).read() data = open(self._script_fullpath).read()
@ -388,18 +415,20 @@ context['main'](**kwargs)
reexec = False reexec = False
if interp: if interp:
prog = os.path.basename(interp) prog = os.path.basename(interp)
if prog.startswith('python2') and sys.version_info.major != 2: if prog.startswith("python2") and sys.version_info.major != 2:
reexec = True reexec = True
elif prog.startswith('python3') and sys.version_info.major == 2: elif prog.startswith("python3") and sys.version_info.major == 2:
reexec = True reexec = True
# Attempt to execute the hooks through the requested version of Python. # Attempt to execute the hooks through the requested version of
# Python.
if reexec: if reexec:
try: try:
self._ExecuteHookViaReexec(interp, context, **kwargs) self._ExecuteHookViaReexec(interp, context, **kwargs)
except OSError as e: except OSError as e:
if e.errno == errno.ENOENT: if e.errno == errno.ENOENT:
# We couldn't find the interpreter, so fallback to importing. # We couldn't find the interpreter, so fallback to
# importing.
reexec = False reexec = False
else: else:
raise raise
@ -415,7 +444,9 @@ context['main'](**kwargs)
def _CheckHook(self): def _CheckHook(self):
# Bail with a nice error if we can't find the hook. # Bail with a nice error if we can't find the hook.
if not os.path.isfile(self._script_fullpath): if not os.path.isfile(self._script_fullpath):
raise HookError('Couldn\'t find repo hook: %s' % self._script_fullpath) raise HookError(
"Couldn't find repo hook: %s" % self._script_fullpath
)
def Run(self, **kwargs): def Run(self, **kwargs):
"""Run the hook. """Run the hook.
@ -424,27 +455,30 @@ context['main'](**kwargs)
this particular hook is not enabled), this is a no-op. this particular hook is not enabled), this is a no-op.
Args: Args:
user_allows_all_hooks: If True, we will never prompt about running the user_allows_all_hooks: If True, we will never prompt about running
hook--we'll just assume it's OK to run it. the hook--we'll just assume it's OK to run it.
kwargs: Keyword arguments to pass to the hook. These are often specific kwargs: Keyword arguments to pass to the hook. These are often
to the hook type. For instance, pre-upload hooks will contain specific to the hook type. For instance, pre-upload hooks will
a project_list. contain a project_list.
Returns: Returns:
True: On success or ignore hooks by user-request True: On success or ignore hooks by user-request
False: The hook failed. The caller should respond with aborting the action. False: The hook failed. The caller should respond with aborting the
Some examples in which False is returned: action. Some examples in which False is returned:
* Finding the hook failed while it was enabled, or * Finding the hook failed while it was enabled, or
* the user declined to run a required hook (from _CheckForHookApproval) * the user declined to run a required hook (from
_CheckForHookApproval)
In all these cases the user did not pass the proper arguments to In all these cases the user did not pass the proper arguments to
ignore the result through the option combinations as listed in ignore the result through the option combinations as listed in
AddHookOptionGroup(). AddHookOptionGroup().
""" """
# Do not do anything in case bypass_hooks is set, or # Do not do anything in case bypass_hooks is set, or
# no-op if there is no hooks project or if hook is disabled. # no-op if there is no hooks project or if hook is disabled.
if (self._bypass_hooks or if (
not self._hooks_project or self._bypass_hooks
self._hook_type not in self._hooks_project.enabled_repo_hooks): or not self._hooks_project
or self._hook_type not in self._hooks_project.enabled_repo_hooks
):
return True return True
passed = True passed = True
@ -457,15 +491,21 @@ context['main'](**kwargs)
self._ExecuteHook(**kwargs) self._ExecuteHook(**kwargs)
except SystemExit as e: except SystemExit as e:
passed = False passed = False
print('ERROR: %s hooks exited with exit code: %s' % (self._hook_type, str(e)), print(
file=sys.stderr) "ERROR: %s hooks exited with exit code: %s"
% (self._hook_type, str(e)),
file=sys.stderr,
)
except HookError as e: except HookError as e:
passed = False passed = False
print('ERROR: %s' % str(e), file=sys.stderr) print("ERROR: %s" % str(e), file=sys.stderr)
if not passed and self._ignore_hooks: if not passed and self._ignore_hooks:
print('\nWARNING: %s hooks failed, but continuing anyways.' % self._hook_type, print(
file=sys.stderr) "\nWARNING: %s hooks failed, but continuing anyways."
% self._hook_type,
file=sys.stderr,
)
passed = True passed = True
return passed return passed
@ -478,16 +518,20 @@ context['main'](**kwargs)
manifest: The current active manifest for this command from which we manifest: The current active manifest for this command from which we
extract a couple of fields. extract a couple of fields.
opt: Contains the commandline options for the action of this hook. opt: Contains the commandline options for the action of this hook.
It should contain the options added by AddHookOptionGroup() in which It should contain the options added by AddHookOptionGroup() in
we are interested in RepoHook execution. which we are interested in RepoHook execution.
""" """
for key in ('bypass_hooks', 'allow_all_hooks', 'ignore_hooks'): for key in ("bypass_hooks", "allow_all_hooks", "ignore_hooks"):
kwargs.setdefault(key, getattr(opt, key)) kwargs.setdefault(key, getattr(opt, key))
kwargs.update({ kwargs.update(
'hooks_project': manifest.repo_hooks_project, {
'repo_topdir': manifest.topdir, "hooks_project": manifest.repo_hooks_project,
'manifest_url': manifest.manifestProject.GetRemote('origin').url, "repo_topdir": manifest.topdir,
}) "manifest_url": manifest.manifestProject.GetRemote(
"origin"
).url,
}
)
return cls(*args, **kwargs) return cls(*args, **kwargs)
@staticmethod @staticmethod
@ -497,13 +541,21 @@ context['main'](**kwargs)
# Note that verify and no-verify are NOT opposites of each other, which # Note that verify and no-verify are NOT opposites of each other, which
# is why they store to different locations. We are using them to match # is why they store to different locations. We are using them to match
# 'git commit' syntax. # 'git commit' syntax.
group = parser.add_option_group(name + ' hooks') group = parser.add_option_group(name + " hooks")
group.add_option('--no-verify', group.add_option(
dest='bypass_hooks', action='store_true', "--no-verify",
help='Do not run the %s hook.' % name) dest="bypass_hooks",
group.add_option('--verify', action="store_true",
dest='allow_all_hooks', action='store_true', help="Do not run the %s hook." % name,
help='Run the %s hook without prompting.' % name) )
group.add_option('--ignore-hooks', group.add_option(
action='store_true', "--verify",
help='Do not abort if %s hooks fail.' % name) dest="allow_all_hooks",
action="store_true",
help="Run the %s hook without prompting." % name,
)
group.add_option(
"--ignore-hooks",
action="store_true",
help="Do not abort if %s hooks fail." % name,
)

468
main.py
View File

@ -74,64 +74,109 @@ MIN_PYTHON_VERSION_SOFT = (3, 6)
MIN_PYTHON_VERSION_HARD = (3, 6) MIN_PYTHON_VERSION_HARD = (3, 6)
if sys.version_info.major < 3: if sys.version_info.major < 3:
print('repo: error: Python 2 is no longer supported; ' print(
'Please upgrade to Python {}.{}+.'.format(*MIN_PYTHON_VERSION_SOFT), "repo: error: Python 2 is no longer supported; "
file=sys.stderr) "Please upgrade to Python {}.{}+.".format(*MIN_PYTHON_VERSION_SOFT),
file=sys.stderr,
)
sys.exit(1) sys.exit(1)
else: else:
if sys.version_info < MIN_PYTHON_VERSION_HARD: if sys.version_info < MIN_PYTHON_VERSION_HARD:
print('repo: error: Python 3 version is too old; ' print(
'Please upgrade to Python {}.{}+.'.format(*MIN_PYTHON_VERSION_SOFT), "repo: error: Python 3 version is too old; "
file=sys.stderr) "Please upgrade to Python {}.{}+.".format(*MIN_PYTHON_VERSION_SOFT),
file=sys.stderr,
)
sys.exit(1) sys.exit(1)
elif sys.version_info < MIN_PYTHON_VERSION_SOFT: elif sys.version_info < MIN_PYTHON_VERSION_SOFT:
print('repo: warning: your Python 3 version is no longer supported; ' print(
'Please upgrade to Python {}.{}+.'.format(*MIN_PYTHON_VERSION_SOFT), "repo: warning: your Python 3 version is no longer supported; "
file=sys.stderr) "Please upgrade to Python {}.{}+.".format(*MIN_PYTHON_VERSION_SOFT),
file=sys.stderr,
)
global_options = optparse.OptionParser( global_options = optparse.OptionParser(
usage='repo [-p|--paginate|--no-pager] COMMAND [ARGS]', usage="repo [-p|--paginate|--no-pager] COMMAND [ARGS]",
add_help_option=False) add_help_option=False,
global_options.add_option('-h', '--help', action='store_true', )
help='show this help message and exit') global_options.add_option(
global_options.add_option('--help-all', action='store_true', "-h", "--help", action="store_true", help="show this help message and exit"
help='show this help message with all subcommands and exit') )
global_options.add_option('-p', '--paginate', global_options.add_option(
dest='pager', action='store_true', "--help-all",
help='display command output in the pager') action="store_true",
global_options.add_option('--no-pager', help="show this help message with all subcommands and exit",
dest='pager', action='store_false', )
help='disable the pager') global_options.add_option(
global_options.add_option('--color', "-p",
choices=('auto', 'always', 'never'), default=None, "--paginate",
help='control color usage: auto, always, never') dest="pager",
global_options.add_option('--trace', action="store_true",
dest='trace', action='store_true', help="display command output in the pager",
help='trace git command execution (REPO_TRACE=1)') )
global_options.add_option('--trace-to-stderr', global_options.add_option(
dest='trace_to_stderr', action='store_true', "--no-pager", dest="pager", action="store_false", help="disable the pager"
help='trace outputs go to stderr in addition to .repo/TRACE_FILE') )
global_options.add_option('--trace-python', global_options.add_option(
dest='trace_python', action='store_true', "--color",
help='trace python command execution') choices=("auto", "always", "never"),
global_options.add_option('--time', default=None,
dest='time', action='store_true', help="control color usage: auto, always, never",
help='time repo command execution') )
global_options.add_option('--version', global_options.add_option(
dest='show_version', action='store_true', "--trace",
help='display this version of repo') dest="trace",
global_options.add_option('--show-toplevel', action="store_true",
action='store_true', help="trace git command execution (REPO_TRACE=1)",
help='display the path of the top-level directory of ' )
'the repo client checkout') global_options.add_option(
global_options.add_option('--event-log', "--trace-to-stderr",
dest='event_log', action='store', dest="trace_to_stderr",
help='filename of event log to append timeline to') action="store_true",
global_options.add_option('--git-trace2-event-log', action='store', help="trace outputs go to stderr in addition to .repo/TRACE_FILE",
help='directory to write git trace2 event log to') )
global_options.add_option('--submanifest-path', action='store', global_options.add_option(
metavar='REL_PATH', help='submanifest path') "--trace-python",
dest="trace_python",
action="store_true",
help="trace python command execution",
)
global_options.add_option(
"--time",
dest="time",
action="store_true",
help="time repo command execution",
)
global_options.add_option(
"--version",
dest="show_version",
action="store_true",
help="display this version of repo",
)
global_options.add_option(
"--show-toplevel",
action="store_true",
help="display the path of the top-level directory of "
"the repo client checkout",
)
global_options.add_option(
"--event-log",
dest="event_log",
action="store",
help="filename of event log to append timeline to",
)
global_options.add_option(
"--git-trace2-event-log",
action="store",
help="directory to write git trace2 event log to",
)
global_options.add_option(
"--submanifest-path",
action="store",
metavar="REL_PATH",
help="submanifest path",
)
class _Repo(object): class _Repo(object):
@ -144,13 +189,15 @@ class _Repo(object):
global_options.print_help() global_options.print_help()
print() print()
if short: if short:
commands = ' '.join(sorted(self.commands)) commands = " ".join(sorted(self.commands))
wrapped_commands = textwrap.wrap(commands, width=77) wrapped_commands = textwrap.wrap(commands, width=77)
print('Available commands:\n %s' % ('\n '.join(wrapped_commands),)) print(
print('\nRun `repo help <command>` for command-specific details.') "Available commands:\n %s" % ("\n ".join(wrapped_commands),)
print('Bug reports:', Wrapper().BUG_URL) )
print("\nRun `repo help <command>` for command-specific details.")
print("Bug reports:", Wrapper().BUG_URL)
else: else:
cmd = self.commands['help']() cmd = self.commands["help"]()
if all_commands: if all_commands:
cmd.PrintAllCommandsBody() cmd.PrintAllCommandsBody()
else: else:
@ -159,10 +206,10 @@ class _Repo(object):
def _ParseArgs(self, argv): def _ParseArgs(self, argv):
"""Parse the main `repo` command line options.""" """Parse the main `repo` command line options."""
for i, arg in enumerate(argv): for i, arg in enumerate(argv):
if not arg.startswith('-'): if not arg.startswith("-"):
name = arg name = arg
glob = argv[:i] glob = argv[:i]
argv = argv[i + 1:] argv = argv[i + 1 :]
break break
else: else:
name = None name = None
@ -182,14 +229,14 @@ class _Repo(object):
if name in self.commands: if name in self.commands:
return name, [] return name, []
key = 'alias.%s' % (name,) key = "alias.%s" % (name,)
alias = RepoConfig.ForRepository(self.repodir).GetString(key) alias = RepoConfig.ForRepository(self.repodir).GetString(key)
if alias is None: if alias is None:
alias = RepoConfig.ForUser().GetString(key) alias = RepoConfig.ForUser().GetString(key)
if alias is None: if alias is None:
return name, [] return name, []
args = alias.strip().split(' ', 1) args = alias.strip().split(" ", 1)
name = args[0] name = args[0]
if len(args) == 2: if len(args) == 2:
args = shlex.split(args[1]) args = shlex.split(args[1])
@ -207,7 +254,7 @@ class _Repo(object):
return 0 return 0
elif gopts.show_version: elif gopts.show_version:
# Always allow global --version regardless of subcommand validity. # Always allow global --version regardless of subcommand validity.
name = 'version' name = "version"
elif gopts.show_toplevel: elif gopts.show_toplevel:
print(os.path.dirname(self.repodir)) print(os.path.dirname(self.repodir))
return 0 return 0
@ -217,12 +264,20 @@ class _Repo(object):
return 1 return 1
run = lambda: self._RunLong(name, gopts, argv) or 0 run = lambda: self._RunLong(name, gopts, argv) or 0
with Trace('starting new command: %s', ', '.join([name] + argv), with Trace(
first_trace=True): "starting new command: %s",
", ".join([name] + argv),
first_trace=True,
):
if gopts.trace_python: if gopts.trace_python:
import trace import trace
tracer = trace.Trace(count=False, trace=True, timing=True,
ignoredirs=set(sys.path[1:])) tracer = trace.Trace(
count=False,
trace=True,
timing=True,
ignoredirs=set(sys.path[1:]),
)
result = tracer.runfunc(run) result = tracer.runfunc(run)
else: else:
result = run() result = run()
@ -237,9 +292,11 @@ class _Repo(object):
outer_client = RepoClient(self.repodir) outer_client = RepoClient(self.repodir)
repo_client = outer_client repo_client = outer_client
if gopts.submanifest_path: if gopts.submanifest_path:
repo_client = RepoClient(self.repodir, repo_client = RepoClient(
self.repodir,
submanifest_path=gopts.submanifest_path, submanifest_path=gopts.submanifest_path,
outer_client=outer_client) outer_client=outer_client,
)
gitc_manifest = None gitc_manifest = None
gitc_client_name = gitc_utils.parse_clientdir(os.getcwd()) gitc_client_name = gitc_utils.parse_clientdir(os.getcwd())
if gitc_client_name: if gitc_client_name:
@ -254,37 +311,50 @@ class _Repo(object):
outer_client=outer_client, outer_client=outer_client,
outer_manifest=outer_client.manifest, outer_manifest=outer_client.manifest,
gitc_manifest=gitc_manifest, gitc_manifest=gitc_manifest,
git_event_log=git_trace2_event_log) git_event_log=git_trace2_event_log,
)
except KeyError: except KeyError:
print("repo: '%s' is not a repo command. See 'repo help'." % name, print(
file=sys.stderr) "repo: '%s' is not a repo command. See 'repo help'." % name,
file=sys.stderr,
)
return 1 return 1
Editor.globalConfig = cmd.client.globalConfig Editor.globalConfig = cmd.client.globalConfig
if not isinstance(cmd, MirrorSafeCommand) and cmd.manifest.IsMirror: if not isinstance(cmd, MirrorSafeCommand) and cmd.manifest.IsMirror:
print("fatal: '%s' requires a working directory" % name, print(
file=sys.stderr) "fatal: '%s' requires a working directory" % name,
file=sys.stderr,
)
return 1 return 1
if isinstance(cmd, GitcAvailableCommand) and not gitc_utils.get_gitc_manifest_dir(): if (
print("fatal: '%s' requires GITC to be available" % name, isinstance(cmd, GitcAvailableCommand)
file=sys.stderr) and not gitc_utils.get_gitc_manifest_dir()
):
print(
"fatal: '%s' requires GITC to be available" % name,
file=sys.stderr,
)
return 1 return 1
if isinstance(cmd, GitcClientCommand) and not gitc_client_name: if isinstance(cmd, GitcClientCommand) and not gitc_client_name:
print("fatal: '%s' requires a GITC client" % name, print("fatal: '%s' requires a GITC client" % name, file=sys.stderr)
file=sys.stderr)
return 1 return 1
try: try:
copts, cargs = cmd.OptionParser.parse_args(argv) copts, cargs = cmd.OptionParser.parse_args(argv)
copts = cmd.ReadEnvironmentOptions(copts) copts = cmd.ReadEnvironmentOptions(copts)
except NoManifestException as e: except NoManifestException as e:
print('error: in `%s`: %s' % (' '.join([name] + argv), str(e)), print(
file=sys.stderr) "error: in `%s`: %s" % (" ".join([name] + argv), str(e)),
print('error: manifest missing or unreadable -- please run init', file=sys.stderr,
file=sys.stderr) )
print(
"error: manifest missing or unreadable -- please run init",
file=sys.stderr,
)
return 1 return 1
if gopts.pager is not False and not isinstance(cmd, InteractiveCommand): if gopts.pager is not False and not isinstance(cmd, InteractiveCommand):
@ -292,7 +362,7 @@ class _Repo(object):
if gopts.pager: if gopts.pager:
use_pager = True use_pager = True
else: else:
use_pager = config.GetBoolean('pager.%s' % name) use_pager = config.GetBoolean("pager.%s" % name)
if use_pager is None: if use_pager is None:
use_pager = cmd.WantPager(copts) use_pager = cmd.WantPager(copts)
if use_pager: if use_pager:
@ -302,7 +372,7 @@ class _Repo(object):
cmd_event = cmd.event_log.Add(name, event_log.TASK_COMMAND, start) cmd_event = cmd.event_log.Add(name, event_log.TASK_COMMAND, start)
cmd.event_log.SetParent(cmd_event) cmd.event_log.SetParent(cmd_event)
git_trace2_event_log.StartEvent() git_trace2_event_log.StartEvent()
git_trace2_event_log.CommandEvent(name='repo', subcommands=[name]) git_trace2_event_log.CommandEvent(name="repo", subcommands=[name])
try: try:
cmd.CommonValidateOptions(copts, cargs) cmd.CommonValidateOptions(copts, cargs)
@ -314,10 +384,10 @@ class _Repo(object):
result = cmd.Execute(copts, cargs) result = cmd.Execute(copts, cargs)
elif outer_manifest and repo_client.manifest.is_submanifest: elif outer_manifest and repo_client.manifest.is_submanifest:
# The command does not support multi-manifest, we are using a # The command does not support multi-manifest, we are using a
# submanifest, and the command line is for the outermost manifest. # submanifest, and the command line is for the outermost
# Re-run using the outermost manifest, which will recurse through the # manifest. Re-run using the outermost manifest, which will
# submanifests. # recurse through the submanifests.
gopts.submanifest_path = '' gopts.submanifest_path = ""
result = self._Run(name, gopts, argv) result = self._Run(name, gopts, argv)
else: else:
# No multi-manifest support. Run the command in the current # No multi-manifest support. Run the command in the current
@ -327,36 +397,52 @@ class _Repo(object):
spec = submanifest.ToSubmanifestSpec() spec = submanifest.ToSubmanifestSpec()
gopts.submanifest_path = submanifest.repo_client.path_prefix gopts.submanifest_path = submanifest.repo_client.path_prefix
child_argv = argv[:] child_argv = argv[:]
child_argv.append('--no-outer-manifest') child_argv.append("--no-outer-manifest")
# Not all subcommands support the 3 manifest options, so only add them # Not all subcommands support the 3 manifest options, so
# if the original command includes them. # only add them if the original command includes them.
if hasattr(copts, 'manifest_url'): if hasattr(copts, "manifest_url"):
child_argv.extend(['--manifest-url', spec.manifestUrl]) child_argv.extend(["--manifest-url", spec.manifestUrl])
if hasattr(copts, 'manifest_name'): if hasattr(copts, "manifest_name"):
child_argv.extend(['--manifest-name', spec.manifestName]) child_argv.extend(
if hasattr(copts, 'manifest_branch'): ["--manifest-name", spec.manifestName]
child_argv.extend(['--manifest-branch', spec.revision]) )
if hasattr(copts, "manifest_branch"):
child_argv.extend(["--manifest-branch", spec.revision])
result = self._Run(name, gopts, child_argv) or result result = self._Run(name, gopts, child_argv) or result
except (DownloadError, ManifestInvalidRevisionError, except (
NoManifestException) as e: DownloadError,
print('error: in `%s`: %s' % (' '.join([name] + argv), str(e)), ManifestInvalidRevisionError,
file=sys.stderr) NoManifestException,
) as e:
print(
"error: in `%s`: %s" % (" ".join([name] + argv), str(e)),
file=sys.stderr,
)
if isinstance(e, NoManifestException): if isinstance(e, NoManifestException):
print('error: manifest missing or unreadable -- please run init', print(
file=sys.stderr) "error: manifest missing or unreadable -- please run init",
file=sys.stderr,
)
result = 1 result = 1
except NoSuchProjectError as e: except NoSuchProjectError as e:
if e.name: if e.name:
print('error: project %s not found' % e.name, file=sys.stderr) print("error: project %s not found" % e.name, file=sys.stderr)
else: else:
print('error: no project in current directory', file=sys.stderr) print("error: no project in current directory", file=sys.stderr)
result = 1 result = 1
except InvalidProjectGroupsError as e: except InvalidProjectGroupsError as e:
if e.name: if e.name:
print('error: project group must be enabled for project %s' % e.name, file=sys.stderr) print(
"error: project group must be enabled for project %s"
% e.name,
file=sys.stderr,
)
else: else:
print('error: project group must be enabled for the project in the current directory', print(
file=sys.stderr) "error: project group must be enabled for the project in "
"the current directory",
file=sys.stderr,
)
result = 1 result = 1
except SystemExit as e: except SystemExit as e:
if e.code: if e.code:
@ -369,20 +455,27 @@ class _Repo(object):
minutes, seconds = divmod(remainder, 60) minutes, seconds = divmod(remainder, 60)
if gopts.time: if gopts.time:
if hours == 0: if hours == 0:
print('real\t%dm%.3fs' % (minutes, seconds), file=sys.stderr) print(
"real\t%dm%.3fs" % (minutes, seconds), file=sys.stderr
)
else: else:
print('real\t%dh%dm%.3fs' % (hours, minutes, seconds), print(
file=sys.stderr) "real\t%dh%dm%.3fs" % (hours, minutes, seconds),
file=sys.stderr,
)
cmd.event_log.FinishEvent(cmd_event, finish, cmd.event_log.FinishEvent(
result is None or result == 0) cmd_event, finish, result is None or result == 0
)
git_trace2_event_log.DefParamRepoEvents( git_trace2_event_log.DefParamRepoEvents(
cmd.manifest.manifestProject.config.DumpConfigDict()) cmd.manifest.manifestProject.config.DumpConfigDict()
)
git_trace2_event_log.ExitEvent(result) git_trace2_event_log.ExitEvent(result)
if gopts.event_log: if gopts.event_log:
cmd.event_log.Write(os.path.abspath( cmd.event_log.Write(
os.path.expanduser(gopts.event_log))) os.path.abspath(os.path.expanduser(gopts.event_log))
)
git_trace2_event_log.Write(gopts.git_trace2_event_log) git_trace2_event_log.Write(gopts.git_trace2_event_log)
return result return result
@ -392,29 +485,31 @@ def _CheckWrapperVersion(ver_str, repo_path):
"""Verify the repo launcher is new enough for this checkout. """Verify the repo launcher is new enough for this checkout.
Args: Args:
ver_str: The version string passed from the repo launcher when it ran us. ver_str: The version string passed from the repo launcher when it ran
us.
repo_path: The path to the repo launcher that loaded us. repo_path: The path to the repo launcher that loaded us.
""" """
# Refuse to work with really old wrapper versions. We don't test these, # Refuse to work with really old wrapper versions. We don't test these,
# so might as well require a somewhat recent sane version. # so might as well require a somewhat recent sane version.
# v1.15 of the repo launcher was released in ~Mar 2012. # v1.15 of the repo launcher was released in ~Mar 2012.
MIN_REPO_VERSION = (1, 15) MIN_REPO_VERSION = (1, 15)
min_str = '.'.join(str(x) for x in MIN_REPO_VERSION) min_str = ".".join(str(x) for x in MIN_REPO_VERSION)
if not repo_path: if not repo_path:
repo_path = '~/bin/repo' repo_path = "~/bin/repo"
if not ver_str: if not ver_str:
print('no --wrapper-version argument', file=sys.stderr) print("no --wrapper-version argument", file=sys.stderr)
sys.exit(1) sys.exit(1)
# Pull out the version of the repo launcher we know about to compare. # Pull out the version of the repo launcher we know about to compare.
exp = Wrapper().VERSION exp = Wrapper().VERSION
ver = tuple(map(int, ver_str.split('.'))) ver = tuple(map(int, ver_str.split(".")))
exp_str = '.'.join(map(str, exp)) exp_str = ".".join(map(str, exp))
if ver < MIN_REPO_VERSION: if ver < MIN_REPO_VERSION:
print(""" print(
"""
repo: error: repo: error:
!!! Your version of repo %s is too old. !!! Your version of repo %s is too old.
!!! We need at least version %s. !!! We need at least version %s.
@ -422,29 +517,42 @@ repo: error:
!!! You must upgrade before you can continue: !!! You must upgrade before you can continue:
cp %s %s cp %s %s
""" % (ver_str, min_str, exp_str, WrapperPath(), repo_path), file=sys.stderr) """
% (ver_str, min_str, exp_str, WrapperPath(), repo_path),
file=sys.stderr,
)
sys.exit(1) sys.exit(1)
if exp > ver: if exp > ver:
print('\n... A new version of repo (%s) is available.' % (exp_str,), print(
file=sys.stderr) "\n... A new version of repo (%s) is available." % (exp_str,),
file=sys.stderr,
)
if os.access(repo_path, os.W_OK): if os.access(repo_path, os.W_OK):
print("""\ print(
"""\
... You should upgrade soon: ... You should upgrade soon:
cp %s %s cp %s %s
""" % (WrapperPath(), repo_path), file=sys.stderr) """
% (WrapperPath(), repo_path),
file=sys.stderr,
)
else: else:
print("""\ print(
"""\
... New version is available at: %s ... New version is available at: %s
... The launcher is run from: %s ... The launcher is run from: %s
!!! The launcher is not writable. Please talk to your sysadmin or distro !!! The launcher is not writable. Please talk to your sysadmin or distro
!!! to get an update installed. !!! to get an update installed.
""" % (WrapperPath(), repo_path), file=sys.stderr) """
% (WrapperPath(), repo_path),
file=sys.stderr,
)
def _CheckRepoDir(repo_dir): def _CheckRepoDir(repo_dir):
if not repo_dir: if not repo_dir:
print('no --repo-dir argument', file=sys.stderr) print("no --repo-dir argument", file=sys.stderr)
sys.exit(1) sys.exit(1)
@ -452,10 +560,10 @@ def _PruneOptions(argv, opt):
i = 0 i = 0
while i < len(argv): while i < len(argv):
a = argv[i] a = argv[i]
if a == '--': if a == "--":
break break
if a.startswith('--'): if a.startswith("--"):
eq = a.find('=') eq = a.find("=")
if eq > 0: if eq > 0:
a = a[0:eq] a = a[0:eq]
if not opt.has_option(a): if not opt.has_option(a):
@ -466,11 +574,11 @@ def _PruneOptions(argv, opt):
class _UserAgentHandler(urllib.request.BaseHandler): class _UserAgentHandler(urllib.request.BaseHandler):
def http_request(self, req): def http_request(self, req):
req.add_header('User-Agent', user_agent.repo) req.add_header("User-Agent", user_agent.repo)
return req return req
def https_request(self, req): def https_request(self, req):
req.add_header('User-Agent', user_agent.repo) req.add_header("User-Agent", user_agent.repo)
return req return req
@ -481,7 +589,7 @@ def _AddPasswordFromUserInput(handler, msg, req):
if user is None: if user is None:
print(msg) print(msg)
try: try:
user = input('User: ') user = input("User: ")
password = getpass.getpass() password = getpass.getpass()
except KeyboardInterrupt: except KeyboardInterrupt:
return return
@ -492,23 +600,28 @@ class _BasicAuthHandler(urllib.request.HTTPBasicAuthHandler):
def http_error_401(self, req, fp, code, msg, headers): def http_error_401(self, req, fp, code, msg, headers):
_AddPasswordFromUserInput(self, msg, req) _AddPasswordFromUserInput(self, msg, req)
return urllib.request.HTTPBasicAuthHandler.http_error_401( return urllib.request.HTTPBasicAuthHandler.http_error_401(
self, req, fp, code, msg, headers) self, req, fp, code, msg, headers
)
def http_error_auth_reqed(self, authreq, host, req, headers): def http_error_auth_reqed(self, authreq, host, req, headers):
try: try:
old_add_header = req.add_header old_add_header = req.add_header
def _add_header(name, val): def _add_header(name, val):
val = val.replace('\n', '') val = val.replace("\n", "")
old_add_header(name, val) old_add_header(name, val)
req.add_header = _add_header req.add_header = _add_header
return urllib.request.AbstractBasicAuthHandler.http_error_auth_reqed( return (
self, authreq, host, req, headers) urllib.request.AbstractBasicAuthHandler.http_error_auth_reqed(
self, authreq, host, req, headers
)
)
except Exception: except Exception:
reset = getattr(self, 'reset_retry_count', None) reset = getattr(self, "reset_retry_count", None)
if reset is not None: if reset is not None:
reset() reset()
elif getattr(self, 'retried', None): elif getattr(self, "retried", None):
self.retried = 0 self.retried = 0
raise raise
@ -517,23 +630,28 @@ class _DigestAuthHandler(urllib.request.HTTPDigestAuthHandler):
def http_error_401(self, req, fp, code, msg, headers): def http_error_401(self, req, fp, code, msg, headers):
_AddPasswordFromUserInput(self, msg, req) _AddPasswordFromUserInput(self, msg, req)
return urllib.request.HTTPDigestAuthHandler.http_error_401( return urllib.request.HTTPDigestAuthHandler.http_error_401(
self, req, fp, code, msg, headers) self, req, fp, code, msg, headers
)
def http_error_auth_reqed(self, auth_header, host, req, headers): def http_error_auth_reqed(self, auth_header, host, req, headers):
try: try:
old_add_header = req.add_header old_add_header = req.add_header
def _add_header(name, val): def _add_header(name, val):
val = val.replace('\n', '') val = val.replace("\n", "")
old_add_header(name, val) old_add_header(name, val)
req.add_header = _add_header req.add_header = _add_header
return urllib.request.AbstractDigestAuthHandler.http_error_auth_reqed( return (
self, auth_header, host, req, headers) urllib.request.AbstractDigestAuthHandler.http_error_auth_reqed(
self, auth_header, host, req, headers
)
)
except Exception: except Exception:
reset = getattr(self, 'reset_retry_count', None) reset = getattr(self, "reset_retry_count", None)
if reset is not None: if reset is not None:
reset() reset()
elif getattr(self, 'retried', None): elif getattr(self, "retried", None):
self.retried = 0 self.retried = 0
raise raise
@ -546,7 +664,9 @@ class _KerberosAuthHandler(urllib.request.BaseHandler):
def http_error_401(self, req, fp, code, msg, headers): def http_error_401(self, req, fp, code, msg, headers):
host = req.get_host() host = req.get_host()
retry = self.http_error_auth_reqed('www-authenticate', host, req, headers) retry = self.http_error_auth_reqed(
"www-authenticate", host, req, headers
)
return retry return retry
def http_error_auth_reqed(self, auth_header, host, req, headers): def http_error_auth_reqed(self, auth_header, host, req, headers):
@ -555,8 +675,13 @@ class _KerberosAuthHandler(urllib.request.BaseHandler):
authdata = self._negotiate_get_authdata(auth_header, headers) authdata = self._negotiate_get_authdata(auth_header, headers)
if self.retried > 3: if self.retried > 3:
raise urllib.request.HTTPError(req.get_full_url(), 401, raise urllib.request.HTTPError(
"Negotiate auth failed", headers, None) req.get_full_url(),
401,
"Negotiate auth failed",
headers,
None,
)
else: else:
self.retried += 1 self.retried += 1
@ -564,7 +689,7 @@ class _KerberosAuthHandler(urllib.request.BaseHandler):
if neghdr is None: if neghdr is None:
return None return None
req.add_unredirected_header('Authorization', neghdr) req.add_unredirected_header("Authorization", neghdr)
response = self.parent.open(req) response = self.parent.open(req)
srvauth = self._negotiate_get_authdata(auth_header, response.info()) srvauth = self._negotiate_get_authdata(auth_header, response.info())
@ -627,8 +752,8 @@ def init_http():
n = netrc.netrc() n = netrc.netrc()
for host in n.hosts: for host in n.hosts:
p = n.hosts[host] p = n.hosts[host]
mgr.add_password(p[1], 'http://%s/' % host, p[0], p[2]) mgr.add_password(p[1], "http://%s/" % host, p[0], p[2])
mgr.add_password(p[1], 'https://%s/' % host, p[0], p[2]) mgr.add_password(p[1], "https://%s/" % host, p[0], p[2])
except netrc.NetrcParseError: except netrc.NetrcParseError:
pass pass
except IOError: except IOError:
@ -638,10 +763,12 @@ def init_http():
if kerberos: if kerberos:
handlers.append(_KerberosAuthHandler()) handlers.append(_KerberosAuthHandler())
if 'http_proxy' in os.environ: if "http_proxy" in os.environ:
url = os.environ['http_proxy'] url = os.environ["http_proxy"]
handlers.append(urllib.request.ProxyHandler({'http': url, 'https': url})) handlers.append(
if 'REPO_CURL_VERBOSE' in os.environ: urllib.request.ProxyHandler({"http": url, "https": url})
)
if "REPO_CURL_VERBOSE" in os.environ:
handlers.append(urllib.request.HTTPHandler(debuglevel=1)) handlers.append(urllib.request.HTTPHandler(debuglevel=1))
handlers.append(urllib.request.HTTPSHandler(debuglevel=1)) handlers.append(urllib.request.HTTPSHandler(debuglevel=1))
urllib.request.install_opener(urllib.request.build_opener(*handlers)) urllib.request.install_opener(urllib.request.build_opener(*handlers))
@ -651,12 +778,17 @@ def _Main(argv):
result = 0 result = 0
opt = optparse.OptionParser(usage="repo wrapperinfo -- ...") opt = optparse.OptionParser(usage="repo wrapperinfo -- ...")
opt.add_option("--repo-dir", dest="repodir", opt.add_option("--repo-dir", dest="repodir", help="path to .repo/")
help="path to .repo/") opt.add_option(
opt.add_option("--wrapper-version", dest="wrapper_version", "--wrapper-version",
help="version of the wrapper script") dest="wrapper_version",
opt.add_option("--wrapper-path", dest="wrapper_path", help="version of the wrapper script",
help="location of the wrapper script") )
opt.add_option(
"--wrapper-path",
dest="wrapper_path",
help="location of the wrapper script",
)
_PruneOptions(argv, opt) _PruneOptions(argv, opt)
opt, argv = opt.parse_args(argv) opt, argv = opt.parse_args(argv)
@ -680,10 +812,10 @@ def _Main(argv):
result = repo._Run(name, gopts, argv) or 0 result = repo._Run(name, gopts, argv) or 0
except KeyboardInterrupt: except KeyboardInterrupt:
print('aborted by user', file=sys.stderr) print("aborted by user", file=sys.stderr)
result = 1 result = 1
except ManifestParseError as mpe: except ManifestParseError as mpe:
print('fatal: %s' % mpe, file=sys.stderr) print("fatal: %s" % mpe, file=sys.stderr)
result = 1 result = 1
except RepoChangedException as rce: except RepoChangedException as rce:
# If repo changed, re-exec ourselves. # If repo changed, re-exec ourselves.
@ -693,13 +825,13 @@ def _Main(argv):
try: try:
os.execv(sys.executable, [__file__] + argv) os.execv(sys.executable, [__file__] + argv)
except OSError as e: except OSError as e:
print('fatal: cannot restart repo after upgrade', file=sys.stderr) print("fatal: cannot restart repo after upgrade", file=sys.stderr)
print('fatal: %s' % e, file=sys.stderr) print("fatal: %s" % e, file=sys.stderr)
result = 128 result = 128
TerminatePager() TerminatePager()
sys.exit(result) sys.exit(result)
if __name__ == '__main__': if __name__ == "__main__":
_Main(sys.argv[1:]) _Main(sys.argv[1:])

File diff suppressed because it is too large Load Diff

View File

@ -29,7 +29,7 @@ def RunPager(globalConfig):
if not os.isatty(0) or not os.isatty(1): if not os.isatty(0) or not os.isatty(1):
return return
pager = _SelectPager(globalConfig) pager = _SelectPager(globalConfig)
if pager == '' or pager == 'cat': if pager == "" or pager == "cat":
return return
if platform_utils.isWindows(): if platform_utils.isWindows():
@ -46,8 +46,8 @@ def TerminatePager():
pager_process.stdin.close() pager_process.stdin.close()
pager_process.wait() pager_process.wait()
pager_process = None pager_process = None
# Restore initial stdout/err in case there is more output in this process # Restore initial stdout/err in case there is more output in this
# after shutting down the pager process # process after shutting down the pager process.
sys.stdout = old_stdout sys.stdout = old_stdout
sys.stderr = old_stderr sys.stderr = old_stderr
@ -55,10 +55,11 @@ def TerminatePager():
def _PipePager(pager): def _PipePager(pager):
global pager_process, old_stdout, old_stderr global pager_process, old_stdout, old_stderr
assert pager_process is None, "Only one active pager process at a time" assert pager_process is None, "Only one active pager process at a time"
# Create pager process, piping stdout/err into its stdin # Create pager process, piping stdout/err into its stdin.
try: try:
pager_process = subprocess.Popen([pager], stdin=subprocess.PIPE, stdout=sys.stdout, pager_process = subprocess.Popen(
stderr=sys.stderr) [pager], stdin=subprocess.PIPE, stdout=sys.stdout, stderr=sys.stderr
)
except FileNotFoundError: except FileNotFoundError:
sys.exit(f'fatal: cannot start pager "{pager}"') sys.exit(f'fatal: cannot start pager "{pager}"')
old_stdout = sys.stdout old_stdout = sys.stdout
@ -72,7 +73,6 @@ def _ForkPager(pager):
# This process turns into the pager; a child it forks will # This process turns into the pager; a child it forks will
# do the real processing and output back to the pager. This # do the real processing and output back to the pager. This
# is necessary to keep the pager in control of the tty. # is necessary to keep the pager in control of the tty.
#
try: try:
r, w = os.pipe() r, w = os.pipe()
pid = os.fork() pid = os.fork()
@ -96,32 +96,31 @@ def _ForkPager(pager):
def _SelectPager(globalConfig): def _SelectPager(globalConfig):
try: try:
return os.environ['GIT_PAGER'] return os.environ["GIT_PAGER"]
except KeyError: except KeyError:
pass pass
pager = globalConfig.GetString('core.pager') pager = globalConfig.GetString("core.pager")
if pager: if pager:
return pager return pager
try: try:
return os.environ['PAGER'] return os.environ["PAGER"]
except KeyError: except KeyError:
pass pass
return 'less' return "less"
def _BecomePager(pager): def _BecomePager(pager):
# Delaying execution of the pager until we have output # Delaying execution of the pager until we have output
# ready works around a long-standing bug in popularly # ready works around a long-standing bug in popularly
# available versions of 'less', a better 'more'. # available versions of 'less', a better 'more'.
#
_a, _b, _c = select.select([0], [], [0]) _a, _b, _c = select.select([0], [], [0])
os.environ['LESS'] = 'FRSX' os.environ["LESS"] = "FRSX"
try: try:
os.execvp(pager, [pager]) os.execvp(pager, [pager])
except OSError: except OSError:
os.execv('/bin/sh', ['sh', '-c', pager]) os.execv("/bin/sh", ["sh", "-c", pager])

View File

@ -20,7 +20,7 @@ import stat
def isWindows(): def isWindows():
""" Returns True when running with the native port of Python for Windows, """Returns True when running with the native port of Python for Windows,
False when running on any other platform (including the Cygwin port of False when running on any other platform (including the Cygwin port of
Python). Python).
""" """
@ -30,18 +30,24 @@ def isWindows():
def symlink(source, link_name): def symlink(source, link_name):
"""Creates a symbolic link pointing to source named link_name. """Creates a symbolic link pointing to source named link_name.
Note: On Windows, source must exist on disk, as the implementation needs Note: On Windows, source must exist on disk, as the implementation needs
to know whether to create a "File" or a "Directory" symbolic link. to know whether to create a "File" or a "Directory" symbolic link.
""" """
if isWindows(): if isWindows():
import platform_utils_win32 import platform_utils_win32
source = _validate_winpath(source) source = _validate_winpath(source)
link_name = _validate_winpath(link_name) link_name = _validate_winpath(link_name)
target = os.path.join(os.path.dirname(link_name), source) target = os.path.join(os.path.dirname(link_name), source)
if isdir(target): if isdir(target):
platform_utils_win32.create_dirsymlink(_makelongpath(source), link_name) platform_utils_win32.create_dirsymlink(
_makelongpath(source), link_name
)
else: else:
platform_utils_win32.create_filesymlink(_makelongpath(source), link_name) platform_utils_win32.create_filesymlink(
_makelongpath(source), link_name
)
else: else:
return os.symlink(source, link_name) return os.symlink(source, link_name)
@ -50,8 +56,10 @@ def _validate_winpath(path):
path = os.path.normpath(path) path = os.path.normpath(path)
if _winpath_is_valid(path): if _winpath_is_valid(path):
return path return path
raise ValueError("Path \"%s\" must be a relative path or an absolute " raise ValueError(
"path starting with a drive letter".format(path)) 'Path "{}" must be a relative path or an absolute '
"path starting with a drive letter".format(path)
)
def _winpath_is_valid(path): def _winpath_is_valid(path):
@ -77,16 +85,17 @@ def _makelongpath(path):
MAX_PATH limit. MAX_PATH limit.
""" """
if isWindows(): if isWindows():
# Note: MAX_PATH is 260, but, for directories, the maximum value is actually 246. # Note: MAX_PATH is 260, but, for directories, the maximum value is
# actually 246.
if len(path) < 246: if len(path) < 246:
return path return path
if path.startswith(u"\\\\?\\"): if path.startswith("\\\\?\\"):
return path return path
if not os.path.isabs(path): if not os.path.isabs(path):
return path return path
# Append prefix and ensure unicode so that the special longpath syntax # Append prefix and ensure unicode so that the special longpath syntax
# is supported by underlying Win32 API calls # is supported by underlying Win32 API calls
return u"\\\\?\\" + os.path.normpath(path) return "\\\\?\\" + os.path.normpath(path)
else: else:
return path return path
@ -94,7 +103,8 @@ def _makelongpath(path):
def rmtree(path, ignore_errors=False): def rmtree(path, ignore_errors=False):
"""shutil.rmtree(path) wrapper with support for long paths on Windows. """shutil.rmtree(path) wrapper with support for long paths on Windows.
Availability: Unix, Windows.""" Availability: Unix, Windows.
"""
onerror = None onerror = None
if isWindows(): if isWindows():
path = _makelongpath(path) path = _makelongpath(path)
@ -103,7 +113,7 @@ def rmtree(path, ignore_errors=False):
def handle_rmtree_error(function, path, excinfo): def handle_rmtree_error(function, path, excinfo):
# Allow deleting read-only files # Allow deleting read-only files.
os.chmod(path, stat.S_IWRITE) os.chmod(path, stat.S_IWRITE)
function(path) function(path)
@ -111,7 +121,8 @@ def handle_rmtree_error(function, path, excinfo):
def rename(src, dst): def rename(src, dst):
"""os.rename(src, dst) wrapper with support for long paths on Windows. """os.rename(src, dst) wrapper with support for long paths on Windows.
Availability: Unix, Windows.""" Availability: Unix, Windows.
"""
if isWindows(): if isWindows():
# On Windows, rename fails if destination exists, see # On Windows, rename fails if destination exists, see
# https://docs.python.org/2/library/os.html#os.rename # https://docs.python.org/2/library/os.html#os.rename
@ -132,7 +143,8 @@ def remove(path, missing_ok=False):
allows deleting read-only files on Windows, with support for long paths and allows deleting read-only files on Windows, with support for long paths and
for deleting directory symbolic links. for deleting directory symbolic links.
Availability: Unix, Windows.""" Availability: Unix, Windows.
"""
longpath = _makelongpath(path) if isWindows() else path longpath = _makelongpath(path) if isWindows() else path
try: try:
os.remove(longpath) os.remove(longpath)
@ -181,7 +193,9 @@ def _walk_windows_impl(top, topdown, onerror, followlinks):
for name in dirs: for name in dirs:
new_path = os.path.join(top, name) new_path = os.path.join(top, name)
if followlinks or not islink(new_path): if followlinks or not islink(new_path):
for x in _walk_windows_impl(new_path, topdown, onerror, followlinks): for x in _walk_windows_impl(
new_path, topdown, onerror, followlinks
):
yield x yield x
if not topdown: if not topdown:
yield top, dirs, nondirs yield top, dirs, nondirs
@ -218,6 +232,7 @@ def islink(path):
""" """
if isWindows(): if isWindows():
import platform_utils_win32 import platform_utils_win32
return platform_utils_win32.islink(_makelongpath(path)) return platform_utils_win32.islink(_makelongpath(path))
else: else:
return os.path.islink(path) return os.path.islink(path)
@ -233,6 +248,7 @@ def readlink(path):
""" """
if isWindows(): if isWindows():
import platform_utils_win32 import platform_utils_win32
return platform_utils_win32.readlink(_makelongpath(path)) return platform_utils_win32.readlink(_makelongpath(path))
else: else:
return os.readlink(path) return os.readlink(path)
@ -250,10 +266,12 @@ def realpath(path):
for c in range(0, 100): # Avoid cycles for c in range(0, 100): # Avoid cycles
if islink(current_path): if islink(current_path):
target = readlink(current_path) target = readlink(current_path)
current_path = os.path.join(os.path.dirname(current_path), target) current_path = os.path.join(
os.path.dirname(current_path), target
)
else: else:
basename = os.path.basename(current_path) basename = os.path.basename(current_path)
if basename == '': if basename == "":
path_tail.append(current_path) path_tail.append(current_path)
break break
path_tail.append(basename) path_tail.append(basename)

View File

@ -19,7 +19,7 @@ from ctypes import c_buffer, c_ubyte, Structure, Union, byref
from ctypes.wintypes import BOOL, BOOLEAN, LPCWSTR, DWORD, HANDLE from ctypes.wintypes import BOOL, BOOLEAN, LPCWSTR, DWORD, HANDLE
from ctypes.wintypes import WCHAR, USHORT, LPVOID, ULONG, LPDWORD from ctypes.wintypes import WCHAR, USHORT, LPVOID, ULONG, LPDWORD
kernel32 = WinDLL('kernel32', use_last_error=True) kernel32 = WinDLL("kernel32", use_last_error=True)
UCHAR = c_ubyte UCHAR = c_ubyte
@ -31,14 +31,17 @@ ERROR_PRIVILEGE_NOT_HELD = 1314
# Win32 API entry points # Win32 API entry points
CreateSymbolicLinkW = kernel32.CreateSymbolicLinkW CreateSymbolicLinkW = kernel32.CreateSymbolicLinkW
CreateSymbolicLinkW.restype = BOOLEAN CreateSymbolicLinkW.restype = BOOLEAN
CreateSymbolicLinkW.argtypes = (LPCWSTR, # lpSymlinkFileName In CreateSymbolicLinkW.argtypes = (
LPCWSTR, # lpSymlinkFileName In
LPCWSTR, # lpTargetFileName In LPCWSTR, # lpTargetFileName In
DWORD) # dwFlags In DWORD, # dwFlags In
)
# Symbolic link creation flags # Symbolic link creation flags
SYMBOLIC_LINK_FLAG_FILE = 0x00 SYMBOLIC_LINK_FLAG_FILE = 0x00
SYMBOLIC_LINK_FLAG_DIRECTORY = 0x01 SYMBOLIC_LINK_FLAG_DIRECTORY = 0x01
# symlink support for CreateSymbolicLink() starting with Windows 10 (1703, v10.0.14972) # symlink support for CreateSymbolicLink() starting with Windows 10 (1703,
# v10.0.14972)
SYMBOLIC_LINK_FLAG_ALLOW_UNPRIVILEGED_CREATE = 0x02 SYMBOLIC_LINK_FLAG_ALLOW_UNPRIVILEGED_CREATE = 0x02
GetFileAttributesW = kernel32.GetFileAttributesW GetFileAttributesW = kernel32.GetFileAttributesW
@ -50,13 +53,15 @@ FILE_ATTRIBUTE_REPARSE_POINT = 0x00400
CreateFileW = kernel32.CreateFileW CreateFileW = kernel32.CreateFileW
CreateFileW.restype = HANDLE CreateFileW.restype = HANDLE
CreateFileW.argtypes = (LPCWSTR, # lpFileName In CreateFileW.argtypes = (
LPCWSTR, # lpFileName In
DWORD, # dwDesiredAccess In DWORD, # dwDesiredAccess In
DWORD, # dwShareMode In DWORD, # dwShareMode In
LPVOID, # lpSecurityAttributes In_opt LPVOID, # lpSecurityAttributes In_opt
DWORD, # dwCreationDisposition In DWORD, # dwCreationDisposition In
DWORD, # dwFlagsAndAttributes In DWORD, # dwFlagsAndAttributes In
HANDLE) # hTemplateFile In_opt HANDLE, # hTemplateFile In_opt
)
CloseHandle = kernel32.CloseHandle CloseHandle = kernel32.CloseHandle
CloseHandle.restype = BOOL CloseHandle.restype = BOOL
@ -69,14 +74,16 @@ FILE_FLAG_OPEN_REPARSE_POINT = 0x00200000
DeviceIoControl = kernel32.DeviceIoControl DeviceIoControl = kernel32.DeviceIoControl
DeviceIoControl.restype = BOOL DeviceIoControl.restype = BOOL
DeviceIoControl.argtypes = (HANDLE, # hDevice In DeviceIoControl.argtypes = (
HANDLE, # hDevice In
DWORD, # dwIoControlCode In DWORD, # dwIoControlCode In
LPVOID, # lpInBuffer In_opt LPVOID, # lpInBuffer In_opt
DWORD, # nInBufferSize In DWORD, # nInBufferSize In
LPVOID, # lpOutBuffer Out_opt LPVOID, # lpOutBuffer Out_opt
DWORD, # nOutBufferSize In DWORD, # nOutBufferSize In
LPDWORD, # lpBytesReturned Out_opt LPDWORD, # lpBytesReturned Out_opt
LPVOID) # lpOverlapped Inout_opt LPVOID, # lpOverlapped Inout_opt
)
# Device I/O control flags and options # Device I/O control flags and options
FSCTL_GET_REPARSE_POINT = 0x000900A8 FSCTL_GET_REPARSE_POINT = 0x000900A8
@ -86,16 +93,18 @@ MAXIMUM_REPARSE_DATA_BUFFER_SIZE = 0x4000
class GENERIC_REPARSE_BUFFER(Structure): class GENERIC_REPARSE_BUFFER(Structure):
_fields_ = (('DataBuffer', UCHAR * 1),) _fields_ = (("DataBuffer", UCHAR * 1),)
class SYMBOLIC_LINK_REPARSE_BUFFER(Structure): class SYMBOLIC_LINK_REPARSE_BUFFER(Structure):
_fields_ = (('SubstituteNameOffset', USHORT), _fields_ = (
('SubstituteNameLength', USHORT), ("SubstituteNameOffset", USHORT),
('PrintNameOffset', USHORT), ("SubstituteNameLength", USHORT),
('PrintNameLength', USHORT), ("PrintNameOffset", USHORT),
('Flags', ULONG), ("PrintNameLength", USHORT),
('PathBuffer', WCHAR * 1)) ("Flags", ULONG),
("PathBuffer", WCHAR * 1),
)
@property @property
def PrintName(self): def PrintName(self):
@ -105,11 +114,13 @@ class SYMBOLIC_LINK_REPARSE_BUFFER(Structure):
class MOUNT_POINT_REPARSE_BUFFER(Structure): class MOUNT_POINT_REPARSE_BUFFER(Structure):
_fields_ = (('SubstituteNameOffset', USHORT), _fields_ = (
('SubstituteNameLength', USHORT), ("SubstituteNameOffset", USHORT),
('PrintNameOffset', USHORT), ("SubstituteNameLength", USHORT),
('PrintNameLength', USHORT), ("PrintNameOffset", USHORT),
('PathBuffer', WCHAR * 1)) ("PrintNameLength", USHORT),
("PathBuffer", WCHAR * 1),
)
@property @property
def PrintName(self): def PrintName(self):
@ -120,14 +131,19 @@ class MOUNT_POINT_REPARSE_BUFFER(Structure):
class REPARSE_DATA_BUFFER(Structure): class REPARSE_DATA_BUFFER(Structure):
class REPARSE_BUFFER(Union): class REPARSE_BUFFER(Union):
_fields_ = (('SymbolicLinkReparseBuffer', SYMBOLIC_LINK_REPARSE_BUFFER), _fields_ = (
('MountPointReparseBuffer', MOUNT_POINT_REPARSE_BUFFER), ("SymbolicLinkReparseBuffer", SYMBOLIC_LINK_REPARSE_BUFFER),
('GenericReparseBuffer', GENERIC_REPARSE_BUFFER)) ("MountPointReparseBuffer", MOUNT_POINT_REPARSE_BUFFER),
_fields_ = (('ReparseTag', ULONG), ("GenericReparseBuffer", GENERIC_REPARSE_BUFFER),
('ReparseDataLength', USHORT), )
('Reserved', USHORT),
('ReparseBuffer', REPARSE_BUFFER)) _fields_ = (
_anonymous_ = ('ReparseBuffer',) ("ReparseTag", ULONG),
("ReparseDataLength", USHORT),
("Reserved", USHORT),
("ReparseBuffer", REPARSE_BUFFER),
)
_anonymous_ = ("ReparseBuffer",)
def create_filesymlink(source, link_name): def create_filesymlink(source, link_name):
@ -136,25 +152,27 @@ def create_filesymlink(source, link_name):
def create_dirsymlink(source, link_name): def create_dirsymlink(source, link_name):
"""Creates a Windows directory symbolic link source pointing to link_name. """Creates a Windows directory symbolic link source pointing to link_name.""" # noqa: E501
"""
_create_symlink(source, link_name, SYMBOLIC_LINK_FLAG_DIRECTORY) _create_symlink(source, link_name, SYMBOLIC_LINK_FLAG_DIRECTORY)
def _create_symlink(source, link_name, dwFlags): def _create_symlink(source, link_name, dwFlags):
if not CreateSymbolicLinkW(link_name, source, if not CreateSymbolicLinkW(
dwFlags | SYMBOLIC_LINK_FLAG_ALLOW_UNPRIVILEGED_CREATE): link_name,
# See https://github.com/golang/go/pull/24307/files#diff-b87bc12e4da2497308f9ef746086e4f0 source,
# "the unprivileged create flag is unsupported below Windows 10 (1703, v10.0.14972). dwFlags | SYMBOLIC_LINK_FLAG_ALLOW_UNPRIVILEGED_CREATE,
# retry without it." ):
# See https://github.com/golang/go/pull/24307/files#diff-b87bc12e4da2497308f9ef746086e4f0 # noqa: E501
# "the unprivileged create flag is unsupported below Windows 10 (1703,
# v10.0.14972). retry without it."
if not CreateSymbolicLinkW(link_name, source, dwFlags): if not CreateSymbolicLinkW(link_name, source, dwFlags):
code = get_last_error() code = get_last_error()
error_desc = FormatError(code).strip() error_desc = FormatError(code).strip()
if code == ERROR_PRIVILEGE_NOT_HELD: if code == ERROR_PRIVILEGE_NOT_HELD:
raise OSError(errno.EPERM, error_desc, link_name) raise OSError(errno.EPERM, error_desc, link_name)
_raise_winerror( _raise_winerror(
code, code, 'Error creating symbolic link "{}"'.format(link_name)
'Error creating symbolic link \"%s\"'.format(link_name)) )
def islink(path): def islink(path):
@ -165,45 +183,48 @@ def islink(path):
def readlink(path): def readlink(path):
reparse_point_handle = CreateFileW(path, reparse_point_handle = CreateFileW(
path,
0, 0,
0, 0,
None, None,
OPEN_EXISTING, OPEN_EXISTING,
FILE_FLAG_OPEN_REPARSE_POINT | FILE_FLAG_OPEN_REPARSE_POINT | FILE_FLAG_BACKUP_SEMANTICS,
FILE_FLAG_BACKUP_SEMANTICS, None,
None) )
if reparse_point_handle == INVALID_HANDLE_VALUE: if reparse_point_handle == INVALID_HANDLE_VALUE:
_raise_winerror( _raise_winerror(
get_last_error(), get_last_error(), 'Error opening symbolic link "{}"'.format(path)
'Error opening symbolic link \"%s\"'.format(path)) )
target_buffer = c_buffer(MAXIMUM_REPARSE_DATA_BUFFER_SIZE) target_buffer = c_buffer(MAXIMUM_REPARSE_DATA_BUFFER_SIZE)
n_bytes_returned = DWORD() n_bytes_returned = DWORD()
io_result = DeviceIoControl(reparse_point_handle, io_result = DeviceIoControl(
reparse_point_handle,
FSCTL_GET_REPARSE_POINT, FSCTL_GET_REPARSE_POINT,
None, None,
0, 0,
target_buffer, target_buffer,
len(target_buffer), len(target_buffer),
byref(n_bytes_returned), byref(n_bytes_returned),
None) None,
)
CloseHandle(reparse_point_handle) CloseHandle(reparse_point_handle)
if not io_result: if not io_result:
_raise_winerror( _raise_winerror(
get_last_error(), get_last_error(), 'Error reading symbolic link "{}"'.format(path)
'Error reading symbolic link \"%s\"'.format(path)) )
rdb = REPARSE_DATA_BUFFER.from_buffer(target_buffer) rdb = REPARSE_DATA_BUFFER.from_buffer(target_buffer)
if rdb.ReparseTag == IO_REPARSE_TAG_SYMLINK: if rdb.ReparseTag == IO_REPARSE_TAG_SYMLINK:
return rdb.SymbolicLinkReparseBuffer.PrintName return rdb.SymbolicLinkReparseBuffer.PrintName
elif rdb.ReparseTag == IO_REPARSE_TAG_MOUNT_POINT: elif rdb.ReparseTag == IO_REPARSE_TAG_MOUNT_POINT:
return rdb.MountPointReparseBuffer.PrintName return rdb.MountPointReparseBuffer.PrintName
# Unsupported reparse point type # Unsupported reparse point type.
_raise_winerror( _raise_winerror(
ERROR_NOT_SUPPORTED, ERROR_NOT_SUPPORTED, 'Error reading symbolic link "{}"'.format(path)
'Error reading symbolic link \"%s\"'.format(path)) )
def _raise_winerror(code, error_desc): def _raise_winerror(code, error_desc):
win_error_desc = FormatError(code).strip() win_error_desc = FormatError(code).strip()
error_desc = "%s: %s".format(error_desc, win_error_desc) error_desc = "{0}: {1}".format(error_desc, win_error_desc)
raise WinError(code, error_desc) raise WinError(code, error_desc)

View File

@ -22,12 +22,12 @@ _NOT_TTY = not os.isatty(2)
# This will erase all content in the current line (wherever the cursor is). # This will erase all content in the current line (wherever the cursor is).
# It does not move the cursor, so this is usually followed by \r to move to # It does not move the cursor, so this is usually followed by \r to move to
# column 0. # column 0.
CSI_ERASE_LINE = '\x1b[2K' CSI_ERASE_LINE = "\x1b[2K"
# This will erase all content in the current line after the cursor. This is # This will erase all content in the current line after the cursor. This is
# useful for partial updates & progress messages as the terminal can display # useful for partial updates & progress messages as the terminal can display
# it better. # it better.
CSI_ERASE_LINE_AFTER = '\x1b[K' CSI_ERASE_LINE_AFTER = "\x1b[K"
def duration_str(total): def duration_str(total):
@ -38,17 +38,24 @@ def duration_str(total):
""" """
hours, rem = divmod(total, 3600) hours, rem = divmod(total, 3600)
mins, secs = divmod(rem, 60) mins, secs = divmod(rem, 60)
ret = '%.3fs' % (secs,) ret = "%.3fs" % (secs,)
if mins: if mins:
ret = '%im%s' % (mins, ret) ret = "%im%s" % (mins, ret)
if hours: if hours:
ret = '%ih%s' % (hours, ret) ret = "%ih%s" % (hours, ret)
return ret return ret
class Progress(object): class Progress(object):
def __init__(self, title, total=0, units='', print_newline=False, delay=True, def __init__(
quiet=False): self,
title,
total=0,
units="",
print_newline=False,
delay=True,
quiet=False,
):
self._title = title self._title = title
self._total = total self._total = total
self._done = 0 self._done = 0
@ -71,13 +78,13 @@ class Progress(object):
self._active += 1 self._active += 1
if not self._show_jobs: if not self._show_jobs:
self._show_jobs = self._active > 1 self._show_jobs = self._active > 1
self.update(inc=0, msg='started ' + name) self.update(inc=0, msg="started " + name)
def finish(self, name): def finish(self, name):
self.update(msg='finished ' + name) self.update(msg="finished " + name)
self._active -= 1 self._active -= 1
def update(self, inc=1, msg=''): def update(self, inc=1, msg=""):
self._done += inc self._done += inc
if _NOT_TTY or IsTraceToStderr(): if _NOT_TTY or IsTraceToStderr():
@ -90,26 +97,35 @@ class Progress(object):
return return
if self._total <= 0: if self._total <= 0:
sys.stderr.write('\r%s: %d,%s' % ( sys.stderr.write(
self._title, "\r%s: %d,%s" % (self._title, self._done, CSI_ERASE_LINE_AFTER)
self._done, )
CSI_ERASE_LINE_AFTER))
sys.stderr.flush() sys.stderr.flush()
else: else:
p = (100 * self._done) / self._total p = (100 * self._done) / self._total
if self._show_jobs: if self._show_jobs:
jobs = '[%d job%s] ' % (self._active, 's' if self._active > 1 else '') jobs = "[%d job%s] " % (
self._active,
"s" if self._active > 1 else "",
)
else: else:
jobs = '' jobs = ""
sys.stderr.write('\r%s: %2d%% %s(%d%s/%d%s)%s%s%s%s' % ( sys.stderr.write(
"\r%s: %2d%% %s(%d%s/%d%s)%s%s%s%s"
% (
self._title, self._title,
p, p,
jobs, jobs,
self._done, self._units, self._done,
self._total, self._units, self._units,
' ' if msg else '', msg, self._total,
self._units,
" " if msg else "",
msg,
CSI_ERASE_LINE_AFTER, CSI_ERASE_LINE_AFTER,
'\n' if self._print_newline else '')) "\n" if self._print_newline else "",
)
)
sys.stderr.flush() sys.stderr.flush()
def end(self): def end(self):
@ -118,19 +134,24 @@ class Progress(object):
duration = duration_str(time() - self._start) duration = duration_str(time() - self._start)
if self._total <= 0: if self._total <= 0:
sys.stderr.write('\r%s: %d, done in %s%s\n' % ( sys.stderr.write(
self._title, "\r%s: %d, done in %s%s\n"
self._done, % (self._title, self._done, duration, CSI_ERASE_LINE_AFTER)
duration, )
CSI_ERASE_LINE_AFTER))
sys.stderr.flush() sys.stderr.flush()
else: else:
p = (100 * self._done) / self._total p = (100 * self._done) / self._total
sys.stderr.write('\r%s: %3d%% (%d%s/%d%s), done in %s%s\n' % ( sys.stderr.write(
"\r%s: %3d%% (%d%s/%d%s), done in %s%s\n"
% (
self._title, self._title,
p, p,
self._done, self._units, self._done,
self._total, self._units, self._units,
self._total,
self._units,
duration, duration,
CSI_ERASE_LINE_AFTER)) CSI_ERASE_LINE_AFTER,
)
)
sys.stderr.flush() sys.stderr.flush()

2424
project.py

File diff suppressed because it is too large Load Diff

View File

@ -29,42 +29,55 @@ import util
def sign(opts): def sign(opts):
"""Sign the launcher!""" """Sign the launcher!"""
output = '' output = ""
for key in opts.keys: for key in opts.keys:
# We use ! at the end of the key so that gpg uses this specific key. # We use ! at the end of the key so that gpg uses this specific key.
# Otherwise it uses the key as a lookup into the overall key and uses the # Otherwise it uses the key as a lookup into the overall key and uses
# default signing key. i.e. It will see that KEYID_RSA is a subkey of # the default signing key. i.e. It will see that KEYID_RSA is a subkey
# another key, and use the primary key to sign instead of the subkey. # of another key, and use the primary key to sign instead of the subkey.
cmd = ['gpg', '--homedir', opts.gpgdir, '-u', f'{key}!', '--batch', '--yes', cmd = [
'--armor', '--detach-sign', '--output', '-', opts.launcher] "gpg",
ret = util.run(opts, cmd, encoding='utf-8', stdout=subprocess.PIPE) "--homedir",
opts.gpgdir,
"-u",
f"{key}!",
"--batch",
"--yes",
"--armor",
"--detach-sign",
"--output",
"-",
opts.launcher,
]
ret = util.run(opts, cmd, encoding="utf-8", stdout=subprocess.PIPE)
output += ret.stdout output += ret.stdout
# Save the combined signatures into one file. # Save the combined signatures into one file.
with open(f'{opts.launcher}.asc', 'w', encoding='utf-8') as fp: with open(f"{opts.launcher}.asc", "w", encoding="utf-8") as fp:
fp.write(output) fp.write(output)
def check(opts): def check(opts):
"""Check the signature.""" """Check the signature."""
util.run(opts, ['gpg', '--verify', f'{opts.launcher}.asc']) util.run(opts, ["gpg", "--verify", f"{opts.launcher}.asc"])
def get_version(opts): def get_version(opts):
"""Get the version from |launcher|.""" """Get the version from |launcher|."""
# Make sure we don't search $PATH when signing the "repo" file in the cwd. # Make sure we don't search $PATH when signing the "repo" file in the cwd.
launcher = os.path.join('.', opts.launcher) launcher = os.path.join(".", opts.launcher)
cmd = [launcher, '--version'] cmd = [launcher, "--version"]
ret = util.run(opts, cmd, encoding='utf-8', stdout=subprocess.PIPE) ret = util.run(opts, cmd, encoding="utf-8", stdout=subprocess.PIPE)
m = re.search(r'repo launcher version ([0-9.]+)', ret.stdout) m = re.search(r"repo launcher version ([0-9.]+)", ret.stdout)
if not m: if not m:
sys.exit(f'{opts.launcher}: unable to detect repo version') sys.exit(f"{opts.launcher}: unable to detect repo version")
return m.group(1) return m.group(1)
def postmsg(opts, version): def postmsg(opts, version):
"""Helpful info to show at the end for release manager.""" """Helpful info to show at the end for release manager."""
print(f""" print(
f"""
Repo launcher bucket: Repo launcher bucket:
gs://git-repo-downloads/ gs://git-repo-downloads/
@ -81,24 +94,39 @@ NB: If a rollback is necessary, the GS bucket archives old versions, and may be
gsutil ls -la gs://git-repo-downloads/repo gs://git-repo-downloads/repo.asc gsutil ls -la gs://git-repo-downloads/repo gs://git-repo-downloads/repo.asc
gsutil cp -a public-read gs://git-repo-downloads/repo#<unique id> gs://git-repo-downloads/repo gsutil cp -a public-read gs://git-repo-downloads/repo#<unique id> gs://git-repo-downloads/repo
gsutil cp -a public-read gs://git-repo-downloads/repo.asc#<unique id> gs://git-repo-downloads/repo.asc gsutil cp -a public-read gs://git-repo-downloads/repo.asc#<unique id> gs://git-repo-downloads/repo.asc
""") """ # noqa: E501
)
def get_parser(): def get_parser():
"""Get a CLI parser.""" """Get a CLI parser."""
parser = argparse.ArgumentParser(description=__doc__) parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument('-n', '--dry-run', parser.add_argument(
dest='dryrun', action='store_true', "-n",
help='show everything that would be done') "--dry-run",
parser.add_argument('--gpgdir', dest="dryrun",
default=os.path.join(util.HOMEDIR, '.gnupg', 'repo'), action="store_true",
help='path to dedicated gpg dir with release keys ' help="show everything that would be done",
'(default: ~/.gnupg/repo/)') )
parser.add_argument('--keyid', dest='keys', default=[], action='append', parser.add_argument(
help='alternative signing keys to use') "--gpgdir",
parser.add_argument('launcher', default=os.path.join(util.HOMEDIR, ".gnupg", "repo"),
default=os.path.join(util.TOPDIR, 'repo'), nargs='?', help="path to dedicated gpg dir with release keys "
help='the launcher script to sign') "(default: ~/.gnupg/repo/)",
)
parser.add_argument(
"--keyid",
dest="keys",
default=[],
action="append",
help="alternative signing keys to use",
)
parser.add_argument(
"launcher",
default=os.path.join(util.TOPDIR, "repo"),
nargs="?",
help="the launcher script to sign",
)
return parser return parser
@ -108,18 +136,20 @@ def main(argv):
opts = parser.parse_args(argv) opts = parser.parse_args(argv)
if not os.path.exists(opts.gpgdir): if not os.path.exists(opts.gpgdir):
parser.error(f'--gpgdir does not exist: {opts.gpgdir}') parser.error(f"--gpgdir does not exist: {opts.gpgdir}")
if not os.path.exists(opts.launcher): if not os.path.exists(opts.launcher):
parser.error(f'launcher does not exist: {opts.launcher}') parser.error(f"launcher does not exist: {opts.launcher}")
opts.launcher = os.path.relpath(opts.launcher) opts.launcher = os.path.relpath(opts.launcher)
print(f'Signing "{opts.launcher}" launcher script and saving to ' print(
f'"{opts.launcher}.asc"') f'Signing "{opts.launcher}" launcher script and saving to '
f'"{opts.launcher}.asc"'
)
if opts.keys: if opts.keys:
print(f'Using custom keys to sign: {" ".join(opts.keys)}') print(f'Using custom keys to sign: {" ".join(opts.keys)}')
else: else:
print('Using official Repo release keys to sign') print("Using official Repo release keys to sign")
opts.keys = [util.KEYID_DSA, util.KEYID_RSA, util.KEYID_ECC] opts.keys = [util.KEYID_DSA, util.KEYID_RSA, util.KEYID_ECC]
util.import_release_key(opts) util.import_release_key(opts)
@ -131,5 +161,5 @@ def main(argv):
return 0 return 0
if __name__ == '__main__': if __name__ == "__main__":
sys.exit(main(sys.argv[1:])) sys.exit(main(sys.argv[1:]))

View File

@ -35,7 +35,7 @@ import util
KEYID = util.KEYID_DSA KEYID = util.KEYID_DSA
# Regular expression to validate tag names. # Regular expression to validate tag names.
RE_VALID_TAG = r'^v([0-9]+[.])+[0-9]+$' RE_VALID_TAG = r"^v([0-9]+[.])+[0-9]+$"
def sign(opts): def sign(opts):
@ -44,11 +44,20 @@ def sign(opts):
# Otherwise it uses the key as a lookup into the overall key and uses the # Otherwise it uses the key as a lookup into the overall key and uses the
# default signing key. i.e. It will see that KEYID_RSA is a subkey of # default signing key. i.e. It will see that KEYID_RSA is a subkey of
# another key, and use the primary key to sign instead of the subkey. # another key, and use the primary key to sign instead of the subkey.
cmd = ['git', 'tag', '-s', opts.tag, '-u', f'{opts.key}!', cmd = [
'-m', f'repo {opts.tag}', opts.commit] "git",
"tag",
"-s",
opts.tag,
"-u",
f"{opts.key}!",
"-m",
f"repo {opts.tag}",
opts.commit,
]
key = 'GNUPGHOME' key = "GNUPGHOME"
print('+', f'export {key}="{opts.gpgdir}"') print("+", f'export {key}="{opts.gpgdir}"')
oldvalue = os.getenv(key) oldvalue = os.getenv(key)
os.putenv(key, opts.gpgdir) os.putenv(key, opts.gpgdir)
util.run(opts, cmd) util.run(opts, cmd)
@ -60,21 +69,27 @@ def sign(opts):
def check(opts): def check(opts):
"""Check the signature.""" """Check the signature."""
util.run(opts, ['git', 'tag', '--verify', opts.tag]) util.run(opts, ["git", "tag", "--verify", opts.tag])
def postmsg(opts): def postmsg(opts):
"""Helpful info to show at the end for release manager.""" """Helpful info to show at the end for release manager."""
cmd = ['git', 'rev-parse', 'remotes/origin/stable'] cmd = ["git", "rev-parse", "remotes/origin/stable"]
ret = util.run(opts, cmd, encoding='utf-8', stdout=subprocess.PIPE) ret = util.run(opts, cmd, encoding="utf-8", stdout=subprocess.PIPE)
current_release = ret.stdout.strip() current_release = ret.stdout.strip()
cmd = ['git', 'log', '--format=%h (%aN) %s', '--no-merges', cmd = [
f'remotes/origin/stable..{opts.tag}'] "git",
ret = util.run(opts, cmd, encoding='utf-8', stdout=subprocess.PIPE) "log",
"--format=%h (%aN) %s",
"--no-merges",
f"remotes/origin/stable..{opts.tag}",
]
ret = util.run(opts, cmd, encoding="utf-8", stdout=subprocess.PIPE)
shortlog = ret.stdout.strip() shortlog = ret.stdout.strip()
print(f""" print(
f"""
Here's the short log since the last release. Here's the short log since the last release.
{shortlog} {shortlog}
@ -84,29 +99,39 @@ NB: People will start upgrading to this version immediately.
To roll back a release: To roll back a release:
git push origin --force {current_release}:stable -n git push origin --force {current_release}:stable -n
""") """
)
def get_parser(): def get_parser():
"""Get a CLI parser.""" """Get a CLI parser."""
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description=__doc__, description=__doc__,
formatter_class=argparse.RawDescriptionHelpFormatter) formatter_class=argparse.RawDescriptionHelpFormatter,
parser.add_argument('-n', '--dry-run', )
dest='dryrun', action='store_true', parser.add_argument(
help='show everything that would be done') "-n",
parser.add_argument('--gpgdir', "--dry-run",
default=os.path.join(util.HOMEDIR, '.gnupg', 'repo'), dest="dryrun",
help='path to dedicated gpg dir with release keys ' action="store_true",
'(default: ~/.gnupg/repo/)') help="show everything that would be done",
parser.add_argument('-f', '--force', action='store_true', )
help='force signing of any tag') parser.add_argument(
parser.add_argument('--keyid', dest='key', "--gpgdir",
help='alternative signing key to use') default=os.path.join(util.HOMEDIR, ".gnupg", "repo"),
parser.add_argument('tag', help="path to dedicated gpg dir with release keys "
help='the tag to create (e.g. "v2.0")') "(default: ~/.gnupg/repo/)",
parser.add_argument('commit', default='HEAD', nargs='?', )
help='the commit to tag') parser.add_argument(
"-f", "--force", action="store_true", help="force signing of any tag"
)
parser.add_argument(
"--keyid", dest="key", help="alternative signing key to use"
)
parser.add_argument("tag", help='the tag to create (e.g. "v2.0")')
parser.add_argument(
"commit", default="HEAD", nargs="?", help="the commit to tag"
)
return parser return parser
@ -116,16 +141,18 @@ def main(argv):
opts = parser.parse_args(argv) opts = parser.parse_args(argv)
if not os.path.exists(opts.gpgdir): if not os.path.exists(opts.gpgdir):
parser.error(f'--gpgdir does not exist: {opts.gpgdir}') parser.error(f"--gpgdir does not exist: {opts.gpgdir}")
if not opts.force and not re.match(RE_VALID_TAG, opts.tag): if not opts.force and not re.match(RE_VALID_TAG, opts.tag):
parser.error(f'tag "{opts.tag}" does not match regex "{RE_VALID_TAG}"; ' parser.error(
'use --force to sign anyways') f'tag "{opts.tag}" does not match regex "{RE_VALID_TAG}"; '
"use --force to sign anyways"
)
if opts.key: if opts.key:
print(f'Using custom key to sign: {opts.key}') print(f"Using custom key to sign: {opts.key}")
else: else:
print('Using official Repo release key to sign') print("Using official Repo release key to sign")
opts.key = KEYID opts.key = KEYID
util.import_release_key(opts) util.import_release_key(opts)
@ -136,5 +163,5 @@ def main(argv):
return 0 return 0
if __name__ == '__main__': if __name__ == "__main__":
sys.exit(main(sys.argv[1:])) sys.exit(main(sys.argv[1:]))

View File

@ -29,72 +29,106 @@ import sys
import tempfile import tempfile
TOPDIR = Path(__file__).resolve().parent.parent TOPDIR = Path(__file__).resolve().parent.parent
MANDIR = TOPDIR.joinpath('man') MANDIR = TOPDIR.joinpath("man")
# Load repo local modules. # Load repo local modules.
sys.path.insert(0, str(TOPDIR)) sys.path.insert(0, str(TOPDIR))
from git_command import RepoSourceVersion from git_command import RepoSourceVersion
import subcmds import subcmds
def worker(cmd, **kwargs): def worker(cmd, **kwargs):
subprocess.run(cmd, **kwargs) subprocess.run(cmd, **kwargs)
def main(argv): def main(argv):
parser = argparse.ArgumentParser(description=__doc__) parser = argparse.ArgumentParser(description=__doc__)
opts = parser.parse_args(argv) parser.parse_args(argv)
if not shutil.which('help2man'): if not shutil.which("help2man"):
sys.exit('Please install help2man to continue.') sys.exit("Please install help2man to continue.")
# Let repo know we're generating man pages so it can avoid some dynamic # Let repo know we're generating man pages so it can avoid some dynamic
# behavior (like probing active number of CPUs). We use a weird name & # behavior (like probing active number of CPUs). We use a weird name &
# value to make it less likely for users to set this var themselves. # value to make it less likely for users to set this var themselves.
os.environ['_REPO_GENERATE_MANPAGES_'] = ' indeed! ' os.environ["_REPO_GENERATE_MANPAGES_"] = " indeed! "
# "repo branch" is an alias for "repo branches". # "repo branch" is an alias for "repo branches".
del subcmds.all_commands['branch'] del subcmds.all_commands["branch"]
(MANDIR / 'repo-branch.1').write_text('.so man1/repo-branches.1') (MANDIR / "repo-branch.1").write_text(".so man1/repo-branches.1")
version = RepoSourceVersion() version = RepoSourceVersion()
cmdlist = [['help2man', '-N', '-n', f'repo {cmd} - manual page for repo {cmd}', cmdlist = [
'-S', f'repo {cmd}', '-m', 'Repo Manual', f'--version-string={version}', [
'-o', MANDIR.joinpath(f'repo-{cmd}.1.tmp'), './repo', "help2man",
'-h', f'help {cmd}'] for cmd in subcmds.all_commands] "-N",
cmdlist.append(['help2man', '-N', '-n', 'repository management tool built on top of git', "-n",
'-S', 'repo', '-m', 'Repo Manual', f'--version-string={version}', f"repo {cmd} - manual page for repo {cmd}",
'-o', MANDIR.joinpath('repo.1.tmp'), './repo', "-S",
'-h', '--help-all']) f"repo {cmd}",
"-m",
"Repo Manual",
f"--version-string={version}",
"-o",
MANDIR.joinpath(f"repo-{cmd}.1.tmp"),
"./repo",
"-h",
f"help {cmd}",
]
for cmd in subcmds.all_commands
]
cmdlist.append(
[
"help2man",
"-N",
"-n",
"repository management tool built on top of git",
"-S",
"repo",
"-m",
"Repo Manual",
f"--version-string={version}",
"-o",
MANDIR.joinpath("repo.1.tmp"),
"./repo",
"-h",
"--help-all",
]
)
with tempfile.TemporaryDirectory() as tempdir: with tempfile.TemporaryDirectory() as tempdir:
tempdir = Path(tempdir) tempdir = Path(tempdir)
repo_dir = tempdir / '.repo' repo_dir = tempdir / ".repo"
repo_dir.mkdir() repo_dir.mkdir()
(repo_dir / 'repo').symlink_to(TOPDIR) (repo_dir / "repo").symlink_to(TOPDIR)
# Create a repo wrapper using the active Python executable. We can't pass # Create a repo wrapper using the active Python executable. We can't
# this directly to help2man as it's too simple, so insert it via shebang. # pass this directly to help2man as it's too simple, so insert it via
data = (TOPDIR / 'repo').read_text(encoding='utf-8') # shebang.
tempbin = tempdir / 'repo' data = (TOPDIR / "repo").read_text(encoding="utf-8")
tempbin.write_text(f'#!{sys.executable}\n' + data, encoding='utf-8') tempbin = tempdir / "repo"
tempbin.write_text(f"#!{sys.executable}\n" + data, encoding="utf-8")
tempbin.chmod(0o755) tempbin.chmod(0o755)
# Run all cmd in parallel, and wait for them to finish. # Run all cmd in parallel, and wait for them to finish.
with multiprocessing.Pool() as pool: with multiprocessing.Pool() as pool:
pool.map(partial(worker, cwd=tempdir, check=True), cmdlist) pool.map(partial(worker, cwd=tempdir, check=True), cmdlist)
for tmp_path in MANDIR.glob('*.1.tmp'): for tmp_path in MANDIR.glob("*.1.tmp"):
path = tmp_path.parent / tmp_path.stem path = tmp_path.parent / tmp_path.stem
old_data = path.read_text() if path.exists() else '' old_data = path.read_text() if path.exists() else ""
data = tmp_path.read_text() data = tmp_path.read_text()
tmp_path.unlink() tmp_path.unlink()
data = replace_regex(data) data = replace_regex(data)
# If the only thing that changed was the date, don't refresh. This avoids # If the only thing that changed was the date, don't refresh. This
# a lot of noise when only one file actually updates. # avoids a lot of noise when only one file actually updates.
old_data = re.sub(r'^(\.TH REPO "1" ")([^"]+)', r'\1', old_data, flags=re.M) old_data = re.sub(
new_data = re.sub(r'^(\.TH REPO "1" ")([^"]+)', r'\1', data, flags=re.M) r'^(\.TH REPO "1" ")([^"]+)', r"\1", old_data, flags=re.M
)
new_data = re.sub(r'^(\.TH REPO "1" ")([^"]+)', r"\1", data, flags=re.M)
if old_data != new_data: if old_data != new_data:
path.write_text(data) path.write_text(data)
@ -109,10 +143,10 @@ def replace_regex(data):
Updated manpage text. Updated manpage text.
""" """
regex = ( regex = (
(r'(It was generated by help2man) [0-9.]+', r'\g<1>.'), (r"(It was generated by help2man) [0-9.]+", r"\g<1>."),
(r'^\033\[[0-9;]*m([^\033]*)\033\[m', r'\g<1>'), (r"^\033\[[0-9;]*m([^\033]*)\033\[m", r"\g<1>"),
(r'^\.IP\n(.*:)\n', r'.SS \g<1>\n'), (r"^\.IP\n(.*:)\n", r".SS \g<1>\n"),
(r'^\.PP\nDescription', r'.SH DETAILS'), (r"^\.PP\nDescription", r".SH DETAILS"),
) )
for pattern, replacement in regex: for pattern, replacement in regex:
data = re.sub(pattern, replacement, data, flags=re.M) data = re.sub(pattern, replacement, data, flags=re.M)

View File

@ -20,54 +20,60 @@ import subprocess
import sys import sys
assert sys.version_info >= (3, 6), 'This module requires Python 3.6+' assert sys.version_info >= (3, 6), "This module requires Python 3.6+"
TOPDIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) TOPDIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
HOMEDIR = os.path.expanduser('~') HOMEDIR = os.path.expanduser("~")
# These are the release keys we sign with. # These are the release keys we sign with.
KEYID_DSA = '8BB9AD793E8E6153AF0F9A4416530D5E920F5C65' KEYID_DSA = "8BB9AD793E8E6153AF0F9A4416530D5E920F5C65"
KEYID_RSA = 'A34A13BE8E76BFF46A0C022DA2E75A824AAB9624' KEYID_RSA = "A34A13BE8E76BFF46A0C022DA2E75A824AAB9624"
KEYID_ECC = 'E1F9040D7A3F6DAFAC897CD3D3B95DA243E48A39' KEYID_ECC = "E1F9040D7A3F6DAFAC897CD3D3B95DA243E48A39"
def cmdstr(cmd): def cmdstr(cmd):
"""Get a nicely quoted shell command.""" """Get a nicely quoted shell command."""
ret = [] ret = []
for arg in cmd: for arg in cmd:
if not re.match(r'^[a-zA-Z0-9/_.=-]+$', arg): if not re.match(r"^[a-zA-Z0-9/_.=-]+$", arg):
arg = f'"{arg}"' arg = f'"{arg}"'
ret.append(arg) ret.append(arg)
return ' '.join(ret) return " ".join(ret)
def run(opts, cmd, check=True, **kwargs): def run(opts, cmd, check=True, **kwargs):
"""Helper around subprocess.run to include logging.""" """Helper around subprocess.run to include logging."""
print('+', cmdstr(cmd)) print("+", cmdstr(cmd))
if opts.dryrun: if opts.dryrun:
cmd = ['true', '--'] + cmd cmd = ["true", "--"] + cmd
try: try:
return subprocess.run(cmd, check=check, **kwargs) return subprocess.run(cmd, check=check, **kwargs)
except subprocess.CalledProcessError as e: except subprocess.CalledProcessError as e:
print(f'aborting: {e}', file=sys.stderr) print(f"aborting: {e}", file=sys.stderr)
sys.exit(1) sys.exit(1)
def import_release_key(opts): def import_release_key(opts):
"""Import the public key of the official release repo signing key.""" """Import the public key of the official release repo signing key."""
# Extract the key from our repo launcher. # Extract the key from our repo launcher.
launcher = getattr(opts, 'launcher', os.path.join(TOPDIR, 'repo')) launcher = getattr(opts, "launcher", os.path.join(TOPDIR, "repo"))
print(f'Importing keys from "{launcher}" launcher script') print(f'Importing keys from "{launcher}" launcher script')
with open(launcher, encoding='utf-8') as fp: with open(launcher, encoding="utf-8") as fp:
data = fp.read() data = fp.read()
keys = re.findall( keys = re.findall(
r'\n-----BEGIN PGP PUBLIC KEY BLOCK-----\n[^-]*' r"\n-----BEGIN PGP PUBLIC KEY BLOCK-----\n[^-]*"
r'\n-----END PGP PUBLIC KEY BLOCK-----\n', data, flags=re.M) r"\n-----END PGP PUBLIC KEY BLOCK-----\n",
run(opts, ['gpg', '--import'], input='\n'.join(keys).encode('utf-8')) data,
flags=re.M,
)
run(opts, ["gpg", "--import"], input="\n".join(keys).encode("utf-8"))
print('Marking keys as fully trusted') print("Marking keys as fully trusted")
run(opts, ['gpg', '--import-ownertrust'], run(
input=f'{KEYID_DSA}:6:\n'.encode('utf-8')) opts,
["gpg", "--import-ownertrust"],
input=f"{KEYID_DSA}:6:\n".encode("utf-8"),
)

View File

@ -29,15 +29,15 @@ from contextlib import ContextDecorator
import platform_utils import platform_utils
# Env var to implicitly turn on tracing. # Env var to implicitly turn on tracing.
REPO_TRACE = 'REPO_TRACE' REPO_TRACE = "REPO_TRACE"
# Temporarily set tracing to always on unless user expicitly sets to 0. # Temporarily set tracing to always on unless user expicitly sets to 0.
_TRACE = os.environ.get(REPO_TRACE) != '0' _TRACE = os.environ.get(REPO_TRACE) != "0"
_TRACE_TO_STDERR = False _TRACE_TO_STDERR = False
_TRACE_FILE = None _TRACE_FILE = None
_TRACE_FILE_NAME = 'TRACE_FILE' _TRACE_FILE_NAME = "TRACE_FILE"
_MAX_SIZE = 70 # in MiB _MAX_SIZE = 70 # in MiB
_NEW_COMMAND_SEP = '+++++++++++++++NEW COMMAND+++++++++++++++++++' _NEW_COMMAND_SEP = "+++++++++++++++NEW COMMAND+++++++++++++++++++"
def IsTraceToStderr(): def IsTraceToStderr():
@ -73,7 +73,7 @@ class Trace(ContextDecorator):
def _time(self): def _time(self):
"""Generate nanoseconds of time in a py3.6 safe way""" """Generate nanoseconds of time in a py3.6 safe way"""
return int(time.time() * 1e+9) return int(time.time() * 1e9)
def __init__(self, fmt, *args, first_trace=False, quiet=True): def __init__(self, fmt, *args, first_trace=False, quiet=True):
"""Initialize the object. """Initialize the object.
@ -93,15 +93,17 @@ class Trace(ContextDecorator):
if first_trace: if first_trace:
_ClearOldTraces() _ClearOldTraces()
self._trace_msg = f'{_NEW_COMMAND_SEP} {self._trace_msg}' self._trace_msg = f"{_NEW_COMMAND_SEP} {self._trace_msg}"
def __enter__(self): def __enter__(self):
if not IsTrace(): if not IsTrace():
return self return self
print_msg = f'PID: {os.getpid()} START: {self._time()} :{self._trace_msg}\n' print_msg = (
f"PID: {os.getpid()} START: {self._time()} :{self._trace_msg}\n"
)
with open(_TRACE_FILE, 'a') as f: with open(_TRACE_FILE, "a") as f:
print(print_msg, file=f) print(print_msg, file=f)
if _TRACE_TO_STDERR: if _TRACE_TO_STDERR:
@ -113,9 +115,11 @@ class Trace(ContextDecorator):
if not IsTrace(): if not IsTrace():
return False return False
print_msg = f'PID: {os.getpid()} END: {self._time()} :{self._trace_msg}\n' print_msg = (
f"PID: {os.getpid()} END: {self._time()} :{self._trace_msg}\n"
)
with open(_TRACE_FILE, 'a') as f: with open(_TRACE_FILE, "a") as f:
print(print_msg, file=f) print(print_msg, file=f)
if _TRACE_TO_STDERR: if _TRACE_TO_STDERR:
@ -130,14 +134,14 @@ def _GetTraceFile(quiet):
repo_dir = os.path.dirname(os.path.dirname(__file__)) repo_dir = os.path.dirname(os.path.dirname(__file__))
trace_file = os.path.join(repo_dir, _TRACE_FILE_NAME) trace_file = os.path.join(repo_dir, _TRACE_FILE_NAME)
if not quiet: if not quiet:
print(f'Trace outputs in {trace_file}', file=sys.stderr) print(f"Trace outputs in {trace_file}", file=sys.stderr)
return trace_file return trace_file
def _ClearOldTraces(): def _ClearOldTraces():
"""Clear the oldest commands if trace file is too big.""" """Clear the oldest commands if trace file is too big."""
try: try:
with open(_TRACE_FILE, 'r', errors='ignore') as f: with open(_TRACE_FILE, "r", errors="ignore") as f:
if os.path.getsize(f.name) / (1024 * 1024) <= _MAX_SIZE: if os.path.getsize(f.name) / (1024 * 1024) <= _MAX_SIZE:
return return
trace_lines = f.readlines() trace_lines = f.readlines()
@ -146,21 +150,21 @@ def _ClearOldTraces():
while sum(len(x) for x in trace_lines) / (1024 * 1024) > _MAX_SIZE: while sum(len(x) for x in trace_lines) / (1024 * 1024) > _MAX_SIZE:
for i, line in enumerate(trace_lines): for i, line in enumerate(trace_lines):
if 'END:' in line and _NEW_COMMAND_SEP in line: if "END:" in line and _NEW_COMMAND_SEP in line:
trace_lines = trace_lines[i + 1:] trace_lines = trace_lines[i + 1 :]
break break
else: else:
# The last chunk is bigger than _MAX_SIZE, so just throw everything away. # The last chunk is bigger than _MAX_SIZE, so just throw everything
# away.
trace_lines = [] trace_lines = []
while trace_lines and trace_lines[-1] == '\n': while trace_lines and trace_lines[-1] == "\n":
trace_lines = trace_lines[:-1] trace_lines = trace_lines[:-1]
# Write to a temporary file with a unique name in the same filesystem # Write to a temporary file with a unique name in the same filesystem
# before replacing the original trace file. # before replacing the original trace file.
temp_dir, temp_prefix = os.path.split(_TRACE_FILE) temp_dir, temp_prefix = os.path.split(_TRACE_FILE)
with tempfile.NamedTemporaryFile('w', with tempfile.NamedTemporaryFile(
dir=temp_dir, "w", dir=temp_dir, prefix=temp_prefix, delete=False
prefix=temp_prefix, ) as f:
delete=False) as f:
f.writelines(trace_lines) f.writelines(trace_lines)
platform_utils.rename(f.name, _TRACE_FILE) platform_utils.rename(f.name, _TRACE_FILE)

View File

@ -13,10 +13,28 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Wrapper to run pytest with the right settings.""" """Wrapper to run black and pytest with the right settings."""
import os
import subprocess
import sys import sys
import pytest import pytest
if __name__ == '__main__':
sys.exit(pytest.main(sys.argv[1:])) def run_black():
"""Returns the exit code of running `black --check`."""
dirpath = os.path.dirname(os.path.realpath(__file__))
return subprocess.run(
[sys.executable, "-m", "black", "--check", dirpath], check=False
).returncode
def main(argv):
"""The main entry."""
black_ret = 0 if argv else run_black()
pytest_ret = pytest.main(argv)
return 0 if not black_ret and not pytest_ret else 1
if __name__ == "__main__":
sys.exit(main(sys.argv[1:]))

View File

@ -26,8 +26,8 @@ wheel: <
# Required by pytest==6.2.2 # Required by pytest==6.2.2
wheel: < wheel: <
name: "infra/python/wheels/packaging-py2_py3" name: "infra/python/wheels/packaging-py3"
version: "version:16.8" version: "version:23.0"
> >
# Required by pytest==6.2.2 # Required by pytest==6.2.2
@ -59,3 +59,44 @@ wheel: <
name: "infra/python/wheels/six-py2_py3" name: "infra/python/wheels/six-py2_py3"
version: "version:1.16.0" version: "version:1.16.0"
> >
wheel: <
name: "infra/python/wheels/black-py3"
version: "version:23.1.0"
>
# Required by black==23.1.0
wheel: <
name: "infra/python/wheels/mypy-extensions-py3"
version: "version:0.4.3"
>
# Required by black==23.1.0
wheel: <
name: "infra/python/wheels/tomli-py3"
version: "version:2.0.1"
>
# Required by black==23.1.0
wheel: <
name: "infra/python/wheels/platformdirs-py3"
version: "version:2.5.2"
>
# Required by black==23.1.0
wheel: <
name: "infra/python/wheels/pathspec-py3"
version: "version:0.9.0"
>
# Required by black==23.1.0
wheel: <
name: "infra/python/wheels/typing-extensions-py3"
version: "version:4.3.0"
>
# Required by black==23.1.0
wheel: <
name: "infra/python/wheels/click-py3"
version: "version:8.0.3"
>

View File

@ -23,39 +23,39 @@ TOPDIR = os.path.dirname(os.path.abspath(__file__))
# Rip out the first intro paragraph. # Rip out the first intro paragraph.
with open(os.path.join(TOPDIR, 'README.md')) as fp: with open(os.path.join(TOPDIR, "README.md")) as fp:
lines = fp.read().splitlines()[2:] lines = fp.read().splitlines()[2:]
end = lines.index('') end = lines.index("")
long_description = ' '.join(lines[0:end]) long_description = " ".join(lines[0:end])
# https://packaging.python.org/tutorials/packaging-projects/ # https://packaging.python.org/tutorials/packaging-projects/
setuptools.setup( setuptools.setup(
name='repo', name="repo",
version='2', version="2",
maintainer='Various', maintainer="Various",
maintainer_email='repo-discuss@googlegroups.com', maintainer_email="repo-discuss@googlegroups.com",
description='Repo helps manage many Git repositories', description="Repo helps manage many Git repositories",
long_description=long_description, long_description=long_description,
long_description_content_type='text/plain', long_description_content_type="text/plain",
url='https://gerrit.googlesource.com/git-repo/', url="https://gerrit.googlesource.com/git-repo/",
project_urls={ project_urls={
'Bug Tracker': 'https://bugs.chromium.org/p/gerrit/issues/list?q=component:Applications%3Erepo', "Bug Tracker": "https://bugs.chromium.org/p/gerrit/issues/list?q=component:Applications%3Erepo", # noqa: E501
}, },
# https://pypi.org/classifiers/ # https://pypi.org/classifiers/
classifiers=[ classifiers=[
'Development Status :: 6 - Mature', "Development Status :: 6 - Mature",
'Environment :: Console', "Environment :: Console",
'Intended Audience :: Developers', "Intended Audience :: Developers",
'License :: OSI Approved :: Apache Software License', "License :: OSI Approved :: Apache Software License",
'Natural Language :: English', "Natural Language :: English",
'Operating System :: MacOS :: MacOS X', "Operating System :: MacOS :: MacOS X",
'Operating System :: Microsoft :: Windows :: Windows 10', "Operating System :: Microsoft :: Windows :: Windows 10",
'Operating System :: POSIX :: Linux', "Operating System :: POSIX :: Linux",
'Programming Language :: Python :: 3', "Programming Language :: Python :: 3",
'Programming Language :: Python :: 3 :: Only', "Programming Language :: Python :: 3 :: Only",
'Topic :: Software Development :: Version Control :: Git', "Topic :: Software Development :: Version Control :: Git",
], ],
python_requires='>=3.6', python_requires=">=3.6",
packages=['subcmds'], packages=["subcmds"],
) )

122
ssh.py
View File

@ -28,21 +28,23 @@ import platform_utils
from repo_trace import Trace from repo_trace import Trace
PROXY_PATH = os.path.join(os.path.dirname(__file__), 'git_ssh') PROXY_PATH = os.path.join(os.path.dirname(__file__), "git_ssh")
def _run_ssh_version(): def _run_ssh_version():
"""run ssh -V to display the version number""" """run ssh -V to display the version number"""
return subprocess.check_output(['ssh', '-V'], stderr=subprocess.STDOUT).decode() return subprocess.check_output(
["ssh", "-V"], stderr=subprocess.STDOUT
).decode()
def _parse_ssh_version(ver_str=None): def _parse_ssh_version(ver_str=None):
"""parse a ssh version string into a tuple""" """parse a ssh version string into a tuple"""
if ver_str is None: if ver_str is None:
ver_str = _run_ssh_version() ver_str = _run_ssh_version()
m = re.match(r'^OpenSSH_([0-9.]+)(p[0-9]+)?\s', ver_str) m = re.match(r"^OpenSSH_([0-9.]+)(p[0-9]+)?\s", ver_str)
if m: if m:
return tuple(int(x) for x in m.group(1).split('.')) return tuple(int(x) for x in m.group(1).split("."))
else: else:
return () return ()
@ -53,15 +55,15 @@ def version():
try: try:
return _parse_ssh_version() return _parse_ssh_version()
except FileNotFoundError: except FileNotFoundError:
print('fatal: ssh not installed', file=sys.stderr) print("fatal: ssh not installed", file=sys.stderr)
sys.exit(1) sys.exit(1)
except subprocess.CalledProcessError: except subprocess.CalledProcessError:
print('fatal: unable to detect ssh version', file=sys.stderr) print("fatal: unable to detect ssh version", file=sys.stderr)
sys.exit(1) sys.exit(1)
URI_SCP = re.compile(r'^([^@:]*@?[^:/]{1,}):') URI_SCP = re.compile(r"^([^@:]*@?[^:/]{1,}):")
URI_ALL = re.compile(r'^([a-z][a-z+-]*)://([^@/]*@?[^/]*)/') URI_ALL = re.compile(r"^([a-z][a-z+-]*)://([^@/]*@?[^/]*)/")
class ProxyManager: class ProxyManager:
@ -70,8 +72,8 @@ class ProxyManager:
This will take care of sharing state between multiprocessing children, and This will take care of sharing state between multiprocessing children, and
make sure that if we crash, we don't leak any of the ssh sessions. make sure that if we crash, we don't leak any of the ssh sessions.
The code should work with a single-process scenario too, and not add too much The code should work with a single-process scenario too, and not add too
overhead due to the manager. much overhead due to the manager.
""" """
# Path to the ssh program to run which will pass our master settings along. # Path to the ssh program to run which will pass our master settings along.
@ -81,16 +83,17 @@ class ProxyManager:
def __init__(self, manager): def __init__(self, manager):
# Protect access to the list of active masters. # Protect access to the list of active masters.
self._lock = multiprocessing.Lock() self._lock = multiprocessing.Lock()
# List of active masters (pid). These will be spawned on demand, and we are # List of active masters (pid). These will be spawned on demand, and we
# responsible for shutting them all down at the end. # are responsible for shutting them all down at the end.
self._masters = manager.list() self._masters = manager.list()
# Set of active masters indexed by "host:port" information. # Set of active masters indexed by "host:port" information.
# The value isn't used, but multiprocessing doesn't provide a set class. # The value isn't used, but multiprocessing doesn't provide a set class.
self._master_keys = manager.dict() self._master_keys = manager.dict()
# Whether ssh masters are known to be broken, so we give up entirely. # Whether ssh masters are known to be broken, so we give up entirely.
self._master_broken = manager.Value('b', False) self._master_broken = manager.Value("b", False)
# List of active ssh sesssions. Clients will be added & removed as # List of active ssh sesssions. Clients will be added & removed as
# connections finish, so this list is just for safety & cleanup if we crash. # connections finish, so this list is just for safety & cleanup if we
# crash.
self._clients = manager.list() self._clients = manager.list()
# Path to directory for holding master sockets. # Path to directory for holding master sockets.
self._sock_path = None self._sock_path = None
@ -132,7 +135,7 @@ class ProxyManager:
while True: while True:
try: try:
procs.pop(0) procs.pop(0)
except: except: # noqa: E722
break break
def close(self): def close(self):
@ -155,64 +158,71 @@ class ProxyManager:
If one doesn't exist already, we'll create it. If one doesn't exist already, we'll create it.
We won't grab any locks, so the caller has to do that. This helps keep the We won't grab any locks, so the caller has to do that. This helps keep
business logic of actually creating the master separate from grabbing locks. the business logic of actually creating the master separate from
grabbing locks.
""" """
# Check to see whether we already think that the master is running; if we # Check to see whether we already think that the master is running; if
# think it's already running, return right away. # we think it's already running, return right away.
if port is not None: if port is not None:
key = '%s:%s' % (host, port) key = "%s:%s" % (host, port)
else: else:
key = host key = host
if key in self._master_keys: if key in self._master_keys:
return True return True
if self._master_broken.value or 'GIT_SSH' in os.environ: if self._master_broken.value or "GIT_SSH" in os.environ:
# Failed earlier, so don't retry. # Failed earlier, so don't retry.
return False return False
# We will make two calls to ssh; this is the common part of both calls. # We will make two calls to ssh; this is the common part of both calls.
command_base = ['ssh', '-o', 'ControlPath %s' % self.sock(), host] command_base = ["ssh", "-o", "ControlPath %s" % self.sock(), host]
if port is not None: if port is not None:
command_base[1:1] = ['-p', str(port)] command_base[1:1] = ["-p", str(port)]
# Since the key wasn't in _master_keys, we think that master isn't running. # Since the key wasn't in _master_keys, we think that master isn't
# ...but before actually starting a master, we'll double-check. This can # running... but before actually starting a master, we'll double-check.
# be important because we can't tell that that 'git@myhost.com' is the same # This can be important because we can't tell that that 'git@myhost.com'
# as 'myhost.com' where "User git" is setup in the user's ~/.ssh/config file. # is the same as 'myhost.com' where "User git" is setup in the user's
check_command = command_base + ['-O', 'check'] # ~/.ssh/config file.
with Trace('Call to ssh (check call): %s', ' '.join(check_command)): check_command = command_base + ["-O", "check"]
with Trace("Call to ssh (check call): %s", " ".join(check_command)):
try: try:
check_process = subprocess.Popen(check_command, check_process = subprocess.Popen(
check_command,
stdout=subprocess.PIPE, stdout=subprocess.PIPE,
stderr=subprocess.PIPE) stderr=subprocess.PIPE,
)
check_process.communicate() # read output, but ignore it... check_process.communicate() # read output, but ignore it...
isnt_running = check_process.wait() isnt_running = check_process.wait()
if not isnt_running: if not isnt_running:
# Our double-check found that the master _was_ infact running. Add to # Our double-check found that the master _was_ infact
# the list of keys. # running. Add to the list of keys.
self._master_keys[key] = True self._master_keys[key] = True
return True return True
except Exception: except Exception:
# Ignore excpetions. We we will fall back to the normal command and # Ignore excpetions. We we will fall back to the normal command
# print to the log there. # and print to the log there.
pass pass
command = command_base[:1] + ['-M', '-N'] + command_base[1:] command = command_base[:1] + ["-M", "-N"] + command_base[1:]
p = None p = None
try: try:
with Trace('Call to ssh: %s', ' '.join(command)): with Trace("Call to ssh: %s", " ".join(command)):
p = subprocess.Popen(command) p = subprocess.Popen(command)
except Exception as e: except Exception as e:
self._master_broken.value = True self._master_broken.value = True
print('\nwarn: cannot enable ssh control master for %s:%s\n%s' print(
% (host, port, str(e)), file=sys.stderr) "\nwarn: cannot enable ssh control master for %s:%s\n%s"
% (host, port, str(e)),
file=sys.stderr,
)
return False return False
time.sleep(1) time.sleep(1)
ssh_died = (p.poll() is not None) ssh_died = p.poll() is not None
if ssh_died: if ssh_died:
return False return False
@ -227,29 +237,29 @@ class ProxyManager:
This will obtain any necessary locks to avoid inter-process races. This will obtain any necessary locks to avoid inter-process races.
""" """
# Bail before grabbing the lock if we already know that we aren't going to # Bail before grabbing the lock if we already know that we aren't going
# try creating new masters below. # to try creating new masters below.
if sys.platform in ('win32', 'cygwin'): if sys.platform in ("win32", "cygwin"):
return False return False
# Acquire the lock. This is needed to prevent opening multiple masters for # Acquire the lock. This is needed to prevent opening multiple masters
# the same host when we're running "repo sync -jN" (for N > 1) _and_ the # for the same host when we're running "repo sync -jN" (for N > 1) _and_
# manifest <remote fetch="ssh://xyz"> specifies a different host from the # the manifest <remote fetch="ssh://xyz"> specifies a different host
# one that was passed to repo init. # from the one that was passed to repo init.
with self._lock: with self._lock:
return self._open_unlocked(host, port) return self._open_unlocked(host, port)
def preconnect(self, url): def preconnect(self, url):
"""If |uri| will create a ssh connection, setup the ssh master for it.""" """If |uri| will create a ssh connection, setup the ssh master for it.""" # noqa: E501
m = URI_ALL.match(url) m = URI_ALL.match(url)
if m: if m:
scheme = m.group(1) scheme = m.group(1)
host = m.group(2) host = m.group(2)
if ':' in host: if ":" in host:
host, port = host.split(':') host, port = host.split(":")
else: else:
port = None port = None
if scheme in ('ssh', 'git+ssh', 'ssh+git'): if scheme in ("ssh", "git+ssh", "ssh+git"):
return self._open(host, port) return self._open(host, port)
return False return False
@ -268,14 +278,14 @@ class ProxyManager:
if self._sock_path is None: if self._sock_path is None:
if not create: if not create:
return None return None
tmp_dir = '/tmp' tmp_dir = "/tmp"
if not os.path.exists(tmp_dir): if not os.path.exists(tmp_dir):
tmp_dir = tempfile.gettempdir() tmp_dir = tempfile.gettempdir()
if version() < (6, 7): if version() < (6, 7):
tokens = '%r@%h:%p' tokens = "%r@%h:%p"
else: else:
tokens = '%C' # hash of %l%h%p%r tokens = "%C" # hash of %l%h%p%r
self._sock_path = os.path.join( self._sock_path = os.path.join(
tempfile.mkdtemp('', 'ssh-', tmp_dir), tempfile.mkdtemp("", "ssh-", tmp_dir), "master-" + tokens
'master-' + tokens) )
return self._sock_path return self._sock_path

View File

@ -19,31 +19,29 @@ all_commands = {}
my_dir = os.path.dirname(__file__) my_dir = os.path.dirname(__file__)
for py in os.listdir(my_dir): for py in os.listdir(my_dir):
if py == '__init__.py': if py == "__init__.py":
continue continue
if py.endswith('.py'): if py.endswith(".py"):
name = py[:-3] name = py[:-3]
clsn = name.capitalize() clsn = name.capitalize()
while clsn.find('_') > 0: while clsn.find("_") > 0:
h = clsn.index('_') h = clsn.index("_")
clsn = clsn[0:h] + clsn[h + 1:].capitalize() clsn = clsn[0:h] + clsn[h + 1 :].capitalize()
mod = __import__(__name__, mod = __import__(__name__, globals(), locals(), ["%s" % name])
globals(),
locals(),
['%s' % name])
mod = getattr(mod, name) mod = getattr(mod, name)
try: try:
cmd = getattr(mod, clsn) cmd = getattr(mod, clsn)
except AttributeError: except AttributeError:
raise SyntaxError('%s/%s does not define class %s' % ( raise SyntaxError(
__name__, py, clsn)) "%s/%s does not define class %s" % (__name__, py, clsn)
)
name = name.replace('_', '-') name = name.replace("_", "-")
cmd.NAME = name cmd.NAME = name
all_commands[name] = cmd all_commands[name] = cmd
# Add 'branch' as an alias for 'branches'. # Add 'branch' as an alias for 'branches'.
all_commands['branch'] = all_commands['branches'] all_commands["branch"] = all_commands["branches"]

View File

@ -36,9 +36,12 @@ It is equivalent to "git branch -D <branchname>".
PARALLEL_JOBS = DEFAULT_LOCAL_JOBS PARALLEL_JOBS = DEFAULT_LOCAL_JOBS
def _Options(self, p): def _Options(self, p):
p.add_option('--all', p.add_option(
dest='all', action='store_true', "--all",
help='delete all branches in all projects') dest="all",
action="store_true",
help="delete all branches in all projects",
)
def ValidateOptions(self, opt, args): def ValidateOptions(self, opt, args):
if not opt.all and not args: if not opt.all and not args:
@ -46,7 +49,7 @@ It is equivalent to "git branch -D <branchname>".
if not opt.all: if not opt.all:
nb = args[0] nb = args[0]
if not git.check_ref_format('heads/%s' % nb): if not git.check_ref_format("heads/%s" % nb):
self.OptionParser.error("'%s' is not a valid branch name" % nb) self.OptionParser.error("'%s' is not a valid branch name" % nb)
else: else:
args.insert(0, "'All local branches'") args.insert(0, "'All local branches'")
@ -69,11 +72,13 @@ It is equivalent to "git branch -D <branchname>".
nb = args[0] nb = args[0]
err = defaultdict(list) err = defaultdict(list)
success = defaultdict(list) success = defaultdict(list)
all_projects = self.GetProjects(args[1:], all_manifests=not opt.this_manifest_only) all_projects = self.GetProjects(
args[1:], all_manifests=not opt.this_manifest_only
)
_RelPath = lambda p: p.RelPath(local=opt.this_manifest_only) _RelPath = lambda p: p.RelPath(local=opt.this_manifest_only)
def _ProcessResults(_pool, pm, states): def _ProcessResults(_pool, pm, states):
for (results, project) in states: for results, project in states:
for branch, status in results.items(): for branch, status in results.items():
if status: if status:
success[branch].append(project) success[branch].append(project)
@ -86,30 +91,46 @@ It is equivalent to "git branch -D <branchname>".
functools.partial(self._ExecuteOne, opt.all, nb), functools.partial(self._ExecuteOne, opt.all, nb),
all_projects, all_projects,
callback=_ProcessResults, callback=_ProcessResults,
output=Progress('Abandon %s' % (nb,), len(all_projects), quiet=opt.quiet)) output=Progress(
"Abandon %s" % (nb,), len(all_projects), quiet=opt.quiet
),
)
width = max(itertools.chain( width = max(
[25], (len(x) for x in itertools.chain(success, err)))) itertools.chain(
[25], (len(x) for x in itertools.chain(success, err))
)
)
if err: if err:
for br in err.keys(): for br in err.keys():
err_msg = "error: cannot abandon %s" % br err_msg = "error: cannot abandon %s" % br
print(err_msg, file=sys.stderr) print(err_msg, file=sys.stderr)
for proj in err[br]: for proj in err[br]:
print(' ' * len(err_msg) + " | %s" % _RelPath(proj), file=sys.stderr) print(
" " * len(err_msg) + " | %s" % _RelPath(proj),
file=sys.stderr,
)
sys.exit(1) sys.exit(1)
elif not success: elif not success:
print('error: no project has local branch(es) : %s' % nb, print(
file=sys.stderr) "error: no project has local branch(es) : %s" % nb,
file=sys.stderr,
)
sys.exit(1) sys.exit(1)
else: else:
# Everything below here is displaying status. # Everything below here is displaying status.
if opt.quiet: if opt.quiet:
return return
print('Abandoned branches:') print("Abandoned branches:")
for br in success.keys(): for br in success.keys():
if len(all_projects) > 1 and len(all_projects) == len(success[br]): if len(all_projects) > 1 and len(all_projects) == len(
success[br]
):
result = "all project" result = "all project"
else: else:
result = "%s" % ( result = "%s" % (
('\n' + ' ' * width + '| ').join(_RelPath(p) for p in success[br])) ("\n" + " " * width + "| ").join(
print("%s%s| %s\n" % (br, ' ' * (width - len(br)), result)) _RelPath(p) for p in success[br]
)
)
print("%s%s| %s\n" % (br, " " * (width - len(br)), result))

View File

@ -21,10 +21,10 @@ from command import Command, DEFAULT_LOCAL_JOBS
class BranchColoring(Coloring): class BranchColoring(Coloring):
def __init__(self, config): def __init__(self, config):
Coloring.__init__(self, config, 'branch') Coloring.__init__(self, config, "branch")
self.current = self.printer('current', fg='green') self.current = self.printer("current", fg="green")
self.local = self.printer('local') self.local = self.printer("local")
self.notinproject = self.printer('notinproject', fg='red') self.notinproject = self.printer("notinproject", fg="red")
class BranchInfo(object): class BranchInfo(object):
@ -98,7 +98,9 @@ is shown, then the branch appears in all projects.
PARALLEL_JOBS = DEFAULT_LOCAL_JOBS PARALLEL_JOBS = DEFAULT_LOCAL_JOBS
def Execute(self, opt, args): def Execute(self, opt, args):
projects = self.GetProjects(args, all_manifests=not opt.this_manifest_only) projects = self.GetProjects(
args, all_manifests=not opt.this_manifest_only
)
out = BranchColoring(self.manifest.manifestProject.config) out = BranchColoring(self.manifest.manifestProject.config)
all_branches = {} all_branches = {}
project_cnt = len(projects) project_cnt = len(projects)
@ -113,12 +115,13 @@ is shown, then the branch appears in all projects.
opt.jobs, opt.jobs,
expand_project_to_branches, expand_project_to_branches,
projects, projects,
callback=_ProcessResults) callback=_ProcessResults,
)
names = sorted(all_branches) names = sorted(all_branches)
if not names: if not names:
print(' (no branches)', file=sys.stderr) print(" (no branches)", file=sys.stderr)
return return
width = 25 width = 25
@ -131,21 +134,21 @@ is shown, then the branch appears in all projects.
in_cnt = len(i.projects) in_cnt = len(i.projects)
if i.IsCurrent: if i.IsCurrent:
current = '*' current = "*"
hdr = out.current hdr = out.current
else: else:
current = ' ' current = " "
hdr = out.local hdr = out.local
if i.IsPublishedEqual: if i.IsPublishedEqual:
published = 'P' published = "P"
elif i.IsPublished: elif i.IsPublished:
published = 'p' published = "p"
else: else:
published = ' ' published = " "
hdr('%c%c %-*s' % (current, published, width, name)) hdr("%c%c %-*s" % (current, published, width, name))
out.write(' |') out.write(" |")
_RelPath = lambda p: p.RelPath(local=opt.this_manifest_only) _RelPath = lambda p: p.RelPath(local=opt.this_manifest_only)
if in_cnt < project_cnt: if in_cnt < project_cnt:
@ -153,7 +156,7 @@ is shown, then the branch appears in all projects.
paths = [] paths = []
non_cur_paths = [] non_cur_paths = []
if i.IsSplitCurrent or (in_cnt <= project_cnt - in_cnt): if i.IsSplitCurrent or (in_cnt <= project_cnt - in_cnt):
in_type = 'in' in_type = "in"
for b in i.projects: for b in i.projects:
relpath = _RelPath(b.project) relpath = _RelPath(b.project)
if not i.IsSplitCurrent or b.current: if not i.IsSplitCurrent or b.current:
@ -162,7 +165,7 @@ is shown, then the branch appears in all projects.
non_cur_paths.append(relpath) non_cur_paths.append(relpath)
else: else:
fmt = out.notinproject fmt = out.notinproject
in_type = 'not in' in_type = "not in"
have = set() have = set()
for b in i.projects: for b in i.projects:
have.add(_RelPath(b.project)) have.add(_RelPath(b.project))
@ -170,22 +173,22 @@ is shown, then the branch appears in all projects.
if _RelPath(p) not in have: if _RelPath(p) not in have:
paths.append(_RelPath(p)) paths.append(_RelPath(p))
s = ' %s %s' % (in_type, ', '.join(paths)) s = " %s %s" % (in_type, ", ".join(paths))
if not i.IsSplitCurrent and (width + 7 + len(s) < 80): if not i.IsSplitCurrent and (width + 7 + len(s) < 80):
fmt = out.current if i.IsCurrent else fmt fmt = out.current if i.IsCurrent else fmt
fmt(s) fmt(s)
else: else:
fmt(' %s:' % in_type) fmt(" %s:" % in_type)
fmt = out.current if i.IsCurrent else out.write fmt = out.current if i.IsCurrent else out.write
for p in paths: for p in paths:
out.nl() out.nl()
fmt(width * ' ' + ' %s' % p) fmt(width * " " + " %s" % p)
fmt = out.write fmt = out.write
for p in non_cur_paths: for p in non_cur_paths:
out.nl() out.nl()
fmt(width * ' ' + ' %s' % p) fmt(width * " " + " %s" % p)
else: else:
out.write(' in all projects') out.write(" in all projects")
out.nl() out.nl()

View File

@ -47,7 +47,9 @@ The command is equivalent to:
nb = args[0] nb = args[0]
err = [] err = []
success = [] success = []
all_projects = self.GetProjects(args[1:], all_manifests=not opt.this_manifest_only) all_projects = self.GetProjects(
args[1:], all_manifests=not opt.this_manifest_only
)
def _ProcessResults(_pool, pm, results): def _ProcessResults(_pool, pm, results):
for status, project in results: for status, project in results:
@ -63,13 +65,18 @@ The command is equivalent to:
functools.partial(self._ExecuteOne, nb), functools.partial(self._ExecuteOne, nb),
all_projects, all_projects,
callback=_ProcessResults, callback=_ProcessResults,
output=Progress('Checkout %s' % (nb,), len(all_projects), quiet=opt.quiet)) output=Progress(
"Checkout %s" % (nb,), len(all_projects), quiet=opt.quiet
),
)
if err: if err:
for p in err: for p in err:
print("error: %s/: cannot checkout %s" % (p.relpath, nb), print(
file=sys.stderr) "error: %s/: cannot checkout %s" % (p.relpath, nb),
file=sys.stderr,
)
sys.exit(1) sys.exit(1)
elif not success: elif not success:
print('error: no project has branch %s' % nb, file=sys.stderr) print("error: no project has branch %s" % nb, file=sys.stderr)
sys.exit(1) sys.exit(1)

View File

@ -17,7 +17,7 @@ import sys
from command import Command from command import Command
from git_command import GitCommand from git_command import GitCommand
CHANGE_ID_RE = re.compile(r'^\s*Change-Id: I([0-9a-f]{40})\s*$') CHANGE_ID_RE = re.compile(r"^\s*Change-Id: I([0-9a-f]{40})\s*$")
class CherryPick(Command): class CherryPick(Command):
@ -39,25 +39,31 @@ change id will be added.
def Execute(self, opt, args): def Execute(self, opt, args):
reference = args[0] reference = args[0]
p = GitCommand(None, p = GitCommand(
['rev-parse', '--verify', reference], None,
["rev-parse", "--verify", reference],
capture_stdout=True, capture_stdout=True,
capture_stderr=True) capture_stderr=True,
)
if p.Wait() != 0: if p.Wait() != 0:
print(p.stderr, file=sys.stderr) print(p.stderr, file=sys.stderr)
sys.exit(1) sys.exit(1)
sha1 = p.stdout.strip() sha1 = p.stdout.strip()
p = GitCommand(None, ['cat-file', 'commit', sha1], capture_stdout=True) p = GitCommand(None, ["cat-file", "commit", sha1], capture_stdout=True)
if p.Wait() != 0: if p.Wait() != 0:
print("error: Failed to retrieve old commit message", file=sys.stderr) print(
"error: Failed to retrieve old commit message", file=sys.stderr
)
sys.exit(1) sys.exit(1)
old_msg = self._StripHeader(p.stdout) old_msg = self._StripHeader(p.stdout)
p = GitCommand(None, p = GitCommand(
['cherry-pick', sha1], None,
["cherry-pick", sha1],
capture_stdout=True, capture_stdout=True,
capture_stderr=True) capture_stderr=True,
)
status = p.Wait() status = p.Wait()
if p.stdout: if p.stdout:
@ -70,17 +76,22 @@ change id will be added.
# commit message. # commit message.
new_msg = self._Reformat(old_msg, sha1) new_msg = self._Reformat(old_msg, sha1)
p = GitCommand(None, ['commit', '--amend', '-F', '-'], p = GitCommand(
None,
["commit", "--amend", "-F", "-"],
input=new_msg, input=new_msg,
capture_stdout=True, capture_stdout=True,
capture_stderr=True) capture_stderr=True,
)
if p.Wait() != 0: if p.Wait() != 0:
print("error: Failed to update commit message", file=sys.stderr) print("error: Failed to update commit message", file=sys.stderr)
sys.exit(1) sys.exit(1)
else: else:
print('NOTE: When committing (please see above) and editing the commit ' print(
'message, please remove the old Change-Id-line and add:') "NOTE: When committing (please see above) and editing the "
"commit message, please remove the old Change-Id-line and add:"
)
print(self._GetReference(sha1), file=sys.stderr) print(self._GetReference(sha1), file=sys.stderr)
print(file=sys.stderr) print(file=sys.stderr)
@ -92,7 +103,7 @@ change id will be added.
def _StripHeader(self, commit_msg): def _StripHeader(self, commit_msg):
lines = commit_msg.splitlines() lines = commit_msg.splitlines()
return "\n".join(lines[lines.index("") + 1:]) return "\n".join(lines[lines.index("") + 1 :])
def _Reformat(self, old_msg, sha1): def _Reformat(self, old_msg, sha1):
new_msg = [] new_msg = []
@ -101,7 +112,7 @@ change id will be added.
if not self._IsChangeId(line): if not self._IsChangeId(line):
new_msg.append(line) new_msg.append(line)
# Add a blank line between the message and the change id/reference # Add a blank line between the message and the change id/reference.
try: try:
if new_msg[-1].strip() != "": if new_msg[-1].strip() != "":
new_msg.append("") new_msg.append("")

View File

@ -31,9 +31,13 @@ to the Unix 'patch' command.
PARALLEL_JOBS = DEFAULT_LOCAL_JOBS PARALLEL_JOBS = DEFAULT_LOCAL_JOBS
def _Options(self, p): def _Options(self, p):
p.add_option('-u', '--absolute', p.add_option(
dest='absolute', action='store_true', "-u",
help='paths are relative to the repository root') "--absolute",
dest="absolute",
action="store_true",
help="paths are relative to the repository root",
)
def _ExecuteOne(self, absolute, local, project): def _ExecuteOne(self, absolute, local, project):
"""Obtains the diff for a specific project. """Obtains the diff for a specific project.
@ -41,8 +45,8 @@ to the Unix 'patch' command.
Args: Args:
absolute: Paths are relative to the root. absolute: Paths are relative to the root.
local: a boolean, if True, the path is relative to the local local: a boolean, if True, the path is relative to the local
(sub)manifest. If false, the path is relative to the (sub)manifest. If false, the path is relative to the outermost
outermost manifest. manifest.
project: Project to get status of. project: Project to get status of.
Returns: Returns:
@ -53,20 +57,25 @@ to the Unix 'patch' command.
return (ret, buf.getvalue()) return (ret, buf.getvalue())
def Execute(self, opt, args): def Execute(self, opt, args):
all_projects = self.GetProjects(args, all_manifests=not opt.this_manifest_only) all_projects = self.GetProjects(
args, all_manifests=not opt.this_manifest_only
)
def _ProcessResults(_pool, _output, results): def _ProcessResults(_pool, _output, results):
ret = 0 ret = 0
for (state, output) in results: for state, output in results:
if output: if output:
print(output, end='') print(output, end="")
if not state: if not state:
ret = 1 ret = 1
return ret return ret
return self.ExecuteInParallel( return self.ExecuteInParallel(
opt.jobs, opt.jobs,
functools.partial(self._ExecuteOne, opt.absolute, opt.this_manifest_only), functools.partial(
self._ExecuteOne, opt.absolute, opt.this_manifest_only
),
all_projects, all_projects,
callback=_ProcessResults, callback=_ProcessResults,
ordered=True) ordered=True,
)

View File

@ -23,7 +23,7 @@ class _Coloring(Coloring):
class Diffmanifests(PagedCommand): class Diffmanifests(PagedCommand):
""" A command to see logs in projects represented by manifests """A command to see logs in projects represented by manifests
This is used to see deeper differences between manifests. Where a simple This is used to see deeper differences between manifests. Where a simple
diff would only show a diff of sha1s for example, this command will display diff would only show a diff of sha1s for example, this command will display
@ -66,145 +66,188 @@ synced and their revisions won't be found.
""" """
def _Options(self, p): def _Options(self, p):
p.add_option('--raw', p.add_option(
dest='raw', action='store_true', "--raw", dest="raw", action="store_true", help="display raw diff"
help='display raw diff') )
p.add_option('--no-color', p.add_option(
dest='color', action='store_false', default=True, "--no-color",
help='does not display the diff in color') dest="color",
p.add_option('--pretty-format', action="store_false",
dest='pretty_format', action='store', default=True,
metavar='<FORMAT>', help="does not display the diff in color",
help='print the log using a custom git pretty format string') )
p.add_option(
"--pretty-format",
dest="pretty_format",
action="store",
metavar="<FORMAT>",
help="print the log using a custom git pretty format string",
)
def _printRawDiff(self, diff, pretty_format=None, local=False): def _printRawDiff(self, diff, pretty_format=None, local=False):
_RelPath = lambda p: p.RelPath(local=local) _RelPath = lambda p: p.RelPath(local=local)
for project in diff['added']: for project in diff["added"]:
self.printText("A %s %s" % (_RelPath(project), project.revisionExpr)) self.printText(
"A %s %s" % (_RelPath(project), project.revisionExpr)
)
self.out.nl() self.out.nl()
for project in diff['removed']: for project in diff["removed"]:
self.printText("R %s %s" % (_RelPath(project), project.revisionExpr)) self.printText(
"R %s %s" % (_RelPath(project), project.revisionExpr)
)
self.out.nl() self.out.nl()
for project, otherProject in diff['changed']: for project, otherProject in diff["changed"]:
self.printText("C %s %s %s" % (_RelPath(project), project.revisionExpr, self.printText(
otherProject.revisionExpr)) "C %s %s %s"
% (
_RelPath(project),
project.revisionExpr,
otherProject.revisionExpr,
)
)
self.out.nl() self.out.nl()
self._printLogs(project, otherProject, raw=True, color=False, pretty_format=pretty_format) self._printLogs(
project,
otherProject,
raw=True,
color=False,
pretty_format=pretty_format,
)
for project, otherProject in diff['unreachable']: for project, otherProject in diff["unreachable"]:
self.printText("U %s %s %s" % (_RelPath(project), project.revisionExpr, self.printText(
otherProject.revisionExpr)) "U %s %s %s"
% (
_RelPath(project),
project.revisionExpr,
otherProject.revisionExpr,
)
)
self.out.nl() self.out.nl()
def _printDiff(self, diff, color=True, pretty_format=None, local=False): def _printDiff(self, diff, color=True, pretty_format=None, local=False):
_RelPath = lambda p: p.RelPath(local=local) _RelPath = lambda p: p.RelPath(local=local)
if diff['added']: if diff["added"]:
self.out.nl() self.out.nl()
self.printText('added projects : \n') self.printText("added projects : \n")
self.out.nl() self.out.nl()
for project in diff['added']: for project in diff["added"]:
self.printProject('\t%s' % (_RelPath(project))) self.printProject("\t%s" % (_RelPath(project)))
self.printText(' at revision ') self.printText(" at revision ")
self.printRevision(project.revisionExpr) self.printRevision(project.revisionExpr)
self.out.nl() self.out.nl()
if diff['removed']: if diff["removed"]:
self.out.nl() self.out.nl()
self.printText('removed projects : \n') self.printText("removed projects : \n")
self.out.nl() self.out.nl()
for project in diff['removed']: for project in diff["removed"]:
self.printProject('\t%s' % (_RelPath(project))) self.printProject("\t%s" % (_RelPath(project)))
self.printText(' at revision ') self.printText(" at revision ")
self.printRevision(project.revisionExpr) self.printRevision(project.revisionExpr)
self.out.nl() self.out.nl()
if diff['missing']: if diff["missing"]:
self.out.nl() self.out.nl()
self.printText('missing projects : \n') self.printText("missing projects : \n")
self.out.nl() self.out.nl()
for project in diff['missing']: for project in diff["missing"]:
self.printProject('\t%s' % (_RelPath(project))) self.printProject("\t%s" % (_RelPath(project)))
self.printText(' at revision ') self.printText(" at revision ")
self.printRevision(project.revisionExpr) self.printRevision(project.revisionExpr)
self.out.nl() self.out.nl()
if diff['changed']: if diff["changed"]:
self.out.nl() self.out.nl()
self.printText('changed projects : \n') self.printText("changed projects : \n")
self.out.nl() self.out.nl()
for project, otherProject in diff['changed']: for project, otherProject in diff["changed"]:
self.printProject('\t%s' % (_RelPath(project))) self.printProject("\t%s" % (_RelPath(project)))
self.printText(' changed from ') self.printText(" changed from ")
self.printRevision(project.revisionExpr) self.printRevision(project.revisionExpr)
self.printText(' to ') self.printText(" to ")
self.printRevision(otherProject.revisionExpr) self.printRevision(otherProject.revisionExpr)
self.out.nl() self.out.nl()
self._printLogs(project, otherProject, raw=False, color=color, self._printLogs(
pretty_format=pretty_format) project,
otherProject,
raw=False,
color=color,
pretty_format=pretty_format,
)
self.out.nl() self.out.nl()
if diff['unreachable']: if diff["unreachable"]:
self.out.nl() self.out.nl()
self.printText('projects with unreachable revisions : \n') self.printText("projects with unreachable revisions : \n")
self.out.nl() self.out.nl()
for project, otherProject in diff['unreachable']: for project, otherProject in diff["unreachable"]:
self.printProject('\t%s ' % (_RelPath(project))) self.printProject("\t%s " % (_RelPath(project)))
self.printRevision(project.revisionExpr) self.printRevision(project.revisionExpr)
self.printText(' or ') self.printText(" or ")
self.printRevision(otherProject.revisionExpr) self.printRevision(otherProject.revisionExpr)
self.printText(' not found') self.printText(" not found")
self.out.nl() self.out.nl()
def _printLogs(self, project, otherProject, raw=False, color=True, def _printLogs(
pretty_format=None): self, project, otherProject, raw=False, color=True, pretty_format=None
):
logs = project.getAddedAndRemovedLogs(otherProject, logs = project.getAddedAndRemovedLogs(
otherProject,
oneline=(pretty_format is None), oneline=(pretty_format is None),
color=color, color=color,
pretty_format=pretty_format) pretty_format=pretty_format,
if logs['removed']: )
removedLogs = logs['removed'].split('\n') if logs["removed"]:
removedLogs = logs["removed"].split("\n")
for log in removedLogs: for log in removedLogs:
if log.strip(): if log.strip():
if raw: if raw:
self.printText(' R ' + log) self.printText(" R " + log)
self.out.nl() self.out.nl()
else: else:
self.printRemoved('\t\t[-] ') self.printRemoved("\t\t[-] ")
self.printText(log) self.printText(log)
self.out.nl() self.out.nl()
if logs['added']: if logs["added"]:
addedLogs = logs['added'].split('\n') addedLogs = logs["added"].split("\n")
for log in addedLogs: for log in addedLogs:
if log.strip(): if log.strip():
if raw: if raw:
self.printText(' A ' + log) self.printText(" A " + log)
self.out.nl() self.out.nl()
else: else:
self.printAdded('\t\t[+] ') self.printAdded("\t\t[+] ")
self.printText(log) self.printText(log)
self.out.nl() self.out.nl()
def ValidateOptions(self, opt, args): def ValidateOptions(self, opt, args):
if not args or len(args) > 2: if not args or len(args) > 2:
self.OptionParser.error('missing manifests to diff') self.OptionParser.error("missing manifests to diff")
if opt.this_manifest_only is False: if opt.this_manifest_only is False:
raise self.OptionParser.error( raise self.OptionParser.error(
'`diffmanifest` only supports the current tree') "`diffmanifest` only supports the current tree"
)
def Execute(self, opt, args): def Execute(self, opt, args):
self.out = _Coloring(self.client.globalConfig) self.out = _Coloring(self.client.globalConfig)
self.printText = self.out.nofmt_printer('text') self.printText = self.out.nofmt_printer("text")
if opt.color: if opt.color:
self.printProject = self.out.nofmt_printer('project', attr='bold') self.printProject = self.out.nofmt_printer("project", attr="bold")
self.printAdded = self.out.nofmt_printer('green', fg='green', attr='bold') self.printAdded = self.out.nofmt_printer(
self.printRemoved = self.out.nofmt_printer('red', fg='red', attr='bold') "green", fg="green", attr="bold"
self.printRevision = self.out.nofmt_printer('revision', fg='yellow') )
self.printRemoved = self.out.nofmt_printer(
"red", fg="red", attr="bold"
)
self.printRevision = self.out.nofmt_printer("revision", fg="yellow")
else: else:
self.printProject = self.printAdded = self.printRemoved = self.printRevision = self.printText self.printProject = (
self.printAdded
) = self.printRemoved = self.printRevision = self.printText
manifest1 = RepoClient(self.repodir) manifest1 = RepoClient(self.repodir)
manifest1.Override(args[0], load_local_manifests=False) manifest1.Override(args[0], load_local_manifests=False)
@ -216,8 +259,15 @@ synced and their revisions won't be found.
diff = manifest1.projectsDiff(manifest2) diff = manifest1.projectsDiff(manifest2)
if opt.raw: if opt.raw:
self._printRawDiff(diff, pretty_format=opt.pretty_format, self._printRawDiff(
local=opt.this_manifest_only) diff,
pretty_format=opt.pretty_format,
local=opt.this_manifest_only,
)
else: else:
self._printDiff(diff, color=opt.color, pretty_format=opt.pretty_format, self._printDiff(
local=opt.this_manifest_only) diff,
color=opt.color,
pretty_format=opt.pretty_format,
local=opt.this_manifest_only,
)

View File

@ -18,7 +18,7 @@ import sys
from command import Command from command import Command
from error import GitError, NoSuchProjectError from error import GitError, NoSuchProjectError
CHANGE_RE = re.compile(r'^([1-9][0-9]*)(?:[/\.-]([1-9][0-9]*))?$') CHANGE_RE = re.compile(r"^([1-9][0-9]*)(?:[/\.-]([1-9][0-9]*))?$")
class Download(Command): class Download(Command):
@ -34,19 +34,34 @@ If no project is specified try to use current directory as a project.
""" """
def _Options(self, p): def _Options(self, p):
p.add_option('-b', '--branch', p.add_option("-b", "--branch", help="create a new branch first")
help='create a new branch first') p.add_option(
p.add_option('-c', '--cherry-pick', "-c",
dest='cherrypick', action='store_true', "--cherry-pick",
help="cherry-pick instead of checkout") dest="cherrypick",
p.add_option('-x', '--record-origin', action='store_true', action="store_true",
help='pass -x when cherry-picking') help="cherry-pick instead of checkout",
p.add_option('-r', '--revert', )
dest='revert', action='store_true', p.add_option(
help="revert instead of checkout") "-x",
p.add_option('-f', '--ff-only', "--record-origin",
dest='ffonly', action='store_true', action="store_true",
help="force fast-forward merge") help="pass -x when cherry-picking",
)
p.add_option(
"-r",
"--revert",
dest="revert",
action="store_true",
help="revert instead of checkout",
)
p.add_option(
"-f",
"--ff-only",
dest="ffonly",
action="store_true",
help="force fast-forward merge",
)
def _ParseChangeIds(self, opt, args): def _ParseChangeIds(self, opt, args):
if not args: if not args:
@ -60,16 +75,16 @@ If no project is specified try to use current directory as a project.
if m: if m:
if not project: if not project:
project = self.GetProjects(".")[0] project = self.GetProjects(".")[0]
print('Defaulting to cwd project', project.name) print("Defaulting to cwd project", project.name)
chg_id = int(m.group(1)) chg_id = int(m.group(1))
if m.group(2): if m.group(2):
ps_id = int(m.group(2)) ps_id = int(m.group(2))
else: else:
ps_id = 1 ps_id = 1
refs = 'refs/changes/%2.2d/%d/' % (chg_id % 100, chg_id) refs = "refs/changes/%2.2d/%d/" % (chg_id % 100, chg_id)
output = project._LsRemote(refs + '*') output = project._LsRemote(refs + "*")
if output: if output:
regex = refs + r'(\d+)' regex = refs + r"(\d+)"
rcomp = re.compile(regex, re.I) rcomp = re.compile(regex, re.I)
for line in output.splitlines(): for line in output.splitlines():
match = rcomp.search(line) match = rcomp.search(line)
@ -77,73 +92,99 @@ If no project is specified try to use current directory as a project.
ps_id = max(int(match.group(1)), ps_id) ps_id = max(int(match.group(1)), ps_id)
to_get.append((project, chg_id, ps_id)) to_get.append((project, chg_id, ps_id))
else: else:
projects = self.GetProjects([a], all_manifests=not opt.this_manifest_only) projects = self.GetProjects(
[a], all_manifests=not opt.this_manifest_only
)
if len(projects) > 1: if len(projects) > 1:
# If the cwd is one of the projects, assume they want that. # If the cwd is one of the projects, assume they want that.
try: try:
project = self.GetProjects('.')[0] project = self.GetProjects(".")[0]
except NoSuchProjectError: except NoSuchProjectError:
project = None project = None
if project not in projects: if project not in projects:
print('error: %s matches too many projects; please re-run inside ' print(
'the project checkout.' % (a,), file=sys.stderr) "error: %s matches too many projects; please "
"re-run inside the project checkout." % (a,),
file=sys.stderr,
)
for project in projects: for project in projects:
print(' %s/ @ %s' % (project.RelPath(local=opt.this_manifest_only), print(
project.revisionExpr), file=sys.stderr) " %s/ @ %s"
% (
project.RelPath(
local=opt.this_manifest_only
),
project.revisionExpr,
),
file=sys.stderr,
)
sys.exit(1) sys.exit(1)
else: else:
project = projects[0] project = projects[0]
print('Defaulting to cwd project', project.name) print("Defaulting to cwd project", project.name)
return to_get return to_get
def ValidateOptions(self, opt, args): def ValidateOptions(self, opt, args):
if opt.record_origin: if opt.record_origin:
if not opt.cherrypick: if not opt.cherrypick:
self.OptionParser.error('-x only makes sense with --cherry-pick') self.OptionParser.error(
"-x only makes sense with --cherry-pick"
)
if opt.ffonly: if opt.ffonly:
self.OptionParser.error('-x and --ff are mutually exclusive options') self.OptionParser.error(
"-x and --ff are mutually exclusive options"
)
def Execute(self, opt, args): def Execute(self, opt, args):
for project, change_id, ps_id in self._ParseChangeIds(opt, args): for project, change_id, ps_id in self._ParseChangeIds(opt, args):
dl = project.DownloadPatchSet(change_id, ps_id) dl = project.DownloadPatchSet(change_id, ps_id)
if not dl: if not dl:
print('[%s] change %d/%d not found' print(
"[%s] change %d/%d not found"
% (project.name, change_id, ps_id), % (project.name, change_id, ps_id),
file=sys.stderr) file=sys.stderr,
)
sys.exit(1) sys.exit(1)
if not opt.revert and not dl.commits: if not opt.revert and not dl.commits:
print('[%s] change %d/%d has already been merged' print(
"[%s] change %d/%d has already been merged"
% (project.name, change_id, ps_id), % (project.name, change_id, ps_id),
file=sys.stderr) file=sys.stderr,
)
continue continue
if len(dl.commits) > 1: if len(dl.commits) > 1:
print('[%s] %d/%d depends on %d unmerged changes:' print(
"[%s] %d/%d depends on %d unmerged changes:"
% (project.name, change_id, ps_id, len(dl.commits)), % (project.name, change_id, ps_id, len(dl.commits)),
file=sys.stderr) file=sys.stderr,
)
for c in dl.commits: for c in dl.commits:
print(' %s' % (c), file=sys.stderr) print(" %s" % (c), file=sys.stderr)
if opt.cherrypick: if opt.cherrypick:
mode = 'cherry-pick' mode = "cherry-pick"
elif opt.revert: elif opt.revert:
mode = 'revert' mode = "revert"
elif opt.ffonly: elif opt.ffonly:
mode = 'fast-forward merge' mode = "fast-forward merge"
else: else:
mode = 'checkout' mode = "checkout"
# We'll combine the branch+checkout operation, but all the rest need a # We'll combine the branch+checkout operation, but all the rest need
# dedicated branch start. # a dedicated branch start.
if opt.branch and mode != 'checkout': if opt.branch and mode != "checkout":
project.StartBranch(opt.branch) project.StartBranch(opt.branch)
try: try:
if opt.cherrypick: if opt.cherrypick:
project._CherryPick(dl.commit, ffonly=opt.ffonly, project._CherryPick(
record_origin=opt.record_origin) dl.commit,
ffonly=opt.ffonly,
record_origin=opt.record_origin,
)
elif opt.revert: elif opt.revert:
project._Revert(dl.commit) project._Revert(dl.commit)
elif opt.ffonly: elif opt.ffonly:
@ -155,6 +196,9 @@ If no project is specified try to use current directory as a project.
project._Checkout(dl.commit) project._Checkout(dl.commit)
except GitError: except GitError:
print('[%s] Could not complete the %s of %s' print(
% (project.name, mode, dl.commit), file=sys.stderr) "[%s] Could not complete the %s of %s"
% (project.name, mode, dl.commit),
file=sys.stderr,
)
sys.exit(1) sys.exit(1)

View File

@ -23,21 +23,26 @@ import sys
import subprocess import subprocess
from color import Coloring from color import Coloring
from command import DEFAULT_LOCAL_JOBS, Command, MirrorSafeCommand, WORKER_BATCH_SIZE from command import (
DEFAULT_LOCAL_JOBS,
Command,
MirrorSafeCommand,
WORKER_BATCH_SIZE,
)
from error import ManifestInvalidRevisionError from error import ManifestInvalidRevisionError
_CAN_COLOR = [ _CAN_COLOR = [
'branch', "branch",
'diff', "diff",
'grep', "grep",
'log', "log",
] ]
class ForallColoring(Coloring): class ForallColoring(Coloring):
def __init__(self, config): def __init__(self, config):
Coloring.__init__(self, config, 'forall') Coloring.__init__(self, config, "forall")
self.project = self.printer('project', attr='bold') self.project = self.printer("project", attr="bold")
class Forall(Command, MirrorSafeCommand): class Forall(Command, MirrorSafeCommand):
@ -134,35 +139,61 @@ without iterating through the remaining projects.
del parser.rargs[0] del parser.rargs[0]
def _Options(self, p): def _Options(self, p):
p.add_option('-r', '--regex', p.add_option(
dest='regex', action='store_true', "-r",
help='execute the command only on projects matching regex or wildcard expression') "--regex",
p.add_option('-i', '--inverse-regex', dest="regex",
dest='inverse_regex', action='store_true', action="store_true",
help='execute the command only on projects not matching regex or ' help="execute the command only on projects matching regex or "
'wildcard expression') "wildcard expression",
p.add_option('-g', '--groups', )
dest='groups', p.add_option(
help='execute the command only on projects matching the specified groups') "-i",
p.add_option('-c', '--command', "--inverse-regex",
help='command (and arguments) to execute', dest="inverse_regex",
dest='command', action="store_true",
action='callback', help="execute the command only on projects not matching regex or "
callback=self._cmd_option) "wildcard expression",
p.add_option('-e', '--abort-on-errors', )
dest='abort_on_errors', action='store_true', p.add_option(
help='abort if a command exits unsuccessfully') "-g",
p.add_option('--ignore-missing', action='store_true', "--groups",
help='silently skip & do not exit non-zero due missing ' dest="groups",
'checkouts') help="execute the command only on projects matching the specified "
"groups",
)
p.add_option(
"-c",
"--command",
help="command (and arguments) to execute",
dest="command",
action="callback",
callback=self._cmd_option,
)
p.add_option(
"-e",
"--abort-on-errors",
dest="abort_on_errors",
action="store_true",
help="abort if a command exits unsuccessfully",
)
p.add_option(
"--ignore-missing",
action="store_true",
help="silently skip & do not exit non-zero due missing "
"checkouts",
)
g = p.get_option_group('--quiet') g = p.get_option_group("--quiet")
g.add_option('-p', g.add_option(
dest='project_header', action='store_true', "-p",
help='show project headers before output') dest="project_header",
p.add_option('--interactive', action="store_true",
action='store_true', help="show project headers before output",
help='force interactive usage') )
p.add_option(
"--interactive", action="store_true", help="force interactive usage"
)
def WantPager(self, opt): def WantPager(self, opt):
return opt.project_header and opt.jobs == 1 return opt.project_header and opt.jobs == 1
@ -176,44 +207,45 @@ without iterating through the remaining projects.
all_trees = not opt.this_manifest_only all_trees = not opt.this_manifest_only
shell = True shell = True
if re.compile(r'^[a-z0-9A-Z_/\.-]+$').match(cmd[0]): if re.compile(r"^[a-z0-9A-Z_/\.-]+$").match(cmd[0]):
shell = False shell = False
if shell: if shell:
cmd.append(cmd[0]) cmd.append(cmd[0])
cmd.extend(opt.command[1:]) cmd.extend(opt.command[1:])
# Historically, forall operated interactively, and in serial. If the user # Historically, forall operated interactively, and in serial. If the
# has selected 1 job, then default to interacive mode. # user has selected 1 job, then default to interacive mode.
if opt.jobs == 1: if opt.jobs == 1:
opt.interactive = True opt.interactive = True
if opt.project_header \ if opt.project_header and not shell and cmd[0] == "git":
and not shell \
and cmd[0] == 'git':
# If this is a direct git command that can enable colorized # If this is a direct git command that can enable colorized
# output and the user prefers coloring, add --color into the # output and the user prefers coloring, add --color into the
# command line because we are going to wrap the command into # command line because we are going to wrap the command into
# a pipe and git won't know coloring should activate. # a pipe and git won't know coloring should activate.
# #
for cn in cmd[1:]: for cn in cmd[1:]:
if not cn.startswith('-'): if not cn.startswith("-"):
break break
else: else:
cn = None cn = None
if cn and cn in _CAN_COLOR: if cn and cn in _CAN_COLOR:
class ColorCmd(Coloring): class ColorCmd(Coloring):
def __init__(self, config, cmd): def __init__(self, config, cmd):
Coloring.__init__(self, config, cmd) Coloring.__init__(self, config, cmd)
if ColorCmd(self.manifest.manifestProject.config, cn).is_on: if ColorCmd(self.manifest.manifestProject.config, cn).is_on:
cmd.insert(cmd.index(cn) + 1, '--color') cmd.insert(cmd.index(cn) + 1, "--color")
mirror = self.manifest.IsMirror mirror = self.manifest.IsMirror
rc = 0 rc = 0
smart_sync_manifest_name = "smart_sync_override.xml" smart_sync_manifest_name = "smart_sync_override.xml"
smart_sync_manifest_path = os.path.join( smart_sync_manifest_path = os.path.join(
self.manifest.manifestProject.worktree, smart_sync_manifest_name) self.manifest.manifestProject.worktree, smart_sync_manifest_name
)
if os.path.isfile(smart_sync_manifest_path): if os.path.isfile(smart_sync_manifest_path):
self.manifest.Override(smart_sync_manifest_path) self.manifest.Override(smart_sync_manifest_path)
@ -221,49 +253,59 @@ without iterating through the remaining projects.
if opt.regex: if opt.regex:
projects = self.FindProjects(args, all_manifests=all_trees) projects = self.FindProjects(args, all_manifests=all_trees)
elif opt.inverse_regex: elif opt.inverse_regex:
projects = self.FindProjects(args, inverse=True, all_manifests=all_trees) projects = self.FindProjects(
args, inverse=True, all_manifests=all_trees
)
else: else:
projects = self.GetProjects(args, groups=opt.groups, all_manifests=all_trees) projects = self.GetProjects(
args, groups=opt.groups, all_manifests=all_trees
)
os.environ['REPO_COUNT'] = str(len(projects)) os.environ["REPO_COUNT"] = str(len(projects))
try: try:
config = self.manifest.manifestProject.config config = self.manifest.manifestProject.config
with multiprocessing.Pool(opt.jobs, InitWorker) as pool: with multiprocessing.Pool(opt.jobs, InitWorker) as pool:
results_it = pool.imap( results_it = pool.imap(
functools.partial(DoWorkWrapper, mirror, opt, cmd, shell, config), functools.partial(
DoWorkWrapper, mirror, opt, cmd, shell, config
),
enumerate(projects), enumerate(projects),
chunksize=WORKER_BATCH_SIZE) chunksize=WORKER_BATCH_SIZE,
)
first = True first = True
for (r, output) in results_it: for r, output in results_it:
if output: if output:
if first: if first:
first = False first = False
elif opt.project_header: elif opt.project_header:
print() print()
# To simplify the DoWorkWrapper, take care of automatic newlines. # To simplify the DoWorkWrapper, take care of automatic
end = '\n' # newlines.
if output[-1] == '\n': end = "\n"
end = '' if output[-1] == "\n":
end = ""
print(output, end=end) print(output, end=end)
rc = rc or r rc = rc or r
if r != 0 and opt.abort_on_errors: if r != 0 and opt.abort_on_errors:
raise Exception('Aborting due to previous error') raise Exception("Aborting due to previous error")
except (KeyboardInterrupt, WorkerKeyboardInterrupt): except (KeyboardInterrupt, WorkerKeyboardInterrupt):
# Catch KeyboardInterrupt raised inside and outside of workers # Catch KeyboardInterrupt raised inside and outside of workers
rc = rc or errno.EINTR rc = rc or errno.EINTR
except Exception as e: except Exception as e:
# Catch any other exceptions raised # Catch any other exceptions raised
print('forall: unhandled error, terminating the pool: %s: %s' % print(
(type(e).__name__, e), "forall: unhandled error, terminating the pool: %s: %s"
file=sys.stderr) % (type(e).__name__, e),
rc = rc or getattr(e, 'errno', 1) file=sys.stderr,
)
rc = rc or getattr(e, "errno", 1)
if rc != 0: if rc != 0:
sys.exit(rc) sys.exit(rc)
class WorkerKeyboardInterrupt(Exception): class WorkerKeyboardInterrupt(Exception):
""" Keyboard interrupt exception for worker processes. """ """Keyboard interrupt exception for worker processes."""
def InitWorker(): def InitWorker():
@ -271,18 +313,18 @@ def InitWorker():
def DoWorkWrapper(mirror, opt, cmd, shell, config, args): def DoWorkWrapper(mirror, opt, cmd, shell, config, args):
""" A wrapper around the DoWork() method. """A wrapper around the DoWork() method.
Catch the KeyboardInterrupt exceptions here and re-raise them as a different, Catch the KeyboardInterrupt exceptions here and re-raise them as a
``Exception``-based exception to stop it flooding the console with stacktraces different, ``Exception``-based exception to stop it flooding the console
and making the parent hang indefinitely. with stacktraces and making the parent hang indefinitely.
""" """
cnt, project = args cnt, project = args
try: try:
return DoWork(project, mirror, opt, cmd, shell, cnt, config) return DoWork(project, mirror, opt, cmd, shell, cnt, config)
except KeyboardInterrupt: except KeyboardInterrupt:
print('%s: Worker interrupted' % project.name) print("%s: Worker interrupted" % project.name)
raise WorkerKeyboardInterrupt() raise WorkerKeyboardInterrupt()
@ -291,30 +333,31 @@ def DoWork(project, mirror, opt, cmd, shell, cnt, config):
def setenv(name, val): def setenv(name, val):
if val is None: if val is None:
val = '' val = ""
env[name] = val env[name] = val
setenv('REPO_PROJECT', project.name) setenv("REPO_PROJECT", project.name)
setenv('REPO_OUTERPATH', project.manifest.path_prefix) setenv("REPO_OUTERPATH", project.manifest.path_prefix)
setenv('REPO_INNERPATH', project.relpath) setenv("REPO_INNERPATH", project.relpath)
setenv('REPO_PATH', project.RelPath(local=opt.this_manifest_only)) setenv("REPO_PATH", project.RelPath(local=opt.this_manifest_only))
setenv('REPO_REMOTE', project.remote.name) setenv("REPO_REMOTE", project.remote.name)
try: try:
# If we aren't in a fully synced state and we don't have the ref the manifest # If we aren't in a fully synced state and we don't have the ref the
# wants, then this will fail. Ignore it for the purposes of this code. # manifest wants, then this will fail. Ignore it for the purposes of
lrev = '' if mirror else project.GetRevisionId() # this code.
lrev = "" if mirror else project.GetRevisionId()
except ManifestInvalidRevisionError: except ManifestInvalidRevisionError:
lrev = '' lrev = ""
setenv('REPO_LREV', lrev) setenv("REPO_LREV", lrev)
setenv('REPO_RREV', project.revisionExpr) setenv("REPO_RREV", project.revisionExpr)
setenv('REPO_UPSTREAM', project.upstream) setenv("REPO_UPSTREAM", project.upstream)
setenv('REPO_DEST_BRANCH', project.dest_branch) setenv("REPO_DEST_BRANCH", project.dest_branch)
setenv('REPO_I', str(cnt + 1)) setenv("REPO_I", str(cnt + 1))
for annotation in project.annotations: for annotation in project.annotations:
setenv("REPO__%s" % (annotation.name), annotation.value) setenv("REPO__%s" % (annotation.name), annotation.value)
if mirror: if mirror:
setenv('GIT_DIR', project.gitdir) setenv("GIT_DIR", project.gitdir)
cwd = project.gitdir cwd = project.gitdir
else: else:
cwd = project.worktree cwd = project.worktree
@ -323,12 +366,13 @@ def DoWork(project, mirror, opt, cmd, shell, cnt, config):
# Allow the user to silently ignore missing checkouts so they can run on # Allow the user to silently ignore missing checkouts so they can run on
# partial checkouts (good for infra recovery tools). # partial checkouts (good for infra recovery tools).
if opt.ignore_missing: if opt.ignore_missing:
return (0, '') return (0, "")
output = '' output = ""
if ((opt.project_header and opt.verbose) if (opt.project_header and opt.verbose) or not opt.project_header:
or not opt.project_header): output = "skipping %s/" % project.RelPath(
output = 'skipping %s/' % project.RelPath(local=opt.this_manifest_only) local=opt.this_manifest_only
)
return (1, output) return (1, output)
if opt.verbose: if opt.verbose:
@ -339,9 +383,17 @@ def DoWork(project, mirror, opt, cmd, shell, cnt, config):
stdin = None if opt.interactive else subprocess.DEVNULL stdin = None if opt.interactive else subprocess.DEVNULL
result = subprocess.run( result = subprocess.run(
cmd, cwd=cwd, shell=shell, env=env, check=False, cmd,
encoding='utf-8', errors='replace', cwd=cwd,
stdin=stdin, stdout=subprocess.PIPE, stderr=stderr) shell=shell,
env=env,
check=False,
encoding="utf-8",
errors="replace",
stdin=stdin,
stdout=subprocess.PIPE,
stderr=stderr,
)
output = result.stdout output = result.stdout
if opt.project_header: if opt.project_header:
@ -352,8 +404,10 @@ def DoWork(project, mirror, opt, cmd, shell, cnt, config):
if mirror: if mirror:
project_header_path = project.name project_header_path = project.name
else: else:
project_header_path = project.RelPath(local=opt.this_manifest_only) project_header_path = project.RelPath(
out.project('project %s/' % project_header_path) local=opt.this_manifest_only
)
out.project("project %s/" % project_header_path)
out.nl() out.nl()
buf.write(output) buf.write(output)
output = buf.getvalue() output = buf.getvalue()

View File

@ -31,16 +31,22 @@ and all locally downloaded sources.
""" """
def _Options(self, p): def _Options(self, p):
p.add_option('-f', '--force', p.add_option(
dest='force', action='store_true', "-f",
help='force the deletion (no prompt)') "--force",
dest="force",
action="store_true",
help="force the deletion (no prompt)",
)
def Execute(self, opt, args): def Execute(self, opt, args):
if not opt.force: if not opt.force:
prompt = ('This will delete GITC client: %s\nAre you sure? (yes/no) ' % prompt = (
self.gitc_manifest.gitc_client_name) "This will delete GITC client: %s\nAre you sure? (yes/no) "
% self.gitc_manifest.gitc_client_name
)
response = input(prompt).lower() response = input(prompt).lower()
if not response == 'yes': if not response == "yes":
print('Response was not "yes"\n Exiting...') print('Response was not "yes"\n Exiting...')
sys.exit(1) sys.exit(1)
platform_utils.rmtree(self.gitc_manifest.gitc_client_dir) platform_utils.rmtree(self.gitc_manifest.gitc_client_dir)

View File

@ -52,25 +52,36 @@ use for this GITC client.
def Execute(self, opt, args): def Execute(self, opt, args):
gitc_client = gitc_utils.parse_clientdir(os.getcwd()) gitc_client = gitc_utils.parse_clientdir(os.getcwd())
if not gitc_client or (opt.gitc_client and gitc_client != opt.gitc_client): if not gitc_client or (
print('fatal: Please update your repo command. See go/gitc for instructions.', opt.gitc_client and gitc_client != opt.gitc_client
file=sys.stderr) ):
print(
"fatal: Please update your repo command. See go/gitc for "
"instructions.",
file=sys.stderr,
)
sys.exit(1) sys.exit(1)
self.client_dir = os.path.join(gitc_utils.get_gitc_manifest_dir(), self.client_dir = os.path.join(
gitc_client) gitc_utils.get_gitc_manifest_dir(), gitc_client
)
super().Execute(opt, args) super().Execute(opt, args)
manifest_file = self.manifest.manifestFile manifest_file = self.manifest.manifestFile
if opt.manifest_file: if opt.manifest_file:
if not os.path.exists(opt.manifest_file): if not os.path.exists(opt.manifest_file):
print('fatal: Specified manifest file %s does not exist.' % print(
opt.manifest_file) "fatal: Specified manifest file %s does not exist."
% opt.manifest_file
)
sys.exit(1) sys.exit(1)
manifest_file = opt.manifest_file manifest_file = opt.manifest_file
manifest = GitcManifest(self.repodir, os.path.join(self.client_dir, manifest = GitcManifest(
'.manifest')) self.repodir, os.path.join(self.client_dir, ".manifest")
)
manifest.Override(manifest_file) manifest.Override(manifest_file)
gitc_utils.generate_gitc_manifest(None, manifest) gitc_utils.generate_gitc_manifest(None, manifest)
print('Please run `cd %s` to view your GITC client.' % print(
os.path.join(wrapper.Wrapper().GITC_FS_ROOT_DIR, gitc_client)) "Please run `cd %s` to view your GITC client."
% os.path.join(wrapper.Wrapper().GITC_FS_ROOT_DIR, gitc_client)
)

View File

@ -23,9 +23,9 @@ from git_command import GitCommand
class GrepColoring(Coloring): class GrepColoring(Coloring):
def __init__(self, config): def __init__(self, config):
Coloring.__init__(self, config, 'grep') Coloring.__init__(self, config, "grep")
self.project = self.printer('project', attr='bold') self.project = self.printer("project", attr="bold")
self.fail = self.printer('fail', fg='red') self.fail = self.printer("fail", fg="red")
class Grep(PagedCommand): class Grep(PagedCommand):
@ -66,15 +66,15 @@ contain a line that matches both expressions:
@staticmethod @staticmethod
def _carry_option(_option, opt_str, value, parser): def _carry_option(_option, opt_str, value, parser):
pt = getattr(parser.values, 'cmd_argv', None) pt = getattr(parser.values, "cmd_argv", None)
if pt is None: if pt is None:
pt = [] pt = []
setattr(parser.values, 'cmd_argv', pt) setattr(parser.values, "cmd_argv", pt)
if opt_str == '-(': if opt_str == "-(":
pt.append('(') pt.append("(")
elif opt_str == '-)': elif opt_str == "-)":
pt.append(')') pt.append(")")
else: else:
pt.append(opt_str) pt.append(opt_str)
@ -86,86 +86,167 @@ contain a line that matches both expressions:
super()._CommonOptions(p, opt_v=False) super()._CommonOptions(p, opt_v=False)
def _Options(self, p): def _Options(self, p):
g = p.add_option_group('Sources') g = p.add_option_group("Sources")
g.add_option('--cached', g.add_option(
action='callback', callback=self._carry_option, "--cached",
help='Search the index, instead of the work tree') action="callback",
g.add_option('-r', '--revision', callback=self._carry_option,
dest='revision', action='append', metavar='TREEish', help="Search the index, instead of the work tree",
help='Search TREEish, instead of the work tree') )
g.add_option(
"-r",
"--revision",
dest="revision",
action="append",
metavar="TREEish",
help="Search TREEish, instead of the work tree",
)
g = p.add_option_group('Pattern') g = p.add_option_group("Pattern")
g.add_option('-e', g.add_option(
action='callback', callback=self._carry_option, "-e",
metavar='PATTERN', type='str', action="callback",
help='Pattern to search for') callback=self._carry_option,
g.add_option('-i', '--ignore-case', metavar="PATTERN",
action='callback', callback=self._carry_option, type="str",
help='Ignore case differences') help="Pattern to search for",
g.add_option('-a', '--text', )
action='callback', callback=self._carry_option, g.add_option(
help="Process binary files as if they were text") "-i",
g.add_option('-I', "--ignore-case",
action='callback', callback=self._carry_option, action="callback",
help="Don't match the pattern in binary files") callback=self._carry_option,
g.add_option('-w', '--word-regexp', help="Ignore case differences",
action='callback', callback=self._carry_option, )
help='Match the pattern only at word boundaries') g.add_option(
g.add_option('-v', '--invert-match', "-a",
action='callback', callback=self._carry_option, "--text",
help='Select non-matching lines') action="callback",
g.add_option('-G', '--basic-regexp', callback=self._carry_option,
action='callback', callback=self._carry_option, help="Process binary files as if they were text",
help='Use POSIX basic regexp for patterns (default)') )
g.add_option('-E', '--extended-regexp', g.add_option(
action='callback', callback=self._carry_option, "-I",
help='Use POSIX extended regexp for patterns') action="callback",
g.add_option('-F', '--fixed-strings', callback=self._carry_option,
action='callback', callback=self._carry_option, help="Don't match the pattern in binary files",
help='Use fixed strings (not regexp) for pattern') )
g.add_option(
"-w",
"--word-regexp",
action="callback",
callback=self._carry_option,
help="Match the pattern only at word boundaries",
)
g.add_option(
"-v",
"--invert-match",
action="callback",
callback=self._carry_option,
help="Select non-matching lines",
)
g.add_option(
"-G",
"--basic-regexp",
action="callback",
callback=self._carry_option,
help="Use POSIX basic regexp for patterns (default)",
)
g.add_option(
"-E",
"--extended-regexp",
action="callback",
callback=self._carry_option,
help="Use POSIX extended regexp for patterns",
)
g.add_option(
"-F",
"--fixed-strings",
action="callback",
callback=self._carry_option,
help="Use fixed strings (not regexp) for pattern",
)
g = p.add_option_group('Pattern Grouping') g = p.add_option_group("Pattern Grouping")
g.add_option('--all-match', g.add_option(
action='callback', callback=self._carry_option, "--all-match",
help='Limit match to lines that have all patterns') action="callback",
g.add_option('--and', '--or', '--not', callback=self._carry_option,
action='callback', callback=self._carry_option, help="Limit match to lines that have all patterns",
help='Boolean operators to combine patterns') )
g.add_option('-(', '-)', g.add_option(
action='callback', callback=self._carry_option, "--and",
help='Boolean operator grouping') "--or",
"--not",
action="callback",
callback=self._carry_option,
help="Boolean operators to combine patterns",
)
g.add_option(
"-(",
"-)",
action="callback",
callback=self._carry_option,
help="Boolean operator grouping",
)
g = p.add_option_group('Output') g = p.add_option_group("Output")
g.add_option('-n', g.add_option(
action='callback', callback=self._carry_option, "-n",
help='Prefix the line number to matching lines') action="callback",
g.add_option('-C', callback=self._carry_option,
action='callback', callback=self._carry_option, help="Prefix the line number to matching lines",
metavar='CONTEXT', type='str', )
help='Show CONTEXT lines around match') g.add_option(
g.add_option('-B', "-C",
action='callback', callback=self._carry_option, action="callback",
metavar='CONTEXT', type='str', callback=self._carry_option,
help='Show CONTEXT lines before match') metavar="CONTEXT",
g.add_option('-A', type="str",
action='callback', callback=self._carry_option, help="Show CONTEXT lines around match",
metavar='CONTEXT', type='str', )
help='Show CONTEXT lines after match') g.add_option(
g.add_option('-l', '--name-only', '--files-with-matches', "-B",
action='callback', callback=self._carry_option, action="callback",
help='Show only file names containing matching lines') callback=self._carry_option,
g.add_option('-L', '--files-without-match', metavar="CONTEXT",
action='callback', callback=self._carry_option, type="str",
help='Show only file names not containing matching lines') help="Show CONTEXT lines before match",
)
g.add_option(
"-A",
action="callback",
callback=self._carry_option,
metavar="CONTEXT",
type="str",
help="Show CONTEXT lines after match",
)
g.add_option(
"-l",
"--name-only",
"--files-with-matches",
action="callback",
callback=self._carry_option,
help="Show only file names containing matching lines",
)
g.add_option(
"-L",
"--files-without-match",
action="callback",
callback=self._carry_option,
help="Show only file names not containing matching lines",
)
def _ExecuteOne(self, cmd_argv, project): def _ExecuteOne(self, cmd_argv, project):
"""Process one project.""" """Process one project."""
try: try:
p = GitCommand(project, p = GitCommand(
project,
cmd_argv, cmd_argv,
bare=False, bare=False,
capture_stdout=True, capture_stdout=True,
capture_stderr=True) capture_stderr=True,
)
except GitError as e: except GitError as e:
return (project, -1, None, str(e)) return (project, -1, None, str(e))
@ -181,42 +262,42 @@ contain a line that matches both expressions:
for project, rc, stdout, stderr in results: for project, rc, stdout, stderr in results:
if rc < 0: if rc < 0:
git_failed = True git_failed = True
out.project('--- project %s ---' % _RelPath(project)) out.project("--- project %s ---" % _RelPath(project))
out.nl() out.nl()
out.fail('%s', stderr) out.fail("%s", stderr)
out.nl() out.nl()
continue continue
if rc: if rc:
# no results # no results
if stderr: if stderr:
if have_rev and 'fatal: ambiguous argument' in stderr: if have_rev and "fatal: ambiguous argument" in stderr:
bad_rev = True bad_rev = True
else: else:
out.project('--- project %s ---' % _RelPath(project)) out.project("--- project %s ---" % _RelPath(project))
out.nl() out.nl()
out.fail('%s', stderr.strip()) out.fail("%s", stderr.strip())
out.nl() out.nl()
continue continue
have_match = True have_match = True
# We cut the last element, to avoid a blank line. # We cut the last element, to avoid a blank line.
r = stdout.split('\n') r = stdout.split("\n")
r = r[0:-1] r = r[0:-1]
if have_rev and full_name: if have_rev and full_name:
for line in r: for line in r:
rev, line = line.split(':', 1) rev, line = line.split(":", 1)
out.write("%s", rev) out.write("%s", rev)
out.write(':') out.write(":")
out.project(_RelPath(project)) out.project(_RelPath(project))
out.write('/') out.write("/")
out.write("%s", line) out.write("%s", line)
out.nl() out.nl()
elif full_name: elif full_name:
for line in r: for line in r:
out.project(_RelPath(project)) out.project(_RelPath(project))
out.write('/') out.write("/")
out.write("%s", line) out.write("%s", line)
out.nl() out.nl()
else: else:
@ -228,41 +309,49 @@ contain a line that matches both expressions:
def Execute(self, opt, args): def Execute(self, opt, args):
out = GrepColoring(self.manifest.manifestProject.config) out = GrepColoring(self.manifest.manifestProject.config)
cmd_argv = ['grep'] cmd_argv = ["grep"]
if out.is_on: if out.is_on:
cmd_argv.append('--color') cmd_argv.append("--color")
cmd_argv.extend(getattr(opt, 'cmd_argv', [])) cmd_argv.extend(getattr(opt, "cmd_argv", []))
if '-e' not in cmd_argv: if "-e" not in cmd_argv:
if not args: if not args:
self.Usage() self.Usage()
cmd_argv.append('-e') cmd_argv.append("-e")
cmd_argv.append(args[0]) cmd_argv.append(args[0])
args = args[1:] args = args[1:]
projects = self.GetProjects(args, all_manifests=not opt.this_manifest_only) projects = self.GetProjects(
args, all_manifests=not opt.this_manifest_only
)
full_name = False full_name = False
if len(projects) > 1: if len(projects) > 1:
cmd_argv.append('--full-name') cmd_argv.append("--full-name")
full_name = True full_name = True
have_rev = False have_rev = False
if opt.revision: if opt.revision:
if '--cached' in cmd_argv: if "--cached" in cmd_argv:
print('fatal: cannot combine --cached and --revision', file=sys.stderr) print(
"fatal: cannot combine --cached and --revision",
file=sys.stderr,
)
sys.exit(1) sys.exit(1)
have_rev = True have_rev = True
cmd_argv.extend(opt.revision) cmd_argv.extend(opt.revision)
cmd_argv.append('--') cmd_argv.append("--")
git_failed, bad_rev, have_match = self.ExecuteInParallel( git_failed, bad_rev, have_match = self.ExecuteInParallel(
opt.jobs, opt.jobs,
functools.partial(self._ExecuteOne, cmd_argv), functools.partial(self._ExecuteOne, cmd_argv),
projects, projects,
callback=functools.partial(self._ProcessResults, full_name, have_rev, opt), callback=functools.partial(
self._ProcessResults, full_name, have_rev, opt
),
output=out, output=out,
ordered=True) ordered=True,
)
if git_failed: if git_failed:
sys.exit(1) sys.exit(1)

View File

@ -18,7 +18,12 @@ import textwrap
from subcmds import all_commands from subcmds import all_commands
from color import Coloring from color import Coloring
from command import PagedCommand, MirrorSafeCommand, GitcAvailableCommand, GitcClientCommand from command import (
PagedCommand,
MirrorSafeCommand,
GitcAvailableCommand,
GitcClientCommand,
)
import gitc_utils import gitc_utils
from wrapper import Wrapper from wrapper import Wrapper
@ -38,37 +43,41 @@ Displays detailed usage information about a command.
maxlen = 0 maxlen = 0
for name in commandNames: for name in commandNames:
maxlen = max(maxlen, len(name)) maxlen = max(maxlen, len(name))
fmt = ' %%-%ds %%s' % maxlen fmt = " %%-%ds %%s" % maxlen
for name in commandNames: for name in commandNames:
command = all_commands[name]() command = all_commands[name]()
try: try:
summary = command.helpSummary.strip() summary = command.helpSummary.strip()
except AttributeError: except AttributeError:
summary = '' summary = ""
print(fmt % (name, summary)) print(fmt % (name, summary))
def _PrintAllCommands(self): def _PrintAllCommands(self):
print('usage: repo COMMAND [ARGS]') print("usage: repo COMMAND [ARGS]")
self.PrintAllCommandsBody() self.PrintAllCommandsBody()
def PrintAllCommandsBody(self): def PrintAllCommandsBody(self):
print('The complete list of recognized repo commands is:') print("The complete list of recognized repo commands is:")
commandNames = list(sorted(all_commands)) commandNames = list(sorted(all_commands))
self._PrintCommands(commandNames) self._PrintCommands(commandNames)
print("See 'repo help <command>' for more information on a " print(
'specific command.') "See 'repo help <command>' for more information on a "
print('Bug reports:', Wrapper().BUG_URL) "specific command."
)
print("Bug reports:", Wrapper().BUG_URL)
def _PrintCommonCommands(self): def _PrintCommonCommands(self):
print('usage: repo COMMAND [ARGS]') print("usage: repo COMMAND [ARGS]")
self.PrintCommonCommandsBody() self.PrintCommonCommandsBody()
def PrintCommonCommandsBody(self): def PrintCommonCommandsBody(self):
print('The most commonly used repo commands are:') print("The most commonly used repo commands are:")
def gitc_supported(cmd): def gitc_supported(cmd):
if not isinstance(cmd, GitcAvailableCommand) and not isinstance(cmd, GitcClientCommand): if not isinstance(cmd, GitcAvailableCommand) and not isinstance(
cmd, GitcClientCommand
):
return True return True
if self.client.isGitcClient: if self.client.isGitcClient:
return True return True
@ -78,21 +87,29 @@ Displays detailed usage information about a command.
return True return True
return False return False
commandNames = list(sorted([name commandNames = list(
sorted(
[
name
for name, command in all_commands.items() for name, command in all_commands.items()
if command.COMMON and gitc_supported(command)])) if command.COMMON and gitc_supported(command)
]
)
)
self._PrintCommands(commandNames) self._PrintCommands(commandNames)
print( print(
"See 'repo help <command>' for more information on a specific command.\n" "See 'repo help <command>' for more information on a specific "
"See 'repo help --all' for a complete list of recognized commands.") "command.\nSee 'repo help --all' for a complete list of recognized "
print('Bug reports:', Wrapper().BUG_URL) "commands."
)
print("Bug reports:", Wrapper().BUG_URL)
def _PrintCommandHelp(self, cmd, header_prefix=''): def _PrintCommandHelp(self, cmd, header_prefix=""):
class _Out(Coloring): class _Out(Coloring):
def __init__(self, gc): def __init__(self, gc):
Coloring.__init__(self, gc, 'help') Coloring.__init__(self, gc, "help")
self.heading = self.printer('heading', attr='bold') self.heading = self.printer("heading", attr="bold")
self._first = True self._first = True
def _PrintSection(self, heading, bodyAttr): def _PrintSection(self, heading, bodyAttr):
@ -100,61 +117,72 @@ Displays detailed usage information about a command.
body = getattr(cmd, bodyAttr) body = getattr(cmd, bodyAttr)
except AttributeError: except AttributeError:
return return
if body == '' or body is None: if body == "" or body is None:
return return
if not self._first: if not self._first:
self.nl() self.nl()
self._first = False self._first = False
self.heading('%s%s', header_prefix, heading) self.heading("%s%s", header_prefix, heading)
self.nl() self.nl()
self.nl() self.nl()
me = 'repo %s' % cmd.NAME me = "repo %s" % cmd.NAME
body = body.strip() body = body.strip()
body = body.replace('%prog', me) body = body.replace("%prog", me)
# Extract the title, but skip any trailing {#anchors}. # Extract the title, but skip any trailing {#anchors}.
asciidoc_hdr = re.compile(r'^\n?#+ ([^{]+)(\{#.+\})?$') asciidoc_hdr = re.compile(r"^\n?#+ ([^{]+)(\{#.+\})?$")
for para in body.split("\n\n"): for para in body.split("\n\n"):
if para.startswith(' '): if para.startswith(" "):
self.write('%s', para) self.write("%s", para)
self.nl() self.nl()
self.nl() self.nl()
continue continue
m = asciidoc_hdr.match(para) m = asciidoc_hdr.match(para)
if m: if m:
self.heading('%s%s', header_prefix, m.group(1)) self.heading("%s%s", header_prefix, m.group(1))
self.nl() self.nl()
self.nl() self.nl()
continue continue
lines = textwrap.wrap(para.replace(' ', ' '), width=80, lines = textwrap.wrap(
break_long_words=False, break_on_hyphens=False) para.replace(" ", " "),
width=80,
break_long_words=False,
break_on_hyphens=False,
)
for line in lines: for line in lines:
self.write('%s', line) self.write("%s", line)
self.nl() self.nl()
self.nl() self.nl()
out = _Out(self.client.globalConfig) out = _Out(self.client.globalConfig)
out._PrintSection('Summary', 'helpSummary') out._PrintSection("Summary", "helpSummary")
cmd.OptionParser.print_help() cmd.OptionParser.print_help()
out._PrintSection('Description', 'helpDescription') out._PrintSection("Description", "helpDescription")
def _PrintAllCommandHelp(self): def _PrintAllCommandHelp(self):
for name in sorted(all_commands): for name in sorted(all_commands):
cmd = all_commands[name](manifest=self.manifest) cmd = all_commands[name](manifest=self.manifest)
self._PrintCommandHelp(cmd, header_prefix='[%s] ' % (name,)) self._PrintCommandHelp(cmd, header_prefix="[%s] " % (name,))
def _Options(self, p): def _Options(self, p):
p.add_option('-a', '--all', p.add_option(
dest='show_all', action='store_true', "-a",
help='show the complete list of commands') "--all",
p.add_option('--help-all', dest="show_all",
dest='show_all_help', action='store_true', action="store_true",
help='show the --help of all commands') help="show the complete list of commands",
)
p.add_option(
"--help-all",
dest="show_all_help",
action="store_true",
help="show the --help of all commands",
)
def Execute(self, opt, args): def Execute(self, opt, args):
if len(args) == 0: if len(args) == 0:
@ -171,7 +199,9 @@ Displays detailed usage information about a command.
try: try:
cmd = all_commands[name](manifest=self.manifest) cmd = all_commands[name](manifest=self.manifest)
except KeyError: except KeyError:
print("repo: '%s' is not a repo command." % name, file=sys.stderr) print(
"repo: '%s' is not a repo command." % name, file=sys.stderr
)
sys.exit(1) sys.exit(1)
self._PrintCommandHelp(cmd) self._PrintCommandHelp(cmd)

View File

@ -26,38 +26,62 @@ class _Coloring(Coloring):
class Info(PagedCommand): class Info(PagedCommand):
COMMON = True COMMON = True
helpSummary = "Get info on the manifest branch, current branch or unmerged branches" helpSummary = (
"Get info on the manifest branch, current branch or unmerged branches"
)
helpUsage = "%prog [-dl] [-o [-c]] [<project>...]" helpUsage = "%prog [-dl] [-o [-c]] [<project>...]"
def _Options(self, p): def _Options(self, p):
p.add_option('-d', '--diff', p.add_option(
dest='all', action='store_true', "-d",
help="show full info and commit diff including remote branches") "--diff",
p.add_option('-o', '--overview', dest="all",
dest='overview', action='store_true', action="store_true",
help='show overview of all local commits') help="show full info and commit diff including remote branches",
p.add_option('-c', '--current-branch', )
dest="current_branch", action="store_true", p.add_option(
help="consider only checked out branches") "-o",
p.add_option('--no-current-branch', "--overview",
dest='current_branch', action='store_false', dest="overview",
help='consider all local branches') action="store_true",
help="show overview of all local commits",
)
p.add_option(
"-c",
"--current-branch",
dest="current_branch",
action="store_true",
help="consider only checked out branches",
)
p.add_option(
"--no-current-branch",
dest="current_branch",
action="store_false",
help="consider all local branches",
)
# Turn this into a warning & remove this someday. # Turn this into a warning & remove this someday.
p.add_option('-b', p.add_option(
dest='current_branch', action='store_true', "-b",
help=optparse.SUPPRESS_HELP) dest="current_branch",
p.add_option('-l', '--local-only', action="store_true",
dest="local", action="store_true", help=optparse.SUPPRESS_HELP,
help="disable all remote operations") )
p.add_option(
"-l",
"--local-only",
dest="local",
action="store_true",
help="disable all remote operations",
)
def Execute(self, opt, args): def Execute(self, opt, args):
self.out = _Coloring(self.client.globalConfig) self.out = _Coloring(self.client.globalConfig)
self.heading = self.out.printer('heading', attr='bold') self.heading = self.out.printer("heading", attr="bold")
self.headtext = self.out.nofmt_printer('headtext', fg='yellow') self.headtext = self.out.nofmt_printer("headtext", fg="yellow")
self.redtext = self.out.printer('redtext', fg='red') self.redtext = self.out.printer("redtext", fg="red")
self.sha = self.out.printer("sha", fg='yellow') self.sha = self.out.printer("sha", fg="yellow")
self.text = self.out.nofmt_printer('text') self.text = self.out.nofmt_printer("text")
self.dimtext = self.out.printer('dimtext', attr='dim') self.dimtext = self.out.printer("dimtext", attr="dim")
self.opt = opt self.opt = opt
@ -108,7 +132,7 @@ class Info(PagedCommand):
currentBranch = p.CurrentBranch currentBranch = p.CurrentBranch
if currentBranch: if currentBranch:
self.heading('Current branch: ') self.heading("Current branch: ")
self.headtext(currentBranch) self.headtext(currentBranch)
self.out.nl() self.out.nl()
@ -135,26 +159,28 @@ class Info(PagedCommand):
if not self.opt.local: if not self.opt.local:
project.Sync_NetworkHalf(quiet=True, current_branch_only=True) project.Sync_NetworkHalf(quiet=True, current_branch_only=True)
branch = self.manifest.manifestProject.config.GetBranch('default').merge branch = self.manifest.manifestProject.config.GetBranch("default").merge
if branch.startswith(R_HEADS): if branch.startswith(R_HEADS):
branch = branch[len(R_HEADS):] branch = branch[len(R_HEADS) :]
logTarget = R_M + branch logTarget = R_M + branch
bareTmp = project.bare_git._bare bareTmp = project.bare_git._bare
project.bare_git._bare = False project.bare_git._bare = False
localCommits = project.bare_git.rev_list( localCommits = project.bare_git.rev_list(
'--abbrev=8', "--abbrev=8",
'--abbrev-commit', "--abbrev-commit",
'--pretty=oneline', "--pretty=oneline",
logTarget + "..", logTarget + "..",
'--') "--",
)
originCommits = project.bare_git.rev_list( originCommits = project.bare_git.rev_list(
'--abbrev=8', "--abbrev=8",
'--abbrev-commit', "--abbrev-commit",
'--pretty=oneline', "--pretty=oneline",
".." + logTarget, ".." + logTarget,
'--') "--",
)
project.bare_git._bare = bareTmp project.bare_git._bare = bareTmp
self.heading("Local Commits: ") self.heading("Local Commits: ")
@ -182,9 +208,10 @@ class Info(PagedCommand):
def _printCommitOverview(self, opt, args): def _printCommitOverview(self, opt, args):
all_branches = [] all_branches = []
for project in self.GetProjects(args, all_manifests=not opt.this_manifest_only): for project in self.GetProjects(
br = [project.GetUploadableBranch(x) args, all_manifests=not opt.this_manifest_only
for x in project.GetBranches()] ):
br = [project.GetUploadableBranch(x) for x in project.GetBranches()]
br = [x for x in br if x] br = [x for x in br if x]
if self.opt.current_branch: if self.opt.current_branch:
br = [x for x in br if x.name == project.CurrentBranch] br = [x for x in br if x.name == project.CurrentBranch]
@ -194,7 +221,7 @@ class Info(PagedCommand):
return return
self.out.nl() self.out.nl()
self.heading('Projects Overview') self.heading("Projects Overview")
project = None project = None
for branch in all_branches: for branch in all_branches:
@ -206,17 +233,21 @@ class Info(PagedCommand):
commits = branch.commits commits = branch.commits
date = branch.date date = branch.date
self.text('%s %-33s (%2d commit%s, %s)' % ( self.text(
branch.name == project.CurrentBranch and '*' or ' ', "%s %-33s (%2d commit%s, %s)"
% (
branch.name == project.CurrentBranch and "*" or " ",
branch.name, branch.name,
len(commits), len(commits),
len(commits) != 1 and 's' or '', len(commits) != 1 and "s" or "",
date)) date,
)
)
self.out.nl() self.out.nl()
for commit in commits: for commit in commits:
split = commit.split() split = commit.split()
self.text('{0:38}{1} '.format('', '-')) self.text("{0:38}{1} ".format("", "-"))
self.sha(split[0] + " ") self.sha(split[0] + " ")
self.text(" ".join(split[1:])) self.text(" ".join(split[1:]))
self.out.nl() self.out.nl()

View File

@ -82,20 +82,38 @@ to update the working directory files.
def _Options(self, p, gitc_init=False): def _Options(self, p, gitc_init=False):
Wrapper().InitParser(p, gitc_init=gitc_init) Wrapper().InitParser(p, gitc_init=gitc_init)
m = p.add_option_group('Multi-manifest') m = p.add_option_group("Multi-manifest")
m.add_option('--outer-manifest', action='store_true', default=True, m.add_option(
help='operate starting at the outermost manifest') "--outer-manifest",
m.add_option('--no-outer-manifest', dest='outer_manifest', action="store_true",
action='store_false', help='do not operate on outer manifests') default=True,
m.add_option('--this-manifest-only', action='store_true', default=None, help="operate starting at the outermost manifest",
help='only operate on this (sub)manifest') )
m.add_option('--no-this-manifest-only', '--all-manifests', m.add_option(
dest='this_manifest_only', action='store_false', "--no-outer-manifest",
help='operate on this manifest and its submanifests') dest="outer_manifest",
action="store_false",
help="do not operate on outer manifests",
)
m.add_option(
"--this-manifest-only",
action="store_true",
default=None,
help="only operate on this (sub)manifest",
)
m.add_option(
"--no-this-manifest-only",
"--all-manifests",
dest="this_manifest_only",
action="store_false",
help="operate on this manifest and its submanifests",
)
def _RegisteredEnvironmentOptions(self): def _RegisteredEnvironmentOptions(self):
return {'REPO_MANIFEST_URL': 'manifest_url', return {
'REPO_MIRROR_LOCATION': 'reference'} "REPO_MANIFEST_URL": "manifest_url",
"REPO_MIRROR_LOCATION": "reference",
}
def _SyncManifest(self, opt): def _SyncManifest(self, opt):
"""Call manifestProject.Sync with arguments from opt. """Call manifestProject.Sync with arguments from opt.
@ -130,13 +148,14 @@ to update the working directory files.
tags=opt.tags, tags=opt.tags,
depth=opt.depth, depth=opt.depth,
git_event_log=self.git_event_log, git_event_log=self.git_event_log,
manifest_name=opt.manifest_name): manifest_name=opt.manifest_name,
):
sys.exit(1) sys.exit(1)
def _Prompt(self, prompt, value): def _Prompt(self, prompt, value):
print('%-10s [%s]: ' % (prompt, value), end='', flush=True) print("%-10s [%s]: " % (prompt, value), end="", flush=True)
a = sys.stdin.readline().strip() a = sys.stdin.readline().strip()
if a == '': if a == "":
return value return value
return a return a
@ -145,18 +164,26 @@ to update the working directory files.
mp = self.manifest.manifestProject mp = self.manifest.manifestProject
# If we don't have local settings, get from global. # If we don't have local settings, get from global.
if not mp.config.Has('user.name') or not mp.config.Has('user.email'): if not mp.config.Has("user.name") or not mp.config.Has("user.email"):
if not gc.Has('user.name') or not gc.Has('user.email'): if not gc.Has("user.name") or not gc.Has("user.email"):
return True return True
mp.config.SetString('user.name', gc.GetString('user.name')) mp.config.SetString("user.name", gc.GetString("user.name"))
mp.config.SetString('user.email', gc.GetString('user.email')) mp.config.SetString("user.email", gc.GetString("user.email"))
if not opt.quiet and not existing_checkout or opt.verbose: if not opt.quiet and not existing_checkout or opt.verbose:
print() print()
print('Your identity is: %s <%s>' % (mp.config.GetString('user.name'), print(
mp.config.GetString('user.email'))) "Your identity is: %s <%s>"
print("If you want to change this, please re-run 'repo init' with --config-name") % (
mp.config.GetString("user.name"),
mp.config.GetString("user.email"),
)
)
print(
"If you want to change this, please re-run 'repo init' with "
"--config-name"
)
return False return False
def _ConfigureUser(self, opt): def _ConfigureUser(self, opt):
@ -165,25 +192,25 @@ to update the working directory files.
while True: while True:
if not opt.quiet: if not opt.quiet:
print() print()
name = self._Prompt('Your Name', mp.UserName) name = self._Prompt("Your Name", mp.UserName)
email = self._Prompt('Your Email', mp.UserEmail) email = self._Prompt("Your Email", mp.UserEmail)
if not opt.quiet: if not opt.quiet:
print() print()
print('Your identity is: %s <%s>' % (name, email)) print("Your identity is: %s <%s>" % (name, email))
print('is this correct [y/N]? ', end='', flush=True) print("is this correct [y/N]? ", end="", flush=True)
a = sys.stdin.readline().strip().lower() a = sys.stdin.readline().strip().lower()
if a in ('yes', 'y', 't', 'true'): if a in ("yes", "y", "t", "true"):
break break
if name != mp.UserName: if name != mp.UserName:
mp.config.SetString('user.name', name) mp.config.SetString("user.name", name)
if email != mp.UserEmail: if email != mp.UserEmail:
mp.config.SetString('user.email', email) mp.config.SetString("user.email", email)
def _HasColorSet(self, gc): def _HasColorSet(self, gc):
for n in ['ui', 'diff', 'status']: for n in ["ui", "diff", "status"]:
if gc.Has('color.%s' % n): if gc.Has("color.%s" % n):
return True return True
return False return False
@ -194,92 +221,112 @@ to update the working directory files.
class _Test(Coloring): class _Test(Coloring):
def __init__(self): def __init__(self):
Coloring.__init__(self, gc, 'test color display') Coloring.__init__(self, gc, "test color display")
self._on = True self._on = True
out = _Test() out = _Test()
print() print()
print("Testing colorized output (for 'repo diff', 'repo status'):") print("Testing colorized output (for 'repo diff', 'repo status'):")
for c in ['black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan']: for c in ["black", "red", "green", "yellow", "blue", "magenta", "cyan"]:
out.write(' ') out.write(" ")
out.printer(fg=c)(' %-6s ', c) out.printer(fg=c)(" %-6s ", c)
out.write(' ') out.write(" ")
out.printer(fg='white', bg='black')(' %s ' % 'white') out.printer(fg="white", bg="black")(" %s " % "white")
out.nl() out.nl()
for c in ['bold', 'dim', 'ul', 'reverse']: for c in ["bold", "dim", "ul", "reverse"]:
out.write(' ') out.write(" ")
out.printer(fg='black', attr=c)(' %-6s ', c) out.printer(fg="black", attr=c)(" %-6s ", c)
out.nl() out.nl()
print('Enable color display in this user account (y/N)? ', end='', flush=True) print(
"Enable color display in this user account (y/N)? ",
end="",
flush=True,
)
a = sys.stdin.readline().strip().lower() a = sys.stdin.readline().strip().lower()
if a in ('y', 'yes', 't', 'true', 'on'): if a in ("y", "yes", "t", "true", "on"):
gc.SetString('color.ui', 'auto') gc.SetString("color.ui", "auto")
def _DisplayResult(self): def _DisplayResult(self):
if self.manifest.IsMirror: if self.manifest.IsMirror:
init_type = 'mirror ' init_type = "mirror "
else: else:
init_type = '' init_type = ""
print() print()
print('repo %shas been initialized in %s' % (init_type, self.manifest.topdir)) print(
"repo %shas been initialized in %s"
% (init_type, self.manifest.topdir)
)
current_dir = os.getcwd() current_dir = os.getcwd()
if current_dir != self.manifest.topdir: if current_dir != self.manifest.topdir:
print('If this is not the directory in which you want to initialize ' print(
'repo, please run:') "If this is not the directory in which you want to initialize "
print(' rm -r %s' % os.path.join(self.manifest.topdir, '.repo')) "repo, please run:"
print('and try again.') )
print(" rm -r %s" % os.path.join(self.manifest.topdir, ".repo"))
print("and try again.")
def ValidateOptions(self, opt, args): def ValidateOptions(self, opt, args):
if opt.reference: if opt.reference:
opt.reference = os.path.expanduser(opt.reference) opt.reference = os.path.expanduser(opt.reference)
# Check this here, else manifest will be tagged "not new" and init won't be # Check this here, else manifest will be tagged "not new" and init won't
# possible anymore without removing the .repo/manifests directory. # be possible anymore without removing the .repo/manifests directory.
if opt.mirror: if opt.mirror:
if opt.archive: if opt.archive:
self.OptionParser.error('--mirror and --archive cannot be used ' self.OptionParser.error(
'together.') "--mirror and --archive cannot be used " "together."
)
if opt.use_superproject is not None: if opt.use_superproject is not None:
self.OptionParser.error('--mirror and --use-superproject cannot be ' self.OptionParser.error(
'used together.') "--mirror and --use-superproject cannot be "
"used together."
)
if opt.archive and opt.use_superproject is not None: if opt.archive and opt.use_superproject is not None:
self.OptionParser.error('--archive and --use-superproject cannot be used ' self.OptionParser.error(
'together.') "--archive and --use-superproject cannot be used " "together."
)
if opt.standalone_manifest and (opt.manifest_branch or if opt.standalone_manifest and (
opt.manifest_name != 'default.xml'): opt.manifest_branch or opt.manifest_name != "default.xml"
self.OptionParser.error('--manifest-branch and --manifest-name cannot' ):
' be used with --standalone-manifest.') self.OptionParser.error(
"--manifest-branch and --manifest-name cannot"
" be used with --standalone-manifest."
)
if args: if args:
if opt.manifest_url: if opt.manifest_url:
self.OptionParser.error( self.OptionParser.error(
'--manifest-url option and URL argument both specified: only use ' "--manifest-url option and URL argument both specified: "
'one to select the manifest URL.') "only use one to select the manifest URL."
)
opt.manifest_url = args.pop(0) opt.manifest_url = args.pop(0)
if args: if args:
self.OptionParser.error('too many arguments to init') self.OptionParser.error("too many arguments to init")
def Execute(self, opt, args): def Execute(self, opt, args):
git_require(MIN_GIT_VERSION_HARD, fail=True) git_require(MIN_GIT_VERSION_HARD, fail=True)
if not git_require(MIN_GIT_VERSION_SOFT): if not git_require(MIN_GIT_VERSION_SOFT):
print('repo: warning: git-%s+ will soon be required; please upgrade your ' print(
'version of git to maintain support.' "repo: warning: git-%s+ will soon be required; please upgrade "
% ('.'.join(str(x) for x in MIN_GIT_VERSION_SOFT),), "your version of git to maintain support."
file=sys.stderr) % (".".join(str(x) for x in MIN_GIT_VERSION_SOFT),),
file=sys.stderr,
)
rp = self.manifest.repoProject rp = self.manifest.repoProject
# Handle new --repo-url requests. # Handle new --repo-url requests.
if opt.repo_url: if opt.repo_url:
remote = rp.GetRemote('origin') remote = rp.GetRemote("origin")
remote.url = opt.repo_url remote.url = opt.repo_url
remote.Save() remote.Save()
@ -288,30 +335,43 @@ to update the working directory files.
wrapper = Wrapper() wrapper = Wrapper()
try: try:
remote_ref, rev = wrapper.check_repo_rev( remote_ref, rev = wrapper.check_repo_rev(
rp.gitdir, opt.repo_rev, repo_verify=opt.repo_verify, quiet=opt.quiet) rp.gitdir,
opt.repo_rev,
repo_verify=opt.repo_verify,
quiet=opt.quiet,
)
except wrapper.CloneFailure: except wrapper.CloneFailure:
print('fatal: double check your --repo-rev setting.', file=sys.stderr) print(
"fatal: double check your --repo-rev setting.",
file=sys.stderr,
)
sys.exit(1) sys.exit(1)
branch = rp.GetBranch('default') branch = rp.GetBranch("default")
branch.merge = remote_ref branch.merge = remote_ref
rp.work_git.reset('--hard', rev) rp.work_git.reset("--hard", rev)
branch.Save() branch.Save()
if opt.worktree: if opt.worktree:
# Older versions of git supported worktree, but had dangerous gc bugs. # Older versions of git supported worktree, but had dangerous gc
git_require((2, 15, 0), fail=True, msg='git gc worktree corruption') # bugs.
git_require((2, 15, 0), fail=True, msg="git gc worktree corruption")
# Provide a short notice that we're reinitializing an existing checkout. # Provide a short notice that we're reinitializing an existing checkout.
# Sometimes developers might not realize that they're in one, or that # Sometimes developers might not realize that they're in one, or that
# repo doesn't do nested checkouts. # repo doesn't do nested checkouts.
existing_checkout = self.manifest.manifestProject.Exists existing_checkout = self.manifest.manifestProject.Exists
if not opt.quiet and existing_checkout: if not opt.quiet and existing_checkout:
print('repo: reusing existing repo client checkout in', self.manifest.topdir) print(
"repo: reusing existing repo client checkout in",
self.manifest.topdir,
)
self._SyncManifest(opt) self._SyncManifest(opt)
if os.isatty(0) and os.isatty(1) and not self.manifest.IsMirror: if os.isatty(0) and os.isatty(1) and not self.manifest.IsMirror:
if opt.config_name or self._ShouldConfigureUser(opt, existing_checkout): if opt.config_name or self._ShouldConfigureUser(
opt, existing_checkout
):
self._ConfigureUser(opt) self._ConfigureUser(opt)
self._ConfigureColor() self._ConfigureColor()

View File

@ -36,30 +36,58 @@ This is similar to running: repo forall -c 'echo "$REPO_PATH : $REPO_PROJECT"'.
""" """
def _Options(self, p): def _Options(self, p):
p.add_option('-r', '--regex', p.add_option(
dest='regex', action='store_true', "-r",
help='filter the project list based on regex or wildcard matching of strings') "--regex",
p.add_option('-g', '--groups', dest="regex",
dest='groups', action="store_true",
help='filter the project list based on the groups the project is in') help="filter the project list based on regex or wildcard matching "
p.add_option('-a', '--all', "of strings",
action='store_true', )
help='show projects regardless of checkout state') p.add_option(
p.add_option('-n', '--name-only', "-g",
dest='name_only', action='store_true', "--groups",
help='display only the name of the repository') dest="groups",
p.add_option('-p', '--path-only', help="filter the project list based on the groups the project is "
dest='path_only', action='store_true', "in",
help='display only the path of the repository') )
p.add_option('-f', '--fullpath', p.add_option(
dest='fullpath', action='store_true', "-a",
help='display the full work tree path instead of the relative path') "--all",
p.add_option('--relative-to', metavar='PATH', action="store_true",
help='display paths relative to this one (default: top of repo client checkout)') help="show projects regardless of checkout state",
)
p.add_option(
"-n",
"--name-only",
dest="name_only",
action="store_true",
help="display only the name of the repository",
)
p.add_option(
"-p",
"--path-only",
dest="path_only",
action="store_true",
help="display only the path of the repository",
)
p.add_option(
"-f",
"--fullpath",
dest="fullpath",
action="store_true",
help="display the full work tree path instead of the relative path",
)
p.add_option(
"--relative-to",
metavar="PATH",
help="display paths relative to this one (default: top of repo "
"client checkout)",
)
def ValidateOptions(self, opt, args): def ValidateOptions(self, opt, args):
if opt.fullpath and opt.name_only: if opt.fullpath and opt.name_only:
self.OptionParser.error('cannot combine -f and -n') self.OptionParser.error("cannot combine -f and -n")
# Resolve any symlinks so the output is stable. # Resolve any symlinks so the output is stable.
if opt.relative_to: if opt.relative_to:
@ -77,10 +105,16 @@ This is similar to running: repo forall -c 'echo "$REPO_PATH : $REPO_PROJECT"'.
args: Positional args. Can be a list of projects to list, or empty. args: Positional args. Can be a list of projects to list, or empty.
""" """
if not opt.regex: if not opt.regex:
projects = self.GetProjects(args, groups=opt.groups, missing_ok=opt.all, projects = self.GetProjects(
all_manifests=not opt.this_manifest_only) args,
groups=opt.groups,
missing_ok=opt.all,
all_manifests=not opt.this_manifest_only,
)
else: else:
projects = self.FindProjects(args, all_manifests=not opt.this_manifest_only) projects = self.FindProjects(
args, all_manifests=not opt.this_manifest_only
)
def _getpath(x): def _getpath(x):
if opt.fullpath: if opt.fullpath:
@ -100,4 +134,4 @@ This is similar to running: repo forall -c 'echo "$REPO_PATH : $REPO_PROJECT"'.
if lines: if lines:
lines.sort() lines.sort()
print('\n'.join(lines)) print("\n".join(lines))

View File

@ -42,86 +42,130 @@ to indicate the remote ref to push changes to via 'repo upload'.
@property @property
def helpDescription(self): def helpDescription(self):
helptext = self._helpDescription + '\n' helptext = self._helpDescription + "\n"
r = os.path.dirname(__file__) r = os.path.dirname(__file__)
r = os.path.dirname(r) r = os.path.dirname(r)
with open(os.path.join(r, 'docs', 'manifest-format.md')) as fd: with open(os.path.join(r, "docs", "manifest-format.md")) as fd:
for line in fd: for line in fd:
helptext += line helptext += line
return helptext return helptext
def _Options(self, p): def _Options(self, p):
p.add_option('-r', '--revision-as-HEAD', p.add_option(
dest='peg_rev', action='store_true', "-r",
help='save revisions as current HEAD') "--revision-as-HEAD",
p.add_option('-m', '--manifest-name', dest="peg_rev",
help='temporary manifest to use for this sync', metavar='NAME.xml') action="store_true",
p.add_option('--suppress-upstream-revision', dest='peg_rev_upstream', help="save revisions as current HEAD",
default=True, action='store_false', )
help='if in -r mode, do not write the upstream field ' p.add_option(
'(only of use if the branch names for a sha1 manifest are ' "-m",
'sensitive)') "--manifest-name",
p.add_option('--suppress-dest-branch', dest='peg_rev_dest_branch', help="temporary manifest to use for this sync",
default=True, action='store_false', metavar="NAME.xml",
help='if in -r mode, do not write the dest-branch field ' )
'(only of use if the branch names for a sha1 manifest are ' p.add_option(
'sensitive)') "--suppress-upstream-revision",
p.add_option('--json', default=False, action='store_true', dest="peg_rev_upstream",
help='output manifest in JSON format (experimental)') default=True,
p.add_option('--pretty', default=False, action='store_true', action="store_false",
help='format output for humans to read') help="if in -r mode, do not write the upstream field "
p.add_option('--no-local-manifests', default=False, action='store_true', "(only of use if the branch names for a sha1 manifest are "
dest='ignore_local_manifests', help='ignore local manifests') "sensitive)",
p.add_option('-o', '--output-file', )
dest='output_file', p.add_option(
default='-', "--suppress-dest-branch",
help='file to save the manifest to. (Filename prefix for multi-tree.)', dest="peg_rev_dest_branch",
metavar='-|NAME.xml') default=True,
action="store_false",
help="if in -r mode, do not write the dest-branch field "
"(only of use if the branch names for a sha1 manifest are "
"sensitive)",
)
p.add_option(
"--json",
default=False,
action="store_true",
help="output manifest in JSON format (experimental)",
)
p.add_option(
"--pretty",
default=False,
action="store_true",
help="format output for humans to read",
)
p.add_option(
"--no-local-manifests",
default=False,
action="store_true",
dest="ignore_local_manifests",
help="ignore local manifests",
)
p.add_option(
"-o",
"--output-file",
dest="output_file",
default="-",
help="file to save the manifest to. (Filename prefix for "
"multi-tree.)",
metavar="-|NAME.xml",
)
def _Output(self, opt): def _Output(self, opt):
# If alternate manifest is specified, override the manifest file that we're using. # If alternate manifest is specified, override the manifest file that
# we're using.
if opt.manifest_name: if opt.manifest_name:
self.manifest.Override(opt.manifest_name, False) self.manifest.Override(opt.manifest_name, False)
for manifest in self.ManifestList(opt): for manifest in self.ManifestList(opt):
output_file = opt.output_file output_file = opt.output_file
if output_file == '-': if output_file == "-":
fd = sys.stdout fd = sys.stdout
else: else:
if manifest.path_prefix: if manifest.path_prefix:
output_file = f'{opt.output_file}:{manifest.path_prefix.replace("/", "%2f")}' output_file = (
fd = open(output_file, 'w') f"{opt.output_file}:"
f'{manifest.path_prefix.replace("/", "%2f")}'
)
fd = open(output_file, "w")
manifest.SetUseLocalManifests(not opt.ignore_local_manifests) manifest.SetUseLocalManifests(not opt.ignore_local_manifests)
if opt.json: if opt.json:
print('warning: --json is experimental!', file=sys.stderr) print("warning: --json is experimental!", file=sys.stderr)
doc = manifest.ToDict(peg_rev=opt.peg_rev, doc = manifest.ToDict(
peg_rev=opt.peg_rev,
peg_rev_upstream=opt.peg_rev_upstream, peg_rev_upstream=opt.peg_rev_upstream,
peg_rev_dest_branch=opt.peg_rev_dest_branch) peg_rev_dest_branch=opt.peg_rev_dest_branch,
)
json_settings = { json_settings = {
# JSON style guide says Uunicode characters are fully allowed. # JSON style guide says Unicode characters are fully
'ensure_ascii': False, # allowed.
"ensure_ascii": False,
# We use 2 space indent to match JSON style guide. # We use 2 space indent to match JSON style guide.
'indent': 2 if opt.pretty else None, "indent": 2 if opt.pretty else None,
'separators': (',', ': ') if opt.pretty else (',', ':'), "separators": (",", ": ") if opt.pretty else (",", ":"),
'sort_keys': True, "sort_keys": True,
} }
fd.write(json.dumps(doc, **json_settings)) fd.write(json.dumps(doc, **json_settings))
else: else:
manifest.Save(fd, manifest.Save(
fd,
peg_rev=opt.peg_rev, peg_rev=opt.peg_rev,
peg_rev_upstream=opt.peg_rev_upstream, peg_rev_upstream=opt.peg_rev_upstream,
peg_rev_dest_branch=opt.peg_rev_dest_branch) peg_rev_dest_branch=opt.peg_rev_dest_branch,
if output_file != '-': )
if output_file != "-":
fd.close() fd.close()
if manifest.path_prefix: if manifest.path_prefix:
print(f'Saved {manifest.path_prefix} submanifest to {output_file}', print(
file=sys.stderr) f"Saved {manifest.path_prefix} submanifest to "
f"{output_file}",
file=sys.stderr,
)
else: else:
print(f'Saved manifest to {output_file}', file=sys.stderr) print(f"Saved manifest to {output_file}", file=sys.stderr)
def ValidateOptions(self, opt, args): def ValidateOptions(self, opt, args):
if args: if args:

View File

@ -34,22 +34,33 @@ are displayed.
""" """
def _Options(self, p): def _Options(self, p):
p.add_option('-c', '--current-branch', p.add_option(
dest="current_branch", action="store_true", "-c",
help="consider only checked out branches") "--current-branch",
p.add_option('--no-current-branch', dest="current_branch",
dest='current_branch', action='store_false', action="store_true",
help='consider all local branches') help="consider only checked out branches",
)
p.add_option(
"--no-current-branch",
dest="current_branch",
action="store_false",
help="consider all local branches",
)
# Turn this into a warning & remove this someday. # Turn this into a warning & remove this someday.
p.add_option('-b', p.add_option(
dest='current_branch', action='store_true', "-b",
help=optparse.SUPPRESS_HELP) dest="current_branch",
action="store_true",
help=optparse.SUPPRESS_HELP,
)
def Execute(self, opt, args): def Execute(self, opt, args):
all_branches = [] all_branches = []
for project in self.GetProjects(args, all_manifests=not opt.this_manifest_only): for project in self.GetProjects(
br = [project.GetUploadableBranch(x) args, all_manifests=not opt.this_manifest_only
for x in project.GetBranches()] ):
br = [project.GetUploadableBranch(x) for x in project.GetBranches()]
br = [x for x in br if x] br = [x for x in br if x]
if opt.current_branch: if opt.current_branch:
br = [x for x in br if x.name == project.CurrentBranch] br = [x for x in br if x.name == project.CurrentBranch]
@ -60,14 +71,14 @@ are displayed.
class Report(Coloring): class Report(Coloring):
def __init__(self, config): def __init__(self, config):
Coloring.__init__(self, config, 'status') Coloring.__init__(self, config, "status")
self.project = self.printer('header', attr='bold') self.project = self.printer("header", attr="bold")
self.text = self.printer('text') self.text = self.printer("text")
out = Report(all_branches[0].project.config) out = Report(all_branches[0].project.config)
out.text("Deprecated. See repo info -o.") out.text("Deprecated. See repo info -o.")
out.nl() out.nl()
out.project('Projects Overview') out.project("Projects Overview")
out.nl() out.nl()
project = None project = None
@ -76,16 +87,23 @@ are displayed.
if project != branch.project: if project != branch.project:
project = branch.project project = branch.project
out.nl() out.nl()
out.project('project %s/' % project.RelPath(local=opt.this_manifest_only)) out.project(
"project %s/"
% project.RelPath(local=opt.this_manifest_only)
)
out.nl() out.nl()
commits = branch.commits commits = branch.commits
date = branch.date date = branch.date
print('%s %-33s (%2d commit%s, %s)' % ( print(
branch.name == project.CurrentBranch and '*' or ' ', "%s %-33s (%2d commit%s, %s)"
% (
branch.name == project.CurrentBranch and "*" or " ",
branch.name, branch.name,
len(commits), len(commits),
len(commits) != 1 and 's' or ' ', len(commits) != 1 and "s" or " ",
date)) date,
)
)
for commit in commits: for commit in commits:
print('%-35s - %s' % ('', commit)) print("%-35s - %s" % ("", commit))

View File

@ -31,10 +31,12 @@ class Prune(PagedCommand):
return project.PruneHeads() return project.PruneHeads()
def Execute(self, opt, args): def Execute(self, opt, args):
projects = self.GetProjects(args, all_manifests=not opt.this_manifest_only) projects = self.GetProjects(
args, all_manifests=not opt.this_manifest_only
)
# NB: Should be able to refactor this module to display summary as results # NB: Should be able to refactor this module to display summary as
# come back from children. # results come back from children.
def _ProcessResults(_pool, _output, results): def _ProcessResults(_pool, _output, results):
return list(itertools.chain.from_iterable(results)) return list(itertools.chain.from_iterable(results))
@ -43,18 +45,19 @@ class Prune(PagedCommand):
self._ExecuteOne, self._ExecuteOne,
projects, projects,
callback=_ProcessResults, callback=_ProcessResults,
ordered=True) ordered=True,
)
if not all_branches: if not all_branches:
return return
class Report(Coloring): class Report(Coloring):
def __init__(self, config): def __init__(self, config):
Coloring.__init__(self, config, 'status') Coloring.__init__(self, config, "status")
self.project = self.printer('header', attr='bold') self.project = self.printer("header", attr="bold")
out = Report(all_branches[0].project.config) out = Report(all_branches[0].project.config)
out.project('Pending Branches') out.project("Pending Branches")
out.nl() out.nl()
project = None project = None
@ -63,19 +66,29 @@ class Prune(PagedCommand):
if project != branch.project: if project != branch.project:
project = branch.project project = branch.project
out.nl() out.nl()
out.project('project %s/' % project.RelPath(local=opt.this_manifest_only)) out.project(
"project %s/"
% project.RelPath(local=opt.this_manifest_only)
)
out.nl() out.nl()
print('%s %-33s ' % ( print(
branch.name == project.CurrentBranch and '*' or ' ', "%s %-33s "
branch.name), end='') % (
branch.name == project.CurrentBranch and "*" or " ",
branch.name,
),
end="",
)
if not branch.base_exists: if not branch.base_exists:
print('(ignoring: tracking branch is gone: %s)' % (branch.base,)) print(
"(ignoring: tracking branch is gone: %s)" % (branch.base,)
)
else: else:
commits = branch.commits commits = branch.commits
date = branch.date date = branch.date
print('(%2d commit%s, %s)' % ( print(
len(commits), "(%2d commit%s, %s)"
len(commits) != 1 and 's' or ' ', % (len(commits), len(commits) != 1 and "s" or " ", date)
date)) )

View File

@ -21,9 +21,9 @@ from git_command import GitCommand
class RebaseColoring(Coloring): class RebaseColoring(Coloring):
def __init__(self, config): def __init__(self, config):
Coloring.__init__(self, config, 'rebase') Coloring.__init__(self, config, "rebase")
self.project = self.printer('project', attr='bold') self.project = self.printer("project", attr="bold")
self.fail = self.printer('fail', fg='red') self.fail = self.printer("fail", fg="red")
class Rebase(Command): class Rebase(Command):
@ -39,61 +39,98 @@ branch but need to incorporate new upstream changes "underneath" them.
""" """
def _Options(self, p): def _Options(self, p):
g = p.get_option_group('--quiet') g = p.get_option_group("--quiet")
g.add_option('-i', '--interactive', g.add_option(
dest="interactive", action="store_true", "-i",
help="interactive rebase (single project only)") "--interactive",
dest="interactive",
action="store_true",
help="interactive rebase (single project only)",
)
p.add_option('--fail-fast', p.add_option(
dest='fail_fast', action='store_true', "--fail-fast",
help='stop rebasing after first error is hit') dest="fail_fast",
p.add_option('-f', '--force-rebase', action="store_true",
dest='force_rebase', action='store_true', help="stop rebasing after first error is hit",
help='pass --force-rebase to git rebase') )
p.add_option('--no-ff', p.add_option(
dest='ff', default=True, action='store_false', "-f",
help='pass --no-ff to git rebase') "--force-rebase",
p.add_option('--autosquash', dest="force_rebase",
dest='autosquash', action='store_true', action="store_true",
help='pass --autosquash to git rebase') help="pass --force-rebase to git rebase",
p.add_option('--whitespace', )
dest='whitespace', action='store', metavar='WS', p.add_option(
help='pass --whitespace to git rebase') "--no-ff",
p.add_option('--auto-stash', dest="ff",
dest='auto_stash', action='store_true', default=True,
help='stash local modifications before starting') action="store_false",
p.add_option('-m', '--onto-manifest', help="pass --no-ff to git rebase",
dest='onto_manifest', action='store_true', )
help='rebase onto the manifest version instead of upstream ' p.add_option(
'HEAD (this helps to make sure the local tree stays ' "--autosquash",
'consistent if you previously synced to a manifest)') dest="autosquash",
action="store_true",
help="pass --autosquash to git rebase",
)
p.add_option(
"--whitespace",
dest="whitespace",
action="store",
metavar="WS",
help="pass --whitespace to git rebase",
)
p.add_option(
"--auto-stash",
dest="auto_stash",
action="store_true",
help="stash local modifications before starting",
)
p.add_option(
"-m",
"--onto-manifest",
dest="onto_manifest",
action="store_true",
help="rebase onto the manifest version instead of upstream "
"HEAD (this helps to make sure the local tree stays "
"consistent if you previously synced to a manifest)",
)
def Execute(self, opt, args): def Execute(self, opt, args):
all_projects = self.GetProjects(args, all_manifests=not opt.this_manifest_only) all_projects = self.GetProjects(
args, all_manifests=not opt.this_manifest_only
)
one_project = len(all_projects) == 1 one_project = len(all_projects) == 1
if opt.interactive and not one_project: if opt.interactive and not one_project:
print('error: interactive rebase not supported with multiple projects', print(
file=sys.stderr) "error: interactive rebase not supported with multiple "
"projects",
file=sys.stderr,
)
if len(args) == 1: if len(args) == 1:
print('note: project %s is mapped to more than one path' % (args[0],), print(
file=sys.stderr) "note: project %s is mapped to more than one path"
% (args[0],),
file=sys.stderr,
)
return 1 return 1
# Setup the common git rebase args that we use for all projects. # Setup the common git rebase args that we use for all projects.
common_args = ['rebase'] common_args = ["rebase"]
if opt.whitespace: if opt.whitespace:
common_args.append('--whitespace=%s' % opt.whitespace) common_args.append("--whitespace=%s" % opt.whitespace)
if opt.quiet: if opt.quiet:
common_args.append('--quiet') common_args.append("--quiet")
if opt.force_rebase: if opt.force_rebase:
common_args.append('--force-rebase') common_args.append("--force-rebase")
if not opt.ff: if not opt.ff:
common_args.append('--no-ff') common_args.append("--no-ff")
if opt.autosquash: if opt.autosquash:
common_args.append('--autosquash') common_args.append("--autosquash")
if opt.interactive: if opt.interactive:
common_args.append('-i') common_args.append("-i")
config = self.manifest.manifestProject.config config = self.manifest.manifestProject.config
out = RebaseColoring(config) out = RebaseColoring(config)
@ -108,30 +145,40 @@ branch but need to incorporate new upstream changes "underneath" them.
cb = project.CurrentBranch cb = project.CurrentBranch
if not cb: if not cb:
if one_project: if one_project:
print("error: project %s has a detached HEAD" % _RelPath(project), print(
file=sys.stderr) "error: project %s has a detached HEAD"
% _RelPath(project),
file=sys.stderr,
)
return 1 return 1
# ignore branches with detatched HEADs # Ignore branches with detached HEADs.
continue continue
upbranch = project.GetBranch(cb) upbranch = project.GetBranch(cb)
if not upbranch.LocalMerge: if not upbranch.LocalMerge:
if one_project: if one_project:
print("error: project %s does not track any remote branches" print(
% _RelPath(project), file=sys.stderr) "error: project %s does not track any remote branches"
% _RelPath(project),
file=sys.stderr,
)
return 1 return 1
# ignore branches without remotes # Ignore branches without remotes.
continue continue
args = common_args[:] args = common_args[:]
if opt.onto_manifest: if opt.onto_manifest:
args.append('--onto') args.append("--onto")
args.append(project.revisionExpr) args.append(project.revisionExpr)
args.append(upbranch.LocalMerge) args.append(upbranch.LocalMerge)
out.project('project %s: rebasing %s -> %s', out.project(
_RelPath(project), cb, upbranch.LocalMerge) "project %s: rebasing %s -> %s",
_RelPath(project),
cb,
upbranch.LocalMerge,
)
out.nl() out.nl()
out.flush() out.flush()
@ -153,13 +200,13 @@ branch but need to incorporate new upstream changes "underneath" them.
continue continue
if needs_stash: if needs_stash:
stash_args.append('pop') stash_args.append("pop")
stash_args.append('--quiet') stash_args.append("--quiet")
if GitCommand(project, stash_args).Wait() != 0: if GitCommand(project, stash_args).Wait() != 0:
ret += 1 ret += 1
if ret: if ret:
out.fail('%i projects had errors', ret) out.fail("%i projects had errors", ret)
out.nl() out.nl()
return ret return ret

View File

@ -35,13 +35,20 @@ need to be performed by an end-user.
""" """
def _Options(self, p): def _Options(self, p):
g = p.add_option_group('repo Version options') g = p.add_option_group("repo Version options")
g.add_option('--no-repo-verify', g.add_option(
dest='repo_verify', default=True, action='store_false', "--no-repo-verify",
help='do not verify repo source code') dest="repo_verify",
g.add_option('--repo-upgraded', default=True,
dest='repo_upgraded', action='store_true', action="store_false",
help=SUPPRESS_HELP) help="do not verify repo source code",
)
g.add_option(
"--repo-upgraded",
dest="repo_upgraded",
action="store_true",
help=SUPPRESS_HELP,
)
def Execute(self, opt, args): def Execute(self, opt, args):
rp = self.manifest.repoProject rp = self.manifest.repoProject
@ -55,7 +62,5 @@ need to be performed by an end-user.
print("error: can't update repo", file=sys.stderr) print("error: can't update repo", file=sys.stderr)
sys.exit(1) sys.exit(1)
rp.bare_git.gc('--auto') rp.bare_git.gc("--auto")
_PostRepoFetch(rp, _PostRepoFetch(rp, repo_verify=opt.repo_verify, verbose=True)
repo_verify=opt.repo_verify,
verbose=True)

View File

@ -21,10 +21,10 @@ from git_command import GitCommand
class _ProjectList(Coloring): class _ProjectList(Coloring):
def __init__(self, gc): def __init__(self, gc):
Coloring.__init__(self, gc, 'interactive') Coloring.__init__(self, gc, "interactive")
self.prompt = self.printer('prompt', fg='blue', attr='bold') self.prompt = self.printer("prompt", fg="blue", attr="bold")
self.header = self.printer('header', attr='bold') self.header = self.printer("header", attr="bold")
self.help = self.printer('help', fg='red', attr='bold') self.help = self.printer("help", fg="red", attr="bold")
class Stage(InteractiveCommand): class Stage(InteractiveCommand):
@ -38,10 +38,14 @@ The '%prog' command stages files to prepare the next commit.
""" """
def _Options(self, p): def _Options(self, p):
g = p.get_option_group('--quiet') g = p.get_option_group("--quiet")
g.add_option('-i', '--interactive', g.add_option(
dest='interactive', action='store_true', "-i",
help='use interactive staging') "--interactive",
dest="interactive",
action="store_true",
help="use interactive staging",
)
def Execute(self, opt, args): def Execute(self, opt, args):
if opt.interactive: if opt.interactive:
@ -51,42 +55,49 @@ The '%prog' command stages files to prepare the next commit.
def _Interactive(self, opt, args): def _Interactive(self, opt, args):
all_projects = [ all_projects = [
p for p in self.GetProjects(args, all_manifests=not opt.this_manifest_only) p
if p.IsDirty()] for p in self.GetProjects(
args, all_manifests=not opt.this_manifest_only
)
if p.IsDirty()
]
if not all_projects: if not all_projects:
print('no projects have uncommitted modifications', file=sys.stderr) print("no projects have uncommitted modifications", file=sys.stderr)
return return
out = _ProjectList(self.manifest.manifestProject.config) out = _ProjectList(self.manifest.manifestProject.config)
while True: while True:
out.header(' %s', 'project') out.header(" %s", "project")
out.nl() out.nl()
for i in range(len(all_projects)): for i in range(len(all_projects)):
project = all_projects[i] project = all_projects[i]
out.write('%3d: %s', i + 1, out.write(
project.RelPath(local=opt.this_manifest_only) + '/') "%3d: %s",
i + 1,
project.RelPath(local=opt.this_manifest_only) + "/",
)
out.nl() out.nl()
out.nl() out.nl()
out.write('%3d: (', 0) out.write("%3d: (", 0)
out.prompt('q') out.prompt("q")
out.write('uit)') out.write("uit)")
out.nl() out.nl()
out.prompt('project> ') out.prompt("project> ")
out.flush() out.flush()
try: try:
a = sys.stdin.readline() a = sys.stdin.readline()
except KeyboardInterrupt: except KeyboardInterrupt:
out.nl() out.nl()
break break
if a == '': if a == "":
out.nl() out.nl()
break break
a = a.strip() a = a.strip()
if a.lower() in ('q', 'quit', 'exit'): if a.lower() in ("q", "quit", "exit"):
break break
if not a: if not a:
continue continue
@ -104,14 +115,16 @@ The '%prog' command stages files to prepare the next commit.
continue continue
projects = [ projects = [
p for p in all_projects p
if a in [p.name, p.RelPath(local=opt.this_manifest_only)]] for p in all_projects
if a in [p.name, p.RelPath(local=opt.this_manifest_only)]
]
if len(projects) == 1: if len(projects) == 1:
_AddI(projects[0]) _AddI(projects[0])
continue continue
print('Bye.') print("Bye.")
def _AddI(project): def _AddI(project):
p = GitCommand(project, ['add', '--interactive'], bare=False) p = GitCommand(project, ["add", "--interactive"], bare=False)
p.Wait() p.Wait()

View File

@ -37,21 +37,34 @@ revision specified in the manifest.
PARALLEL_JOBS = DEFAULT_LOCAL_JOBS PARALLEL_JOBS = DEFAULT_LOCAL_JOBS
def _Options(self, p): def _Options(self, p):
p.add_option('--all', p.add_option(
dest='all', action='store_true', "--all",
help='begin branch in all projects') dest="all",
p.add_option('-r', '--rev', '--revision', dest='revision', action="store_true",
help='point branch at this revision instead of upstream') help="begin branch in all projects",
p.add_option('--head', '--HEAD', )
dest='revision', action='store_const', const='HEAD', p.add_option(
help='abbreviation for --rev HEAD') "-r",
"--rev",
"--revision",
dest="revision",
help="point branch at this revision instead of upstream",
)
p.add_option(
"--head",
"--HEAD",
dest="revision",
action="store_const",
const="HEAD",
help="abbreviation for --rev HEAD",
)
def ValidateOptions(self, opt, args): def ValidateOptions(self, opt, args):
if not args: if not args:
self.Usage() self.Usage()
nb = args[0] nb = args[0]
if not git.check_ref_format('heads/%s' % nb): if not git.check_ref_format("heads/%s" % nb):
self.OptionParser.error("'%s' is not a valid name" % nb) self.OptionParser.error("'%s' is not a valid name" % nb)
def _ExecuteOne(self, revision, nb, project): def _ExecuteOne(self, revision, nb, project):
@ -59,7 +72,7 @@ revision specified in the manifest.
# If the current revision is immutable, such as a SHA1, a tag or # If the current revision is immutable, such as a SHA1, a tag or
# a change, then we can't push back to it. Substitute with # a change, then we can't push back to it. Substitute with
# dest_branch, if defined; or with manifest default revision instead. # dest_branch, if defined; or with manifest default revision instead.
branch_merge = '' branch_merge = ""
if IsImmutable(project.revisionExpr): if IsImmutable(project.revisionExpr):
if project.dest_branch: if project.dest_branch:
branch_merge = project.dest_branch branch_merge = project.dest_branch
@ -68,9 +81,13 @@ revision specified in the manifest.
try: try:
ret = project.StartBranch( ret = project.StartBranch(
nb, branch_merge=branch_merge, revision=revision) nb, branch_merge=branch_merge, revision=revision
)
except Exception as e: except Exception as e:
print('error: unable to checkout %s: %s' % (project.name, e), file=sys.stderr) print(
"error: unable to checkout %s: %s" % (project.name, e),
file=sys.stderr,
)
ret = False ret = False
return (ret, project) return (ret, project)
@ -81,17 +98,21 @@ revision specified in the manifest.
if not opt.all: if not opt.all:
projects = args[1:] projects = args[1:]
if len(projects) < 1: if len(projects) < 1:
projects = ['.'] # start it in the local project by default projects = ["."] # start it in the local project by default
all_projects = self.GetProjects(projects, all_projects = self.GetProjects(
projects,
missing_ok=bool(self.gitc_manifest), missing_ok=bool(self.gitc_manifest),
all_manifests=not opt.this_manifest_only) all_manifests=not opt.this_manifest_only,
)
# This must happen after we find all_projects, since GetProjects may need # This must happen after we find all_projects, since GetProjects may
# the local directory, which will disappear once we save the GITC manifest. # need the local directory, which will disappear once we save the GITC
# manifest.
if self.gitc_manifest: if self.gitc_manifest:
gitc_projects = self.GetProjects(projects, manifest=self.gitc_manifest, gitc_projects = self.GetProjects(
missing_ok=True) projects, manifest=self.gitc_manifest, missing_ok=True
)
for project in gitc_projects: for project in gitc_projects:
if project.old_revision: if project.old_revision:
project.already_synced = True project.already_synced = True
@ -102,17 +123,18 @@ revision specified in the manifest.
# Save the GITC manifest. # Save the GITC manifest.
gitc_utils.save_manifest(self.gitc_manifest) gitc_utils.save_manifest(self.gitc_manifest)
# Make sure we have a valid CWD # Make sure we have a valid CWD.
if not os.path.exists(os.getcwd()): if not os.path.exists(os.getcwd()):
os.chdir(self.manifest.topdir) os.chdir(self.manifest.topdir)
pm = Progress('Syncing %s' % nb, len(all_projects), quiet=opt.quiet) pm = Progress("Syncing %s" % nb, len(all_projects), quiet=opt.quiet)
for project in all_projects: for project in all_projects:
gitc_project = self.gitc_manifest.paths[project.relpath] gitc_project = self.gitc_manifest.paths[project.relpath]
# Sync projects that have not been opened. # Sync projects that have not been opened.
if not gitc_project.already_synced: if not gitc_project.already_synced:
proj_localdir = os.path.join(self.gitc_manifest.gitc_client_dir, proj_localdir = os.path.join(
project.relpath) self.gitc_manifest.gitc_client_dir, project.relpath
)
project.worktree = proj_localdir project.worktree = proj_localdir
if not os.path.exists(proj_localdir): if not os.path.exists(proj_localdir):
os.makedirs(proj_localdir) os.makedirs(proj_localdir)
@ -124,7 +146,7 @@ revision specified in the manifest.
pm.end() pm.end()
def _ProcessResults(_pool, pm, results): def _ProcessResults(_pool, pm, results):
for (result, project) in results: for result, project in results:
if not result: if not result:
err.append(project) err.append(project)
pm.update() pm.update()
@ -134,10 +156,16 @@ revision specified in the manifest.
functools.partial(self._ExecuteOne, opt.revision, nb), functools.partial(self._ExecuteOne, opt.revision, nb),
all_projects, all_projects,
callback=_ProcessResults, callback=_ProcessResults,
output=Progress('Starting %s' % (nb,), len(all_projects), quiet=opt.quiet)) output=Progress(
"Starting %s" % (nb,), len(all_projects), quiet=opt.quiet
),
)
if err: if err:
for p in err: for p in err:
print("error: %s/: cannot start %s" % (p.RelPath(local=opt.this_manifest_only), nb), print(
file=sys.stderr) "error: %s/: cannot start %s"
% (p.RelPath(local=opt.this_manifest_only), nb),
file=sys.stderr,
)
sys.exit(1) sys.exit(1)

View File

@ -79,9 +79,14 @@ the following meanings:
PARALLEL_JOBS = DEFAULT_LOCAL_JOBS PARALLEL_JOBS = DEFAULT_LOCAL_JOBS
def _Options(self, p): def _Options(self, p):
p.add_option('-o', '--orphans', p.add_option(
dest='orphans', action='store_true', "-o",
help="include objects in working directory outside of repo projects") "--orphans",
dest="orphans",
action="store_true",
help="include objects in working directory outside of repo "
"projects",
)
def _StatusHelper(self, quiet, local, project): def _StatusHelper(self, quiet, local, project):
"""Obtains the status for a specific project. """Obtains the status for a specific project.
@ -92,92 +97,106 @@ the following meanings:
Args: Args:
quiet: Where to output the status. quiet: Where to output the status.
local: a boolean, if True, the path is relative to the local local: a boolean, if True, the path is relative to the local
(sub)manifest. If false, the path is relative to the (sub)manifest. If false, the path is relative to the outermost
outermost manifest. manifest.
project: Project to get status of. project: Project to get status of.
Returns: Returns:
The status of the project. The status of the project.
""" """
buf = io.StringIO() buf = io.StringIO()
ret = project.PrintWorkTreeStatus(quiet=quiet, output_redir=buf, ret = project.PrintWorkTreeStatus(
local=local) quiet=quiet, output_redir=buf, local=local
)
return (ret, buf.getvalue()) return (ret, buf.getvalue())
def _FindOrphans(self, dirs, proj_dirs, proj_dirs_parents, outstring): def _FindOrphans(self, dirs, proj_dirs, proj_dirs_parents, outstring):
"""find 'dirs' that are present in 'proj_dirs_parents' but not in 'proj_dirs'""" """find 'dirs' that are present in 'proj_dirs_parents' but not in 'proj_dirs'""" # noqa: E501
status_header = ' --\t' status_header = " --\t"
for item in dirs: for item in dirs:
if not platform_utils.isdir(item): if not platform_utils.isdir(item):
outstring.append(''.join([status_header, item])) outstring.append("".join([status_header, item]))
continue continue
if item in proj_dirs: if item in proj_dirs:
continue continue
if item in proj_dirs_parents: if item in proj_dirs_parents:
self._FindOrphans(glob.glob('%s/.*' % item) + self._FindOrphans(
glob.glob('%s/*' % item), glob.glob("%s/.*" % item) + glob.glob("%s/*" % item),
proj_dirs, proj_dirs_parents, outstring) proj_dirs,
proj_dirs_parents,
outstring,
)
continue continue
outstring.append(''.join([status_header, item, '/'])) outstring.append("".join([status_header, item, "/"]))
def Execute(self, opt, args): def Execute(self, opt, args):
all_projects = self.GetProjects(args, all_manifests=not opt.this_manifest_only) all_projects = self.GetProjects(
args, all_manifests=not opt.this_manifest_only
)
def _ProcessResults(_pool, _output, results): def _ProcessResults(_pool, _output, results):
ret = 0 ret = 0
for (state, output) in results: for state, output in results:
if output: if output:
print(output, end='') print(output, end="")
if state == 'CLEAN': if state == "CLEAN":
ret += 1 ret += 1
return ret return ret
counter = self.ExecuteInParallel( counter = self.ExecuteInParallel(
opt.jobs, opt.jobs,
functools.partial(self._StatusHelper, opt.quiet, opt.this_manifest_only), functools.partial(
self._StatusHelper, opt.quiet, opt.this_manifest_only
),
all_projects, all_projects,
callback=_ProcessResults, callback=_ProcessResults,
ordered=True) ordered=True,
)
if not opt.quiet and len(all_projects) == counter: if not opt.quiet and len(all_projects) == counter:
print('nothing to commit (working directory clean)') print("nothing to commit (working directory clean)")
if opt.orphans: if opt.orphans:
proj_dirs = set() proj_dirs = set()
proj_dirs_parents = set() proj_dirs_parents = set()
for project in self.GetProjects(None, missing_ok=True, all_manifests=not opt.this_manifest_only): for project in self.GetProjects(
None, missing_ok=True, all_manifests=not opt.this_manifest_only
):
relpath = project.RelPath(local=opt.this_manifest_only) relpath = project.RelPath(local=opt.this_manifest_only)
proj_dirs.add(relpath) proj_dirs.add(relpath)
(head, _tail) = os.path.split(relpath) (head, _tail) = os.path.split(relpath)
while head != "": while head != "":
proj_dirs_parents.add(head) proj_dirs_parents.add(head)
(head, _tail) = os.path.split(head) (head, _tail) = os.path.split(head)
proj_dirs.add('.repo') proj_dirs.add(".repo")
class StatusColoring(Coloring): class StatusColoring(Coloring):
def __init__(self, config): def __init__(self, config):
Coloring.__init__(self, config, 'status') Coloring.__init__(self, config, "status")
self.project = self.printer('header', attr='bold') self.project = self.printer("header", attr="bold")
self.untracked = self.printer('untracked', fg='red') self.untracked = self.printer("untracked", fg="red")
orig_path = os.getcwd() orig_path = os.getcwd()
try: try:
os.chdir(self.manifest.topdir) os.chdir(self.manifest.topdir)
outstring = [] outstring = []
self._FindOrphans(glob.glob('.*') + self._FindOrphans(
glob.glob('*'), glob.glob(".*") + glob.glob("*"),
proj_dirs, proj_dirs_parents, outstring) proj_dirs,
proj_dirs_parents,
outstring,
)
if outstring: if outstring:
output = StatusColoring(self.client.globalConfig) output = StatusColoring(self.client.globalConfig)
output.project('Objects not within a project (orphans)') output.project("Objects not within a project (orphans)")
output.nl() output.nl()
for entry in outstring: for entry in outstring:
output.untracked(entry) output.untracked(entry)
output.nl() output.nl()
else: else:
print('No orphan files or directories') print("No orphan files or directories")
finally: finally:
# Restore CWD. # Restore CWD.

File diff suppressed because it is too large Load Diff

View File

@ -49,7 +49,7 @@ def _VerifyPendingCommits(branches: List[ReviewableBranch]) -> bool:
# #
# Each branch may be configured to have a different threshold. # Each branch may be configured to have a different threshold.
remote = branch.project.GetBranch(branch.name).remote remote = branch.project.GetBranch(branch.name).remote
key = f'review.{remote.review}.uploadwarningthreshold' key = f"review.{remote.review}.uploadwarningthreshold"
threshold = branch.project.config.GetInt(key) threshold = branch.project.config.GetInt(key)
if threshold is None: if threshold is None:
threshold = _DEFAULT_UNUSUAL_COMMIT_THRESHOLD threshold = _DEFAULT_UNUSUAL_COMMIT_THRESHOLD
@ -62,29 +62,37 @@ def _VerifyPendingCommits(branches: List[ReviewableBranch]) -> bool:
# If any branch has many commits, prompt the user. # If any branch has many commits, prompt the user.
if many_commits: if many_commits:
if len(branches) > 1: if len(branches) > 1:
print('ATTENTION: One or more branches has an unusually high number ' print(
'of commits.') "ATTENTION: One or more branches has an unusually high number "
"of commits."
)
else: else:
print('ATTENTION: You are uploading an unusually high number of commits.') print(
print('YOU PROBABLY DO NOT MEAN TO DO THIS. (Did you rebase across ' "ATTENTION: You are uploading an unusually high number of "
'branches?)') "commits."
)
print(
"YOU PROBABLY DO NOT MEAN TO DO THIS. (Did you rebase across "
"branches?)"
)
answer = input( answer = input(
"If you are sure you intend to do this, type 'yes': ").strip() "If you are sure you intend to do this, type 'yes': "
return answer == 'yes' ).strip()
return answer == "yes"
return True return True
def _die(fmt, *args): def _die(fmt, *args):
msg = fmt % args msg = fmt % args
print('error: %s' % msg, file=sys.stderr) print("error: %s" % msg, file=sys.stderr)
sys.exit(1) sys.exit(1)
def _SplitEmails(values): def _SplitEmails(values):
result = [] result = []
for value in values: for value in values:
result.extend([s.strip() for s in value.split(',')]) result.extend([s.strip() for s in value.split(",")])
return result return result
@ -198,80 +206,169 @@ Gerrit Code Review: https://www.gerritcodereview.com/
PARALLEL_JOBS = DEFAULT_LOCAL_JOBS PARALLEL_JOBS = DEFAULT_LOCAL_JOBS
def _Options(self, p): def _Options(self, p):
p.add_option('-t', p.add_option(
dest='auto_topic', action='store_true', "-t",
help='send local branch name to Gerrit Code Review') dest="auto_topic",
p.add_option('--hashtag', '--ht', action="store_true",
dest='hashtags', action='append', default=[], help="send local branch name to Gerrit Code Review",
help='add hashtags (comma delimited) to the review') )
p.add_option('--hashtag-branch', '--htb', p.add_option(
action='store_true', "--hashtag",
help='add local branch name as a hashtag') "--ht",
p.add_option('-l', '--label', dest="hashtags",
dest='labels', action='append', default=[], action="append",
help='add a label when uploading')
p.add_option('--re', '--reviewers',
type='string', action='append', dest='reviewers',
help='request reviews from these people')
p.add_option('--cc',
type='string', action='append', dest='cc',
help='also send email to these email addresses')
p.add_option('--br', '--branch',
type='string', action='store', dest='branch',
help='(local) branch to upload')
p.add_option('-c', '--current-branch',
dest='current_branch', action='store_true',
help='upload current git branch')
p.add_option('--no-current-branch',
dest='current_branch', action='store_false',
help='upload all git branches')
# Turn this into a warning & remove this someday.
p.add_option('--cbr',
dest='current_branch', action='store_true',
help=optparse.SUPPRESS_HELP)
p.add_option('--ne', '--no-emails',
action='store_false', dest='notify', default=True,
help='do not send e-mails on upload')
p.add_option('-p', '--private',
action='store_true', dest='private', default=False,
help='upload as a private change (deprecated; use --wip)')
p.add_option('-w', '--wip',
action='store_true', dest='wip', default=False,
help='upload as a work-in-progress change')
p.add_option('-r', '--ready',
action='store_true', default=False,
help='mark change as ready (clears work-in-progress setting)')
p.add_option('-o', '--push-option',
type='string', action='append', dest='push_options',
default=[], default=[],
help='additional push options to transmit') help="add hashtags (comma delimited) to the review",
p.add_option('-D', '--destination', '--dest', )
type='string', action='store', dest='dest_branch', p.add_option(
metavar='BRANCH', "--hashtag-branch",
help='submit for review on this target branch') "--htb",
p.add_option('-n', '--dry-run', action="store_true",
dest='dryrun', default=False, action='store_true', help="add local branch name as a hashtag",
help='do everything except actually upload the CL') )
p.add_option('-y', '--yes', p.add_option(
default=False, action='store_true', "-l",
help='answer yes to all safe prompts') "--label",
p.add_option('--ignore-untracked-files', dest="labels",
action='store_true', default=False, action="append",
help='ignore untracked files in the working copy') default=[],
p.add_option('--no-ignore-untracked-files', help="add a label when uploading",
dest='ignore_untracked_files', action='store_false', )
help='always ask about untracked files in the working copy') p.add_option(
p.add_option('--no-cert-checks', "--re",
dest='validate_certs', action='store_false', default=True, "--reviewers",
help='disable verifying ssl certs (unsafe)') type="string",
RepoHook.AddOptionGroup(p, 'pre-upload') action="append",
dest="reviewers",
help="request reviews from these people",
)
p.add_option(
"--cc",
type="string",
action="append",
dest="cc",
help="also send email to these email addresses",
)
p.add_option(
"--br",
"--branch",
type="string",
action="store",
dest="branch",
help="(local) branch to upload",
)
p.add_option(
"-c",
"--current-branch",
dest="current_branch",
action="store_true",
help="upload current git branch",
)
p.add_option(
"--no-current-branch",
dest="current_branch",
action="store_false",
help="upload all git branches",
)
# Turn this into a warning & remove this someday.
p.add_option(
"--cbr",
dest="current_branch",
action="store_true",
help=optparse.SUPPRESS_HELP,
)
p.add_option(
"--ne",
"--no-emails",
action="store_false",
dest="notify",
default=True,
help="do not send e-mails on upload",
)
p.add_option(
"-p",
"--private",
action="store_true",
dest="private",
default=False,
help="upload as a private change (deprecated; use --wip)",
)
p.add_option(
"-w",
"--wip",
action="store_true",
dest="wip",
default=False,
help="upload as a work-in-progress change",
)
p.add_option(
"-r",
"--ready",
action="store_true",
default=False,
help="mark change as ready (clears work-in-progress setting)",
)
p.add_option(
"-o",
"--push-option",
type="string",
action="append",
dest="push_options",
default=[],
help="additional push options to transmit",
)
p.add_option(
"-D",
"--destination",
"--dest",
type="string",
action="store",
dest="dest_branch",
metavar="BRANCH",
help="submit for review on this target branch",
)
p.add_option(
"-n",
"--dry-run",
dest="dryrun",
default=False,
action="store_true",
help="do everything except actually upload the CL",
)
p.add_option(
"-y",
"--yes",
default=False,
action="store_true",
help="answer yes to all safe prompts",
)
p.add_option(
"--ignore-untracked-files",
action="store_true",
default=False,
help="ignore untracked files in the working copy",
)
p.add_option(
"--no-ignore-untracked-files",
dest="ignore_untracked_files",
action="store_false",
help="always ask about untracked files in the working copy",
)
p.add_option(
"--no-cert-checks",
dest="validate_certs",
action="store_false",
default=True,
help="disable verifying ssl certs (unsafe)",
)
RepoHook.AddOptionGroup(p, "pre-upload")
def _SingleBranch(self, opt, branch, people): def _SingleBranch(self, opt, branch, people):
project = branch.project project = branch.project
name = branch.name name = branch.name
remote = project.GetBranch(name).remote remote = project.GetBranch(name).remote
key = 'review.%s.autoupload' % remote.review key = "review.%s.autoupload" % remote.review
answer = project.config.GetBoolean(key) answer = project.config.GetBoolean(key)
if answer is False: if answer is False:
@ -281,25 +378,36 @@ Gerrit Code Review: https://www.gerritcodereview.com/
date = branch.date date = branch.date
commit_list = branch.commits commit_list = branch.commits
destination = opt.dest_branch or project.dest_branch or project.revisionExpr destination = (
print('Upload project %s/ to remote branch %s%s:' % opt.dest_branch or project.dest_branch or project.revisionExpr
(project.RelPath(local=opt.this_manifest_only), destination, )
' (private)' if opt.private else '')) print(
print(' branch %s (%2d commit%s, %s):' % ( "Upload project %s/ to remote branch %s%s:"
% (
project.RelPath(local=opt.this_manifest_only),
destination,
" (private)" if opt.private else "",
)
)
print(
" branch %s (%2d commit%s, %s):"
% (
name, name,
len(commit_list), len(commit_list),
len(commit_list) != 1 and 's' or '', len(commit_list) != 1 and "s" or "",
date)) date,
)
)
for commit in commit_list: for commit in commit_list:
print(' %s' % commit) print(" %s" % commit)
print('to %s (y/N)? ' % remote.review, end='', flush=True) print("to %s (y/N)? " % remote.review, end="", flush=True)
if opt.yes: if opt.yes:
print('<--yes>') print("<--yes>")
answer = True answer = True
else: else:
answer = sys.stdin.readline().strip().lower() answer = sys.stdin.readline().strip().lower()
answer = answer in ('y', 'yes', '1', 'true', 't') answer = answer in ("y", "yes", "1", "true", "t")
if not answer: if not answer:
_die("upload aborted by user") _die("upload aborted by user")
@ -314,11 +422,11 @@ Gerrit Code Review: https://www.gerritcodereview.com/
branches = {} branches = {}
script = [] script = []
script.append('# Uncomment the branches to upload:') script.append("# Uncomment the branches to upload:")
for project, avail in pending: for project, avail in pending:
project_path = project.RelPath(local=opt.this_manifest_only) project_path = project.RelPath(local=opt.this_manifest_only)
script.append('#') script.append("#")
script.append(f'# project {project_path}/:') script.append(f"# project {project_path}/:")
b = {} b = {}
for branch in avail: for branch in avail:
@ -329,26 +437,34 @@ Gerrit Code Review: https://www.gerritcodereview.com/
commit_list = branch.commits commit_list = branch.commits
if b: if b:
script.append('#') script.append("#")
destination = opt.dest_branch or project.dest_branch or project.revisionExpr destination = (
script.append('# branch %s (%2d commit%s, %s) to remote branch %s:' % ( opt.dest_branch
or project.dest_branch
or project.revisionExpr
)
script.append(
"# branch %s (%2d commit%s, %s) to remote branch %s:"
% (
name, name,
len(commit_list), len(commit_list),
len(commit_list) != 1 and 's' or '', len(commit_list) != 1 and "s" or "",
date, date,
destination)) destination,
)
)
for commit in commit_list: for commit in commit_list:
script.append('# %s' % commit) script.append("# %s" % commit)
b[name] = branch b[name] = branch
projects[project_path] = project projects[project_path] = project
branches[project_path] = b branches[project_path] = b
script.append('') script.append("")
script = Editor.EditString("\n".join(script)).split("\n") script = Editor.EditString("\n".join(script)).split("\n")
project_re = re.compile(r'^#?\s*project\s*([^\s]+)/:$') project_re = re.compile(r"^#?\s*project\s*([^\s]+)/:$")
branch_re = re.compile(r'^\s*branch\s*([^\s(]+)\s*\(.*') branch_re = re.compile(r"^\s*branch\s*([^\s(]+)\s*\(.*")
project = None project = None
todo = [] todo = []
@ -359,18 +475,18 @@ Gerrit Code Review: https://www.gerritcodereview.com/
name = m.group(1) name = m.group(1)
project = projects.get(name) project = projects.get(name)
if not project: if not project:
_die('project %s not available for upload', name) _die("project %s not available for upload", name)
continue continue
m = branch_re.match(line) m = branch_re.match(line)
if m: if m:
name = m.group(1) name = m.group(1)
if not project: if not project:
_die('project for branch %s not in script', name) _die("project for branch %s not in script", name)
project_path = project.RelPath(local=opt.this_manifest_only) project_path = project.RelPath(local=opt.this_manifest_only)
branch = branches[project_path].get(name) branch = branches[project_path].get(name)
if not branch: if not branch:
_die('branch %s not in %s', name, project_path) _die("branch %s not in %s", name, project_path)
todo.append(branch) todo.append(branch)
if not todo: if not todo:
_die("nothing uncommented for upload") _die("nothing uncommented for upload")
@ -384,21 +500,21 @@ Gerrit Code Review: https://www.gerritcodereview.com/
def _AppendAutoList(self, branch, people): def _AppendAutoList(self, branch, people):
""" """
Appends the list of reviewers in the git project's config. Appends the list of reviewers in the git project's config.
Appends the list of users in the CC list in the git project's config if a Appends the list of users in the CC list in the git project's config if
non-empty reviewer list was found. a non-empty reviewer list was found.
""" """
name = branch.name name = branch.name
project = branch.project project = branch.project
key = 'review.%s.autoreviewer' % project.GetBranch(name).remote.review key = "review.%s.autoreviewer" % project.GetBranch(name).remote.review
raw_list = project.config.GetString(key) raw_list = project.config.GetString(key)
if raw_list is not None: if raw_list is not None:
people[0].extend([entry.strip() for entry in raw_list.split(',')]) people[0].extend([entry.strip() for entry in raw_list.split(",")])
key = 'review.%s.autocopy' % project.GetBranch(name).remote.review key = "review.%s.autocopy" % project.GetBranch(name).remote.review
raw_list = project.config.GetString(key) raw_list = project.config.GetString(key)
if raw_list is not None and len(people[0]) > 0: if raw_list is not None and len(people[0]) > 0:
people[1].extend([entry.strip() for entry in raw_list.split(',')]) people[1].extend([entry.strip() for entry in raw_list.split(",")])
def _FindGerritChange(self, branch): def _FindGerritChange(self, branch):
last_pub = branch.project.WasPublished(branch.name) last_pub = branch.project.WasPublished(branch.name)
@ -408,7 +524,7 @@ Gerrit Code Review: https://www.gerritcodereview.com/
refs = branch.GetPublishedRefs() refs = branch.GetPublishedRefs()
try: try:
# refs/changes/XYZ/N --> XYZ # refs/changes/XYZ/N --> XYZ
return refs.get(last_pub).split('/')[-2] return refs.get(last_pub).split("/")[-2]
except (AttributeError, IndexError): except (AttributeError, IndexError):
return "" return ""
@ -419,93 +535,113 @@ Gerrit Code Review: https://www.gerritcodereview.com/
people = copy.deepcopy(original_people) people = copy.deepcopy(original_people)
self._AppendAutoList(branch, people) self._AppendAutoList(branch, people)
# Check if there are local changes that may have been forgotten # Check if there are local changes that may have been forgotten.
changes = branch.project.UncommitedFiles() changes = branch.project.UncommitedFiles()
if opt.ignore_untracked_files: if opt.ignore_untracked_files:
untracked = set(branch.project.UntrackedFiles()) untracked = set(branch.project.UntrackedFiles())
changes = [x for x in changes if x not in untracked] changes = [x for x in changes if x not in untracked]
if changes: if changes:
key = 'review.%s.autoupload' % branch.project.remote.review key = "review.%s.autoupload" % branch.project.remote.review
answer = branch.project.config.GetBoolean(key) answer = branch.project.config.GetBoolean(key)
# if they want to auto upload, let's not ask because it could be automated # If they want to auto upload, let's not ask because it
# could be automated.
if answer is None: if answer is None:
print() print()
print('Uncommitted changes in %s (did you forget to amend?):' print(
% branch.project.name) "Uncommitted changes in %s (did you forget to "
print('\n'.join(changes)) "amend?):" % branch.project.name
print('Continue uploading? (y/N) ', end='', flush=True) )
print("\n".join(changes))
print("Continue uploading? (y/N) ", end="", flush=True)
if opt.yes: if opt.yes:
print('<--yes>') print("<--yes>")
a = 'yes' a = "yes"
else: else:
a = sys.stdin.readline().strip().lower() a = sys.stdin.readline().strip().lower()
if a not in ('y', 'yes', 't', 'true', 'on'): if a not in ("y", "yes", "t", "true", "on"):
print("skipping upload", file=sys.stderr) print("skipping upload", file=sys.stderr)
branch.uploaded = False branch.uploaded = False
branch.error = 'User aborted' branch.error = "User aborted"
continue continue
# Check if topic branches should be sent to the server during upload # Check if topic branches should be sent to the server during
# upload.
if opt.auto_topic is not True: if opt.auto_topic is not True:
key = 'review.%s.uploadtopic' % branch.project.remote.review key = "review.%s.uploadtopic" % branch.project.remote.review
opt.auto_topic = branch.project.config.GetBoolean(key) opt.auto_topic = branch.project.config.GetBoolean(key)
def _ExpandCommaList(value): def _ExpandCommaList(value):
"""Split |value| up into comma delimited entries.""" """Split |value| up into comma delimited entries."""
if not value: if not value:
return return
for ret in value.split(','): for ret in value.split(","):
ret = ret.strip() ret = ret.strip()
if ret: if ret:
yield ret yield ret
# Check if hashtags should be included. # Check if hashtags should be included.
key = 'review.%s.uploadhashtags' % branch.project.remote.review key = "review.%s.uploadhashtags" % branch.project.remote.review
hashtags = set(_ExpandCommaList(branch.project.config.GetString(key))) hashtags = set(
_ExpandCommaList(branch.project.config.GetString(key))
)
for tag in opt.hashtags: for tag in opt.hashtags:
hashtags.update(_ExpandCommaList(tag)) hashtags.update(_ExpandCommaList(tag))
if opt.hashtag_branch: if opt.hashtag_branch:
hashtags.add(branch.name) hashtags.add(branch.name)
# Check if labels should be included. # Check if labels should be included.
key = 'review.%s.uploadlabels' % branch.project.remote.review key = "review.%s.uploadlabels" % branch.project.remote.review
labels = set(_ExpandCommaList(branch.project.config.GetString(key))) labels = set(
_ExpandCommaList(branch.project.config.GetString(key))
)
for label in opt.labels: for label in opt.labels:
labels.update(_ExpandCommaList(label)) labels.update(_ExpandCommaList(label))
# Handle e-mail notifications. # Handle e-mail notifications.
if opt.notify is False: if opt.notify is False:
notify = 'NONE' notify = "NONE"
else: else:
key = 'review.%s.uploadnotify' % branch.project.remote.review key = (
"review.%s.uploadnotify" % branch.project.remote.review
)
notify = branch.project.config.GetString(key) notify = branch.project.config.GetString(key)
destination = opt.dest_branch or branch.project.dest_branch destination = opt.dest_branch or branch.project.dest_branch
if branch.project.dest_branch and not opt.dest_branch: if branch.project.dest_branch and not opt.dest_branch:
merge_branch = self._GetMergeBranch( merge_branch = self._GetMergeBranch(
branch.project, local_branch=branch.name) branch.project, local_branch=branch.name
)
full_dest = destination full_dest = destination
if not full_dest.startswith(R_HEADS): if not full_dest.startswith(R_HEADS):
full_dest = R_HEADS + full_dest full_dest = R_HEADS + full_dest
# If the merge branch of the local branch is different from the # If the merge branch of the local branch is different from
# project's revision AND destination, this might not be intentional. # the project's revision AND destination, this might not be
if (merge_branch and merge_branch != branch.project.revisionExpr # intentional.
and merge_branch != full_dest): if (
print(f'For local branch {branch.name}: merge branch ' merge_branch
f'{merge_branch} does not match destination branch ' and merge_branch != branch.project.revisionExpr
f'{destination}') and merge_branch != full_dest
print('skipping upload.') ):
print(f'Please use `--destination {destination}` if this is intentional') print(
f"For local branch {branch.name}: merge branch "
f"{merge_branch} does not match destination branch "
f"{destination}"
)
print("skipping upload.")
print(
f"Please use `--destination {destination}` if this "
"is intentional"
)
branch.uploaded = False branch.uploaded = False
continue continue
branch.UploadForReview(people, branch.UploadForReview(
people,
dryrun=opt.dryrun, dryrun=opt.dryrun,
auto_topic=opt.auto_topic, auto_topic=opt.auto_topic,
hashtags=hashtags, hashtags=hashtags,
@ -516,7 +652,8 @@ Gerrit Code Review: https://www.gerritcodereview.com/
ready=opt.ready, ready=opt.ready,
dest_branch=destination, dest_branch=destination,
validate_certs=opt.validate_certs, validate_certs=opt.validate_certs,
push_options=opt.push_options) push_options=opt.push_options,
)
branch.uploaded = True branch.uploaded = True
except UploadError as e: except UploadError as e:
@ -525,44 +662,58 @@ Gerrit Code Review: https://www.gerritcodereview.com/
have_errors = True have_errors = True
print(file=sys.stderr) print(file=sys.stderr)
print('----------------------------------------------------------------------', file=sys.stderr) print("-" * 70, file=sys.stderr)
if have_errors: if have_errors:
for branch in todo: for branch in todo:
if not branch.uploaded: if not branch.uploaded:
if len(str(branch.error)) <= 30: if len(str(branch.error)) <= 30:
fmt = ' (%s)' fmt = " (%s)"
else: else:
fmt = '\n (%s)' fmt = "\n (%s)"
print(('[FAILED] %-15s %-15s' + fmt) % ( print(
branch.project.RelPath(local=opt.this_manifest_only) + '/', ("[FAILED] %-15s %-15s" + fmt)
% (
branch.project.RelPath(local=opt.this_manifest_only)
+ "/",
branch.name, branch.name,
str(branch.error)), str(branch.error),
file=sys.stderr) ),
file=sys.stderr,
)
print() print()
for branch in todo: for branch in todo:
if branch.uploaded: if branch.uploaded:
print('[OK ] %-15s %s' % ( print(
branch.project.RelPath(local=opt.this_manifest_only) + '/', "[OK ] %-15s %s"
branch.name), % (
file=sys.stderr) branch.project.RelPath(local=opt.this_manifest_only)
+ "/",
branch.name,
),
file=sys.stderr,
)
if have_errors: if have_errors:
sys.exit(1) sys.exit(1)
def _GetMergeBranch(self, project, local_branch=None): def _GetMergeBranch(self, project, local_branch=None):
if local_branch is None: if local_branch is None:
p = GitCommand(project, p = GitCommand(
['rev-parse', '--abbrev-ref', 'HEAD'], project,
["rev-parse", "--abbrev-ref", "HEAD"],
capture_stdout=True, capture_stdout=True,
capture_stderr=True) capture_stderr=True,
)
p.Wait() p.Wait()
local_branch = p.stdout.strip() local_branch = p.stdout.strip()
p = GitCommand(project, p = GitCommand(
['config', '--get', 'branch.%s.merge' % local_branch], project,
["config", "--get", "branch.%s.merge" % local_branch],
capture_stdout=True, capture_stdout=True,
capture_stderr=True) capture_stderr=True,
)
p.Wait() p.Wait()
merge_branch = p.stdout.strip() merge_branch = p.stdout.strip()
return merge_branch return merge_branch
@ -579,19 +730,26 @@ Gerrit Code Review: https://www.gerritcodereview.com/
return (project, avail) return (project, avail)
def Execute(self, opt, args): def Execute(self, opt, args):
projects = self.GetProjects(args, all_manifests=not opt.this_manifest_only) projects = self.GetProjects(
args, all_manifests=not opt.this_manifest_only
)
def _ProcessResults(_pool, _out, results): def _ProcessResults(_pool, _out, results):
pending = [] pending = []
for result in results: for result in results:
project, avail = result project, avail = result
if avail is None: if avail is None:
print('repo: error: %s: Unable to upload branch "%s". ' print(
'You might be able to fix the branch by running:\n' 'repo: error: %s: Unable to upload branch "%s". '
' git branch --set-upstream-to m/%s' % "You might be able to fix the branch by running:\n"
(project.RelPath(local=opt.this_manifest_only), project.CurrentBranch, " git branch --set-upstream-to m/%s"
project.manifest.branch), % (
file=sys.stderr) project.RelPath(local=opt.this_manifest_only),
project.CurrentBranch,
project.manifest.branch,
),
file=sys.stderr,
)
elif avail: elif avail:
pending.append(result) pending.append(result)
return pending return pending
@ -600,29 +758,47 @@ Gerrit Code Review: https://www.gerritcodereview.com/
opt.jobs, opt.jobs,
functools.partial(self._GatherOne, opt), functools.partial(self._GatherOne, opt),
projects, projects,
callback=_ProcessResults) callback=_ProcessResults,
)
if not pending: if not pending:
if opt.branch is None: if opt.branch is None:
print('repo: error: no branches ready for upload', file=sys.stderr) print(
"repo: error: no branches ready for upload", file=sys.stderr
)
else: else:
print('repo: error: no branches named "%s" ready for upload' % print(
(opt.branch,), file=sys.stderr) 'repo: error: no branches named "%s" ready for upload'
% (opt.branch,),
file=sys.stderr,
)
return 1 return 1
manifests = {project.manifest.topdir: project.manifest manifests = {
for (project, available) in pending} project.manifest.topdir: project.manifest
for (project, available) in pending
}
ret = 0 ret = 0
for manifest in manifests.values(): for manifest in manifests.values():
pending_proj_names = [project.name for (project, available) in pending pending_proj_names = [
if project.manifest.topdir == manifest.topdir] project.name
pending_worktrees = [project.worktree for (project, available) in pending for (project, available) in pending
if project.manifest.topdir == manifest.topdir] if project.manifest.topdir == manifest.topdir
]
pending_worktrees = [
project.worktree
for (project, available) in pending
if project.manifest.topdir == manifest.topdir
]
hook = RepoHook.FromSubcmd( hook = RepoHook.FromSubcmd(
hook_type='pre-upload', manifest=manifest, hook_type="pre-upload",
opt=opt, abort_if_user_denies=True) manifest=manifest,
if not hook.Run(project_list=pending_proj_names, opt=opt,
worktree_list=pending_worktrees): abort_if_user_denies=True,
)
if not hook.Run(
project_list=pending_proj_names, worktree_list=pending_worktrees
):
ret = 1 ret = 1
if ret: if ret:
return ret return ret

View File

@ -34,33 +34,40 @@ class Version(Command, MirrorSafeCommand):
def Execute(self, opt, args): def Execute(self, opt, args):
rp = self.manifest.repoProject rp = self.manifest.repoProject
rem = rp.GetRemote() rem = rp.GetRemote()
branch = rp.GetBranch('default') branch = rp.GetBranch("default")
# These might not be the same. Report them both. # These might not be the same. Report them both.
src_ver = RepoSourceVersion() src_ver = RepoSourceVersion()
rp_ver = rp.bare_git.describe(HEAD) rp_ver = rp.bare_git.describe(HEAD)
print('repo version %s' % rp_ver) print("repo version %s" % rp_ver)
print(' (from %s)' % rem.url) print(" (from %s)" % rem.url)
print(' (tracking %s)' % branch.merge) print(" (tracking %s)" % branch.merge)
print(' (%s)' % rp.bare_git.log('-1', '--format=%cD', HEAD)) print(" (%s)" % rp.bare_git.log("-1", "--format=%cD", HEAD))
if self.wrapper_path is not None: if self.wrapper_path is not None:
print('repo launcher version %s' % self.wrapper_version) print("repo launcher version %s" % self.wrapper_version)
print(' (from %s)' % self.wrapper_path) print(" (from %s)" % self.wrapper_path)
if src_ver != rp_ver: if src_ver != rp_ver:
print(' (currently at %s)' % src_ver) print(" (currently at %s)" % src_ver)
print('repo User-Agent %s' % user_agent.repo) print("repo User-Agent %s" % user_agent.repo)
print('git %s' % git.version_tuple().full) print("git %s" % git.version_tuple().full)
print('git User-Agent %s' % user_agent.git) print("git User-Agent %s" % user_agent.git)
print('Python %s' % sys.version) print("Python %s" % sys.version)
uname = platform.uname() uname = platform.uname()
if sys.version_info.major < 3: if sys.version_info.major < 3:
# Python 3 returns a named tuple, but Python 2 is simpler. # Python 3 returns a named tuple, but Python 2 is simpler.
print(uname) print(uname)
else: else:
print('OS %s %s (%s)' % (uname.system, uname.release, uname.version)) print(
print('CPU %s (%s)' % "OS %s %s (%s)" % (uname.system, uname.release, uname.version)
(uname.machine, uname.processor if uname.processor else 'unknown')) )
print('Bug reports:', Wrapper().BUG_URL) print(
"CPU %s (%s)"
% (
uname.machine,
uname.processor if uname.processor else "unknown",
)
)
print("Bug reports:", Wrapper().BUG_URL)

View File

@ -21,5 +21,5 @@ import repo_trace
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def disable_repo_trace(tmp_path): def disable_repo_trace(tmp_path):
"""Set an environment marker to relax certain strict checks for test code.""" """Set an environment marker to relax certain strict checks for test code.""" # noqa: E501
repo_trace._TRACE_FILE = str(tmp_path / 'TRACE_FILE_from_test') repo_trace._TRACE_FILE = str(tmp_path / "TRACE_FILE_from_test")

View File

@ -38,8 +38,8 @@ class GetEditor(EditorTestCase):
def test_basic(self): def test_basic(self):
"""Basic checking of _GetEditor.""" """Basic checking of _GetEditor."""
self.setEditor(':') self.setEditor(":")
self.assertEqual(':', Editor._GetEditor()) self.assertEqual(":", Editor._GetEditor())
class EditString(EditorTestCase): class EditString(EditorTestCase):
@ -47,10 +47,10 @@ class EditString(EditorTestCase):
def test_no_editor(self): def test_no_editor(self):
"""Check behavior when no editor is available.""" """Check behavior when no editor is available."""
self.setEditor(':') self.setEditor(":")
self.assertEqual('foo', Editor.EditString('foo')) self.assertEqual("foo", Editor.EditString("foo"))
def test_cat_editor(self): def test_cat_editor(self):
"""Check behavior when editor is `cat`.""" """Check behavior when editor is `cat`."""
self.setEditor('cat') self.setEditor("cat")
self.assertEqual('foo', Editor.EditString('foo')) self.assertEqual("foo", Editor.EditString("foo"))

View File

@ -47,7 +47,9 @@ class PickleTests(unittest.TestCase):
try: try:
newobj = pickle.loads(p) newobj = pickle.loads(p)
except Exception as e: # pylint: disable=broad-except except Exception as e: # pylint: disable=broad-except
self.fail('Class %s is unable to be pickled: %s\n' self.fail(
'Incomplete super().__init__(...) call?' % (cls, e)) "Class %s is unable to be pickled: %s\n"
"Incomplete super().__init__(...) call?" % (cls, e)
)
self.assertIsInstance(newobj, cls) self.assertIsInstance(newobj, cls)
self.assertEqual(str(obj), str(newobj)) self.assertEqual(str(obj), str(newobj))

View File

@ -31,32 +31,38 @@ class GitCommandTest(unittest.TestCase):
"""Tests the GitCommand class (via git_command.git).""" """Tests the GitCommand class (via git_command.git)."""
def setUp(self): def setUp(self):
def realpath_mock(val): def realpath_mock(val):
return val return val
mock.patch.object(os.path, 'realpath', side_effect=realpath_mock).start() mock.patch.object(
os.path, "realpath", side_effect=realpath_mock
).start()
def tearDown(self): def tearDown(self):
mock.patch.stopall() mock.patch.stopall()
def test_alternative_setting_when_matching(self): def test_alternative_setting_when_matching(self):
r = git_command._build_env( r = git_command._build_env(
objdir = os.path.join('zap', 'objects'), objdir=os.path.join("zap", "objects"), gitdir="zap"
gitdir = 'zap'
) )
self.assertIsNone(r.get('GIT_ALTERNATE_OBJECT_DIRECTORIES')) self.assertIsNone(r.get("GIT_ALTERNATE_OBJECT_DIRECTORIES"))
self.assertEqual(r.get('GIT_OBJECT_DIRECTORY'), os.path.join('zap', 'objects')) self.assertEqual(
r.get("GIT_OBJECT_DIRECTORY"), os.path.join("zap", "objects")
)
def test_alternative_setting_when_different(self): def test_alternative_setting_when_different(self):
r = git_command._build_env( r = git_command._build_env(
objdir = os.path.join('wow', 'objects'), objdir=os.path.join("wow", "objects"), gitdir="zap"
gitdir = 'zap'
) )
self.assertEqual(r.get('GIT_ALTERNATE_OBJECT_DIRECTORIES'), os.path.join('zap', 'objects')) self.assertEqual(
self.assertEqual(r.get('GIT_OBJECT_DIRECTORY'), os.path.join('wow', 'objects')) r.get("GIT_ALTERNATE_OBJECT_DIRECTORIES"),
os.path.join("zap", "objects"),
)
self.assertEqual(
r.get("GIT_OBJECT_DIRECTORY"), os.path.join("wow", "objects")
)
class GitCallUnitTest(unittest.TestCase): class GitCallUnitTest(unittest.TestCase):
@ -68,8 +74,8 @@ class GitCallUnitTest(unittest.TestCase):
self.assertIsNotNone(ver) self.assertIsNotNone(ver)
# We don't dive too deep into the values here to avoid having to update # We don't dive too deep into the values here to avoid having to update
# whenever git versions change. We do check relative to this min version # whenever git versions change. We do check relative to this min
# as this is what `repo` itself requires via MIN_GIT_VERSION. # version as this is what `repo` itself requires via MIN_GIT_VERSION.
MIN_GIT_VERSION = (2, 10, 2) MIN_GIT_VERSION = (2, 10, 2)
self.assertTrue(isinstance(ver.major, int)) self.assertTrue(isinstance(ver.major, int))
self.assertTrue(isinstance(ver.minor, int)) self.assertTrue(isinstance(ver.minor, int))
@ -82,7 +88,7 @@ class GitCallUnitTest(unittest.TestCase):
self.assertGreaterEqual(ver, MIN_GIT_VERSION) self.assertGreaterEqual(ver, MIN_GIT_VERSION)
self.assertLess(ver, (9999, 9999, 9999)) self.assertLess(ver, (9999, 9999, 9999))
self.assertNotEqual('', ver.full) self.assertNotEqual("", ver.full)
class UserAgentUnitTest(unittest.TestCase): class UserAgentUnitTest(unittest.TestCase):
@ -91,25 +97,25 @@ class UserAgentUnitTest(unittest.TestCase):
def test_smoke_os(self): def test_smoke_os(self):
"""Make sure UA OS setting returns something useful.""" """Make sure UA OS setting returns something useful."""
os_name = git_command.user_agent.os os_name = git_command.user_agent.os
# We can't dive too deep because of OS/tool differences, but we can check # We can't dive too deep because of OS/tool differences, but we can
# the general form. # check the general form.
m = re.match(r'^[^ ]+$', os_name) m = re.match(r"^[^ ]+$", os_name)
self.assertIsNotNone(m) self.assertIsNotNone(m)
def test_smoke_repo(self): def test_smoke_repo(self):
"""Make sure repo UA returns something useful.""" """Make sure repo UA returns something useful."""
ua = git_command.user_agent.repo ua = git_command.user_agent.repo
# We can't dive too deep because of OS/tool differences, but we can check # We can't dive too deep because of OS/tool differences, but we can
# the general form. # check the general form.
m = re.match(r'^git-repo/[^ ]+ ([^ ]+) git/[^ ]+ Python/[0-9.]+', ua) m = re.match(r"^git-repo/[^ ]+ ([^ ]+) git/[^ ]+ Python/[0-9.]+", ua)
self.assertIsNotNone(m) self.assertIsNotNone(m)
def test_smoke_git(self): def test_smoke_git(self):
"""Make sure git UA returns something useful.""" """Make sure git UA returns something useful."""
ua = git_command.user_agent.git ua = git_command.user_agent.git
# We can't dive too deep because of OS/tool differences, but we can check # We can't dive too deep because of OS/tool differences, but we can
# the general form. # check the general form.
m = re.match(r'^git/[^ ]+ ([^ ]+) git-repo/[^ ]+', ua) m = re.match(r"^git/[^ ]+ ([^ ]+) git-repo/[^ ]+", ua)
self.assertIsNotNone(m) self.assertIsNotNone(m)
@ -119,7 +125,9 @@ class GitRequireTests(unittest.TestCase):
def setUp(self): def setUp(self):
self.wrapper = wrapper.Wrapper() self.wrapper = wrapper.Wrapper()
ver = self.wrapper.GitVersion(1, 2, 3, 4) ver = self.wrapper.GitVersion(1, 2, 3, 4)
mock.patch.object(git_command.git, 'version_tuple', return_value=ver).start() mock.patch.object(
git_command.git, "version_tuple", return_value=ver
).start()
def tearDown(self): def tearDown(self):
mock.patch.stopall() mock.patch.stopall()
@ -152,5 +160,5 @@ class GitRequireTests(unittest.TestCase):
def test_older_fatal_msg(self): def test_older_fatal_msg(self):
"""Test fatal require calls with old versions and message.""" """Test fatal require calls with old versions and message."""
with self.assertRaises(SystemExit) as e: with self.assertRaises(SystemExit) as e:
git_command.git_require((2,), fail=True, msg='so sad') git_command.git_require((2,), fail=True, msg="so sad")
self.assertNotEqual(0, e.code) self.assertNotEqual(0, e.code)

View File

@ -22,18 +22,16 @@ import git_config
def fixture(*paths): def fixture(*paths):
"""Return a path relative to test/fixtures. """Return a path relative to test/fixtures."""
""" return os.path.join(os.path.dirname(__file__), "fixtures", *paths)
return os.path.join(os.path.dirname(__file__), 'fixtures', *paths)
class GitConfigReadOnlyTests(unittest.TestCase): class GitConfigReadOnlyTests(unittest.TestCase):
"""Read-only tests of the GitConfig class.""" """Read-only tests of the GitConfig class."""
def setUp(self): def setUp(self):
"""Create a GitConfig object using the test.gitconfig fixture. """Create a GitConfig object using the test.gitconfig fixture."""
""" config_fixture = fixture("test.gitconfig")
config_fixture = fixture('test.gitconfig')
self.config = git_config.GitConfig(config_fixture) self.config = git_config.GitConfig(config_fixture)
def test_GetString_with_empty_config_values(self): def test_GetString_with_empty_config_values(self):
@ -44,7 +42,7 @@ class GitConfigReadOnlyTests(unittest.TestCase):
empty empty
""" """
val = self.config.GetString('section.empty') val = self.config.GetString("section.empty")
self.assertEqual(val, None) self.assertEqual(val, None)
def test_GetString_with_true_value(self): def test_GetString_with_true_value(self):
@ -55,54 +53,54 @@ class GitConfigReadOnlyTests(unittest.TestCase):
nonempty = true nonempty = true
""" """
val = self.config.GetString('section.nonempty') val = self.config.GetString("section.nonempty")
self.assertEqual(val, 'true') self.assertEqual(val, "true")
def test_GetString_from_missing_file(self): def test_GetString_from_missing_file(self):
""" """
Test missing config file Test missing config file
""" """
config_fixture = fixture('not.present.gitconfig') config_fixture = fixture("not.present.gitconfig")
config = git_config.GitConfig(config_fixture) config = git_config.GitConfig(config_fixture)
val = config.GetString('empty') val = config.GetString("empty")
self.assertEqual(val, None) self.assertEqual(val, None)
def test_GetBoolean_undefined(self): def test_GetBoolean_undefined(self):
"""Test GetBoolean on key that doesn't exist.""" """Test GetBoolean on key that doesn't exist."""
self.assertIsNone(self.config.GetBoolean('section.missing')) self.assertIsNone(self.config.GetBoolean("section.missing"))
def test_GetBoolean_invalid(self): def test_GetBoolean_invalid(self):
"""Test GetBoolean on invalid boolean value.""" """Test GetBoolean on invalid boolean value."""
self.assertIsNone(self.config.GetBoolean('section.boolinvalid')) self.assertIsNone(self.config.GetBoolean("section.boolinvalid"))
def test_GetBoolean_true(self): def test_GetBoolean_true(self):
"""Test GetBoolean on valid true boolean.""" """Test GetBoolean on valid true boolean."""
self.assertTrue(self.config.GetBoolean('section.booltrue')) self.assertTrue(self.config.GetBoolean("section.booltrue"))
def test_GetBoolean_false(self): def test_GetBoolean_false(self):
"""Test GetBoolean on valid false boolean.""" """Test GetBoolean on valid false boolean."""
self.assertFalse(self.config.GetBoolean('section.boolfalse')) self.assertFalse(self.config.GetBoolean("section.boolfalse"))
def test_GetInt_undefined(self): def test_GetInt_undefined(self):
"""Test GetInt on key that doesn't exist.""" """Test GetInt on key that doesn't exist."""
self.assertIsNone(self.config.GetInt('section.missing')) self.assertIsNone(self.config.GetInt("section.missing"))
def test_GetInt_invalid(self): def test_GetInt_invalid(self):
"""Test GetInt on invalid integer value.""" """Test GetInt on invalid integer value."""
self.assertIsNone(self.config.GetBoolean('section.intinvalid')) self.assertIsNone(self.config.GetBoolean("section.intinvalid"))
def test_GetInt_valid(self): def test_GetInt_valid(self):
"""Test GetInt on valid integers.""" """Test GetInt on valid integers."""
TESTS = ( TESTS = (
('inthex', 16), ("inthex", 16),
('inthexk', 16384), ("inthexk", 16384),
('int', 10), ("int", 10),
('intk', 10240), ("intk", 10240),
('intm', 10485760), ("intm", 10485760),
('intg', 10737418240), ("intg", 10737418240),
) )
for key, value in TESTS: for key, value in TESTS:
self.assertEqual(value, self.config.GetInt('section.%s' % (key,))) self.assertEqual(value, self.config.GetInt("section.%s" % (key,)))
class GitConfigReadWriteTests(unittest.TestCase): class GitConfigReadWriteTests(unittest.TestCase):
@ -119,70 +117,74 @@ class GitConfigReadWriteTests(unittest.TestCase):
def test_SetString(self): def test_SetString(self):
"""Test SetString behavior.""" """Test SetString behavior."""
# Set a value. # Set a value.
self.assertIsNone(self.config.GetString('foo.bar')) self.assertIsNone(self.config.GetString("foo.bar"))
self.config.SetString('foo.bar', 'val') self.config.SetString("foo.bar", "val")
self.assertEqual('val', self.config.GetString('foo.bar')) self.assertEqual("val", self.config.GetString("foo.bar"))
# Make sure the value was actually written out. # Make sure the value was actually written out.
config = self.get_config() config = self.get_config()
self.assertEqual('val', config.GetString('foo.bar')) self.assertEqual("val", config.GetString("foo.bar"))
# Update the value. # Update the value.
self.config.SetString('foo.bar', 'valll') self.config.SetString("foo.bar", "valll")
self.assertEqual('valll', self.config.GetString('foo.bar')) self.assertEqual("valll", self.config.GetString("foo.bar"))
config = self.get_config() config = self.get_config()
self.assertEqual('valll', config.GetString('foo.bar')) self.assertEqual("valll", config.GetString("foo.bar"))
# Delete the value. # Delete the value.
self.config.SetString('foo.bar', None) self.config.SetString("foo.bar", None)
self.assertIsNone(self.config.GetString('foo.bar')) self.assertIsNone(self.config.GetString("foo.bar"))
config = self.get_config() config = self.get_config()
self.assertIsNone(config.GetString('foo.bar')) self.assertIsNone(config.GetString("foo.bar"))
def test_SetBoolean(self): def test_SetBoolean(self):
"""Test SetBoolean behavior.""" """Test SetBoolean behavior."""
# Set a true value. # Set a true value.
self.assertIsNone(self.config.GetBoolean('foo.bar')) self.assertIsNone(self.config.GetBoolean("foo.bar"))
for val in (True, 1): for val in (True, 1):
self.config.SetBoolean('foo.bar', val) self.config.SetBoolean("foo.bar", val)
self.assertTrue(self.config.GetBoolean('foo.bar')) self.assertTrue(self.config.GetBoolean("foo.bar"))
# Make sure the value was actually written out. # Make sure the value was actually written out.
config = self.get_config() config = self.get_config()
self.assertTrue(config.GetBoolean('foo.bar')) self.assertTrue(config.GetBoolean("foo.bar"))
self.assertEqual('true', config.GetString('foo.bar')) self.assertEqual("true", config.GetString("foo.bar"))
# Set a false value. # Set a false value.
for val in (False, 0): for val in (False, 0):
self.config.SetBoolean('foo.bar', val) self.config.SetBoolean("foo.bar", val)
self.assertFalse(self.config.GetBoolean('foo.bar')) self.assertFalse(self.config.GetBoolean("foo.bar"))
# Make sure the value was actually written out. # Make sure the value was actually written out.
config = self.get_config() config = self.get_config()
self.assertFalse(config.GetBoolean('foo.bar')) self.assertFalse(config.GetBoolean("foo.bar"))
self.assertEqual('false', config.GetString('foo.bar')) self.assertEqual("false", config.GetString("foo.bar"))
# Delete the value. # Delete the value.
self.config.SetBoolean('foo.bar', None) self.config.SetBoolean("foo.bar", None)
self.assertIsNone(self.config.GetBoolean('foo.bar')) self.assertIsNone(self.config.GetBoolean("foo.bar"))
config = self.get_config() config = self.get_config()
self.assertIsNone(config.GetBoolean('foo.bar')) self.assertIsNone(config.GetBoolean("foo.bar"))
def test_GetSyncAnalysisStateData(self): def test_GetSyncAnalysisStateData(self):
"""Test config entries with a sync state analysis data.""" """Test config entries with a sync state analysis data."""
superproject_logging_data = {} superproject_logging_data = {}
superproject_logging_data['test'] = False superproject_logging_data["test"] = False
options = type('options', (object,), {})() options = type("options", (object,), {})()
options.verbose = 'true' options.verbose = "true"
options.mp_update = 'false' options.mp_update = "false"
TESTS = ( TESTS = (
('superproject.test', 'false'), ("superproject.test", "false"),
('options.verbose', 'true'), ("options.verbose", "true"),
('options.mpupdate', 'false'), ("options.mpupdate", "false"),
('main.version', '1'), ("main.version", "1"),
) )
self.config.UpdateSyncAnalysisState(options, superproject_logging_data) self.config.UpdateSyncAnalysisState(options, superproject_logging_data)
sync_data = self.config.GetSyncAnalysisStateData() sync_data = self.config.GetSyncAnalysisStateData()
for key, value in TESTS: for key, value in TESTS:
self.assertEqual(sync_data[f'{git_config.SYNC_STATE_PREFIX}{key}'], value) self.assertEqual(
self.assertTrue(sync_data[f'{git_config.SYNC_STATE_PREFIX}main.synctime']) sync_data[f"{git_config.SYNC_STATE_PREFIX}{key}"], value
)
self.assertTrue(
sync_data[f"{git_config.SYNC_STATE_PREFIX}main.synctime"]
)

View File

@ -30,18 +30,19 @@ from test_manifest_xml import sort_attributes
class SuperprojectTestCase(unittest.TestCase): class SuperprojectTestCase(unittest.TestCase):
"""TestCase for the Superproject module.""" """TestCase for the Superproject module."""
PARENT_SID_KEY = 'GIT_TRACE2_PARENT_SID' PARENT_SID_KEY = "GIT_TRACE2_PARENT_SID"
PARENT_SID_VALUE = 'parent_sid' PARENT_SID_VALUE = "parent_sid"
SELF_SID_REGEX = r'repo-\d+T\d+Z-.*' SELF_SID_REGEX = r"repo-\d+T\d+Z-.*"
FULL_SID_REGEX = r'^%s/%s' % (PARENT_SID_VALUE, SELF_SID_REGEX) FULL_SID_REGEX = r"^%s/%s" % (PARENT_SID_VALUE, SELF_SID_REGEX)
def setUp(self): def setUp(self):
"""Set up superproject every time.""" """Set up superproject every time."""
self.tempdirobj = tempfile.TemporaryDirectory(prefix='repo_tests') self.tempdirobj = tempfile.TemporaryDirectory(prefix="repo_tests")
self.tempdir = self.tempdirobj.name self.tempdir = self.tempdirobj.name
self.repodir = os.path.join(self.tempdir, '.repo') self.repodir = os.path.join(self.tempdir, ".repo")
self.manifest_file = os.path.join( self.manifest_file = os.path.join(
self.repodir, manifest_xml.MANIFEST_FILE_NAME) self.repodir, manifest_xml.MANIFEST_FILE_NAME
)
os.mkdir(self.repodir) os.mkdir(self.repodir)
self.platform = platform.system().lower() self.platform = platform.system().lower()
@ -53,25 +54,35 @@ class SuperprojectTestCase(unittest.TestCase):
self.git_event_log = git_trace2_event_log.EventLog(env=env) self.git_event_log = git_trace2_event_log.EventLog(env=env)
# The manifest parsing really wants a git repo currently. # The manifest parsing really wants a git repo currently.
gitdir = os.path.join(self.repodir, 'manifests.git') gitdir = os.path.join(self.repodir, "manifests.git")
os.mkdir(gitdir) os.mkdir(gitdir)
with open(os.path.join(gitdir, 'config'), 'w') as fp: with open(os.path.join(gitdir, "config"), "w") as fp:
fp.write("""[remote "origin"] fp.write(
"""[remote "origin"]
url = https://localhost:0/manifest url = https://localhost:0/manifest
""") """
)
manifest = self.getXmlManifest(""" manifest = self.getXmlManifest(
"""
<manifest> <manifest>
<remote name="default-remote" fetch="http://localhost" /> <remote name="default-remote" fetch="http://localhost" />
<default remote="default-remote" revision="refs/heads/main" /> <default remote="default-remote" revision="refs/heads/main" />
<superproject name="superproject"/> <superproject name="superproject"/>
<project path="art" name="platform/art" groups="notdefault,platform-""" + self.platform + """ <project path="art" name="platform/art" groups="notdefault,platform-"""
+ self.platform
+ """
" /></manifest> " /></manifest>
""") """
)
self._superproject = git_superproject.Superproject( self._superproject = git_superproject.Superproject(
manifest, name='superproject', manifest,
remote=manifest.remotes.get('default-remote').ToRemoteSpec('superproject'), name="superproject",
revision='refs/heads/main') remote=manifest.remotes.get("default-remote").ToRemoteSpec(
"superproject"
),
revision="refs/heads/main",
)
def tearDown(self): def tearDown(self):
"""Tear down superproject every time.""" """Tear down superproject every time."""
@ -79,29 +90,29 @@ class SuperprojectTestCase(unittest.TestCase):
def getXmlManifest(self, data): def getXmlManifest(self, data):
"""Helper to initialize a manifest for testing.""" """Helper to initialize a manifest for testing."""
with open(self.manifest_file, 'w') as fp: with open(self.manifest_file, "w") as fp:
fp.write(data) fp.write(data)
return manifest_xml.XmlManifest(self.repodir, self.manifest_file) return manifest_xml.XmlManifest(self.repodir, self.manifest_file)
def verifyCommonKeys(self, log_entry, expected_event_name, full_sid=True): def verifyCommonKeys(self, log_entry, expected_event_name, full_sid=True):
"""Helper function to verify common event log keys.""" """Helper function to verify common event log keys."""
self.assertIn('event', log_entry) self.assertIn("event", log_entry)
self.assertIn('sid', log_entry) self.assertIn("sid", log_entry)
self.assertIn('thread', log_entry) self.assertIn("thread", log_entry)
self.assertIn('time', log_entry) self.assertIn("time", log_entry)
# Do basic data format validation. # Do basic data format validation.
self.assertEqual(expected_event_name, log_entry['event']) self.assertEqual(expected_event_name, log_entry["event"])
if full_sid: if full_sid:
self.assertRegex(log_entry['sid'], self.FULL_SID_REGEX) self.assertRegex(log_entry["sid"], self.FULL_SID_REGEX)
else: else:
self.assertRegex(log_entry['sid'], self.SELF_SID_REGEX) self.assertRegex(log_entry["sid"], self.SELF_SID_REGEX)
self.assertRegex(log_entry['time'], r'^\d+-\d+-\d+T\d+:\d+:\d+\.\d+Z$') self.assertRegex(log_entry["time"], r"^\d+-\d+-\d+T\d+:\d+:\d+\.\d+Z$")
def readLog(self, log_path): def readLog(self, log_path):
"""Helper function to read log data into a list.""" """Helper function to read log data into a list."""
log_data = [] log_data = []
with open(log_path, mode='rb') as f: with open(log_path, mode="rb") as f:
for line in f: for line in f:
log_data.append(json.loads(line)) log_data.append(json.loads(line))
return log_data return log_data
@ -109,57 +120,71 @@ class SuperprojectTestCase(unittest.TestCase):
def verifyErrorEvent(self): def verifyErrorEvent(self):
"""Helper to verify that error event is written.""" """Helper to verify that error event is written."""
with tempfile.TemporaryDirectory(prefix='event_log_tests') as tempdir: with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir:
log_path = self.git_event_log.Write(path=tempdir) log_path = self.git_event_log.Write(path=tempdir)
self.log_data = self.readLog(log_path) self.log_data = self.readLog(log_path)
self.assertEqual(len(self.log_data), 2) self.assertEqual(len(self.log_data), 2)
error_event = self.log_data[1] error_event = self.log_data[1]
self.verifyCommonKeys(self.log_data[0], expected_event_name='version') self.verifyCommonKeys(self.log_data[0], expected_event_name="version")
self.verifyCommonKeys(error_event, expected_event_name='error') self.verifyCommonKeys(error_event, expected_event_name="error")
# Check for 'error' event specific fields. # Check for 'error' event specific fields.
self.assertIn('msg', error_event) self.assertIn("msg", error_event)
self.assertIn('fmt', error_event) self.assertIn("fmt", error_event)
def test_superproject_get_superproject_no_superproject(self): def test_superproject_get_superproject_no_superproject(self):
"""Test with no url.""" """Test with no url."""
manifest = self.getXmlManifest(""" manifest = self.getXmlManifest(
"""
<manifest> <manifest>
</manifest> </manifest>
""") """
)
self.assertIsNone(manifest.superproject) self.assertIsNone(manifest.superproject)
def test_superproject_get_superproject_invalid_url(self): def test_superproject_get_superproject_invalid_url(self):
"""Test with an invalid url.""" """Test with an invalid url."""
manifest = self.getXmlManifest(""" manifest = self.getXmlManifest(
"""
<manifest> <manifest>
<remote name="test-remote" fetch="localhost" /> <remote name="test-remote" fetch="localhost" />
<default remote="test-remote" revision="refs/heads/main" /> <default remote="test-remote" revision="refs/heads/main" />
<superproject name="superproject"/> <superproject name="superproject"/>
</manifest> </manifest>
""") """
)
superproject = git_superproject.Superproject( superproject = git_superproject.Superproject(
manifest, name='superproject', manifest,
remote=manifest.remotes.get('test-remote').ToRemoteSpec('superproject'), name="superproject",
revision='refs/heads/main') remote=manifest.remotes.get("test-remote").ToRemoteSpec(
"superproject"
),
revision="refs/heads/main",
)
sync_result = superproject.Sync(self.git_event_log) sync_result = superproject.Sync(self.git_event_log)
self.assertFalse(sync_result.success) self.assertFalse(sync_result.success)
self.assertTrue(sync_result.fatal) self.assertTrue(sync_result.fatal)
def test_superproject_get_superproject_invalid_branch(self): def test_superproject_get_superproject_invalid_branch(self):
"""Test with an invalid branch.""" """Test with an invalid branch."""
manifest = self.getXmlManifest(""" manifest = self.getXmlManifest(
"""
<manifest> <manifest>
<remote name="test-remote" fetch="localhost" /> <remote name="test-remote" fetch="localhost" />
<default remote="test-remote" revision="refs/heads/main" /> <default remote="test-remote" revision="refs/heads/main" />
<superproject name="superproject"/> <superproject name="superproject"/>
</manifest> </manifest>
""") """
)
self._superproject = git_superproject.Superproject( self._superproject = git_superproject.Superproject(
manifest, name='superproject', manifest,
remote=manifest.remotes.get('test-remote').ToRemoteSpec('superproject'), name="superproject",
revision='refs/heads/main') remote=manifest.remotes.get("test-remote").ToRemoteSpec(
with mock.patch.object(self._superproject, '_branch', 'junk'): "superproject"
),
revision="refs/heads/main",
)
with mock.patch.object(self._superproject, "_branch", "junk"):
sync_result = self._superproject.Sync(self.git_event_log) sync_result = self._superproject.Sync(self.git_event_log)
self.assertFalse(sync_result.success) self.assertFalse(sync_result.success)
self.assertTrue(sync_result.fatal) self.assertTrue(sync_result.fatal)
@ -167,48 +192,61 @@ class SuperprojectTestCase(unittest.TestCase):
def test_superproject_get_superproject_mock_init(self): def test_superproject_get_superproject_mock_init(self):
"""Test with _Init failing.""" """Test with _Init failing."""
with mock.patch.object(self._superproject, '_Init', return_value=False): with mock.patch.object(self._superproject, "_Init", return_value=False):
sync_result = self._superproject.Sync(self.git_event_log) sync_result = self._superproject.Sync(self.git_event_log)
self.assertFalse(sync_result.success) self.assertFalse(sync_result.success)
self.assertTrue(sync_result.fatal) self.assertTrue(sync_result.fatal)
def test_superproject_get_superproject_mock_fetch(self): def test_superproject_get_superproject_mock_fetch(self):
"""Test with _Fetch failing.""" """Test with _Fetch failing."""
with mock.patch.object(self._superproject, '_Init', return_value=True): with mock.patch.object(self._superproject, "_Init", return_value=True):
os.mkdir(self._superproject._superproject_path) os.mkdir(self._superproject._superproject_path)
with mock.patch.object(self._superproject, '_Fetch', return_value=False): with mock.patch.object(
self._superproject, "_Fetch", return_value=False
):
sync_result = self._superproject.Sync(self.git_event_log) sync_result = self._superproject.Sync(self.git_event_log)
self.assertFalse(sync_result.success) self.assertFalse(sync_result.success)
self.assertTrue(sync_result.fatal) self.assertTrue(sync_result.fatal)
def test_superproject_get_all_project_commit_ids_mock_ls_tree(self): def test_superproject_get_all_project_commit_ids_mock_ls_tree(self):
"""Test with LsTree being a mock.""" """Test with LsTree being a mock."""
data = ('120000 blob 158258bdf146f159218e2b90f8b699c4d85b5804\tAndroid.bp\x00' data = (
'160000 commit 2c2724cb36cd5a9cec6c852c681efc3b7c6b86ea\tart\x00' "120000 blob 158258bdf146f159218e2b90f8b699c4d85b5804\tAndroid.bp\x00"
'160000 commit e9d25da64d8d365dbba7c8ee00fe8c4473fe9a06\tbootable/recovery\x00' "160000 commit 2c2724cb36cd5a9cec6c852c681efc3b7c6b86ea\tart\x00"
'120000 blob acc2cbdf438f9d2141f0ae424cec1d8fc4b5d97f\tbootstrap.bash\x00' "160000 commit e9d25da64d8d365dbba7c8ee00fe8c4473fe9a06\tbootable/recovery\x00"
'160000 commit ade9b7a0d874e25fff4bf2552488825c6f111928\tbuild/bazel\x00') "120000 blob acc2cbdf438f9d2141f0ae424cec1d8fc4b5d97f\tbootstrap.bash\x00"
with mock.patch.object(self._superproject, '_Init', return_value=True): "160000 commit ade9b7a0d874e25fff4bf2552488825c6f111928\tbuild/bazel\x00"
with mock.patch.object(self._superproject, '_Fetch', return_value=True): )
with mock.patch.object(self._superproject, '_LsTree', return_value=data): with mock.patch.object(self._superproject, "_Init", return_value=True):
commit_ids_result = self._superproject._GetAllProjectsCommitIds() with mock.patch.object(
self.assertEqual(commit_ids_result.commit_ids, { self._superproject, "_Fetch", return_value=True
'art': '2c2724cb36cd5a9cec6c852c681efc3b7c6b86ea', ):
'bootable/recovery': 'e9d25da64d8d365dbba7c8ee00fe8c4473fe9a06', with mock.patch.object(
'build/bazel': 'ade9b7a0d874e25fff4bf2552488825c6f111928' self._superproject, "_LsTree", return_value=data
}) ):
commit_ids_result = (
self._superproject._GetAllProjectsCommitIds()
)
self.assertEqual(
commit_ids_result.commit_ids,
{
"art": "2c2724cb36cd5a9cec6c852c681efc3b7c6b86ea",
"bootable/recovery": "e9d25da64d8d365dbba7c8ee00fe8c4473fe9a06",
"build/bazel": "ade9b7a0d874e25fff4bf2552488825c6f111928",
},
)
self.assertFalse(commit_ids_result.fatal) self.assertFalse(commit_ids_result.fatal)
def test_superproject_write_manifest_file(self): def test_superproject_write_manifest_file(self):
"""Test with writing manifest to a file after setting revisionId.""" """Test with writing manifest to a file after setting revisionId."""
self.assertEqual(len(self._superproject._manifest.projects), 1) self.assertEqual(len(self._superproject._manifest.projects), 1)
project = self._superproject._manifest.projects[0] project = self._superproject._manifest.projects[0]
project.SetRevisionId('ABCDEF') project.SetRevisionId("ABCDEF")
# Create temporary directory so that it can write the file. # Create temporary directory so that it can write the file.
os.mkdir(self._superproject._superproject_path) os.mkdir(self._superproject._superproject_path)
manifest_path = self._superproject._WriteManifestFile() manifest_path = self._superproject._WriteManifestFile()
self.assertIsNotNone(manifest_path) self.assertIsNotNone(manifest_path)
with open(manifest_path, 'r') as fp: with open(manifest_path, "r") as fp:
manifest_xml_data = fp.read() manifest_xml_data = fp.read()
self.assertEqual( self.assertEqual(
sort_attributes(manifest_xml_data), sort_attributes(manifest_xml_data),
@ -218,46 +256,58 @@ class SuperprojectTestCase(unittest.TestCase):
'<project groups="notdefault,platform-' + self.platform + '" ' '<project groups="notdefault,platform-' + self.platform + '" '
'name="platform/art" path="art" revision="ABCDEF" upstream="refs/heads/main"/>' 'name="platform/art" path="art" revision="ABCDEF" upstream="refs/heads/main"/>'
'<superproject name="superproject"/>' '<superproject name="superproject"/>'
'</manifest>') "</manifest>",
)
def test_superproject_update_project_revision_id(self): def test_superproject_update_project_revision_id(self):
"""Test with LsTree being a mock.""" """Test with LsTree being a mock."""
self.assertEqual(len(self._superproject._manifest.projects), 1) self.assertEqual(len(self._superproject._manifest.projects), 1)
projects = self._superproject._manifest.projects projects = self._superproject._manifest.projects
data = ('160000 commit 2c2724cb36cd5a9cec6c852c681efc3b7c6b86ea\tart\x00' data = (
'160000 commit e9d25da64d8d365dbba7c8ee00fe8c4473fe9a06\tbootable/recovery\x00') "160000 commit 2c2724cb36cd5a9cec6c852c681efc3b7c6b86ea\tart\x00"
with mock.patch.object(self._superproject, '_Init', return_value=True): "160000 commit e9d25da64d8d365dbba7c8ee00fe8c4473fe9a06\tbootable/recovery\x00"
with mock.patch.object(self._superproject, '_Fetch', return_value=True): )
with mock.patch.object(self._superproject, with mock.patch.object(self._superproject, "_Init", return_value=True):
'_LsTree', with mock.patch.object(
return_value=data): self._superproject, "_Fetch", return_value=True
):
with mock.patch.object(
self._superproject, "_LsTree", return_value=data
):
# Create temporary directory so that it can write the file. # Create temporary directory so that it can write the file.
os.mkdir(self._superproject._superproject_path) os.mkdir(self._superproject._superproject_path)
update_result = self._superproject.UpdateProjectsRevisionId(projects, self.git_event_log) update_result = self._superproject.UpdateProjectsRevisionId(
projects, self.git_event_log
)
self.assertIsNotNone(update_result.manifest_path) self.assertIsNotNone(update_result.manifest_path)
self.assertFalse(update_result.fatal) self.assertFalse(update_result.fatal)
with open(update_result.manifest_path, 'r') as fp: with open(update_result.manifest_path, "r") as fp:
manifest_xml_data = fp.read() manifest_xml_data = fp.read()
self.assertEqual( self.assertEqual(
sort_attributes(manifest_xml_data), sort_attributes(manifest_xml_data),
'<?xml version="1.0" ?><manifest>' '<?xml version="1.0" ?><manifest>'
'<remote fetch="http://localhost" name="default-remote"/>' '<remote fetch="http://localhost" name="default-remote"/>'
'<default remote="default-remote" revision="refs/heads/main"/>' '<default remote="default-remote" revision="refs/heads/main"/>'
'<project groups="notdefault,platform-' + self.platform + '" ' '<project groups="notdefault,platform-'
+ self.platform
+ '" '
'name="platform/art" path="art" ' 'name="platform/art" path="art" '
'revision="2c2724cb36cd5a9cec6c852c681efc3b7c6b86ea" upstream="refs/heads/main"/>' 'revision="2c2724cb36cd5a9cec6c852c681efc3b7c6b86ea" upstream="refs/heads/main"/>'
'<superproject name="superproject"/>' '<superproject name="superproject"/>'
'</manifest>') "</manifest>",
)
def test_superproject_update_project_revision_id_no_superproject_tag(self): def test_superproject_update_project_revision_id_no_superproject_tag(self):
"""Test update of commit ids of a manifest without superproject tag.""" """Test update of commit ids of a manifest without superproject tag."""
manifest = self.getXmlManifest(""" manifest = self.getXmlManifest(
"""
<manifest> <manifest>
<remote name="default-remote" fetch="http://localhost" /> <remote name="default-remote" fetch="http://localhost" />
<default remote="default-remote" revision="refs/heads/main" /> <default remote="default-remote" revision="refs/heads/main" />
<project name="test-name"/> <project name="test-name"/>
</manifest> </manifest>
""") """
)
self.maxDiff = None self.maxDiff = None
self.assertIsNone(manifest.superproject) self.assertIsNone(manifest.superproject)
self.assertEqual( self.assertEqual(
@ -266,59 +316,81 @@ class SuperprojectTestCase(unittest.TestCase):
'<remote fetch="http://localhost" name="default-remote"/>' '<remote fetch="http://localhost" name="default-remote"/>'
'<default remote="default-remote" revision="refs/heads/main"/>' '<default remote="default-remote" revision="refs/heads/main"/>'
'<project name="test-name"/>' '<project name="test-name"/>'
'</manifest>') "</manifest>",
)
def test_superproject_update_project_revision_id_from_local_manifest_group(self): def test_superproject_update_project_revision_id_from_local_manifest_group(
self,
):
"""Test update of commit ids of a manifest that have local manifest no superproject group.""" """Test update of commit ids of a manifest that have local manifest no superproject group."""
local_group = manifest_xml.LOCAL_MANIFEST_GROUP_PREFIX + ':local' local_group = manifest_xml.LOCAL_MANIFEST_GROUP_PREFIX + ":local"
manifest = self.getXmlManifest(""" manifest = self.getXmlManifest(
"""
<manifest> <manifest>
<remote name="default-remote" fetch="http://localhost" /> <remote name="default-remote" fetch="http://localhost" />
<remote name="goog" fetch="http://localhost2" /> <remote name="goog" fetch="http://localhost2" />
<default remote="default-remote" revision="refs/heads/main" /> <default remote="default-remote" revision="refs/heads/main" />
<superproject name="superproject"/> <superproject name="superproject"/>
<project path="vendor/x" name="platform/vendor/x" remote="goog" <project path="vendor/x" name="platform/vendor/x" remote="goog"
groups=\"""" + local_group + """ groups=\""""
+ local_group
+ """
" revision="master-with-vendor" clone-depth="1" /> " revision="master-with-vendor" clone-depth="1" />
<project path="art" name="platform/art" groups="notdefault,platform-""" + self.platform + """ <project path="art" name="platform/art" groups="notdefault,platform-"""
+ self.platform
+ """
" /></manifest> " /></manifest>
""") """
)
self.maxDiff = None self.maxDiff = None
self._superproject = git_superproject.Superproject( self._superproject = git_superproject.Superproject(
manifest, name='superproject', manifest,
remote=manifest.remotes.get('default-remote').ToRemoteSpec('superproject'), name="superproject",
revision='refs/heads/main') remote=manifest.remotes.get("default-remote").ToRemoteSpec(
"superproject"
),
revision="refs/heads/main",
)
self.assertEqual(len(self._superproject._manifest.projects), 2) self.assertEqual(len(self._superproject._manifest.projects), 2)
projects = self._superproject._manifest.projects projects = self._superproject._manifest.projects
data = ('160000 commit 2c2724cb36cd5a9cec6c852c681efc3b7c6b86ea\tart\x00') data = "160000 commit 2c2724cb36cd5a9cec6c852c681efc3b7c6b86ea\tart\x00"
with mock.patch.object(self._superproject, '_Init', return_value=True): with mock.patch.object(self._superproject, "_Init", return_value=True):
with mock.patch.object(self._superproject, '_Fetch', return_value=True): with mock.patch.object(
with mock.patch.object(self._superproject, self._superproject, "_Fetch", return_value=True
'_LsTree', ):
return_value=data): with mock.patch.object(
self._superproject, "_LsTree", return_value=data
):
# Create temporary directory so that it can write the file. # Create temporary directory so that it can write the file.
os.mkdir(self._superproject._superproject_path) os.mkdir(self._superproject._superproject_path)
update_result = self._superproject.UpdateProjectsRevisionId(projects, self.git_event_log) update_result = self._superproject.UpdateProjectsRevisionId(
projects, self.git_event_log
)
self.assertIsNotNone(update_result.manifest_path) self.assertIsNotNone(update_result.manifest_path)
self.assertFalse(update_result.fatal) self.assertFalse(update_result.fatal)
with open(update_result.manifest_path, 'r') as fp: with open(update_result.manifest_path, "r") as fp:
manifest_xml_data = fp.read() manifest_xml_data = fp.read()
# Verify platform/vendor/x's project revision hasn't changed. # Verify platform/vendor/x's project revision hasn't
# changed.
self.assertEqual( self.assertEqual(
sort_attributes(manifest_xml_data), sort_attributes(manifest_xml_data),
'<?xml version="1.0" ?><manifest>' '<?xml version="1.0" ?><manifest>'
'<remote fetch="http://localhost" name="default-remote"/>' '<remote fetch="http://localhost" name="default-remote"/>'
'<remote fetch="http://localhost2" name="goog"/>' '<remote fetch="http://localhost2" name="goog"/>'
'<default remote="default-remote" revision="refs/heads/main"/>' '<default remote="default-remote" revision="refs/heads/main"/>'
'<project groups="notdefault,platform-' + self.platform + '" ' '<project groups="notdefault,platform-'
+ self.platform
+ '" '
'name="platform/art" path="art" ' 'name="platform/art" path="art" '
'revision="2c2724cb36cd5a9cec6c852c681efc3b7c6b86ea" upstream="refs/heads/main"/>' 'revision="2c2724cb36cd5a9cec6c852c681efc3b7c6b86ea" upstream="refs/heads/main"/>'
'<superproject name="superproject"/>' '<superproject name="superproject"/>'
'</manifest>') "</manifest>",
)
def test_superproject_update_project_revision_id_with_pinned_manifest(self): def test_superproject_update_project_revision_id_with_pinned_manifest(self):
"""Test update of commit ids of a pinned manifest.""" """Test update of commit ids of a pinned manifest."""
manifest = self.getXmlManifest(""" manifest = self.getXmlManifest(
"""
<manifest> <manifest>
<remote name="default-remote" fetch="http://localhost" /> <remote name="default-remote" fetch="http://localhost" />
<default remote="default-remote" revision="refs/heads/main" /> <default remote="default-remote" revision="refs/heads/main" />
@ -326,37 +398,53 @@ class SuperprojectTestCase(unittest.TestCase):
<project path="vendor/x" name="platform/vendor/x" revision="" /> <project path="vendor/x" name="platform/vendor/x" revision="" />
<project path="vendor/y" name="platform/vendor/y" <project path="vendor/y" name="platform/vendor/y"
revision="52d3c9f7c107839ece2319d077de0cd922aa9d8f" /> revision="52d3c9f7c107839ece2319d077de0cd922aa9d8f" />
<project path="art" name="platform/art" groups="notdefault,platform-""" + self.platform + """ <project path="art" name="platform/art" groups="notdefault,platform-"""
+ self.platform
+ """
" /></manifest> " /></manifest>
""") """
)
self.maxDiff = None self.maxDiff = None
self._superproject = git_superproject.Superproject( self._superproject = git_superproject.Superproject(
manifest, name='superproject', manifest,
remote=manifest.remotes.get('default-remote').ToRemoteSpec('superproject'), name="superproject",
revision='refs/heads/main') remote=manifest.remotes.get("default-remote").ToRemoteSpec(
"superproject"
),
revision="refs/heads/main",
)
self.assertEqual(len(self._superproject._manifest.projects), 3) self.assertEqual(len(self._superproject._manifest.projects), 3)
projects = self._superproject._manifest.projects projects = self._superproject._manifest.projects
data = ('160000 commit 2c2724cb36cd5a9cec6c852c681efc3b7c6b86ea\tart\x00' data = (
'160000 commit e9d25da64d8d365dbba7c8ee00fe8c4473fe9a06\tvendor/x\x00') "160000 commit 2c2724cb36cd5a9cec6c852c681efc3b7c6b86ea\tart\x00"
with mock.patch.object(self._superproject, '_Init', return_value=True): "160000 commit e9d25da64d8d365dbba7c8ee00fe8c4473fe9a06\tvendor/x\x00"
with mock.patch.object(self._superproject, '_Fetch', return_value=True): )
with mock.patch.object(self._superproject, with mock.patch.object(self._superproject, "_Init", return_value=True):
'_LsTree', with mock.patch.object(
return_value=data): self._superproject, "_Fetch", return_value=True
):
with mock.patch.object(
self._superproject, "_LsTree", return_value=data
):
# Create temporary directory so that it can write the file. # Create temporary directory so that it can write the file.
os.mkdir(self._superproject._superproject_path) os.mkdir(self._superproject._superproject_path)
update_result = self._superproject.UpdateProjectsRevisionId(projects, self.git_event_log) update_result = self._superproject.UpdateProjectsRevisionId(
projects, self.git_event_log
)
self.assertIsNotNone(update_result.manifest_path) self.assertIsNotNone(update_result.manifest_path)
self.assertFalse(update_result.fatal) self.assertFalse(update_result.fatal)
with open(update_result.manifest_path, 'r') as fp: with open(update_result.manifest_path, "r") as fp:
manifest_xml_data = fp.read() manifest_xml_data = fp.read()
# Verify platform/vendor/x's project revision hasn't changed. # Verify platform/vendor/x's project revision hasn't
# changed.
self.assertEqual( self.assertEqual(
sort_attributes(manifest_xml_data), sort_attributes(manifest_xml_data),
'<?xml version="1.0" ?><manifest>' '<?xml version="1.0" ?><manifest>'
'<remote fetch="http://localhost" name="default-remote"/>' '<remote fetch="http://localhost" name="default-remote"/>'
'<default remote="default-remote" revision="refs/heads/main"/>' '<default remote="default-remote" revision="refs/heads/main"/>'
'<project groups="notdefault,platform-' + self.platform + '" ' '<project groups="notdefault,platform-'
+ self.platform
+ '" '
'name="platform/art" path="art" ' 'name="platform/art" path="art" '
'revision="2c2724cb36cd5a9cec6c852c681efc3b7c6b86ea" upstream="refs/heads/main"/>' 'revision="2c2724cb36cd5a9cec6c852c681efc3b7c6b86ea" upstream="refs/heads/main"/>'
'<project name="platform/vendor/x" path="vendor/x" ' '<project name="platform/vendor/x" path="vendor/x" '
@ -364,42 +452,78 @@ class SuperprojectTestCase(unittest.TestCase):
'<project name="platform/vendor/y" path="vendor/y" ' '<project name="platform/vendor/y" path="vendor/y" '
'revision="52d3c9f7c107839ece2319d077de0cd922aa9d8f"/>' 'revision="52d3c9f7c107839ece2319d077de0cd922aa9d8f"/>'
'<superproject name="superproject"/>' '<superproject name="superproject"/>'
'</manifest>') "</manifest>",
)
def test_Fetch(self): def test_Fetch(self):
manifest = self.getXmlManifest(""" manifest = self.getXmlManifest(
"""
<manifest> <manifest>
<remote name="default-remote" fetch="http://localhost" /> <remote name="default-remote" fetch="http://localhost" />
<default remote="default-remote" revision="refs/heads/main" /> <default remote="default-remote" revision="refs/heads/main" />
<superproject name="superproject"/> <superproject name="superproject"/>
" /></manifest> " /></manifest>
""") """
)
self.maxDiff = None self.maxDiff = None
self._superproject = git_superproject.Superproject( self._superproject = git_superproject.Superproject(
manifest, name='superproject', manifest,
remote=manifest.remotes.get('default-remote').ToRemoteSpec('superproject'), name="superproject",
revision='refs/heads/main') remote=manifest.remotes.get("default-remote").ToRemoteSpec(
"superproject"
),
revision="refs/heads/main",
)
os.mkdir(self._superproject._superproject_path) os.mkdir(self._superproject._superproject_path)
os.mkdir(self._superproject._work_git) os.mkdir(self._superproject._work_git)
with mock.patch.object(self._superproject, '_Init', return_value=True): with mock.patch.object(self._superproject, "_Init", return_value=True):
with mock.patch('git_superproject.GitCommand', autospec=True) as mock_git_command: with mock.patch(
with mock.patch('git_superproject.GitRefs.get', autospec=True) as mock_git_refs: "git_superproject.GitCommand", autospec=True
) as mock_git_command:
with mock.patch(
"git_superproject.GitRefs.get", autospec=True
) as mock_git_refs:
instance = mock_git_command.return_value instance = mock_git_command.return_value
instance.Wait.return_value = 0 instance.Wait.return_value = 0
mock_git_refs.side_effect = ['', '1234'] mock_git_refs.side_effect = ["", "1234"]
self.assertTrue(self._superproject._Fetch()) self.assertTrue(self._superproject._Fetch())
self.assertEqual(mock_git_command.call_args.args,(None, [ self.assertEqual(
'fetch', 'http://localhost/superproject', '--depth', '1', mock_git_command.call_args.args,
'--force', '--no-tags', '--filter', 'blob:none', (
'refs/heads/main:refs/heads/main' None,
])) [
"fetch",
"http://localhost/superproject",
"--depth",
"1",
"--force",
"--no-tags",
"--filter",
"blob:none",
"refs/heads/main:refs/heads/main",
],
),
)
# If branch for revision exists, set as --negotiation-tip. # If branch for revision exists, set as --negotiation-tip.
self.assertTrue(self._superproject._Fetch()) self.assertTrue(self._superproject._Fetch())
self.assertEqual(mock_git_command.call_args.args,(None, [ self.assertEqual(
'fetch', 'http://localhost/superproject', '--depth', '1', mock_git_command.call_args.args,
'--force', '--no-tags', '--filter', 'blob:none', (
'--negotiation-tip', '1234', None,
'refs/heads/main:refs/heads/main' [
])) "fetch",
"http://localhost/superproject",
"--depth",
"1",
"--force",
"--no-tags",
"--filter",
"blob:none",
"--negotiation-tip",
"1234",
"refs/heads/main:refs/heads/main",
],
),
)

View File

@ -29,17 +29,18 @@ import platform_utils
def serverLoggingThread(socket_path, server_ready, received_traces): def serverLoggingThread(socket_path, server_ready, received_traces):
"""Helper function to receive logs over a Unix domain socket. """Helper function to receive logs over a Unix domain socket.
Appends received messages on the provided socket and appends to received_traces. Appends received messages on the provided socket and appends to
received_traces.
Args: Args:
socket_path: path to a Unix domain socket on which to listen for traces socket_path: path to a Unix domain socket on which to listen for traces
server_ready: a threading.Condition used to signal to the caller that this thread is ready to server_ready: a threading.Condition used to signal to the caller that
accept connections this thread is ready to accept connections
received_traces: a list to which received traces will be appended (after decoding to a utf-8 received_traces: a list to which received traces will be appended (after
string). decoding to a utf-8 string).
""" """
platform_utils.remove(socket_path, missing_ok=True) platform_utils.remove(socket_path, missing_ok=True)
data = b'' data = b""
with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as sock: with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as sock:
sock.bind(socket_path) sock.bind(socket_path)
sock.listen(0) sock.listen(0)
@ -51,16 +52,16 @@ def serverLoggingThread(socket_path, server_ready, received_traces):
if not recved: if not recved:
break break
data += recved data += recved
received_traces.extend(data.decode('utf-8').splitlines()) received_traces.extend(data.decode("utf-8").splitlines())
class EventLogTestCase(unittest.TestCase): class EventLogTestCase(unittest.TestCase):
"""TestCase for the EventLog module.""" """TestCase for the EventLog module."""
PARENT_SID_KEY = 'GIT_TRACE2_PARENT_SID' PARENT_SID_KEY = "GIT_TRACE2_PARENT_SID"
PARENT_SID_VALUE = 'parent_sid' PARENT_SID_VALUE = "parent_sid"
SELF_SID_REGEX = r'repo-\d+T\d+Z-.*' SELF_SID_REGEX = r"repo-\d+T\d+Z-.*"
FULL_SID_REGEX = r'^%s/%s' % (PARENT_SID_VALUE, SELF_SID_REGEX) FULL_SID_REGEX = r"^%s/%s" % (PARENT_SID_VALUE, SELF_SID_REGEX)
def setUp(self): def setUp(self):
"""Load the event_log module every time.""" """Load the event_log module every time."""
@ -73,34 +74,37 @@ class EventLogTestCase(unittest.TestCase):
self._event_log_module = git_trace2_event_log.EventLog(env=env) self._event_log_module = git_trace2_event_log.EventLog(env=env)
self._log_data = None self._log_data = None
def verifyCommonKeys(self, log_entry, expected_event_name=None, full_sid=True): def verifyCommonKeys(
self, log_entry, expected_event_name=None, full_sid=True
):
"""Helper function to verify common event log keys.""" """Helper function to verify common event log keys."""
self.assertIn('event', log_entry) self.assertIn("event", log_entry)
self.assertIn('sid', log_entry) self.assertIn("sid", log_entry)
self.assertIn('thread', log_entry) self.assertIn("thread", log_entry)
self.assertIn('time', log_entry) self.assertIn("time", log_entry)
# Do basic data format validation. # Do basic data format validation.
if expected_event_name: if expected_event_name:
self.assertEqual(expected_event_name, log_entry['event']) self.assertEqual(expected_event_name, log_entry["event"])
if full_sid: if full_sid:
self.assertRegex(log_entry['sid'], self.FULL_SID_REGEX) self.assertRegex(log_entry["sid"], self.FULL_SID_REGEX)
else: else:
self.assertRegex(log_entry['sid'], self.SELF_SID_REGEX) self.assertRegex(log_entry["sid"], self.SELF_SID_REGEX)
self.assertRegex(log_entry['time'], r'^\d+-\d+-\d+T\d+:\d+:\d+\.\d+Z$') self.assertRegex(log_entry["time"], r"^\d+-\d+-\d+T\d+:\d+:\d+\.\d+Z$")
def readLog(self, log_path): def readLog(self, log_path):
"""Helper function to read log data into a list.""" """Helper function to read log data into a list."""
log_data = [] log_data = []
with open(log_path, mode='rb') as f: with open(log_path, mode="rb") as f:
for line in f: for line in f:
log_data.append(json.loads(line)) log_data.append(json.loads(line))
return log_data return log_data
def remove_prefix(self, s, prefix): def remove_prefix(self, s, prefix):
"""Return a copy string after removing |prefix| from |s|, if present or the original string.""" """Return a copy string after removing |prefix| from |s|, if present or
the original string."""
if s.startswith(prefix): if s.startswith(prefix):
return s[len(prefix):] return s[len(prefix) :]
else: else:
return s return s
@ -123,19 +127,19 @@ class EventLogTestCase(unittest.TestCase):
Expected event log: Expected event log:
<version event> <version event>
""" """
with tempfile.TemporaryDirectory(prefix='event_log_tests') as tempdir: with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir:
log_path = self._event_log_module.Write(path=tempdir) log_path = self._event_log_module.Write(path=tempdir)
self._log_data = self.readLog(log_path) self._log_data = self.readLog(log_path)
# A log with no added events should only have the version entry. # A log with no added events should only have the version entry.
self.assertEqual(len(self._log_data), 1) self.assertEqual(len(self._log_data), 1)
version_event = self._log_data[0] version_event = self._log_data[0]
self.verifyCommonKeys(version_event, expected_event_name='version') self.verifyCommonKeys(version_event, expected_event_name="version")
# Check for 'version' event specific fields. # Check for 'version' event specific fields.
self.assertIn('evt', version_event) self.assertIn("evt", version_event)
self.assertIn('exe', version_event) self.assertIn("exe", version_event)
# Verify "evt" version field is a string. # Verify "evt" version field is a string.
self.assertIsInstance(version_event['evt'], str) self.assertIsInstance(version_event["evt"], str)
def test_start_event(self): def test_start_event(self):
"""Test and validate 'start' event data is valid. """Test and validate 'start' event data is valid.
@ -145,17 +149,17 @@ class EventLogTestCase(unittest.TestCase):
<start event> <start event>
""" """
self._event_log_module.StartEvent() self._event_log_module.StartEvent()
with tempfile.TemporaryDirectory(prefix='event_log_tests') as tempdir: with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir:
log_path = self._event_log_module.Write(path=tempdir) log_path = self._event_log_module.Write(path=tempdir)
self._log_data = self.readLog(log_path) self._log_data = self.readLog(log_path)
self.assertEqual(len(self._log_data), 2) self.assertEqual(len(self._log_data), 2)
start_event = self._log_data[1] start_event = self._log_data[1]
self.verifyCommonKeys(self._log_data[0], expected_event_name='version') self.verifyCommonKeys(self._log_data[0], expected_event_name="version")
self.verifyCommonKeys(start_event, expected_event_name='start') self.verifyCommonKeys(start_event, expected_event_name="start")
# Check for 'start' event specific fields. # Check for 'start' event specific fields.
self.assertIn('argv', start_event) self.assertIn("argv", start_event)
self.assertTrue(isinstance(start_event['argv'], list)) self.assertTrue(isinstance(start_event["argv"], list))
def test_exit_event_result_none(self): def test_exit_event_result_none(self):
"""Test 'exit' event data is valid when result is None. """Test 'exit' event data is valid when result is None.
@ -167,18 +171,18 @@ class EventLogTestCase(unittest.TestCase):
<exit event> <exit event>
""" """
self._event_log_module.ExitEvent(None) self._event_log_module.ExitEvent(None)
with tempfile.TemporaryDirectory(prefix='event_log_tests') as tempdir: with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir:
log_path = self._event_log_module.Write(path=tempdir) log_path = self._event_log_module.Write(path=tempdir)
self._log_data = self.readLog(log_path) self._log_data = self.readLog(log_path)
self.assertEqual(len(self._log_data), 2) self.assertEqual(len(self._log_data), 2)
exit_event = self._log_data[1] exit_event = self._log_data[1]
self.verifyCommonKeys(self._log_data[0], expected_event_name='version') self.verifyCommonKeys(self._log_data[0], expected_event_name="version")
self.verifyCommonKeys(exit_event, expected_event_name='exit') self.verifyCommonKeys(exit_event, expected_event_name="exit")
# Check for 'exit' event specific fields. # Check for 'exit' event specific fields.
self.assertIn('code', exit_event) self.assertIn("code", exit_event)
# 'None' result should convert to 0 (successful) return code. # 'None' result should convert to 0 (successful) return code.
self.assertEqual(exit_event['code'], 0) self.assertEqual(exit_event["code"], 0)
def test_exit_event_result_integer(self): def test_exit_event_result_integer(self):
"""Test 'exit' event data is valid when result is an integer. """Test 'exit' event data is valid when result is an integer.
@ -188,17 +192,17 @@ class EventLogTestCase(unittest.TestCase):
<exit event> <exit event>
""" """
self._event_log_module.ExitEvent(2) self._event_log_module.ExitEvent(2)
with tempfile.TemporaryDirectory(prefix='event_log_tests') as tempdir: with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir:
log_path = self._event_log_module.Write(path=tempdir) log_path = self._event_log_module.Write(path=tempdir)
self._log_data = self.readLog(log_path) self._log_data = self.readLog(log_path)
self.assertEqual(len(self._log_data), 2) self.assertEqual(len(self._log_data), 2)
exit_event = self._log_data[1] exit_event = self._log_data[1]
self.verifyCommonKeys(self._log_data[0], expected_event_name='version') self.verifyCommonKeys(self._log_data[0], expected_event_name="version")
self.verifyCommonKeys(exit_event, expected_event_name='exit') self.verifyCommonKeys(exit_event, expected_event_name="exit")
# Check for 'exit' event specific fields. # Check for 'exit' event specific fields.
self.assertIn('code', exit_event) self.assertIn("code", exit_event)
self.assertEqual(exit_event['code'], 2) self.assertEqual(exit_event["code"], 2)
def test_command_event(self): def test_command_event(self):
"""Test and validate 'command' event data is valid. """Test and validate 'command' event data is valid.
@ -207,22 +211,24 @@ class EventLogTestCase(unittest.TestCase):
<version event> <version event>
<command event> <command event>
""" """
name = 'repo' name = "repo"
subcommands = ['init' 'this'] subcommands = ["init" "this"]
self._event_log_module.CommandEvent(name='repo', subcommands=subcommands) self._event_log_module.CommandEvent(
with tempfile.TemporaryDirectory(prefix='event_log_tests') as tempdir: name="repo", subcommands=subcommands
)
with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir:
log_path = self._event_log_module.Write(path=tempdir) log_path = self._event_log_module.Write(path=tempdir)
self._log_data = self.readLog(log_path) self._log_data = self.readLog(log_path)
self.assertEqual(len(self._log_data), 2) self.assertEqual(len(self._log_data), 2)
command_event = self._log_data[1] command_event = self._log_data[1]
self.verifyCommonKeys(self._log_data[0], expected_event_name='version') self.verifyCommonKeys(self._log_data[0], expected_event_name="version")
self.verifyCommonKeys(command_event, expected_event_name='command') self.verifyCommonKeys(command_event, expected_event_name="command")
# Check for 'command' event specific fields. # Check for 'command' event specific fields.
self.assertIn('name', command_event) self.assertIn("name", command_event)
self.assertIn('subcommands', command_event) self.assertIn("subcommands", command_event)
self.assertEqual(command_event['name'], name) self.assertEqual(command_event["name"], name)
self.assertEqual(command_event['subcommands'], subcommands) self.assertEqual(command_event["subcommands"], subcommands)
def test_def_params_event_repo_config(self): def test_def_params_event_repo_config(self):
"""Test 'def_params' event data outputs only repo config keys. """Test 'def_params' event data outputs only repo config keys.
@ -233,26 +239,26 @@ class EventLogTestCase(unittest.TestCase):
<def_param event> <def_param event>
""" """
config = { config = {
'git.foo': 'bar', "git.foo": "bar",
'repo.partialclone': 'true', "repo.partialclone": "true",
'repo.partialclonefilter': 'blob:none', "repo.partialclonefilter": "blob:none",
} }
self._event_log_module.DefParamRepoEvents(config) self._event_log_module.DefParamRepoEvents(config)
with tempfile.TemporaryDirectory(prefix='event_log_tests') as tempdir: with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir:
log_path = self._event_log_module.Write(path=tempdir) log_path = self._event_log_module.Write(path=tempdir)
self._log_data = self.readLog(log_path) self._log_data = self.readLog(log_path)
self.assertEqual(len(self._log_data), 3) self.assertEqual(len(self._log_data), 3)
def_param_events = self._log_data[1:] def_param_events = self._log_data[1:]
self.verifyCommonKeys(self._log_data[0], expected_event_name='version') self.verifyCommonKeys(self._log_data[0], expected_event_name="version")
for event in def_param_events: for event in def_param_events:
self.verifyCommonKeys(event, expected_event_name='def_param') self.verifyCommonKeys(event, expected_event_name="def_param")
# Check for 'def_param' event specific fields. # Check for 'def_param' event specific fields.
self.assertIn('param', event) self.assertIn("param", event)
self.assertIn('value', event) self.assertIn("value", event)
self.assertTrue(event['param'].startswith('repo.')) self.assertTrue(event["param"].startswith("repo."))
def test_def_params_event_no_repo_config(self): def test_def_params_event_no_repo_config(self):
"""Test 'def_params' event data won't output non-repo config keys. """Test 'def_params' event data won't output non-repo config keys.
@ -261,17 +267,17 @@ class EventLogTestCase(unittest.TestCase):
<version event> <version event>
""" """
config = { config = {
'git.foo': 'bar', "git.foo": "bar",
'git.core.foo2': 'baz', "git.core.foo2": "baz",
} }
self._event_log_module.DefParamRepoEvents(config) self._event_log_module.DefParamRepoEvents(config)
with tempfile.TemporaryDirectory(prefix='event_log_tests') as tempdir: with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir:
log_path = self._event_log_module.Write(path=tempdir) log_path = self._event_log_module.Write(path=tempdir)
self._log_data = self.readLog(log_path) self._log_data = self.readLog(log_path)
self.assertEqual(len(self._log_data), 1) self.assertEqual(len(self._log_data), 1)
self.verifyCommonKeys(self._log_data[0], expected_event_name='version') self.verifyCommonKeys(self._log_data[0], expected_event_name="version")
def test_data_event_config(self): def test_data_event_config(self):
"""Test 'data' event data outputs all config keys. """Test 'data' event data outputs all config keys.
@ -282,31 +288,33 @@ class EventLogTestCase(unittest.TestCase):
<data event> <data event>
""" """
config = { config = {
'git.foo': 'bar', "git.foo": "bar",
'repo.partialclone': 'false', "repo.partialclone": "false",
'repo.syncstate.superproject.hassuperprojecttag': 'true', "repo.syncstate.superproject.hassuperprojecttag": "true",
'repo.syncstate.superproject.sys.argv': ['--', 'sync', 'protobuf'], "repo.syncstate.superproject.sys.argv": ["--", "sync", "protobuf"],
} }
prefix_value = 'prefix' prefix_value = "prefix"
self._event_log_module.LogDataConfigEvents(config, prefix_value) self._event_log_module.LogDataConfigEvents(config, prefix_value)
with tempfile.TemporaryDirectory(prefix='event_log_tests') as tempdir: with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir:
log_path = self._event_log_module.Write(path=tempdir) log_path = self._event_log_module.Write(path=tempdir)
self._log_data = self.readLog(log_path) self._log_data = self.readLog(log_path)
self.assertEqual(len(self._log_data), 5) self.assertEqual(len(self._log_data), 5)
data_events = self._log_data[1:] data_events = self._log_data[1:]
self.verifyCommonKeys(self._log_data[0], expected_event_name='version') self.verifyCommonKeys(self._log_data[0], expected_event_name="version")
for event in data_events: for event in data_events:
self.verifyCommonKeys(event) self.verifyCommonKeys(event)
# Check for 'data' event specific fields. # Check for 'data' event specific fields.
self.assertIn('key', event) self.assertIn("key", event)
self.assertIn('value', event) self.assertIn("value", event)
key = event['key'] key = event["key"]
key = self.remove_prefix(key, f'{prefix_value}/') key = self.remove_prefix(key, f"{prefix_value}/")
value = event['value'] value = event["value"]
self.assertEqual(self._event_log_module.GetDataEventName(value), event['event']) self.assertEqual(
self._event_log_module.GetDataEventName(value), event["event"]
)
self.assertTrue(key in config and value == config[key]) self.assertTrue(key in config and value == config[key])
def test_error_event(self): def test_error_event(self):
@ -316,38 +324,45 @@ class EventLogTestCase(unittest.TestCase):
<version event> <version event>
<error event> <error event>
""" """
msg = 'invalid option: --cahced' msg = "invalid option: --cahced"
fmt = 'invalid option: %s' fmt = "invalid option: %s"
self._event_log_module.ErrorEvent(msg, fmt) self._event_log_module.ErrorEvent(msg, fmt)
with tempfile.TemporaryDirectory(prefix='event_log_tests') as tempdir: with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir:
log_path = self._event_log_module.Write(path=tempdir) log_path = self._event_log_module.Write(path=tempdir)
self._log_data = self.readLog(log_path) self._log_data = self.readLog(log_path)
self.assertEqual(len(self._log_data), 2) self.assertEqual(len(self._log_data), 2)
error_event = self._log_data[1] error_event = self._log_data[1]
self.verifyCommonKeys(self._log_data[0], expected_event_name='version') self.verifyCommonKeys(self._log_data[0], expected_event_name="version")
self.verifyCommonKeys(error_event, expected_event_name='error') self.verifyCommonKeys(error_event, expected_event_name="error")
# Check for 'error' event specific fields. # Check for 'error' event specific fields.
self.assertIn('msg', error_event) self.assertIn("msg", error_event)
self.assertIn('fmt', error_event) self.assertIn("fmt", error_event)
self.assertEqual(error_event['msg'], msg) self.assertEqual(error_event["msg"], msg)
self.assertEqual(error_event['fmt'], fmt) self.assertEqual(error_event["fmt"], fmt)
def test_write_with_filename(self): def test_write_with_filename(self):
"""Test Write() with a path to a file exits with None.""" """Test Write() with a path to a file exits with None."""
self.assertIsNone(self._event_log_module.Write(path='path/to/file')) self.assertIsNone(self._event_log_module.Write(path="path/to/file"))
def test_write_with_git_config(self): def test_write_with_git_config(self):
"""Test Write() uses the git config path when 'git config' call succeeds.""" """Test Write() uses the git config path when 'git config' call
with tempfile.TemporaryDirectory(prefix='event_log_tests') as tempdir: succeeds."""
with mock.patch.object(self._event_log_module, with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir:
'_GetEventTargetPath', return_value=tempdir): with mock.patch.object(
self.assertEqual(os.path.dirname(self._event_log_module.Write()), tempdir) self._event_log_module,
"_GetEventTargetPath",
return_value=tempdir,
):
self.assertEqual(
os.path.dirname(self._event_log_module.Write()), tempdir
)
def test_write_no_git_config(self): def test_write_no_git_config(self):
"""Test Write() with no git config variable present exits with None.""" """Test Write() with no git config variable present exits with None."""
with mock.patch.object(self._event_log_module, with mock.patch.object(
'_GetEventTargetPath', return_value=None): self._event_log_module, "_GetEventTargetPath", return_value=None
):
self.assertIsNone(self._event_log_module.Write()) self.assertIsNone(self._event_log_module.Write())
def test_write_non_string(self): def test_write_non_string(self):
@ -356,32 +371,38 @@ class EventLogTestCase(unittest.TestCase):
self._event_log_module.Write(path=1234) self._event_log_module.Write(path=1234)
def test_write_socket(self): def test_write_socket(self):
"""Test Write() with Unix domain socket for |path| and validate received traces.""" """Test Write() with Unix domain socket for |path| and validate received
traces."""
received_traces = [] received_traces = []
with tempfile.TemporaryDirectory(prefix='test_server_sockets') as tempdir: with tempfile.TemporaryDirectory(
prefix="test_server_sockets"
) as tempdir:
socket_path = os.path.join(tempdir, "server.sock") socket_path = os.path.join(tempdir, "server.sock")
server_ready = threading.Condition() server_ready = threading.Condition()
# Start "server" listening on Unix domain socket at socket_path. # Start "server" listening on Unix domain socket at socket_path.
try: try:
server_thread = threading.Thread( server_thread = threading.Thread(
target=serverLoggingThread, target=serverLoggingThread,
args=(socket_path, server_ready, received_traces)) args=(socket_path, server_ready, received_traces),
)
server_thread.start() server_thread.start()
with server_ready: with server_ready:
server_ready.wait(timeout=120) server_ready.wait(timeout=120)
self._event_log_module.StartEvent() self._event_log_module.StartEvent()
path = self._event_log_module.Write(path=f'af_unix:{socket_path}') path = self._event_log_module.Write(
path=f"af_unix:{socket_path}"
)
finally: finally:
server_thread.join(timeout=5) server_thread.join(timeout=5)
self.assertEqual(path, f'af_unix:stream:{socket_path}') self.assertEqual(path, f"af_unix:stream:{socket_path}")
self.assertEqual(len(received_traces), 2) self.assertEqual(len(received_traces), 2)
version_event = json.loads(received_traces[0]) version_event = json.loads(received_traces[0])
start_event = json.loads(received_traces[1]) start_event = json.loads(received_traces[1])
self.verifyCommonKeys(version_event, expected_event_name='version') self.verifyCommonKeys(version_event, expected_event_name="version")
self.verifyCommonKeys(start_event, expected_event_name='start') self.verifyCommonKeys(start_event, expected_event_name="start")
# Check for 'start' event specific fields. # Check for 'start' event specific fields.
self.assertIn('argv', start_event) self.assertIn("argv", start_event)
self.assertIsInstance(start_event['argv'], list) self.assertIsInstance(start_event["argv"], list)

View File

@ -17,39 +17,38 @@
import hooks import hooks
import unittest import unittest
class RepoHookShebang(unittest.TestCase): class RepoHookShebang(unittest.TestCase):
"""Check shebang parsing in RepoHook.""" """Check shebang parsing in RepoHook."""
def test_no_shebang(self): def test_no_shebang(self):
"""Lines w/out shebangs should be rejected.""" """Lines w/out shebangs should be rejected."""
DATA = ( DATA = ("", "#\n# foo\n", "# Bad shebang in script\n#!/foo\n")
'',
'#\n# foo\n',
'# Bad shebang in script\n#!/foo\n'
)
for data in DATA: for data in DATA:
self.assertIsNone(hooks.RepoHook._ExtractInterpFromShebang(data)) self.assertIsNone(hooks.RepoHook._ExtractInterpFromShebang(data))
def test_direct_interp(self): def test_direct_interp(self):
"""Lines whose shebang points directly to the interpreter.""" """Lines whose shebang points directly to the interpreter."""
DATA = ( DATA = (
('#!/foo', '/foo'), ("#!/foo", "/foo"),
('#! /foo', '/foo'), ("#! /foo", "/foo"),
('#!/bin/foo ', '/bin/foo'), ("#!/bin/foo ", "/bin/foo"),
('#! /usr/foo ', '/usr/foo'), ("#! /usr/foo ", "/usr/foo"),
('#! /usr/foo -args', '/usr/foo'), ("#! /usr/foo -args", "/usr/foo"),
) )
for shebang, interp in DATA: for shebang, interp in DATA:
self.assertEqual(hooks.RepoHook._ExtractInterpFromShebang(shebang), self.assertEqual(
interp) hooks.RepoHook._ExtractInterpFromShebang(shebang), interp
)
def test_env_interp(self): def test_env_interp(self):
"""Lines whose shebang launches through `env`.""" """Lines whose shebang launches through `env`."""
DATA = ( DATA = (
('#!/usr/bin/env foo', 'foo'), ("#!/usr/bin/env foo", "foo"),
('#!/bin/env foo', 'foo'), ("#!/bin/env foo", "foo"),
('#! /bin/env /bin/foo ', '/bin/foo'), ("#! /bin/env /bin/foo ", "/bin/foo"),
) )
for shebang, interp in DATA: for shebang, interp in DATA:
self.assertEqual(hooks.RepoHook._ExtractInterpFromShebang(shebang), self.assertEqual(
interp) hooks.RepoHook._ExtractInterpFromShebang(shebang), interp
)

File diff suppressed because it is too large Load Diff

View File

@ -27,24 +27,26 @@ class RemoveTests(unittest.TestCase):
def testMissingOk(self): def testMissingOk(self):
"""Check missing_ok handling.""" """Check missing_ok handling."""
with tempfile.TemporaryDirectory() as tmpdir: with tempfile.TemporaryDirectory() as tmpdir:
path = os.path.join(tmpdir, 'test') path = os.path.join(tmpdir, "test")
# Should not fail. # Should not fail.
platform_utils.remove(path, missing_ok=True) platform_utils.remove(path, missing_ok=True)
# Should fail. # Should fail.
self.assertRaises(OSError, platform_utils.remove, path) self.assertRaises(OSError, platform_utils.remove, path)
self.assertRaises(OSError, platform_utils.remove, path, missing_ok=False) self.assertRaises(
OSError, platform_utils.remove, path, missing_ok=False
)
# Should not fail if it exists. # Should not fail if it exists.
open(path, 'w').close() open(path, "w").close()
platform_utils.remove(path, missing_ok=True) platform_utils.remove(path, missing_ok=True)
self.assertFalse(os.path.exists(path)) self.assertFalse(os.path.exists(path))
open(path, 'w').close() open(path, "w").close()
platform_utils.remove(path) platform_utils.remove(path)
self.assertFalse(os.path.exists(path)) self.assertFalse(os.path.exists(path))
open(path, 'w').close() open(path, "w").close()
platform_utils.remove(path, missing_ok=False) platform_utils.remove(path, missing_ok=False)
self.assertFalse(os.path.exists(path)) self.assertFalse(os.path.exists(path))

View File

@ -32,18 +32,18 @@ import project
@contextlib.contextmanager @contextlib.contextmanager
def TempGitTree(): def TempGitTree():
"""Create a new empty git checkout for testing.""" """Create a new empty git checkout for testing."""
with tempfile.TemporaryDirectory(prefix='repo-tests') as tempdir: with tempfile.TemporaryDirectory(prefix="repo-tests") as tempdir:
# Tests need to assume, that main is default branch at init, # Tests need to assume, that main is default branch at init,
# which is not supported in config until 2.28. # which is not supported in config until 2.28.
cmd = ['git', 'init'] cmd = ["git", "init"]
if git_command.git_require((2, 28, 0)): if git_command.git_require((2, 28, 0)):
cmd += ['--initial-branch=main'] cmd += ["--initial-branch=main"]
else: else:
# Use template dir for init. # Use template dir for init.
templatedir = tempfile.mkdtemp(prefix='.test-template') templatedir = tempfile.mkdtemp(prefix=".test-template")
with open(os.path.join(templatedir, 'HEAD'), 'w') as fp: with open(os.path.join(templatedir, "HEAD"), "w") as fp:
fp.write('ref: refs/heads/main\n') fp.write("ref: refs/heads/main\n")
cmd += ['--template', templatedir] cmd += ["--template", templatedir]
subprocess.check_call(cmd, cwd=tempdir) subprocess.check_call(cmd, cwd=tempdir)
yield tempdir yield tempdir
@ -53,12 +53,14 @@ class FakeProject(object):
def __init__(self, worktree): def __init__(self, worktree):
self.worktree = worktree self.worktree = worktree
self.gitdir = os.path.join(worktree, '.git') self.gitdir = os.path.join(worktree, ".git")
self.name = 'fakeproject' self.name = "fakeproject"
self.work_git = project.Project._GitGetByExec( self.work_git = project.Project._GitGetByExec(
self, bare=False, gitdir=self.gitdir) self, bare=False, gitdir=self.gitdir
)
self.bare_git = project.Project._GitGetByExec( self.bare_git = project.Project._GitGetByExec(
self, bare=True, gitdir=self.gitdir) self, bare=True, gitdir=self.gitdir
)
self.config = git_config.GitConfig.ForRepository(gitdir=self.gitdir) self.config = git_config.GitConfig.ForRepository(gitdir=self.gitdir)
@ -71,20 +73,21 @@ class ReviewableBranchTests(unittest.TestCase):
fakeproj = FakeProject(tempdir) fakeproj = FakeProject(tempdir)
# Generate some commits. # Generate some commits.
with open(os.path.join(tempdir, 'readme'), 'w') as fp: with open(os.path.join(tempdir, "readme"), "w") as fp:
fp.write('txt') fp.write("txt")
fakeproj.work_git.add('readme') fakeproj.work_git.add("readme")
fakeproj.work_git.commit('-mAdd file') fakeproj.work_git.commit("-mAdd file")
fakeproj.work_git.checkout('-b', 'work') fakeproj.work_git.checkout("-b", "work")
fakeproj.work_git.rm('-f', 'readme') fakeproj.work_git.rm("-f", "readme")
fakeproj.work_git.commit('-mDel file') fakeproj.work_git.commit("-mDel file")
# Start off with the normal details. # Start off with the normal details.
rb = project.ReviewableBranch( rb = project.ReviewableBranch(
fakeproj, fakeproj.config.GetBranch('work'), 'main') fakeproj, fakeproj.config.GetBranch("work"), "main"
self.assertEqual('work', rb.name) )
self.assertEqual("work", rb.name)
self.assertEqual(1, len(rb.commits)) self.assertEqual(1, len(rb.commits))
self.assertIn('Del file', rb.commits[0]) self.assertIn("Del file", rb.commits[0])
d = rb.unabbrev_commits d = rb.unabbrev_commits
self.assertEqual(1, len(d)) self.assertEqual(1, len(d))
short, long = next(iter(d.items())) short, long = next(iter(d.items()))
@ -94,9 +97,10 @@ class ReviewableBranchTests(unittest.TestCase):
self.assertTrue(rb.date) self.assertTrue(rb.date)
# Now delete the tracking branch! # Now delete the tracking branch!
fakeproj.work_git.branch('-D', 'main') fakeproj.work_git.branch("-D", "main")
rb = project.ReviewableBranch( rb = project.ReviewableBranch(
fakeproj, fakeproj.config.GetBranch('work'), 'main') fakeproj, fakeproj.config.GetBranch("work"), "main"
)
self.assertEqual(0, len(rb.commits)) self.assertEqual(0, len(rb.commits))
self.assertFalse(rb.base_exists) self.assertFalse(rb.base_exists)
# Hard to assert anything useful about this. # Hard to assert anything useful about this.
@ -118,10 +122,10 @@ class CopyLinkTestCase(unittest.TestCase):
""" """
def setUp(self): def setUp(self):
self.tempdirobj = tempfile.TemporaryDirectory(prefix='repo_tests') self.tempdirobj = tempfile.TemporaryDirectory(prefix="repo_tests")
self.tempdir = self.tempdirobj.name self.tempdir = self.tempdirobj.name
self.topdir = os.path.join(self.tempdir, 'checkout') self.topdir = os.path.join(self.tempdir, "checkout")
self.worktree = os.path.join(self.topdir, 'git-project') self.worktree = os.path.join(self.topdir, "git-project")
os.makedirs(self.topdir) os.makedirs(self.topdir)
os.makedirs(self.worktree) os.makedirs(self.worktree)
@ -130,7 +134,7 @@ class CopyLinkTestCase(unittest.TestCase):
@staticmethod @staticmethod
def touch(path): def touch(path):
with open(path, 'w'): with open(path, "w"):
pass pass
def assertExists(self, path, msg=None): def assertExists(self, path, msg=None):
@ -139,18 +143,19 @@ class CopyLinkTestCase(unittest.TestCase):
return return
if msg is None: if msg is None:
msg = ['path is missing: %s' % path] msg = ["path is missing: %s" % path]
while path != '/': while path != "/":
path = os.path.dirname(path) path = os.path.dirname(path)
if not path: if not path:
# If we're given something like "foo", abort once we get to "". # If we're given something like "foo", abort once we get to
# "".
break break
result = os.path.exists(path) result = os.path.exists(path)
msg.append('\tos.path.exists(%s): %s' % (path, result)) msg.append("\tos.path.exists(%s): %s" % (path, result))
if result: if result:
msg.append('\tcontents: %r' % os.listdir(path)) msg.append("\tcontents: %r" % os.listdir(path))
break break
msg = '\n'.join(msg) msg = "\n".join(msg)
raise self.failureException(msg) raise self.failureException(msg)
@ -163,98 +168,99 @@ class CopyFile(CopyLinkTestCase):
def test_basic(self): def test_basic(self):
"""Basic test of copying a file from a project to the toplevel.""" """Basic test of copying a file from a project to the toplevel."""
src = os.path.join(self.worktree, 'foo.txt') src = os.path.join(self.worktree, "foo.txt")
self.touch(src) self.touch(src)
cf = self.CopyFile('foo.txt', 'foo') cf = self.CopyFile("foo.txt", "foo")
cf._Copy() cf._Copy()
self.assertExists(os.path.join(self.topdir, 'foo')) self.assertExists(os.path.join(self.topdir, "foo"))
def test_src_subdir(self): def test_src_subdir(self):
"""Copy a file from a subdir of a project.""" """Copy a file from a subdir of a project."""
src = os.path.join(self.worktree, 'bar', 'foo.txt') src = os.path.join(self.worktree, "bar", "foo.txt")
os.makedirs(os.path.dirname(src)) os.makedirs(os.path.dirname(src))
self.touch(src) self.touch(src)
cf = self.CopyFile('bar/foo.txt', 'new.txt') cf = self.CopyFile("bar/foo.txt", "new.txt")
cf._Copy() cf._Copy()
self.assertExists(os.path.join(self.topdir, 'new.txt')) self.assertExists(os.path.join(self.topdir, "new.txt"))
def test_dest_subdir(self): def test_dest_subdir(self):
"""Copy a file to a subdir of a checkout.""" """Copy a file to a subdir of a checkout."""
src = os.path.join(self.worktree, 'foo.txt') src = os.path.join(self.worktree, "foo.txt")
self.touch(src) self.touch(src)
cf = self.CopyFile('foo.txt', 'sub/dir/new.txt') cf = self.CopyFile("foo.txt", "sub/dir/new.txt")
self.assertFalse(os.path.exists(os.path.join(self.topdir, 'sub'))) self.assertFalse(os.path.exists(os.path.join(self.topdir, "sub")))
cf._Copy() cf._Copy()
self.assertExists(os.path.join(self.topdir, 'sub', 'dir', 'new.txt')) self.assertExists(os.path.join(self.topdir, "sub", "dir", "new.txt"))
def test_update(self): def test_update(self):
"""Make sure changed files get copied again.""" """Make sure changed files get copied again."""
src = os.path.join(self.worktree, 'foo.txt') src = os.path.join(self.worktree, "foo.txt")
dest = os.path.join(self.topdir, 'bar') dest = os.path.join(self.topdir, "bar")
with open(src, 'w') as f: with open(src, "w") as f:
f.write('1st') f.write("1st")
cf = self.CopyFile('foo.txt', 'bar') cf = self.CopyFile("foo.txt", "bar")
cf._Copy() cf._Copy()
self.assertExists(dest) self.assertExists(dest)
with open(dest) as f: with open(dest) as f:
self.assertEqual(f.read(), '1st') self.assertEqual(f.read(), "1st")
with open(src, 'w') as f: with open(src, "w") as f:
f.write('2nd!') f.write("2nd!")
cf._Copy() cf._Copy()
with open(dest) as f: with open(dest) as f:
self.assertEqual(f.read(), '2nd!') self.assertEqual(f.read(), "2nd!")
def test_src_block_symlink(self): def test_src_block_symlink(self):
"""Do not allow reading from a symlinked path.""" """Do not allow reading from a symlinked path."""
src = os.path.join(self.worktree, 'foo.txt') src = os.path.join(self.worktree, "foo.txt")
sym = os.path.join(self.worktree, 'sym') sym = os.path.join(self.worktree, "sym")
self.touch(src) self.touch(src)
platform_utils.symlink('foo.txt', sym) platform_utils.symlink("foo.txt", sym)
self.assertExists(sym) self.assertExists(sym)
cf = self.CopyFile('sym', 'foo') cf = self.CopyFile("sym", "foo")
self.assertRaises(error.ManifestInvalidPathError, cf._Copy) self.assertRaises(error.ManifestInvalidPathError, cf._Copy)
def test_src_block_symlink_traversal(self): def test_src_block_symlink_traversal(self):
"""Do not allow reading through a symlink dir.""" """Do not allow reading through a symlink dir."""
realfile = os.path.join(self.tempdir, 'file.txt') realfile = os.path.join(self.tempdir, "file.txt")
self.touch(realfile) self.touch(realfile)
src = os.path.join(self.worktree, 'bar', 'file.txt') src = os.path.join(self.worktree, "bar", "file.txt")
platform_utils.symlink(self.tempdir, os.path.join(self.worktree, 'bar')) platform_utils.symlink(self.tempdir, os.path.join(self.worktree, "bar"))
self.assertExists(src) self.assertExists(src)
cf = self.CopyFile('bar/file.txt', 'foo') cf = self.CopyFile("bar/file.txt", "foo")
self.assertRaises(error.ManifestInvalidPathError, cf._Copy) self.assertRaises(error.ManifestInvalidPathError, cf._Copy)
def test_src_block_copy_from_dir(self): def test_src_block_copy_from_dir(self):
"""Do not allow copying from a directory.""" """Do not allow copying from a directory."""
src = os.path.join(self.worktree, 'dir') src = os.path.join(self.worktree, "dir")
os.makedirs(src) os.makedirs(src)
cf = self.CopyFile('dir', 'foo') cf = self.CopyFile("dir", "foo")
self.assertRaises(error.ManifestInvalidPathError, cf._Copy) self.assertRaises(error.ManifestInvalidPathError, cf._Copy)
def test_dest_block_symlink(self): def test_dest_block_symlink(self):
"""Do not allow writing to a symlink.""" """Do not allow writing to a symlink."""
src = os.path.join(self.worktree, 'foo.txt') src = os.path.join(self.worktree, "foo.txt")
self.touch(src) self.touch(src)
platform_utils.symlink('dest', os.path.join(self.topdir, 'sym')) platform_utils.symlink("dest", os.path.join(self.topdir, "sym"))
cf = self.CopyFile('foo.txt', 'sym') cf = self.CopyFile("foo.txt", "sym")
self.assertRaises(error.ManifestInvalidPathError, cf._Copy) self.assertRaises(error.ManifestInvalidPathError, cf._Copy)
def test_dest_block_symlink_traversal(self): def test_dest_block_symlink_traversal(self):
"""Do not allow writing through a symlink dir.""" """Do not allow writing through a symlink dir."""
src = os.path.join(self.worktree, 'foo.txt') src = os.path.join(self.worktree, "foo.txt")
self.touch(src) self.touch(src)
platform_utils.symlink(tempfile.gettempdir(), platform_utils.symlink(
os.path.join(self.topdir, 'sym')) tempfile.gettempdir(), os.path.join(self.topdir, "sym")
cf = self.CopyFile('foo.txt', 'sym/foo.txt') )
cf = self.CopyFile("foo.txt", "sym/foo.txt")
self.assertRaises(error.ManifestInvalidPathError, cf._Copy) self.assertRaises(error.ManifestInvalidPathError, cf._Copy)
def test_src_block_copy_to_dir(self): def test_src_block_copy_to_dir(self):
"""Do not allow copying to a directory.""" """Do not allow copying to a directory."""
src = os.path.join(self.worktree, 'foo.txt') src = os.path.join(self.worktree, "foo.txt")
self.touch(src) self.touch(src)
os.makedirs(os.path.join(self.topdir, 'dir')) os.makedirs(os.path.join(self.topdir, "dir"))
cf = self.CopyFile('foo.txt', 'dir') cf = self.CopyFile("foo.txt", "dir")
self.assertRaises(error.ManifestInvalidPathError, cf._Copy) self.assertRaises(error.ManifestInvalidPathError, cf._Copy)
@ -266,86 +272,106 @@ class LinkFile(CopyLinkTestCase):
def test_basic(self): def test_basic(self):
"""Basic test of linking a file from a project into the toplevel.""" """Basic test of linking a file from a project into the toplevel."""
src = os.path.join(self.worktree, 'foo.txt') src = os.path.join(self.worktree, "foo.txt")
self.touch(src) self.touch(src)
lf = self.LinkFile('foo.txt', 'foo') lf = self.LinkFile("foo.txt", "foo")
lf._Link() lf._Link()
dest = os.path.join(self.topdir, 'foo') dest = os.path.join(self.topdir, "foo")
self.assertExists(dest) self.assertExists(dest)
self.assertTrue(os.path.islink(dest)) self.assertTrue(os.path.islink(dest))
self.assertEqual(os.path.join('git-project', 'foo.txt'), os.readlink(dest)) self.assertEqual(
os.path.join("git-project", "foo.txt"), os.readlink(dest)
)
def test_src_subdir(self): def test_src_subdir(self):
"""Link to a file in a subdir of a project.""" """Link to a file in a subdir of a project."""
src = os.path.join(self.worktree, 'bar', 'foo.txt') src = os.path.join(self.worktree, "bar", "foo.txt")
os.makedirs(os.path.dirname(src)) os.makedirs(os.path.dirname(src))
self.touch(src) self.touch(src)
lf = self.LinkFile('bar/foo.txt', 'foo') lf = self.LinkFile("bar/foo.txt", "foo")
lf._Link() lf._Link()
self.assertExists(os.path.join(self.topdir, 'foo')) self.assertExists(os.path.join(self.topdir, "foo"))
def test_src_self(self): def test_src_self(self):
"""Link to the project itself.""" """Link to the project itself."""
dest = os.path.join(self.topdir, 'foo', 'bar') dest = os.path.join(self.topdir, "foo", "bar")
lf = self.LinkFile('.', 'foo/bar') lf = self.LinkFile(".", "foo/bar")
lf._Link() lf._Link()
self.assertExists(dest) self.assertExists(dest)
self.assertEqual(os.path.join('..', 'git-project'), os.readlink(dest)) self.assertEqual(os.path.join("..", "git-project"), os.readlink(dest))
def test_dest_subdir(self): def test_dest_subdir(self):
"""Link a file to a subdir of a checkout.""" """Link a file to a subdir of a checkout."""
src = os.path.join(self.worktree, 'foo.txt') src = os.path.join(self.worktree, "foo.txt")
self.touch(src) self.touch(src)
lf = self.LinkFile('foo.txt', 'sub/dir/foo/bar') lf = self.LinkFile("foo.txt", "sub/dir/foo/bar")
self.assertFalse(os.path.exists(os.path.join(self.topdir, 'sub'))) self.assertFalse(os.path.exists(os.path.join(self.topdir, "sub")))
lf._Link() lf._Link()
self.assertExists(os.path.join(self.topdir, 'sub', 'dir', 'foo', 'bar')) self.assertExists(os.path.join(self.topdir, "sub", "dir", "foo", "bar"))
def test_src_block_relative(self): def test_src_block_relative(self):
"""Do not allow relative symlinks.""" """Do not allow relative symlinks."""
BAD_SOURCES = ( BAD_SOURCES = (
'./', "./",
'..', "..",
'../', "../",
'foo/.', "foo/.",
'foo/./bar', "foo/./bar",
'foo/..', "foo/..",
'foo/../foo', "foo/../foo",
) )
for src in BAD_SOURCES: for src in BAD_SOURCES:
lf = self.LinkFile(src, 'foo') lf = self.LinkFile(src, "foo")
self.assertRaises(error.ManifestInvalidPathError, lf._Link) self.assertRaises(error.ManifestInvalidPathError, lf._Link)
def test_update(self): def test_update(self):
"""Make sure changed targets get updated.""" """Make sure changed targets get updated."""
dest = os.path.join(self.topdir, 'sym') dest = os.path.join(self.topdir, "sym")
src = os.path.join(self.worktree, 'foo.txt') src = os.path.join(self.worktree, "foo.txt")
self.touch(src) self.touch(src)
lf = self.LinkFile('foo.txt', 'sym') lf = self.LinkFile("foo.txt", "sym")
lf._Link() lf._Link()
self.assertEqual(os.path.join('git-project', 'foo.txt'), os.readlink(dest)) self.assertEqual(
os.path.join("git-project", "foo.txt"), os.readlink(dest)
)
# Point the symlink somewhere else. # Point the symlink somewhere else.
os.unlink(dest) os.unlink(dest)
platform_utils.symlink(self.tempdir, dest) platform_utils.symlink(self.tempdir, dest)
lf._Link() lf._Link()
self.assertEqual(os.path.join('git-project', 'foo.txt'), os.readlink(dest)) self.assertEqual(
os.path.join("git-project", "foo.txt"), os.readlink(dest)
)
class MigrateWorkTreeTests(unittest.TestCase): class MigrateWorkTreeTests(unittest.TestCase):
"""Check _MigrateOldWorkTreeGitDir handling.""" """Check _MigrateOldWorkTreeGitDir handling."""
_SYMLINKS = { _SYMLINKS = {
'config', 'description', 'hooks', 'info', 'logs', 'objects', "config",
'packed-refs', 'refs', 'rr-cache', 'shallow', 'svn', "description",
"hooks",
"info",
"logs",
"objects",
"packed-refs",
"refs",
"rr-cache",
"shallow",
"svn",
} }
_FILES = { _FILES = {
'COMMIT_EDITMSG', 'FETCH_HEAD', 'HEAD', 'index', 'ORIG_HEAD', "COMMIT_EDITMSG",
'unknown-file-should-be-migrated', "FETCH_HEAD",
"HEAD",
"index",
"ORIG_HEAD",
"unknown-file-should-be-migrated",
} }
_CLEAN_FILES = { _CLEAN_FILES = {
'a-vim-temp-file~', '#an-emacs-temp-file#', "a-vim-temp-file~",
"#an-emacs-temp-file#",
} }
@classmethod @classmethod
@ -355,15 +381,17 @@ class MigrateWorkTreeTests(unittest.TestCase):
with tempfile.TemporaryDirectory() as tempdir: with tempfile.TemporaryDirectory() as tempdir:
tempdir = Path(tempdir) tempdir = Path(tempdir)
gitdir = tempdir / '.repo/projects/src/test.git' gitdir = tempdir / ".repo/projects/src/test.git"
gitdir.mkdir(parents=True) gitdir.mkdir(parents=True)
cmd = ['git', 'init', '--bare', str(gitdir)] cmd = ["git", "init", "--bare", str(gitdir)]
subprocess.check_call(cmd) subprocess.check_call(cmd)
dotgit = tempdir / 'src/test/.git' dotgit = tempdir / "src/test/.git"
dotgit.mkdir(parents=True) dotgit.mkdir(parents=True)
for name in cls._SYMLINKS: for name in cls._SYMLINKS:
(dotgit / name).symlink_to(f'../../../.repo/projects/src/test.git/{name}') (dotgit / name).symlink_to(
f"../../../.repo/projects/src/test.git/{name}"
)
for name in cls._FILES | cls._CLEAN_FILES: for name in cls._FILES | cls._CLEAN_FILES:
(dotgit / name).write_text(name) (dotgit / name).write_text(name)
@ -372,15 +400,18 @@ class MigrateWorkTreeTests(unittest.TestCase):
def test_standard(self): def test_standard(self):
"""Migrate a standard checkout that we expect.""" """Migrate a standard checkout that we expect."""
with self._simple_layout() as tempdir: with self._simple_layout() as tempdir:
dotgit = tempdir / 'src/test/.git' dotgit = tempdir / "src/test/.git"
project.Project._MigrateOldWorkTreeGitDir(str(dotgit)) project.Project._MigrateOldWorkTreeGitDir(str(dotgit))
# Make sure the dir was transformed into a symlink. # Make sure the dir was transformed into a symlink.
self.assertTrue(dotgit.is_symlink()) self.assertTrue(dotgit.is_symlink())
self.assertEqual(os.readlink(dotgit), os.path.normpath('../../.repo/projects/src/test.git')) self.assertEqual(
os.readlink(dotgit),
os.path.normpath("../../.repo/projects/src/test.git"),
)
# Make sure files were moved over. # Make sure files were moved over.
gitdir = tempdir / '.repo/projects/src/test.git' gitdir = tempdir / ".repo/projects/src/test.git"
for name in self._FILES: for name in self._FILES:
self.assertEqual(name, (gitdir / name).read_text()) self.assertEqual(name, (gitdir / name).read_text())
# Make sure files were removed. # Make sure files were removed.
@ -390,9 +421,11 @@ class MigrateWorkTreeTests(unittest.TestCase):
def test_unknown(self): def test_unknown(self):
"""A checkout with unknown files should abort.""" """A checkout with unknown files should abort."""
with self._simple_layout() as tempdir: with self._simple_layout() as tempdir:
dotgit = tempdir / 'src/test/.git' dotgit = tempdir / "src/test/.git"
(tempdir / '.repo/projects/src/test.git/random-file').write_text('one') (tempdir / ".repo/projects/src/test.git/random-file").write_text(
(dotgit / 'random-file').write_text('two') "one"
)
(dotgit / "random-file").write_text("two")
with self.assertRaises(error.GitError): with self.assertRaises(error.GitError):
project.Project._MigrateOldWorkTreeGitDir(str(dotgit)) project.Project._MigrateOldWorkTreeGitDir(str(dotgit))
@ -410,18 +443,16 @@ class ManifestPropertiesFetchedCorrectly(unittest.TestCase):
"""Ensure properties are fetched properly.""" """Ensure properties are fetched properly."""
def setUpManifest(self, tempdir): def setUpManifest(self, tempdir):
repodir = os.path.join(tempdir, '.repo') repodir = os.path.join(tempdir, ".repo")
manifest_dir = os.path.join(repodir, 'manifests') manifest_dir = os.path.join(repodir, "manifests")
manifest_file = os.path.join( manifest_file = os.path.join(repodir, manifest_xml.MANIFEST_FILE_NAME)
repodir, manifest_xml.MANIFEST_FILE_NAME)
local_manifest_dir = os.path.join(
repodir, manifest_xml.LOCAL_MANIFESTS_DIR_NAME)
os.mkdir(repodir) os.mkdir(repodir)
os.mkdir(manifest_dir) os.mkdir(manifest_dir)
manifest = manifest_xml.XmlManifest(repodir, manifest_file) manifest = manifest_xml.XmlManifest(repodir, manifest_file)
return project.ManifestProject( return project.ManifestProject(
manifest, 'test/manifest', os.path.join(tempdir, '.git'), tempdir) manifest, "test/manifest", os.path.join(tempdir, ".git"), tempdir
)
def test_manifest_config_properties(self): def test_manifest_config_properties(self):
"""Test we are fetching the manifest config properties correctly.""" """Test we are fetching the manifest config properties correctly."""
@ -432,51 +463,61 @@ class ManifestPropertiesFetchedCorrectly(unittest.TestCase):
# Set property using the expected Set method, then ensure # Set property using the expected Set method, then ensure
# the porperty functions are using the correct Get methods. # the porperty functions are using the correct Get methods.
fakeproj.config.SetString( fakeproj.config.SetString(
'manifest.standalone', 'https://chicken/manifest.git') "manifest.standalone", "https://chicken/manifest.git"
)
self.assertEqual( self.assertEqual(
fakeproj.standalone_manifest_url, 'https://chicken/manifest.git') fakeproj.standalone_manifest_url, "https://chicken/manifest.git"
)
fakeproj.config.SetString('manifest.groups', 'test-group, admin-group') fakeproj.config.SetString(
self.assertEqual(fakeproj.manifest_groups, 'test-group, admin-group') "manifest.groups", "test-group, admin-group"
)
self.assertEqual(
fakeproj.manifest_groups, "test-group, admin-group"
)
fakeproj.config.SetString('repo.reference', 'mirror/ref') fakeproj.config.SetString("repo.reference", "mirror/ref")
self.assertEqual(fakeproj.reference, 'mirror/ref') self.assertEqual(fakeproj.reference, "mirror/ref")
fakeproj.config.SetBoolean('repo.dissociate', False) fakeproj.config.SetBoolean("repo.dissociate", False)
self.assertFalse(fakeproj.dissociate) self.assertFalse(fakeproj.dissociate)
fakeproj.config.SetBoolean('repo.archive', False) fakeproj.config.SetBoolean("repo.archive", False)
self.assertFalse(fakeproj.archive) self.assertFalse(fakeproj.archive)
fakeproj.config.SetBoolean('repo.mirror', False) fakeproj.config.SetBoolean("repo.mirror", False)
self.assertFalse(fakeproj.mirror) self.assertFalse(fakeproj.mirror)
fakeproj.config.SetBoolean('repo.worktree', False) fakeproj.config.SetBoolean("repo.worktree", False)
self.assertFalse(fakeproj.use_worktree) self.assertFalse(fakeproj.use_worktree)
fakeproj.config.SetBoolean('repo.clonebundle', False) fakeproj.config.SetBoolean("repo.clonebundle", False)
self.assertFalse(fakeproj.clone_bundle) self.assertFalse(fakeproj.clone_bundle)
fakeproj.config.SetBoolean('repo.submodules', False) fakeproj.config.SetBoolean("repo.submodules", False)
self.assertFalse(fakeproj.submodules) self.assertFalse(fakeproj.submodules)
fakeproj.config.SetBoolean('repo.git-lfs', False) fakeproj.config.SetBoolean("repo.git-lfs", False)
self.assertFalse(fakeproj.git_lfs) self.assertFalse(fakeproj.git_lfs)
fakeproj.config.SetBoolean('repo.superproject', False) fakeproj.config.SetBoolean("repo.superproject", False)
self.assertFalse(fakeproj.use_superproject) self.assertFalse(fakeproj.use_superproject)
fakeproj.config.SetBoolean('repo.partialclone', False) fakeproj.config.SetBoolean("repo.partialclone", False)
self.assertFalse(fakeproj.partial_clone) self.assertFalse(fakeproj.partial_clone)
fakeproj.config.SetString('repo.depth', '48') fakeproj.config.SetString("repo.depth", "48")
self.assertEqual(fakeproj.depth, '48') self.assertEqual(fakeproj.depth, "48")
fakeproj.config.SetString('repo.clonefilter', 'blob:limit=10M') fakeproj.config.SetString("repo.clonefilter", "blob:limit=10M")
self.assertEqual(fakeproj.clone_filter, 'blob:limit=10M') self.assertEqual(fakeproj.clone_filter, "blob:limit=10M")
fakeproj.config.SetString('repo.partialcloneexclude', 'third_party/big_repo') fakeproj.config.SetString(
self.assertEqual(fakeproj.partial_clone_exclude, 'third_party/big_repo') "repo.partialcloneexclude", "third_party/big_repo"
)
self.assertEqual(
fakeproj.partial_clone_exclude, "third_party/big_repo"
)
fakeproj.config.SetString('manifest.platform', 'auto') fakeproj.config.SetString("manifest.platform", "auto")
self.assertEqual(fakeproj.manifest_platform, 'auto') self.assertEqual(fakeproj.manifest_platform, "auto")

View File

@ -25,7 +25,7 @@ class TraceTests(unittest.TestCase):
"""Check Trace behavior.""" """Check Trace behavior."""
def testTrace_MaxSizeEnforced(self): def testTrace_MaxSizeEnforced(self):
content = 'git chicken' content = "git chicken"
with repo_trace.Trace(content, first_trace=True): with repo_trace.Trace(content, first_trace=True):
pass pass
@ -34,23 +34,27 @@ class TraceTests(unittest.TestCase):
with repo_trace.Trace(content): with repo_trace.Trace(content):
pass pass
self.assertGreater( self.assertGreater(
os.path.getsize(repo_trace._TRACE_FILE), first_trace_size) os.path.getsize(repo_trace._TRACE_FILE), first_trace_size
)
# Check we clear everything is the last chunk is larger than _MAX_SIZE. # Check we clear everything is the last chunk is larger than _MAX_SIZE.
with mock.patch('repo_trace._MAX_SIZE', 0): with mock.patch("repo_trace._MAX_SIZE", 0):
with repo_trace.Trace(content, first_trace=True): with repo_trace.Trace(content, first_trace=True):
pass pass
self.assertEqual(first_trace_size, self.assertEqual(
os.path.getsize(repo_trace._TRACE_FILE)) first_trace_size, os.path.getsize(repo_trace._TRACE_FILE)
)
# Check we only clear the chunks we need to. # Check we only clear the chunks we need to.
repo_trace._MAX_SIZE = (first_trace_size + 1) / (1024 * 1024) repo_trace._MAX_SIZE = (first_trace_size + 1) / (1024 * 1024)
with repo_trace.Trace(content, first_trace=True): with repo_trace.Trace(content, first_trace=True):
pass pass
self.assertEqual(first_trace_size * 2, self.assertEqual(
os.path.getsize(repo_trace._TRACE_FILE)) first_trace_size * 2, os.path.getsize(repo_trace._TRACE_FILE)
)
with repo_trace.Trace(content, first_trace=True): with repo_trace.Trace(content, first_trace=True):
pass pass
self.assertEqual(first_trace_size * 2, self.assertEqual(
os.path.getsize(repo_trace._TRACE_FILE)) first_trace_size * 2, os.path.getsize(repo_trace._TRACE_FILE)
)

View File

@ -27,18 +27,22 @@ class SshTests(unittest.TestCase):
def test_parse_ssh_version(self): def test_parse_ssh_version(self):
"""Check _parse_ssh_version() handling.""" """Check _parse_ssh_version() handling."""
ver = ssh._parse_ssh_version('Unknown\n') ver = ssh._parse_ssh_version("Unknown\n")
self.assertEqual(ver, ()) self.assertEqual(ver, ())
ver = ssh._parse_ssh_version('OpenSSH_1.0\n') ver = ssh._parse_ssh_version("OpenSSH_1.0\n")
self.assertEqual(ver, (1, 0)) self.assertEqual(ver, (1, 0))
ver = ssh._parse_ssh_version('OpenSSH_6.6.1p1 Ubuntu-2ubuntu2.13, OpenSSL 1.0.1f 6 Jan 2014\n') ver = ssh._parse_ssh_version(
"OpenSSH_6.6.1p1 Ubuntu-2ubuntu2.13, OpenSSL 1.0.1f 6 Jan 2014\n"
)
self.assertEqual(ver, (6, 6, 1)) self.assertEqual(ver, (6, 6, 1))
ver = ssh._parse_ssh_version('OpenSSH_7.6p1 Ubuntu-4ubuntu0.3, OpenSSL 1.0.2n 7 Dec 2017\n') ver = ssh._parse_ssh_version(
"OpenSSH_7.6p1 Ubuntu-4ubuntu0.3, OpenSSL 1.0.2n 7 Dec 2017\n"
)
self.assertEqual(ver, (7, 6)) self.assertEqual(ver, (7, 6))
def test_version(self): def test_version(self):
"""Check version() handling.""" """Check version() handling."""
with mock.patch('ssh._run_ssh_version', return_value='OpenSSH_1.2\n'): with mock.patch("ssh._run_ssh_version", return_value="OpenSSH_1.2\n"):
self.assertEqual(ssh.version(), (1, 2)) self.assertEqual(ssh.version(), (1, 2))
def test_context_manager_empty(self): def test_context_manager_empty(self):
@ -51,9 +55,9 @@ class SshTests(unittest.TestCase):
"""Verify orphaned clients & masters get cleaned up.""" """Verify orphaned clients & masters get cleaned up."""
with multiprocessing.Manager() as manager: with multiprocessing.Manager() as manager:
with ssh.ProxyManager(manager) as ssh_proxy: with ssh.ProxyManager(manager) as ssh_proxy:
client = subprocess.Popen(['sleep', '964853320']) client = subprocess.Popen(["sleep", "964853320"])
ssh_proxy.add_client(client) ssh_proxy.add_client(client)
master = subprocess.Popen(['sleep', '964853321']) master = subprocess.Popen(["sleep", "964853321"])
ssh_proxy.add_master(master) ssh_proxy.add_master(master)
# If the process still exists, these will throw timeout errors. # If the process still exists, these will throw timeout errors.
client.wait(0) client.wait(0)
@ -63,12 +67,12 @@ class SshTests(unittest.TestCase):
"""Check sock() function.""" """Check sock() function."""
manager = multiprocessing.Manager() manager = multiprocessing.Manager()
proxy = ssh.ProxyManager(manager) proxy = ssh.ProxyManager(manager)
with mock.patch('tempfile.mkdtemp', return_value='/tmp/foo'): with mock.patch("tempfile.mkdtemp", return_value="/tmp/foo"):
# old ssh version uses port # Old ssh version uses port.
with mock.patch('ssh.version', return_value=(6, 6)): with mock.patch("ssh.version", return_value=(6, 6)):
self.assertTrue(proxy.sock().endswith('%p')) self.assertTrue(proxy.sock().endswith("%p"))
proxy._sock_path = None proxy._sock_path = None
# new ssh version uses hash # New ssh version uses hash.
with mock.patch('ssh.version', return_value=(6, 7)): with mock.patch("ssh.version", return_value=(6, 7)):
self.assertTrue(proxy.sock().endswith('%C')) self.assertTrue(proxy.sock().endswith("%C"))

View File

@ -25,30 +25,30 @@ class AllCommands(unittest.TestCase):
def test_required_basic(self): def test_required_basic(self):
"""Basic checking of registered commands.""" """Basic checking of registered commands."""
# NB: We don't test all subcommands as we want to avoid "change detection" # NB: We don't test all subcommands as we want to avoid "change
# tests, so we just look for the most common/important ones here that are # detection" tests, so we just look for the most common/important ones
# unlikely to ever change. # here that are unlikely to ever change.
for cmd in {'cherry-pick', 'help', 'init', 'start', 'sync', 'upload'}: for cmd in {"cherry-pick", "help", "init", "start", "sync", "upload"}:
self.assertIn(cmd, subcmds.all_commands) self.assertIn(cmd, subcmds.all_commands)
def test_naming(self): def test_naming(self):
"""Verify we don't add things that we shouldn't.""" """Verify we don't add things that we shouldn't."""
for cmd in subcmds.all_commands: for cmd in subcmds.all_commands:
# Reject filename suffixes like "help.py". # Reject filename suffixes like "help.py".
self.assertNotIn('.', cmd) self.assertNotIn(".", cmd)
# Make sure all '_' were converted to '-'. # Make sure all '_' were converted to '-'.
self.assertNotIn('_', cmd) self.assertNotIn("_", cmd)
# Reject internal python paths like "__init__". # Reject internal python paths like "__init__".
self.assertFalse(cmd.startswith('__')) self.assertFalse(cmd.startswith("__"))
def test_help_desc_style(self): def test_help_desc_style(self):
"""Force some consistency in option descriptions. """Force some consistency in option descriptions.
Python's optparse & argparse has a few default options like --help. Their Python's optparse & argparse has a few default options like --help.
option description text uses lowercase sentence fragments, so enforce our Their option description text uses lowercase sentence fragments, so
options follow the same style so UI is consistent. enforce our options follow the same style so UI is consistent.
We enforce: We enforce:
* Text starts with lowercase. * Text starts with lowercase.
@ -63,11 +63,15 @@ class AllCommands(unittest.TestCase):
c = option.help[0] c = option.help[0]
self.assertEqual( self.assertEqual(
c.lower(), c, c.lower(),
msg=f'subcmds/{name}.py: {option.get_opt_string()}: help text ' c,
f'should start with lowercase: "{option.help}"') msg=f"subcmds/{name}.py: {option.get_opt_string()}: "
f'help text should start with lowercase: "{option.help}"',
)
self.assertNotEqual( self.assertNotEqual(
option.help[-1], '.', option.help[-1],
msg=f'subcmds/{name}.py: {option.get_opt_string()}: help text ' ".",
f'should not end in a period: "{option.help}"') msg=f"subcmds/{name}.py: {option.get_opt_string()}: "
f'help text should not end in a period: "{option.help}"',
)

View File

@ -27,9 +27,7 @@ class InitCommand(unittest.TestCase):
def test_cli_parser_good(self): def test_cli_parser_good(self):
"""Check valid command line options.""" """Check valid command line options."""
ARGV = ( ARGV = ([],)
[],
)
for argv in ARGV: for argv in ARGV:
opts, args = self.cmd.OptionParser.parse_args(argv) opts, args = self.cmd.OptionParser.parse_args(argv)
self.cmd.ValidateOptions(opts, args) self.cmd.ValidateOptions(opts, args)
@ -38,10 +36,9 @@ class InitCommand(unittest.TestCase):
"""Check invalid command line options.""" """Check invalid command line options."""
ARGV = ( ARGV = (
# Too many arguments. # Too many arguments.
['url', 'asdf'], ["url", "asdf"],
# Conflicting options. # Conflicting options.
['--mirror', '--archive'], ["--mirror", "--archive"],
) )
for argv in ARGV: for argv in ARGV:
opts, args = self.cmd.OptionParser.parse_args(argv) opts, args = self.cmd.OptionParser.parse_args(argv)

View File

@ -23,56 +23,70 @@ import command
from subcmds import sync from subcmds import sync
@pytest.mark.parametrize('use_superproject, cli_args, result', [ @pytest.mark.parametrize(
(True, ['--current-branch'], True), "use_superproject, cli_args, result",
(True, ['--no-current-branch'], True), [
(True, ["--current-branch"], True),
(True, ["--no-current-branch"], True),
(True, [], True), (True, [], True),
(False, ['--current-branch'], True), (False, ["--current-branch"], True),
(False, ['--no-current-branch'], False), (False, ["--no-current-branch"], False),
(False, [], None), (False, [], None),
]) ],
)
def test_get_current_branch_only(use_superproject, cli_args, result): def test_get_current_branch_only(use_superproject, cli_args, result):
"""Test Sync._GetCurrentBranchOnly logic. """Test Sync._GetCurrentBranchOnly logic.
Sync._GetCurrentBranchOnly should return True if a superproject is requested, Sync._GetCurrentBranchOnly should return True if a superproject is
and otherwise the value of the current_branch_only option. requested, and otherwise the value of the current_branch_only option.
""" """
cmd = sync.Sync() cmd = sync.Sync()
opts, _ = cmd.OptionParser.parse_args(cli_args) opts, _ = cmd.OptionParser.parse_args(cli_args)
with mock.patch('git_superproject.UseSuperproject', with mock.patch(
return_value=use_superproject): "git_superproject.UseSuperproject", return_value=use_superproject
):
assert cmd._GetCurrentBranchOnly(opts, cmd.manifest) == result assert cmd._GetCurrentBranchOnly(opts, cmd.manifest) == result
# Used to patch os.cpu_count() for reliable results. # Used to patch os.cpu_count() for reliable results.
OS_CPU_COUNT = 24 OS_CPU_COUNT = 24
@pytest.mark.parametrize('argv, jobs_manifest, jobs, jobs_net, jobs_check', [
@pytest.mark.parametrize(
"argv, jobs_manifest, jobs, jobs_net, jobs_check",
[
# No user or manifest settings. # No user or manifest settings.
([], None, OS_CPU_COUNT, 1, command.DEFAULT_LOCAL_JOBS), ([], None, OS_CPU_COUNT, 1, command.DEFAULT_LOCAL_JOBS),
# No user settings, so manifest settings control. # No user settings, so manifest settings control.
([], 3, 3, 3, 3), ([], 3, 3, 3, 3),
# User settings, but no manifest. # User settings, but no manifest.
(['--jobs=4'], None, 4, 4, 4), (["--jobs=4"], None, 4, 4, 4),
(['--jobs=4', '--jobs-network=5'], None, 4, 5, 4), (["--jobs=4", "--jobs-network=5"], None, 4, 5, 4),
(['--jobs=4', '--jobs-checkout=6'], None, 4, 4, 6), (["--jobs=4", "--jobs-checkout=6"], None, 4, 4, 6),
(['--jobs=4', '--jobs-network=5', '--jobs-checkout=6'], None, 4, 5, 6), (["--jobs=4", "--jobs-network=5", "--jobs-checkout=6"], None, 4, 5, 6),
(['--jobs-network=5'], None, OS_CPU_COUNT, 5, command.DEFAULT_LOCAL_JOBS), (
(['--jobs-checkout=6'], None, OS_CPU_COUNT, 1, 6), ["--jobs-network=5"],
(['--jobs-network=5', '--jobs-checkout=6'], None, OS_CPU_COUNT, 5, 6), None,
OS_CPU_COUNT,
5,
command.DEFAULT_LOCAL_JOBS,
),
(["--jobs-checkout=6"], None, OS_CPU_COUNT, 1, 6),
(["--jobs-network=5", "--jobs-checkout=6"], None, OS_CPU_COUNT, 5, 6),
# User settings with manifest settings. # User settings with manifest settings.
(['--jobs=4'], 3, 4, 4, 4), (["--jobs=4"], 3, 4, 4, 4),
(['--jobs=4', '--jobs-network=5'], 3, 4, 5, 4), (["--jobs=4", "--jobs-network=5"], 3, 4, 5, 4),
(['--jobs=4', '--jobs-checkout=6'], 3, 4, 4, 6), (["--jobs=4", "--jobs-checkout=6"], 3, 4, 4, 6),
(['--jobs=4', '--jobs-network=5', '--jobs-checkout=6'], 3, 4, 5, 6), (["--jobs=4", "--jobs-network=5", "--jobs-checkout=6"], 3, 4, 5, 6),
(['--jobs-network=5'], 3, 3, 5, 3), (["--jobs-network=5"], 3, 3, 5, 3),
(['--jobs-checkout=6'], 3, 3, 3, 6), (["--jobs-checkout=6"], 3, 3, 3, 6),
(['--jobs-network=5', '--jobs-checkout=6'], 3, 3, 5, 6), (["--jobs-network=5", "--jobs-checkout=6"], 3, 3, 5, 6),
# Settings that exceed rlimits get capped. # Settings that exceed rlimits get capped.
(['--jobs=1000000'], None, 83, 83, 83), (["--jobs=1000000"], None, 83, 83, 83),
([], 1000000, 83, 83, 83), ([], 1000000, 83, 83, 83),
]) ],
)
def test_cli_jobs(argv, jobs_manifest, jobs, jobs_net, jobs_check): def test_cli_jobs(argv, jobs_manifest, jobs, jobs_net, jobs_check):
"""Tests --jobs option behavior.""" """Tests --jobs option behavior."""
mp = mock.MagicMock() mp = mock.MagicMock()
@ -82,8 +96,8 @@ def test_cli_jobs(argv, jobs_manifest, jobs, jobs_net, jobs_check):
opts, args = cmd.OptionParser.parse_args(argv) opts, args = cmd.OptionParser.parse_args(argv)
cmd.ValidateOptions(opts, args) cmd.ValidateOptions(opts, args)
with mock.patch.object(sync, '_rlimit_nofile', return_value=(256, 256)): with mock.patch.object(sync, "_rlimit_nofile", return_value=(256, 256)):
with mock.patch.object(os, 'cpu_count', return_value=OS_CPU_COUNT): with mock.patch.object(os, "cpu_count", return_value=OS_CPU_COUNT):
cmd._ValidateOptionsWithManifest(opts, mp) cmd._ValidateOptionsWithManifest(opts, mp)
assert opts.jobs == jobs assert opts.jobs == jobs
assert opts.jobs_network == jobs_net assert opts.jobs_network == jobs_net
@ -96,38 +110,51 @@ class GetPreciousObjectsState(unittest.TestCase):
def setUp(self): def setUp(self):
"""Common setup.""" """Common setup."""
self.cmd = sync.Sync() self.cmd = sync.Sync()
self.project = p = mock.MagicMock(use_git_worktrees=False, self.project = p = mock.MagicMock(
UseAlternates=False) use_git_worktrees=False, UseAlternates=False
)
p.manifest.GetProjectsWithName.return_value = [p] p.manifest.GetProjectsWithName.return_value = [p]
self.opt = mock.Mock(spec_set=['this_manifest_only']) self.opt = mock.Mock(spec_set=["this_manifest_only"])
self.opt.this_manifest_only = False self.opt.this_manifest_only = False
def test_worktrees(self): def test_worktrees(self):
"""False for worktrees.""" """False for worktrees."""
self.project.use_git_worktrees = True self.project.use_git_worktrees = True
self.assertFalse(self.cmd._GetPreciousObjectsState(self.project, self.opt)) self.assertFalse(
self.cmd._GetPreciousObjectsState(self.project, self.opt)
)
def test_not_shared(self): def test_not_shared(self):
"""Singleton project.""" """Singleton project."""
self.assertFalse(self.cmd._GetPreciousObjectsState(self.project, self.opt)) self.assertFalse(
self.cmd._GetPreciousObjectsState(self.project, self.opt)
)
def test_shared(self): def test_shared(self):
"""Shared project.""" """Shared project."""
self.project.manifest.GetProjectsWithName.return_value = [ self.project.manifest.GetProjectsWithName.return_value = [
self.project, self.project self.project,
self.project,
] ]
self.assertTrue(self.cmd._GetPreciousObjectsState(self.project, self.opt)) self.assertTrue(
self.cmd._GetPreciousObjectsState(self.project, self.opt)
)
def test_shared_with_alternates(self): def test_shared_with_alternates(self):
"""Shared project, with alternates.""" """Shared project, with alternates."""
self.project.manifest.GetProjectsWithName.return_value = [ self.project.manifest.GetProjectsWithName.return_value = [
self.project, self.project self.project,
self.project,
] ]
self.project.UseAlternates = True self.project.UseAlternates = True
self.assertFalse(self.cmd._GetPreciousObjectsState(self.project, self.opt)) self.assertFalse(
self.cmd._GetPreciousObjectsState(self.project, self.opt)
)
def test_not_found(self): def test_not_found(self):
"""Project not found in manifest.""" """Project not found in manifest."""
self.project.manifest.GetProjectsWithName.return_value = [] self.project.manifest.GetProjectsWithName.return_value = []
self.assertFalse(self.cmd._GetPreciousObjectsState(self.project, self.opt)) self.assertFalse(
self.cmd._GetPreciousObjectsState(self.project, self.opt)
)

View File

@ -24,5 +24,5 @@ class UpdateManpagesTest(unittest.TestCase):
def test_replace_regex(self): def test_replace_regex(self):
"""Check that replace_regex works.""" """Check that replace_regex works."""
data = '\n\033[1mSummary\033[m\n' data = "\n\033[1mSummary\033[m\n"
self.assertEqual(update_manpages.replace_regex(data),'\nSummary\n') self.assertEqual(update_manpages.replace_regex(data), "\nSummary\n")

View File

@ -28,9 +28,8 @@ import wrapper
def fixture(*paths): def fixture(*paths):
"""Return a path relative to tests/fixtures. """Return a path relative to tests/fixtures."""
""" return os.path.join(os.path.dirname(__file__), "fixtures", *paths)
return os.path.join(os.path.dirname(__file__), 'fixtures', *paths)
class RepoWrapperTestCase(unittest.TestCase): class RepoWrapperTestCase(unittest.TestCase):
@ -43,28 +42,31 @@ class RepoWrapperTestCase(unittest.TestCase):
class RepoWrapperUnitTest(RepoWrapperTestCase): class RepoWrapperUnitTest(RepoWrapperTestCase):
"""Tests helper functions in the repo wrapper """Tests helper functions in the repo wrapper"""
"""
def test_version(self): def test_version(self):
"""Make sure _Version works.""" """Make sure _Version works."""
with self.assertRaises(SystemExit) as e: with self.assertRaises(SystemExit) as e:
with mock.patch('sys.stdout', new_callable=StringIO) as stdout: with mock.patch("sys.stdout", new_callable=StringIO) as stdout:
with mock.patch('sys.stderr', new_callable=StringIO) as stderr: with mock.patch("sys.stderr", new_callable=StringIO) as stderr:
self.wrapper._Version() self.wrapper._Version()
self.assertEqual(0, e.exception.code) self.assertEqual(0, e.exception.code)
self.assertEqual('', stderr.getvalue()) self.assertEqual("", stderr.getvalue())
self.assertIn('repo launcher version', stdout.getvalue()) self.assertIn("repo launcher version", stdout.getvalue())
def test_python_constraints(self): def test_python_constraints(self):
"""The launcher should never require newer than main.py.""" """The launcher should never require newer than main.py."""
self.assertGreaterEqual(main.MIN_PYTHON_VERSION_HARD, self.assertGreaterEqual(
self.wrapper.MIN_PYTHON_VERSION_HARD) main.MIN_PYTHON_VERSION_HARD, self.wrapper.MIN_PYTHON_VERSION_HARD
self.assertGreaterEqual(main.MIN_PYTHON_VERSION_SOFT, )
self.wrapper.MIN_PYTHON_VERSION_SOFT) self.assertGreaterEqual(
main.MIN_PYTHON_VERSION_SOFT, self.wrapper.MIN_PYTHON_VERSION_SOFT
)
# Make sure the versions are themselves in sync. # Make sure the versions are themselves in sync.
self.assertGreaterEqual(self.wrapper.MIN_PYTHON_VERSION_SOFT, self.assertGreaterEqual(
self.wrapper.MIN_PYTHON_VERSION_HARD) self.wrapper.MIN_PYTHON_VERSION_SOFT,
self.wrapper.MIN_PYTHON_VERSION_HARD,
)
def test_init_parser(self): def test_init_parser(self):
"""Make sure 'init' GetParser works.""" """Make sure 'init' GetParser works."""
@ -84,48 +86,76 @@ class RepoWrapperUnitTest(RepoWrapperTestCase):
""" """
Test reading a missing gitc config file Test reading a missing gitc config file
""" """
self.wrapper.GITC_CONFIG_FILE = fixture('missing_gitc_config') self.wrapper.GITC_CONFIG_FILE = fixture("missing_gitc_config")
val = self.wrapper.get_gitc_manifest_dir() val = self.wrapper.get_gitc_manifest_dir()
self.assertEqual(val, '') self.assertEqual(val, "")
def test_get_gitc_manifest_dir(self): def test_get_gitc_manifest_dir(self):
""" """
Test reading the gitc config file and parsing the directory Test reading the gitc config file and parsing the directory
""" """
self.wrapper.GITC_CONFIG_FILE = fixture('gitc_config') self.wrapper.GITC_CONFIG_FILE = fixture("gitc_config")
val = self.wrapper.get_gitc_manifest_dir() val = self.wrapper.get_gitc_manifest_dir()
self.assertEqual(val, '/test/usr/local/google/gitc') self.assertEqual(val, "/test/usr/local/google/gitc")
def test_gitc_parse_clientdir_no_gitc(self): def test_gitc_parse_clientdir_no_gitc(self):
""" """
Test parsing the gitc clientdir without gitc running Test parsing the gitc clientdir without gitc running
""" """
self.wrapper.GITC_CONFIG_FILE = fixture('missing_gitc_config') self.wrapper.GITC_CONFIG_FILE = fixture("missing_gitc_config")
self.assertEqual(self.wrapper.gitc_parse_clientdir('/something'), None) self.assertEqual(self.wrapper.gitc_parse_clientdir("/something"), None)
self.assertEqual(self.wrapper.gitc_parse_clientdir('/gitc/manifest-rw/test'), 'test') self.assertEqual(
self.wrapper.gitc_parse_clientdir("/gitc/manifest-rw/test"), "test"
)
def test_gitc_parse_clientdir(self): def test_gitc_parse_clientdir(self):
""" """
Test parsing the gitc clientdir Test parsing the gitc clientdir
""" """
self.wrapper.GITC_CONFIG_FILE = fixture('gitc_config') self.wrapper.GITC_CONFIG_FILE = fixture("gitc_config")
self.assertEqual(self.wrapper.gitc_parse_clientdir('/something'), None) self.assertEqual(self.wrapper.gitc_parse_clientdir("/something"), None)
self.assertEqual(self.wrapper.gitc_parse_clientdir('/gitc/manifest-rw/test'), 'test') self.assertEqual(
self.assertEqual(self.wrapper.gitc_parse_clientdir('/gitc/manifest-rw/test/'), 'test') self.wrapper.gitc_parse_clientdir("/gitc/manifest-rw/test"), "test"
self.assertEqual(self.wrapper.gitc_parse_clientdir('/gitc/manifest-rw/test/extra'), 'test') )
self.assertEqual(self.wrapper.gitc_parse_clientdir('/test/usr/local/google/gitc/test'), 'test') self.assertEqual(
self.assertEqual(self.wrapper.gitc_parse_clientdir('/test/usr/local/google/gitc/test/'), 'test') self.wrapper.gitc_parse_clientdir("/gitc/manifest-rw/test/"), "test"
self.assertEqual(self.wrapper.gitc_parse_clientdir('/test/usr/local/google/gitc/test/extra'), )
'test') self.assertEqual(
self.assertEqual(self.wrapper.gitc_parse_clientdir('/gitc/manifest-rw/'), None) self.wrapper.gitc_parse_clientdir("/gitc/manifest-rw/test/extra"),
self.assertEqual(self.wrapper.gitc_parse_clientdir('/test/usr/local/google/gitc/'), None) "test",
)
self.assertEqual(
self.wrapper.gitc_parse_clientdir(
"/test/usr/local/google/gitc/test"
),
"test",
)
self.assertEqual(
self.wrapper.gitc_parse_clientdir(
"/test/usr/local/google/gitc/test/"
),
"test",
)
self.assertEqual(
self.wrapper.gitc_parse_clientdir(
"/test/usr/local/google/gitc/test/extra"
),
"test",
)
self.assertEqual(
self.wrapper.gitc_parse_clientdir("/gitc/manifest-rw/"), None
)
self.assertEqual(
self.wrapper.gitc_parse_clientdir("/test/usr/local/google/gitc/"),
None,
)
class SetGitTrace2ParentSid(RepoWrapperTestCase): class SetGitTrace2ParentSid(RepoWrapperTestCase):
"""Check SetGitTrace2ParentSid behavior.""" """Check SetGitTrace2ParentSid behavior."""
KEY = 'GIT_TRACE2_PARENT_SID' KEY = "GIT_TRACE2_PARENT_SID"
VALID_FORMAT = re.compile(r'^repo-[0-9]{8}T[0-9]{6}Z-P[0-9a-f]{8}$') VALID_FORMAT = re.compile(r"^repo-[0-9]{8}T[0-9]{6}Z-P[0-9a-f]{8}$")
def test_first_set(self): def test_first_set(self):
"""Test env var not yet set.""" """Test env var not yet set."""
@ -137,11 +167,11 @@ class SetGitTrace2ParentSid(RepoWrapperTestCase):
def test_append(self): def test_append(self):
"""Test env var is appended.""" """Test env var is appended."""
env = {self.KEY: 'pfx'} env = {self.KEY: "pfx"}
self.wrapper.SetGitTrace2ParentSid(env) self.wrapper.SetGitTrace2ParentSid(env)
self.assertIn(self.KEY, env) self.assertIn(self.KEY, env)
value = env[self.KEY] value = env[self.KEY]
self.assertTrue(value.startswith('pfx/')) self.assertTrue(value.startswith("pfx/"))
self.assertRegex(value[4:], self.VALID_FORMAT) self.assertRegex(value[4:], self.VALID_FORMAT)
def test_global_context(self): def test_global_context(self):
@ -158,18 +188,18 @@ class RunCommand(RepoWrapperTestCase):
def test_capture(self): def test_capture(self):
"""Check capture_output handling.""" """Check capture_output handling."""
ret = self.wrapper.run_command(['echo', 'hi'], capture_output=True) ret = self.wrapper.run_command(["echo", "hi"], capture_output=True)
# echo command appends OS specific linesep, but on Windows + Git Bash # echo command appends OS specific linesep, but on Windows + Git Bash
# we get UNIX ending, so we allow both. # we get UNIX ending, so we allow both.
self.assertIn(ret.stdout, ['hi' + os.linesep, 'hi\n']) self.assertIn(ret.stdout, ["hi" + os.linesep, "hi\n"])
def test_check(self): def test_check(self):
"""Check check handling.""" """Check check handling."""
self.wrapper.run_command(['true'], check=False) self.wrapper.run_command(["true"], check=False)
self.wrapper.run_command(['true'], check=True) self.wrapper.run_command(["true"], check=True)
self.wrapper.run_command(['false'], check=False) self.wrapper.run_command(["false"], check=False)
with self.assertRaises(self.wrapper.RunError): with self.assertRaises(self.wrapper.RunError):
self.wrapper.run_command(['false'], check=True) self.wrapper.run_command(["false"], check=True)
class RunGit(RepoWrapperTestCase): class RunGit(RepoWrapperTestCase):
@ -177,14 +207,14 @@ class RunGit(RepoWrapperTestCase):
def test_capture(self): def test_capture(self):
"""Check capture_output handling.""" """Check capture_output handling."""
ret = self.wrapper.run_git('--version') ret = self.wrapper.run_git("--version")
self.assertIn('git', ret.stdout) self.assertIn("git", ret.stdout)
def test_check(self): def test_check(self):
"""Check check handling.""" """Check check handling."""
with self.assertRaises(self.wrapper.CloneFailure): with self.assertRaises(self.wrapper.CloneFailure):
self.wrapper.run_git('--version-asdfasdf') self.wrapper.run_git("--version-asdfasdf")
self.wrapper.run_git('--version-asdfasdf', check=False) self.wrapper.run_git("--version-asdfasdf", check=False)
class ParseGitVersion(RepoWrapperTestCase): class ParseGitVersion(RepoWrapperTestCase):
@ -197,25 +227,26 @@ class ParseGitVersion(RepoWrapperTestCase):
def test_bad_ver(self): def test_bad_ver(self):
"""Check handling of bad git versions.""" """Check handling of bad git versions."""
ret = self.wrapper.ParseGitVersion(ver_str='asdf') ret = self.wrapper.ParseGitVersion(ver_str="asdf")
self.assertIsNone(ret) self.assertIsNone(ret)
def test_normal_ver(self): def test_normal_ver(self):
"""Check handling of normal git versions.""" """Check handling of normal git versions."""
ret = self.wrapper.ParseGitVersion(ver_str='git version 2.25.1') ret = self.wrapper.ParseGitVersion(ver_str="git version 2.25.1")
self.assertEqual(2, ret.major) self.assertEqual(2, ret.major)
self.assertEqual(25, ret.minor) self.assertEqual(25, ret.minor)
self.assertEqual(1, ret.micro) self.assertEqual(1, ret.micro)
self.assertEqual('2.25.1', ret.full) self.assertEqual("2.25.1", ret.full)
def test_extended_ver(self): def test_extended_ver(self):
"""Check handling of extended distro git versions.""" """Check handling of extended distro git versions."""
ret = self.wrapper.ParseGitVersion( ret = self.wrapper.ParseGitVersion(
ver_str='git version 1.30.50.696.g5e7596f4ac-goog') ver_str="git version 1.30.50.696.g5e7596f4ac-goog"
)
self.assertEqual(1, ret.major) self.assertEqual(1, ret.major)
self.assertEqual(30, ret.minor) self.assertEqual(30, ret.minor)
self.assertEqual(50, ret.micro) self.assertEqual(50, ret.micro)
self.assertEqual('1.30.50.696.g5e7596f4ac-goog', ret.full) self.assertEqual("1.30.50.696.g5e7596f4ac-goog", ret.full)
class CheckGitVersion(RepoWrapperTestCase): class CheckGitVersion(RepoWrapperTestCase):
@ -223,23 +254,29 @@ class CheckGitVersion(RepoWrapperTestCase):
def test_unknown(self): def test_unknown(self):
"""Unknown versions should abort.""" """Unknown versions should abort."""
with mock.patch.object(self.wrapper, 'ParseGitVersion', return_value=None): with mock.patch.object(
self.wrapper, "ParseGitVersion", return_value=None
):
with self.assertRaises(self.wrapper.CloneFailure): with self.assertRaises(self.wrapper.CloneFailure):
self.wrapper._CheckGitVersion() self.wrapper._CheckGitVersion()
def test_old(self): def test_old(self):
"""Old versions should abort.""" """Old versions should abort."""
with mock.patch.object( with mock.patch.object(
self.wrapper, 'ParseGitVersion', self.wrapper,
return_value=self.wrapper.GitVersion(1, 0, 0, '1.0.0')): "ParseGitVersion",
return_value=self.wrapper.GitVersion(1, 0, 0, "1.0.0"),
):
with self.assertRaises(self.wrapper.CloneFailure): with self.assertRaises(self.wrapper.CloneFailure):
self.wrapper._CheckGitVersion() self.wrapper._CheckGitVersion()
def test_new(self): def test_new(self):
"""Newer versions should run fine.""" """Newer versions should run fine."""
with mock.patch.object( with mock.patch.object(
self.wrapper, 'ParseGitVersion', self.wrapper,
return_value=self.wrapper.GitVersion(100, 0, 0, '100.0.0')): "ParseGitVersion",
return_value=self.wrapper.GitVersion(100, 0, 0, "100.0.0"),
):
self.wrapper._CheckGitVersion() self.wrapper._CheckGitVersion()
@ -250,26 +287,34 @@ class Requirements(RepoWrapperTestCase):
"""Don't crash if the file is missing (old version).""" """Don't crash if the file is missing (old version)."""
testdir = os.path.dirname(os.path.realpath(__file__)) testdir = os.path.dirname(os.path.realpath(__file__))
self.assertIsNone(self.wrapper.Requirements.from_dir(testdir)) self.assertIsNone(self.wrapper.Requirements.from_dir(testdir))
self.assertIsNone(self.wrapper.Requirements.from_file( self.assertIsNone(
os.path.join(testdir, 'xxxxxxxxxxxxxxxxxxxxxxxx'))) self.wrapper.Requirements.from_file(
os.path.join(testdir, "xxxxxxxxxxxxxxxxxxxxxxxx")
)
)
def test_corrupt_data(self): def test_corrupt_data(self):
"""If the file can't be parsed, don't blow up.""" """If the file can't be parsed, don't blow up."""
self.assertIsNone(self.wrapper.Requirements.from_file(__file__)) self.assertIsNone(self.wrapper.Requirements.from_file(__file__))
self.assertIsNone(self.wrapper.Requirements.from_data(b'x')) self.assertIsNone(self.wrapper.Requirements.from_data(b"x"))
def test_valid_data(self): def test_valid_data(self):
"""Make sure we can parse the file we ship.""" """Make sure we can parse the file we ship."""
self.assertIsNotNone(self.wrapper.Requirements.from_data(b'{}')) self.assertIsNotNone(self.wrapper.Requirements.from_data(b"{}"))
rootdir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) rootdir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
self.assertIsNotNone(self.wrapper.Requirements.from_dir(rootdir)) self.assertIsNotNone(self.wrapper.Requirements.from_dir(rootdir))
self.assertIsNotNone(self.wrapper.Requirements.from_file(os.path.join( self.assertIsNotNone(
rootdir, 'requirements.json'))) self.wrapper.Requirements.from_file(
os.path.join(rootdir, "requirements.json")
)
)
def test_format_ver(self): def test_format_ver(self):
"""Check format_ver can format.""" """Check format_ver can format."""
self.assertEqual('1.2.3', self.wrapper.Requirements._format_ver((1, 2, 3))) self.assertEqual(
self.assertEqual('1', self.wrapper.Requirements._format_ver([1])) "1.2.3", self.wrapper.Requirements._format_ver((1, 2, 3))
)
self.assertEqual("1", self.wrapper.Requirements._format_ver([1]))
def test_assert_all_unknown(self): def test_assert_all_unknown(self):
"""Check assert_all works with incompatible file.""" """Check assert_all works with incompatible file."""
@ -278,44 +323,48 @@ class Requirements(RepoWrapperTestCase):
def test_assert_all_new_repo(self): def test_assert_all_new_repo(self):
"""Check assert_all accepts new enough repo.""" """Check assert_all accepts new enough repo."""
reqs = self.wrapper.Requirements({'repo': {'hard': [1, 0]}}) reqs = self.wrapper.Requirements({"repo": {"hard": [1, 0]}})
reqs.assert_all() reqs.assert_all()
def test_assert_all_old_repo(self): def test_assert_all_old_repo(self):
"""Check assert_all rejects old repo.""" """Check assert_all rejects old repo."""
reqs = self.wrapper.Requirements({'repo': {'hard': [99999, 0]}}) reqs = self.wrapper.Requirements({"repo": {"hard": [99999, 0]}})
with self.assertRaises(SystemExit): with self.assertRaises(SystemExit):
reqs.assert_all() reqs.assert_all()
def test_assert_all_new_python(self): def test_assert_all_new_python(self):
"""Check assert_all accepts new enough python.""" """Check assert_all accepts new enough python."""
reqs = self.wrapper.Requirements({'python': {'hard': sys.version_info}}) reqs = self.wrapper.Requirements({"python": {"hard": sys.version_info}})
reqs.assert_all() reqs.assert_all()
def test_assert_all_old_python(self): def test_assert_all_old_python(self):
"""Check assert_all rejects old python.""" """Check assert_all rejects old python."""
reqs = self.wrapper.Requirements({'python': {'hard': [99999, 0]}}) reqs = self.wrapper.Requirements({"python": {"hard": [99999, 0]}})
with self.assertRaises(SystemExit): with self.assertRaises(SystemExit):
reqs.assert_all() reqs.assert_all()
def test_assert_ver_unknown(self): def test_assert_ver_unknown(self):
"""Check assert_ver works with incompatible file.""" """Check assert_ver works with incompatible file."""
reqs = self.wrapper.Requirements({}) reqs = self.wrapper.Requirements({})
reqs.assert_ver('xxx', (1, 0)) reqs.assert_ver("xxx", (1, 0))
def test_assert_ver_new(self): def test_assert_ver_new(self):
"""Check assert_ver allows new enough versions.""" """Check assert_ver allows new enough versions."""
reqs = self.wrapper.Requirements({'git': {'hard': [1, 0], 'soft': [2, 0]}}) reqs = self.wrapper.Requirements(
reqs.assert_ver('git', (1, 0)) {"git": {"hard": [1, 0], "soft": [2, 0]}}
reqs.assert_ver('git', (1, 5)) )
reqs.assert_ver('git', (2, 0)) reqs.assert_ver("git", (1, 0))
reqs.assert_ver('git', (2, 5)) reqs.assert_ver("git", (1, 5))
reqs.assert_ver("git", (2, 0))
reqs.assert_ver("git", (2, 5))
def test_assert_ver_old(self): def test_assert_ver_old(self):
"""Check assert_ver rejects old versions.""" """Check assert_ver rejects old versions."""
reqs = self.wrapper.Requirements({'git': {'hard': [1, 0], 'soft': [2, 0]}}) reqs = self.wrapper.Requirements(
{"git": {"hard": [1, 0], "soft": [2, 0]}}
)
with self.assertRaises(SystemExit): with self.assertRaises(SystemExit):
reqs.assert_ver('git', (0, 5)) reqs.assert_ver("git", (0, 5))
class NeedSetupGnuPG(RepoWrapperTestCase): class NeedSetupGnuPG(RepoWrapperTestCase):
@ -323,38 +372,38 @@ class NeedSetupGnuPG(RepoWrapperTestCase):
def test_missing_dir(self): def test_missing_dir(self):
"""The ~/.repoconfig tree doesn't exist yet.""" """The ~/.repoconfig tree doesn't exist yet."""
with tempfile.TemporaryDirectory(prefix='repo-tests') as tempdir: with tempfile.TemporaryDirectory(prefix="repo-tests") as tempdir:
self.wrapper.home_dot_repo = os.path.join(tempdir, 'foo') self.wrapper.home_dot_repo = os.path.join(tempdir, "foo")
self.assertTrue(self.wrapper.NeedSetupGnuPG()) self.assertTrue(self.wrapper.NeedSetupGnuPG())
def test_missing_keyring(self): def test_missing_keyring(self):
"""The keyring-version file doesn't exist yet.""" """The keyring-version file doesn't exist yet."""
with tempfile.TemporaryDirectory(prefix='repo-tests') as tempdir: with tempfile.TemporaryDirectory(prefix="repo-tests") as tempdir:
self.wrapper.home_dot_repo = tempdir self.wrapper.home_dot_repo = tempdir
self.assertTrue(self.wrapper.NeedSetupGnuPG()) self.assertTrue(self.wrapper.NeedSetupGnuPG())
def test_empty_keyring(self): def test_empty_keyring(self):
"""The keyring-version file exists, but is empty.""" """The keyring-version file exists, but is empty."""
with tempfile.TemporaryDirectory(prefix='repo-tests') as tempdir: with tempfile.TemporaryDirectory(prefix="repo-tests") as tempdir:
self.wrapper.home_dot_repo = tempdir self.wrapper.home_dot_repo = tempdir
with open(os.path.join(tempdir, 'keyring-version'), 'w'): with open(os.path.join(tempdir, "keyring-version"), "w"):
pass pass
self.assertTrue(self.wrapper.NeedSetupGnuPG()) self.assertTrue(self.wrapper.NeedSetupGnuPG())
def test_old_keyring(self): def test_old_keyring(self):
"""The keyring-version file exists, but it's old.""" """The keyring-version file exists, but it's old."""
with tempfile.TemporaryDirectory(prefix='repo-tests') as tempdir: with tempfile.TemporaryDirectory(prefix="repo-tests") as tempdir:
self.wrapper.home_dot_repo = tempdir self.wrapper.home_dot_repo = tempdir
with open(os.path.join(tempdir, 'keyring-version'), 'w') as fp: with open(os.path.join(tempdir, "keyring-version"), "w") as fp:
fp.write('1.0\n') fp.write("1.0\n")
self.assertTrue(self.wrapper.NeedSetupGnuPG()) self.assertTrue(self.wrapper.NeedSetupGnuPG())
def test_new_keyring(self): def test_new_keyring(self):
"""The keyring-version file exists, and is up-to-date.""" """The keyring-version file exists, and is up-to-date."""
with tempfile.TemporaryDirectory(prefix='repo-tests') as tempdir: with tempfile.TemporaryDirectory(prefix="repo-tests") as tempdir:
self.wrapper.home_dot_repo = tempdir self.wrapper.home_dot_repo = tempdir
with open(os.path.join(tempdir, 'keyring-version'), 'w') as fp: with open(os.path.join(tempdir, "keyring-version"), "w") as fp:
fp.write('1000.0\n') fp.write("1000.0\n")
self.assertFalse(self.wrapper.NeedSetupGnuPG()) self.assertFalse(self.wrapper.NeedSetupGnuPG())
@ -363,14 +412,18 @@ class SetupGnuPG(RepoWrapperTestCase):
def test_full(self): def test_full(self):
"""Make sure it works completely.""" """Make sure it works completely."""
with tempfile.TemporaryDirectory(prefix='repo-tests') as tempdir: with tempfile.TemporaryDirectory(prefix="repo-tests") as tempdir:
self.wrapper.home_dot_repo = tempdir self.wrapper.home_dot_repo = tempdir
self.wrapper.gpg_dir = os.path.join(self.wrapper.home_dot_repo, 'gnupg') self.wrapper.gpg_dir = os.path.join(
self.wrapper.home_dot_repo, "gnupg"
)
self.assertTrue(self.wrapper.SetupGnuPG(True)) self.assertTrue(self.wrapper.SetupGnuPG(True))
with open(os.path.join(tempdir, 'keyring-version'), 'r') as fp: with open(os.path.join(tempdir, "keyring-version"), "r") as fp:
data = fp.read() data = fp.read()
self.assertEqual('.'.join(str(x) for x in self.wrapper.KEYRING_VERSION), self.assertEqual(
data.strip()) ".".join(str(x) for x in self.wrapper.KEYRING_VERSION),
data.strip(),
)
class VerifyRev(RepoWrapperTestCase): class VerifyRev(RepoWrapperTestCase):
@ -378,30 +431,37 @@ class VerifyRev(RepoWrapperTestCase):
def test_verify_passes(self): def test_verify_passes(self):
"""Check when we have a valid signed tag.""" """Check when we have a valid signed tag."""
desc_result = self.wrapper.RunResult(0, 'v1.0\n', '') desc_result = self.wrapper.RunResult(0, "v1.0\n", "")
gpg_result = self.wrapper.RunResult(0, '', '') gpg_result = self.wrapper.RunResult(0, "", "")
with mock.patch.object(self.wrapper, 'run_git', with mock.patch.object(
side_effect=(desc_result, gpg_result)): self.wrapper, "run_git", side_effect=(desc_result, gpg_result)
ret = self.wrapper.verify_rev('/', 'refs/heads/stable', '1234', True) ):
self.assertEqual('v1.0^0', ret) ret = self.wrapper.verify_rev(
"/", "refs/heads/stable", "1234", True
)
self.assertEqual("v1.0^0", ret)
def test_unsigned_commit(self): def test_unsigned_commit(self):
"""Check we fall back to signed tag when we have an unsigned commit.""" """Check we fall back to signed tag when we have an unsigned commit."""
desc_result = self.wrapper.RunResult(0, 'v1.0-10-g1234\n', '') desc_result = self.wrapper.RunResult(0, "v1.0-10-g1234\n", "")
gpg_result = self.wrapper.RunResult(0, '', '') gpg_result = self.wrapper.RunResult(0, "", "")
with mock.patch.object(self.wrapper, 'run_git', with mock.patch.object(
side_effect=(desc_result, gpg_result)): self.wrapper, "run_git", side_effect=(desc_result, gpg_result)
ret = self.wrapper.verify_rev('/', 'refs/heads/stable', '1234', True) ):
self.assertEqual('v1.0^0', ret) ret = self.wrapper.verify_rev(
"/", "refs/heads/stable", "1234", True
)
self.assertEqual("v1.0^0", ret)
def test_verify_fails(self): def test_verify_fails(self):
"""Check we fall back to signed tag when we have an unsigned commit.""" """Check we fall back to signed tag when we have an unsigned commit."""
desc_result = self.wrapper.RunResult(0, 'v1.0-10-g1234\n', '') desc_result = self.wrapper.RunResult(0, "v1.0-10-g1234\n", "")
gpg_result = Exception gpg_result = Exception
with mock.patch.object(self.wrapper, 'run_git', with mock.patch.object(
side_effect=(desc_result, gpg_result)): self.wrapper, "run_git", side_effect=(desc_result, gpg_result)
):
with self.assertRaises(Exception): with self.assertRaises(Exception):
self.wrapper.verify_rev('/', 'refs/heads/stable', '1234', True) self.wrapper.verify_rev("/", "refs/heads/stable", "1234", True)
class GitCheckoutTestCase(RepoWrapperTestCase): class GitCheckoutTestCase(RepoWrapperTestCase):
@ -413,33 +473,40 @@ class GitCheckoutTestCase(RepoWrapperTestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
# Create a repo to operate on, but do it once per-class. # Create a repo to operate on, but do it once per-class.
cls.tempdirobj = tempfile.TemporaryDirectory(prefix='repo-rev-tests') cls.tempdirobj = tempfile.TemporaryDirectory(prefix="repo-rev-tests")
cls.GIT_DIR = cls.tempdirobj.name cls.GIT_DIR = cls.tempdirobj.name
run_git = wrapper.Wrapper().run_git run_git = wrapper.Wrapper().run_git
remote = os.path.join(cls.GIT_DIR, 'remote') remote = os.path.join(cls.GIT_DIR, "remote")
os.mkdir(remote) os.mkdir(remote)
# Tests need to assume, that main is default branch at init, # Tests need to assume, that main is default branch at init,
# which is not supported in config until 2.28. # which is not supported in config until 2.28.
if git_command.git_require((2, 28, 0)): if git_command.git_require((2, 28, 0)):
initstr = '--initial-branch=main' initstr = "--initial-branch=main"
else: else:
# Use template dir for init. # Use template dir for init.
templatedir = tempfile.mkdtemp(prefix='.test-template') templatedir = tempfile.mkdtemp(prefix=".test-template")
with open(os.path.join(templatedir, 'HEAD'), 'w') as fp: with open(os.path.join(templatedir, "HEAD"), "w") as fp:
fp.write('ref: refs/heads/main\n') fp.write("ref: refs/heads/main\n")
initstr = '--template=' + templatedir initstr = "--template=" + templatedir
run_git('init', initstr, cwd=remote) run_git("init", initstr, cwd=remote)
run_git('commit', '--allow-empty', '-minit', cwd=remote) run_git("commit", "--allow-empty", "-minit", cwd=remote)
run_git('branch', 'stable', cwd=remote) run_git("branch", "stable", cwd=remote)
run_git('tag', 'v1.0', cwd=remote) run_git("tag", "v1.0", cwd=remote)
run_git('commit', '--allow-empty', '-m2nd commit', cwd=remote) run_git("commit", "--allow-empty", "-m2nd commit", cwd=remote)
cls.REV_LIST = run_git('rev-list', 'HEAD', cwd=remote).stdout.splitlines() cls.REV_LIST = run_git(
"rev-list", "HEAD", cwd=remote
).stdout.splitlines()
run_git('init', cwd=cls.GIT_DIR) run_git("init", cwd=cls.GIT_DIR)
run_git('fetch', remote, '+refs/heads/*:refs/remotes/origin/*', cwd=cls.GIT_DIR) run_git(
"fetch",
remote,
"+refs/heads/*:refs/remotes/origin/*",
cwd=cls.GIT_DIR,
)
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
@ -454,36 +521,40 @@ class ResolveRepoRev(GitCheckoutTestCase):
def test_explicit_branch(self): def test_explicit_branch(self):
"""Check refs/heads/branch argument.""" """Check refs/heads/branch argument."""
rrev, lrev = self.wrapper.resolve_repo_rev(self.GIT_DIR, 'refs/heads/stable') rrev, lrev = self.wrapper.resolve_repo_rev(
self.assertEqual('refs/heads/stable', rrev) self.GIT_DIR, "refs/heads/stable"
)
self.assertEqual("refs/heads/stable", rrev)
self.assertEqual(self.REV_LIST[1], lrev) self.assertEqual(self.REV_LIST[1], lrev)
with self.assertRaises(self.wrapper.CloneFailure): with self.assertRaises(self.wrapper.CloneFailure):
self.wrapper.resolve_repo_rev(self.GIT_DIR, 'refs/heads/unknown') self.wrapper.resolve_repo_rev(self.GIT_DIR, "refs/heads/unknown")
def test_explicit_tag(self): def test_explicit_tag(self):
"""Check refs/tags/tag argument.""" """Check refs/tags/tag argument."""
rrev, lrev = self.wrapper.resolve_repo_rev(self.GIT_DIR, 'refs/tags/v1.0') rrev, lrev = self.wrapper.resolve_repo_rev(
self.assertEqual('refs/tags/v1.0', rrev) self.GIT_DIR, "refs/tags/v1.0"
)
self.assertEqual("refs/tags/v1.0", rrev)
self.assertEqual(self.REV_LIST[1], lrev) self.assertEqual(self.REV_LIST[1], lrev)
with self.assertRaises(self.wrapper.CloneFailure): with self.assertRaises(self.wrapper.CloneFailure):
self.wrapper.resolve_repo_rev(self.GIT_DIR, 'refs/tags/unknown') self.wrapper.resolve_repo_rev(self.GIT_DIR, "refs/tags/unknown")
def test_branch_name(self): def test_branch_name(self):
"""Check branch argument.""" """Check branch argument."""
rrev, lrev = self.wrapper.resolve_repo_rev(self.GIT_DIR, 'stable') rrev, lrev = self.wrapper.resolve_repo_rev(self.GIT_DIR, "stable")
self.assertEqual('refs/heads/stable', rrev) self.assertEqual("refs/heads/stable", rrev)
self.assertEqual(self.REV_LIST[1], lrev) self.assertEqual(self.REV_LIST[1], lrev)
rrev, lrev = self.wrapper.resolve_repo_rev(self.GIT_DIR, 'main') rrev, lrev = self.wrapper.resolve_repo_rev(self.GIT_DIR, "main")
self.assertEqual('refs/heads/main', rrev) self.assertEqual("refs/heads/main", rrev)
self.assertEqual(self.REV_LIST[0], lrev) self.assertEqual(self.REV_LIST[0], lrev)
def test_tag_name(self): def test_tag_name(self):
"""Check tag argument.""" """Check tag argument."""
rrev, lrev = self.wrapper.resolve_repo_rev(self.GIT_DIR, 'v1.0') rrev, lrev = self.wrapper.resolve_repo_rev(self.GIT_DIR, "v1.0")
self.assertEqual('refs/tags/v1.0', rrev) self.assertEqual("refs/tags/v1.0", rrev)
self.assertEqual(self.REV_LIST[1], lrev) self.assertEqual(self.REV_LIST[1], lrev)
def test_full_commit(self): def test_full_commit(self):
@ -503,7 +574,7 @@ class ResolveRepoRev(GitCheckoutTestCase):
def test_unknown(self): def test_unknown(self):
"""Check unknown ref/commit argument.""" """Check unknown ref/commit argument."""
with self.assertRaises(self.wrapper.CloneFailure): with self.assertRaises(self.wrapper.CloneFailure):
self.wrapper.resolve_repo_rev(self.GIT_DIR, 'boooooooya') self.wrapper.resolve_repo_rev(self.GIT_DIR, "boooooooya")
class CheckRepoVerify(RepoWrapperTestCase): class CheckRepoVerify(RepoWrapperTestCase):
@ -515,13 +586,17 @@ class CheckRepoVerify(RepoWrapperTestCase):
def test_gpg_initialized(self): def test_gpg_initialized(self):
"""Should pass if gpg is setup already.""" """Should pass if gpg is setup already."""
with mock.patch.object(self.wrapper, 'NeedSetupGnuPG', return_value=False): with mock.patch.object(
self.wrapper, "NeedSetupGnuPG", return_value=False
):
self.assertTrue(self.wrapper.check_repo_verify(True)) self.assertTrue(self.wrapper.check_repo_verify(True))
def test_need_gpg_setup(self): def test_need_gpg_setup(self):
"""Should pass/fail based on gpg setup.""" """Should pass/fail based on gpg setup."""
with mock.patch.object(self.wrapper, 'NeedSetupGnuPG', return_value=True): with mock.patch.object(
with mock.patch.object(self.wrapper, 'SetupGnuPG') as m: self.wrapper, "NeedSetupGnuPG", return_value=True
):
with mock.patch.object(self.wrapper, "SetupGnuPG") as m:
m.return_value = True m.return_value = True
self.assertTrue(self.wrapper.check_repo_verify(True)) self.assertTrue(self.wrapper.check_repo_verify(True))
@ -534,22 +609,34 @@ class CheckRepoRev(GitCheckoutTestCase):
def test_verify_works(self): def test_verify_works(self):
"""Should pass when verification passes.""" """Should pass when verification passes."""
with mock.patch.object(self.wrapper, 'check_repo_verify', return_value=True): with mock.patch.object(
with mock.patch.object(self.wrapper, 'verify_rev', return_value='12345'): self.wrapper, "check_repo_verify", return_value=True
rrev, lrev = self.wrapper.check_repo_rev(self.GIT_DIR, 'stable') ):
self.assertEqual('refs/heads/stable', rrev) with mock.patch.object(
self.assertEqual('12345', lrev) self.wrapper, "verify_rev", return_value="12345"
):
rrev, lrev = self.wrapper.check_repo_rev(self.GIT_DIR, "stable")
self.assertEqual("refs/heads/stable", rrev)
self.assertEqual("12345", lrev)
def test_verify_fails(self): def test_verify_fails(self):
"""Should fail when verification fails.""" """Should fail when verification fails."""
with mock.patch.object(self.wrapper, 'check_repo_verify', return_value=True): with mock.patch.object(
with mock.patch.object(self.wrapper, 'verify_rev', side_effect=Exception): self.wrapper, "check_repo_verify", return_value=True
):
with mock.patch.object(
self.wrapper, "verify_rev", side_effect=Exception
):
with self.assertRaises(Exception): with self.assertRaises(Exception):
self.wrapper.check_repo_rev(self.GIT_DIR, 'stable') self.wrapper.check_repo_rev(self.GIT_DIR, "stable")
def test_verify_ignore(self): def test_verify_ignore(self):
"""Should pass when verification is disabled.""" """Should pass when verification is disabled."""
with mock.patch.object(self.wrapper, 'verify_rev', side_effect=Exception): with mock.patch.object(
rrev, lrev = self.wrapper.check_repo_rev(self.GIT_DIR, 'stable', repo_verify=False) self.wrapper, "verify_rev", side_effect=Exception
self.assertEqual('refs/heads/stable', rrev) ):
rrev, lrev = self.wrapper.check_repo_rev(
self.GIT_DIR, "stable", repo_verify=False
)
self.assertEqual("refs/heads/stable", rrev)
self.assertEqual(self.REV_LIST[1], lrev) self.assertEqual(self.REV_LIST[1], lrev)

View File

@ -27,6 +27,7 @@ python =
[testenv] [testenv]
deps = deps =
black
pytest pytest
pytest-timeout pytest-timeout
commands = {envpython} run_tests {posargs} commands = {envpython} run_tests {posargs}

View File

@ -19,12 +19,12 @@ import os
def WrapperPath(): def WrapperPath():
return os.path.join(os.path.dirname(__file__), 'repo') return os.path.join(os.path.dirname(__file__), "repo")
@functools.lru_cache(maxsize=None) @functools.lru_cache(maxsize=None)
def Wrapper(): def Wrapper():
modname = 'wrapper' modname = "wrapper"
loader = importlib.machinery.SourceFileLoader(modname, WrapperPath()) loader = importlib.machinery.SourceFileLoader(modname, WrapperPath())
spec = importlib.util.spec_from_loader(modname, loader) spec = importlib.util.spec_from_loader(modname, loader)
module = importlib.util.module_from_spec(spec) module = importlib.util.module_from_spec(spec)