diff --git a/.flake8 b/.flake8 index 82453b56..dd7f4d36 100644 --- a/.flake8 +++ b/.flake8 @@ -1,5 +1,8 @@ [flake8] max-line-length = 80 +per-file-ignores = + # E501: line too long + tests/test_git_superproject.py: E501 extend-ignore = # E203: Whitespace before ':' # See https://github.com/PyCQA/pycodestyle/issues/373 diff --git a/color.py b/color.py index fdd72534..8f29b59f 100644 --- a/color.py +++ b/color.py @@ -17,196 +17,200 @@ import sys import pager -COLORS = {None: -1, - 'normal': -1, - 'black': 0, - 'red': 1, - 'green': 2, - 'yellow': 3, - 'blue': 4, - 'magenta': 5, - 'cyan': 6, - 'white': 7} +COLORS = { + None: -1, + "normal": -1, + "black": 0, + "red": 1, + "green": 2, + "yellow": 3, + "blue": 4, + "magenta": 5, + "cyan": 6, + "white": 7, +} -ATTRS = {None: -1, - 'bold': 1, - 'dim': 2, - 'ul': 4, - 'blink': 5, - 'reverse': 7} +ATTRS = {None: -1, "bold": 1, "dim": 2, "ul": 4, "blink": 5, "reverse": 7} RESET = "\033[m" def is_color(s): - return s in COLORS + return s in COLORS def is_attr(s): - return s in ATTRS + return s in ATTRS def _Color(fg=None, bg=None, attr=None): - fg = COLORS[fg] - bg = COLORS[bg] - attr = ATTRS[attr] + fg = COLORS[fg] + bg = COLORS[bg] + attr = ATTRS[attr] - if attr >= 0 or fg >= 0 or bg >= 0: - need_sep = False - code = "\033[" + if attr >= 0 or fg >= 0 or bg >= 0: + need_sep = False + code = "\033[" - if attr >= 0: - code += chr(ord('0') + attr) - need_sep = True + if attr >= 0: + code += chr(ord("0") + attr) + need_sep = True - if fg >= 0: - if need_sep: - code += ';' - need_sep = True + if fg >= 0: + if need_sep: + code += ";" + need_sep = True - if fg < 8: - code += '3%c' % (ord('0') + fg) - else: - code += '38;5;%d' % fg + if fg < 8: + code += "3%c" % (ord("0") + fg) + else: + code += "38;5;%d" % fg - if bg >= 0: - if need_sep: - code += ';' + if bg >= 0: + if need_sep: + code += ";" - if bg < 8: - code += '4%c' % (ord('0') + bg) - else: - code += '48;5;%d' % bg - code += 'm' - else: - code = '' - return code + if bg < 8: + code += "4%c" % (ord("0") + bg) + else: + code += "48;5;%d" % bg + code += "m" + else: + code = "" + return code DEFAULT = None def SetDefaultColoring(state): - """Set coloring behavior to |state|. + """Set coloring behavior to |state|. - This is useful for overriding config options via the command line. - """ - if state is None: - # Leave it alone -- return quick! - return + This is useful for overriding config options via the command line. + """ + if state is None: + # Leave it alone -- return quick! + return - global DEFAULT - state = state.lower() - if state in ('auto',): - DEFAULT = state - elif state in ('always', 'yes', 'true', True): - DEFAULT = 'always' - elif state in ('never', 'no', 'false', False): - DEFAULT = 'never' + global DEFAULT + state = state.lower() + if state in ("auto",): + DEFAULT = state + elif state in ("always", "yes", "true", True): + DEFAULT = "always" + elif state in ("never", "no", "false", False): + DEFAULT = "never" class Coloring(object): - def __init__(self, config, section_type): - self._section = 'color.%s' % section_type - self._config = config - self._out = sys.stdout + def __init__(self, config, section_type): + self._section = "color.%s" % section_type + self._config = config + self._out = sys.stdout - on = DEFAULT - if on is None: - on = self._config.GetString(self._section) - if on is None: - on = self._config.GetString('color.ui') + on = DEFAULT + if on is None: + on = self._config.GetString(self._section) + if on is None: + on = self._config.GetString("color.ui") - if on == 'auto': - if pager.active or os.isatty(1): - self._on = True - else: - self._on = False - elif on in ('true', 'always'): - self._on = True - else: - self._on = False - - def redirect(self, out): - self._out = out - - @property - def is_on(self): - return self._on - - def write(self, fmt, *args): - self._out.write(fmt % args) - - def flush(self): - self._out.flush() - - def nl(self): - self._out.write('\n') - - def printer(self, opt=None, fg=None, bg=None, attr=None): - s = self - c = self.colorer(opt, fg, bg, attr) - - def f(fmt, *args): - s._out.write(c(fmt, *args)) - return f - - def nofmt_printer(self, opt=None, fg=None, bg=None, attr=None): - s = self - c = self.nofmt_colorer(opt, fg, bg, attr) - - def f(fmt): - s._out.write(c(fmt)) - return f - - def colorer(self, opt=None, fg=None, bg=None, attr=None): - if self._on: - c = self._parse(opt, fg, bg, attr) - - def f(fmt, *args): - output = fmt % args - return ''.join([c, output, RESET]) - return f - else: - - def f(fmt, *args): - return fmt % args - return f - - def nofmt_colorer(self, opt=None, fg=None, bg=None, attr=None): - if self._on: - c = self._parse(opt, fg, bg, attr) - - def f(fmt): - return ''.join([c, fmt, RESET]) - return f - else: - def f(fmt): - return fmt - return f - - def _parse(self, opt, fg, bg, attr): - if not opt: - return _Color(fg, bg, attr) - - v = self._config.GetString('%s.%s' % (self._section, opt)) - if v is None: - return _Color(fg, bg, attr) - - v = v.strip().lower() - if v == "reset": - return RESET - elif v == '': - return _Color(fg, bg, attr) - - have_fg = False - for a in v.split(' '): - if is_color(a): - if have_fg: - bg = a + if on == "auto": + if pager.active or os.isatty(1): + self._on = True + else: + self._on = False + elif on in ("true", "always"): + self._on = True else: - fg = a - elif is_attr(a): - attr = a + self._on = False - return _Color(fg, bg, attr) + def redirect(self, out): + self._out = out + + @property + def is_on(self): + return self._on + + def write(self, fmt, *args): + self._out.write(fmt % args) + + def flush(self): + self._out.flush() + + def nl(self): + self._out.write("\n") + + def printer(self, opt=None, fg=None, bg=None, attr=None): + s = self + c = self.colorer(opt, fg, bg, attr) + + def f(fmt, *args): + s._out.write(c(fmt, *args)) + + return f + + def nofmt_printer(self, opt=None, fg=None, bg=None, attr=None): + s = self + c = self.nofmt_colorer(opt, fg, bg, attr) + + def f(fmt): + s._out.write(c(fmt)) + + return f + + def colorer(self, opt=None, fg=None, bg=None, attr=None): + if self._on: + c = self._parse(opt, fg, bg, attr) + + def f(fmt, *args): + output = fmt % args + return "".join([c, output, RESET]) + + return f + else: + + def f(fmt, *args): + return fmt % args + + return f + + def nofmt_colorer(self, opt=None, fg=None, bg=None, attr=None): + if self._on: + c = self._parse(opt, fg, bg, attr) + + def f(fmt): + return "".join([c, fmt, RESET]) + + return f + else: + + def f(fmt): + return fmt + + return f + + def _parse(self, opt, fg, bg, attr): + if not opt: + return _Color(fg, bg, attr) + + v = self._config.GetString("%s.%s" % (self._section, opt)) + if v is None: + return _Color(fg, bg, attr) + + v = v.strip().lower() + if v == "reset": + return RESET + elif v == "": + return _Color(fg, bg, attr) + + have_fg = False + for a in v.split(" "): + if is_color(a): + if have_fg: + bg = a + else: + fg = a + elif is_attr(a): + attr = a + + return _Color(fg, bg, attr) diff --git a/command.py b/command.py index 68f36f03..939a4630 100644 --- a/command.py +++ b/command.py @@ -25,7 +25,7 @@ import progress # Are we generating man-pages? -GENERATE_MANPAGES = os.environ.get('_REPO_GENERATE_MANPAGES_') == ' indeed! ' +GENERATE_MANPAGES = os.environ.get("_REPO_GENERATE_MANPAGES_") == " indeed! " # Number of projects to submit to a single worker process at a time. @@ -43,403 +43,470 @@ DEFAULT_LOCAL_JOBS = min(os.cpu_count(), 8) class Command(object): - """Base class for any command line action in repo. - """ + """Base class for any command line action in repo.""" - # Singleton for all commands to track overall repo command execution and - # provide event summary to callers. Only used by sync subcommand currently. - # - # NB: This is being replaced by git trace2 events. See git_trace2_event_log. - event_log = EventLog() + # Singleton for all commands to track overall repo command execution and + # provide event summary to callers. Only used by sync subcommand currently. + # + # NB: This is being replaced by git trace2 events. See git_trace2_event_log. + event_log = EventLog() - # Whether this command is a "common" one, i.e. whether the user would commonly - # use it or it's a more uncommon command. This is used by the help command to - # show short-vs-full summaries. - COMMON = False + # Whether this command is a "common" one, i.e. whether the user would + # commonly use it or it's a more uncommon command. This is used by the help + # command to show short-vs-full summaries. + COMMON = False - # Whether this command supports running in parallel. If greater than 0, - # it is the number of parallel jobs to default to. - PARALLEL_JOBS = None + # Whether this command supports running in parallel. If greater than 0, + # it is the number of parallel jobs to default to. + PARALLEL_JOBS = None - # Whether this command supports Multi-manifest. If False, then main.py will - # iterate over the manifests and invoke the command once per (sub)manifest. - # This is only checked after calling ValidateOptions, so that partially - # migrated subcommands can set it to False. - MULTI_MANIFEST_SUPPORT = True + # Whether this command supports Multi-manifest. If False, then main.py will + # iterate over the manifests and invoke the command once per (sub)manifest. + # This is only checked after calling ValidateOptions, so that partially + # migrated subcommands can set it to False. + MULTI_MANIFEST_SUPPORT = True - def __init__(self, repodir=None, client=None, manifest=None, gitc_manifest=None, - git_event_log=None, outer_client=None, outer_manifest=None): - self.repodir = repodir - self.client = client - self.outer_client = outer_client or client - self.manifest = manifest - self.gitc_manifest = gitc_manifest - self.git_event_log = git_event_log - self.outer_manifest = outer_manifest + def __init__( + self, + repodir=None, + client=None, + manifest=None, + gitc_manifest=None, + git_event_log=None, + outer_client=None, + outer_manifest=None, + ): + self.repodir = repodir + self.client = client + self.outer_client = outer_client or client + self.manifest = manifest + self.gitc_manifest = gitc_manifest + self.git_event_log = git_event_log + self.outer_manifest = outer_manifest - # Cache for the OptionParser property. - self._optparse = None + # Cache for the OptionParser property. + self._optparse = None - def WantPager(self, _opt): - return False + def WantPager(self, _opt): + return False - def ReadEnvironmentOptions(self, opts): - """ Set options from environment variables. """ + def ReadEnvironmentOptions(self, opts): + """Set options from environment variables.""" - env_options = self._RegisteredEnvironmentOptions() + env_options = self._RegisteredEnvironmentOptions() - for env_key, opt_key in env_options.items(): - # Get the user-set option value if any - opt_value = getattr(opts, opt_key) + for env_key, opt_key in env_options.items(): + # Get the user-set option value if any + opt_value = getattr(opts, opt_key) - # If the value is set, it means the user has passed it as a command - # line option, and we should use that. Otherwise we can try to set it - # with the value from the corresponding environment variable. - if opt_value is not None: - continue + # If the value is set, it means the user has passed it as a command + # line option, and we should use that. Otherwise we can try to set + # it with the value from the corresponding environment variable. + if opt_value is not None: + continue - env_value = os.environ.get(env_key) - if env_value is not None: - setattr(opts, opt_key, env_value) + env_value = os.environ.get(env_key) + if env_value is not None: + setattr(opts, opt_key, env_value) - return opts + return opts - @property - def OptionParser(self): - if self._optparse is None: - try: - me = 'repo %s' % self.NAME - usage = self.helpUsage.strip().replace('%prog', me) - except AttributeError: - usage = 'repo %s' % self.NAME - epilog = 'Run `repo help %s` to view the detailed manual.' % self.NAME - self._optparse = optparse.OptionParser(usage=usage, epilog=epilog) - self._CommonOptions(self._optparse) - self._Options(self._optparse) - return self._optparse + @property + def OptionParser(self): + if self._optparse is None: + try: + me = "repo %s" % self.NAME + usage = self.helpUsage.strip().replace("%prog", me) + except AttributeError: + usage = "repo %s" % self.NAME + epilog = ( + "Run `repo help %s` to view the detailed manual." % self.NAME + ) + self._optparse = optparse.OptionParser(usage=usage, epilog=epilog) + self._CommonOptions(self._optparse) + self._Options(self._optparse) + return self._optparse - def _CommonOptions(self, p, opt_v=True): - """Initialize the option parser with common options. + def _CommonOptions(self, p, opt_v=True): + """Initialize the option parser with common options. - These will show up for *all* subcommands, so use sparingly. - NB: Keep in sync with repo:InitParser(). - """ - g = p.add_option_group('Logging options') - opts = ['-v'] if opt_v else [] - g.add_option(*opts, '--verbose', - dest='output_mode', action='store_true', - help='show all output') - g.add_option('-q', '--quiet', - dest='output_mode', action='store_false', - help='only show errors') + These will show up for *all* subcommands, so use sparingly. + NB: Keep in sync with repo:InitParser(). + """ + g = p.add_option_group("Logging options") + opts = ["-v"] if opt_v else [] + g.add_option( + *opts, + "--verbose", + dest="output_mode", + action="store_true", + help="show all output", + ) + g.add_option( + "-q", + "--quiet", + dest="output_mode", + action="store_false", + help="only show errors", + ) - if self.PARALLEL_JOBS is not None: - default = 'based on number of CPU cores' - if not GENERATE_MANPAGES: - # Only include active cpu count if we aren't generating man pages. - default = f'%default; {default}' - p.add_option( - '-j', '--jobs', - type=int, default=self.PARALLEL_JOBS, - help=f'number of jobs to run in parallel (default: {default})') + if self.PARALLEL_JOBS is not None: + default = "based on number of CPU cores" + if not GENERATE_MANPAGES: + # Only include active cpu count if we aren't generating man + # pages. + default = f"%default; {default}" + p.add_option( + "-j", + "--jobs", + type=int, + default=self.PARALLEL_JOBS, + help=f"number of jobs to run in parallel (default: {default})", + ) - m = p.add_option_group('Multi-manifest options') - m.add_option('--outer-manifest', action='store_true', default=None, - help='operate starting at the outermost manifest') - m.add_option('--no-outer-manifest', dest='outer_manifest', - action='store_false', help='do not operate on outer manifests') - m.add_option('--this-manifest-only', action='store_true', default=None, - help='only operate on this (sub)manifest') - m.add_option('--no-this-manifest-only', '--all-manifests', - dest='this_manifest_only', action='store_false', - help='operate on this manifest and its submanifests') + m = p.add_option_group("Multi-manifest options") + m.add_option( + "--outer-manifest", + action="store_true", + default=None, + help="operate starting at the outermost manifest", + ) + m.add_option( + "--no-outer-manifest", + dest="outer_manifest", + action="store_false", + help="do not operate on outer manifests", + ) + m.add_option( + "--this-manifest-only", + action="store_true", + default=None, + help="only operate on this (sub)manifest", + ) + m.add_option( + "--no-this-manifest-only", + "--all-manifests", + dest="this_manifest_only", + action="store_false", + help="operate on this manifest and its submanifests", + ) - def _Options(self, p): - """Initialize the option parser with subcommand-specific options.""" + def _Options(self, p): + """Initialize the option parser with subcommand-specific options.""" - def _RegisteredEnvironmentOptions(self): - """Get options that can be set from environment variables. + def _RegisteredEnvironmentOptions(self): + """Get options that can be set from environment variables. - Return a dictionary mapping environment variable name - to option key name that it can override. + Return a dictionary mapping environment variable name + to option key name that it can override. - Example: {'REPO_MY_OPTION': 'my_option'} + Example: {'REPO_MY_OPTION': 'my_option'} - Will allow the option with key value 'my_option' to be set - from the value in the environment variable named 'REPO_MY_OPTION'. + Will allow the option with key value 'my_option' to be set + from the value in the environment variable named 'REPO_MY_OPTION'. - Note: This does not work properly for options that are explicitly - set to None by the user, or options that are defined with a - default value other than None. + Note: This does not work properly for options that are explicitly + set to None by the user, or options that are defined with a + default value other than None. - """ - return {} + """ + return {} - def Usage(self): - """Display usage and terminate. - """ - self.OptionParser.print_usage() - sys.exit(1) + def Usage(self): + """Display usage and terminate.""" + self.OptionParser.print_usage() + sys.exit(1) - def CommonValidateOptions(self, opt, args): - """Validate common options.""" - opt.quiet = opt.output_mode is False - opt.verbose = opt.output_mode is True - if opt.outer_manifest is None: - # By default, treat multi-manifest instances as a single manifest from - # the user's perspective. - opt.outer_manifest = True + def CommonValidateOptions(self, opt, args): + """Validate common options.""" + opt.quiet = opt.output_mode is False + opt.verbose = opt.output_mode is True + if opt.outer_manifest is None: + # By default, treat multi-manifest instances as a single manifest + # from the user's perspective. + opt.outer_manifest = True - def ValidateOptions(self, opt, args): - """Validate the user options & arguments before executing. + def ValidateOptions(self, opt, args): + """Validate the user options & arguments before executing. - This is meant to help break the code up into logical steps. Some tips: - * Use self.OptionParser.error to display CLI related errors. - * Adjust opt member defaults as makes sense. - * Adjust the args list, but do so inplace so the caller sees updates. - * Try to avoid updating self state. Leave that to Execute. - """ + This is meant to help break the code up into logical steps. Some tips: + * Use self.OptionParser.error to display CLI related errors. + * Adjust opt member defaults as makes sense. + * Adjust the args list, but do so inplace so the caller sees updates. + * Try to avoid updating self state. Leave that to Execute. + """ - def Execute(self, opt, args): - """Perform the action, after option parsing is complete. - """ - raise NotImplementedError + def Execute(self, opt, args): + """Perform the action, after option parsing is complete.""" + raise NotImplementedError - @staticmethod - def ExecuteInParallel(jobs, func, inputs, callback, output=None, ordered=False): - """Helper for managing parallel execution boiler plate. + @staticmethod + def ExecuteInParallel( + jobs, func, inputs, callback, output=None, ordered=False + ): + """Helper for managing parallel execution boiler plate. - For subcommands that can easily split their work up. + For subcommands that can easily split their work up. - Args: - jobs: How many parallel processes to use. - func: The function to apply to each of the |inputs|. Usually a - functools.partial for wrapping additional arguments. It will be run - in a separate process, so it must be pickalable, so nested functions - won't work. Methods on the subcommand Command class should work. - inputs: The list of items to process. Must be a list. - callback: The function to pass the results to for processing. It will be - executed in the main thread and process the results of |func| as they - become available. Thus it may be a local nested function. Its return - value is passed back directly. It takes three arguments: - - The processing pool (or None with one job). - - The |output| argument. - - An iterator for the results. - output: An output manager. May be progress.Progess or color.Coloring. - ordered: Whether the jobs should be processed in order. + Args: + jobs: How many parallel processes to use. + func: The function to apply to each of the |inputs|. Usually a + functools.partial for wrapping additional arguments. It will be + run in a separate process, so it must be pickalable, so nested + functions won't work. Methods on the subcommand Command class + should work. + inputs: The list of items to process. Must be a list. + callback: The function to pass the results to for processing. It + will be executed in the main thread and process the results of + |func| as they become available. Thus it may be a local nested + function. Its return value is passed back directly. It takes + three arguments: + - The processing pool (or None with one job). + - The |output| argument. + - An iterator for the results. + output: An output manager. May be progress.Progess or + color.Coloring. + ordered: Whether the jobs should be processed in order. - Returns: - The |callback| function's results are returned. - """ - try: - # NB: Multiprocessing is heavy, so don't spin it up for one job. - if len(inputs) == 1 or jobs == 1: - return callback(None, output, (func(x) for x in inputs)) - else: - with multiprocessing.Pool(jobs) as pool: - submit = pool.imap if ordered else pool.imap_unordered - return callback(pool, output, submit(func, inputs, chunksize=WORKER_BATCH_SIZE)) - finally: - if isinstance(output, progress.Progress): - output.end() - - def _ResetPathToProjectMap(self, projects): - self._by_path = dict((p.worktree, p) for p in projects) - - def _UpdatePathToProjectMap(self, project): - self._by_path[project.worktree] = project - - def _GetProjectByPath(self, manifest, path): - project = None - if os.path.exists(path): - oldpath = None - while (path and - path != oldpath and - path != manifest.topdir): + Returns: + The |callback| function's results are returned. + """ try: - project = self._by_path[path] - break - except KeyError: - oldpath = path - path = os.path.dirname(path) - if not project and path == manifest.topdir: - try: - project = self._by_path[path] - except KeyError: - pass - else: - try: - project = self._by_path[path] - except KeyError: - pass - return project + # NB: Multiprocessing is heavy, so don't spin it up for one job. + if len(inputs) == 1 or jobs == 1: + return callback(None, output, (func(x) for x in inputs)) + else: + with multiprocessing.Pool(jobs) as pool: + submit = pool.imap if ordered else pool.imap_unordered + return callback( + pool, + output, + submit(func, inputs, chunksize=WORKER_BATCH_SIZE), + ) + finally: + if isinstance(output, progress.Progress): + output.end() - def GetProjects(self, args, manifest=None, groups='', missing_ok=False, - submodules_ok=False, all_manifests=False): - """A list of projects that match the arguments. + def _ResetPathToProjectMap(self, projects): + self._by_path = dict((p.worktree, p) for p in projects) - Args: - args: a list of (case-insensitive) strings, projects to search for. - manifest: an XmlManifest, the manifest to use, or None for default. - groups: a string, the manifest groups in use. - missing_ok: a boolean, whether to allow missing projects. - submodules_ok: a boolean, whether to allow submodules. - all_manifests: a boolean, if True then all manifests and submanifests are - used. If False, then only the local (sub)manifest is used. + def _UpdatePathToProjectMap(self, project): + self._by_path[project.worktree] = project - Returns: - A list of matching Project instances. - """ - if all_manifests: - if not manifest: - manifest = self.manifest.outer_client - all_projects_list = manifest.all_projects - else: - if not manifest: - manifest = self.manifest - all_projects_list = manifest.projects - result = [] + def _GetProjectByPath(self, manifest, path): + project = None + if os.path.exists(path): + oldpath = None + while path and path != oldpath and path != manifest.topdir: + try: + project = self._by_path[path] + break + except KeyError: + oldpath = path + path = os.path.dirname(path) + if not project and path == manifest.topdir: + try: + project = self._by_path[path] + except KeyError: + pass + else: + try: + project = self._by_path[path] + except KeyError: + pass + return project - if not groups: - groups = manifest.GetGroupsStr() - groups = [x for x in re.split(r'[,\s]+', groups) if x] + def GetProjects( + self, + args, + manifest=None, + groups="", + missing_ok=False, + submodules_ok=False, + all_manifests=False, + ): + """A list of projects that match the arguments. - if not args: - derived_projects = {} - for project in all_projects_list: - if submodules_ok or project.sync_s: - derived_projects.update((p.name, p) - for p in project.GetDerivedSubprojects()) - all_projects_list.extend(derived_projects.values()) - for project in all_projects_list: - if (missing_ok or project.Exists) and project.MatchesGroups(groups): - result.append(project) - else: - self._ResetPathToProjectMap(all_projects_list) + Args: + args: a list of (case-insensitive) strings, projects to search for. + manifest: an XmlManifest, the manifest to use, or None for default. + groups: a string, the manifest groups in use. + missing_ok: a boolean, whether to allow missing projects. + submodules_ok: a boolean, whether to allow submodules. + all_manifests: a boolean, if True then all manifests and + submanifests are used. If False, then only the local + (sub)manifest is used. - for arg in args: - # We have to filter by manifest groups in case the requested project is - # checked out multiple times or differently based on them. - projects = [project + Returns: + A list of matching Project instances. + """ + if all_manifests: + if not manifest: + manifest = self.manifest.outer_client + all_projects_list = manifest.all_projects + else: + if not manifest: + manifest = self.manifest + all_projects_list = manifest.projects + result = [] + + if not groups: + groups = manifest.GetGroupsStr() + groups = [x for x in re.split(r"[,\s]+", groups) if x] + + if not args: + derived_projects = {} + for project in all_projects_list: + if submodules_ok or project.sync_s: + derived_projects.update( + (p.name, p) for p in project.GetDerivedSubprojects() + ) + all_projects_list.extend(derived_projects.values()) + for project in all_projects_list: + if (missing_ok or project.Exists) and project.MatchesGroups( + groups + ): + result.append(project) + else: + self._ResetPathToProjectMap(all_projects_list) + + for arg in args: + # We have to filter by manifest groups in case the requested + # project is checked out multiple times or differently based on + # them. + projects = [ + project for project in manifest.GetProjectsWithName( - arg, all_manifests=all_manifests) - if project.MatchesGroups(groups)] + arg, all_manifests=all_manifests + ) + if project.MatchesGroups(groups) + ] - if not projects: - path = os.path.abspath(arg).replace('\\', '/') - tree = manifest - if all_manifests: - # Look for the deepest matching submanifest. - for tree in reversed(list(manifest.all_manifests)): - if path.startswith(tree.topdir): - break - project = self._GetProjectByPath(tree, path) + if not projects: + path = os.path.abspath(arg).replace("\\", "/") + tree = manifest + if all_manifests: + # Look for the deepest matching submanifest. + for tree in reversed(list(manifest.all_manifests)): + if path.startswith(tree.topdir): + break + project = self._GetProjectByPath(tree, path) - # If it's not a derived project, update path->project mapping and - # search again, as arg might actually point to a derived subproject. - if (project and not project.Derived and (submodules_ok or - project.sync_s)): - search_again = False - for subproject in project.GetDerivedSubprojects(): - self._UpdatePathToProjectMap(subproject) - search_again = True - if search_again: - project = self._GetProjectByPath(manifest, path) or project + # If it's not a derived project, update path->project + # mapping and search again, as arg might actually point to + # a derived subproject. + if ( + project + and not project.Derived + and (submodules_ok or project.sync_s) + ): + search_again = False + for subproject in project.GetDerivedSubprojects(): + self._UpdatePathToProjectMap(subproject) + search_again = True + if search_again: + project = ( + self._GetProjectByPath(manifest, path) + or project + ) - if project: - projects = [project] + if project: + projects = [project] - if not projects: - raise NoSuchProjectError(arg) + if not projects: + raise NoSuchProjectError(arg) - for project in projects: - if not missing_ok and not project.Exists: - raise NoSuchProjectError('%s (%s)' % ( - arg, project.RelPath(local=not all_manifests))) - if not project.MatchesGroups(groups): - raise InvalidProjectGroupsError(arg) + for project in projects: + if not missing_ok and not project.Exists: + raise NoSuchProjectError( + "%s (%s)" + % (arg, project.RelPath(local=not all_manifests)) + ) + if not project.MatchesGroups(groups): + raise InvalidProjectGroupsError(arg) - result.extend(projects) + result.extend(projects) - def _getpath(x): - return x.relpath - result.sort(key=_getpath) - return result + def _getpath(x): + return x.relpath - def FindProjects(self, args, inverse=False, all_manifests=False): - """Find projects from command line arguments. + result.sort(key=_getpath) + return result - Args: - args: a list of (case-insensitive) strings, projects to search for. - inverse: a boolean, if True, then projects not matching any |args| are - returned. - all_manifests: a boolean, if True then all manifests and submanifests are - used. If False, then only the local (sub)manifest is used. - """ - result = [] - patterns = [re.compile(r'%s' % a, re.IGNORECASE) for a in args] - for project in self.GetProjects('', all_manifests=all_manifests): - paths = [project.name, project.RelPath(local=not all_manifests)] - for pattern in patterns: - match = any(pattern.search(x) for x in paths) - if not inverse and match: - result.append(project) - break - if inverse and match: - break - else: - if inverse: - result.append(project) - result.sort(key=lambda project: (project.manifest.path_prefix, - project.relpath)) - return result + def FindProjects(self, args, inverse=False, all_manifests=False): + """Find projects from command line arguments. - def ManifestList(self, opt): - """Yields all of the manifests to traverse. + Args: + args: a list of (case-insensitive) strings, projects to search for. + inverse: a boolean, if True, then projects not matching any |args| + are returned. + all_manifests: a boolean, if True then all manifests and + submanifests are used. If False, then only the local + (sub)manifest is used. + """ + result = [] + patterns = [re.compile(r"%s" % a, re.IGNORECASE) for a in args] + for project in self.GetProjects("", all_manifests=all_manifests): + paths = [project.name, project.RelPath(local=not all_manifests)] + for pattern in patterns: + match = any(pattern.search(x) for x in paths) + if not inverse and match: + result.append(project) + break + if inverse and match: + break + else: + if inverse: + result.append(project) + result.sort( + key=lambda project: (project.manifest.path_prefix, project.relpath) + ) + return result - Args: - opt: The command options. - """ - top = self.outer_manifest - if not opt.outer_manifest or opt.this_manifest_only: - top = self.manifest - yield top - if not opt.this_manifest_only: - for child in top.all_children: - yield child + def ManifestList(self, opt): + """Yields all of the manifests to traverse. + + Args: + opt: The command options. + """ + top = self.outer_manifest + if not opt.outer_manifest or opt.this_manifest_only: + top = self.manifest + yield top + if not opt.this_manifest_only: + for child in top.all_children: + yield child class InteractiveCommand(Command): - """Command which requires user interaction on the tty and - must not run within a pager, even if the user asks to. - """ + """Command which requires user interaction on the tty and must not run + within a pager, even if the user asks to. + """ - def WantPager(self, _opt): - return False + def WantPager(self, _opt): + return False class PagedCommand(Command): - """Command which defaults to output in a pager, as its - display tends to be larger than one screen full. - """ + """Command which defaults to output in a pager, as its display tends to be + larger than one screen full. + """ - def WantPager(self, _opt): - return True + def WantPager(self, _opt): + return True class MirrorSafeCommand(object): - """Command permits itself to run within a mirror, - and does not require a working directory. - """ + """Command permits itself to run within a mirror, and does not require a + working directory. + """ class GitcAvailableCommand(object): - """Command that requires GITC to be available, but does - not require the local client to be a GITC client. - """ + """Command that requires GITC to be available, but does not require the + local client to be a GITC client. + """ class GitcClientCommand(object): - """Command that requires the local client to be a GITC - client. - """ + """Command that requires the local client to be a GITC client.""" diff --git a/editor.py b/editor.py index b84a42d4..96835aba 100644 --- a/editor.py +++ b/editor.py @@ -23,93 +23,99 @@ import platform_utils class Editor(object): - """Manages the user's preferred text editor.""" + """Manages the user's preferred text editor.""" - _editor = None - globalConfig = None + _editor = None + globalConfig = None - @classmethod - def _GetEditor(cls): - if cls._editor is None: - cls._editor = cls._SelectEditor() - return cls._editor + @classmethod + def _GetEditor(cls): + if cls._editor is None: + cls._editor = cls._SelectEditor() + return cls._editor - @classmethod - def _SelectEditor(cls): - e = os.getenv('GIT_EDITOR') - if e: - return e + @classmethod + def _SelectEditor(cls): + e = os.getenv("GIT_EDITOR") + if e: + return e - if cls.globalConfig: - e = cls.globalConfig.GetString('core.editor') - if e: - return e + if cls.globalConfig: + e = cls.globalConfig.GetString("core.editor") + if e: + return e - e = os.getenv('VISUAL') - if e: - return e + e = os.getenv("VISUAL") + if e: + return e - e = os.getenv('EDITOR') - if e: - return e + e = os.getenv("EDITOR") + if e: + return e - if os.getenv('TERM') == 'dumb': - print( - """No editor specified in GIT_EDITOR, core.editor, VISUAL or EDITOR. + if os.getenv("TERM") == "dumb": + print( + """No editor specified in GIT_EDITOR, core.editor, VISUAL or EDITOR. Tried to fall back to vi but terminal is dumb. Please configure at -least one of these before using this command.""", file=sys.stderr) - sys.exit(1) +least one of these before using this command.""", # noqa: E501 + file=sys.stderr, + ) + sys.exit(1) - return 'vi' + return "vi" - @classmethod - def EditString(cls, data): - """Opens an editor to edit the given content. + @classmethod + def EditString(cls, data): + """Opens an editor to edit the given content. - Args: - data: The text to edit. + Args: + data: The text to edit. - Returns: - New value of edited text. + Returns: + New value of edited text. - Raises: - EditorError: The editor failed to run. - """ - editor = cls._GetEditor() - if editor == ':': - return data + Raises: + EditorError: The editor failed to run. + """ + editor = cls._GetEditor() + if editor == ":": + return data - fd, path = tempfile.mkstemp() - try: - os.write(fd, data.encode('utf-8')) - os.close(fd) - fd = None + fd, path = tempfile.mkstemp() + try: + os.write(fd, data.encode("utf-8")) + os.close(fd) + fd = None - if platform_utils.isWindows(): - # Split on spaces, respecting quoted strings - import shlex - args = shlex.split(editor) - shell = False - elif re.compile("^.*[$ \t'].*$").match(editor): - args = [editor + ' "$@"', 'sh'] - shell = True - else: - args = [editor] - shell = False - args.append(path) + if platform_utils.isWindows(): + # Split on spaces, respecting quoted strings + import shlex - try: - rc = subprocess.Popen(args, shell=shell).wait() - except OSError as e: - raise EditorError('editor failed, %s: %s %s' - % (str(e), editor, path)) - if rc != 0: - raise EditorError('editor failed with exit status %d: %s %s' - % (rc, editor, path)) + args = shlex.split(editor) + shell = False + elif re.compile("^.*[$ \t'].*$").match(editor): + args = [editor + ' "$@"', "sh"] + shell = True + else: + args = [editor] + shell = False + args.append(path) - with open(path, mode='rb') as fd2: - return fd2.read().decode('utf-8') - finally: - if fd: - os.close(fd) - platform_utils.remove(path) + try: + rc = subprocess.Popen(args, shell=shell).wait() + except OSError as e: + raise EditorError( + "editor failed, %s: %s %s" % (str(e), editor, path) + ) + if rc != 0: + raise EditorError( + "editor failed with exit status %d: %s %s" + % (rc, editor, path) + ) + + with open(path, mode="rb") as fd2: + return fd2.read().decode("utf-8") + finally: + if fd: + os.close(fd) + platform_utils.remove(path) diff --git a/error.py b/error.py index cbefcb7e..3cf34d54 100644 --- a/error.py +++ b/error.py @@ -14,122 +14,112 @@ class ManifestParseError(Exception): - """Failed to parse the manifest file. - """ + """Failed to parse the manifest file.""" class ManifestInvalidRevisionError(ManifestParseError): - """The revision value in a project is incorrect. - """ + """The revision value in a project is incorrect.""" class ManifestInvalidPathError(ManifestParseError): - """A path used in or is incorrect. - """ + """A path used in or is incorrect.""" class NoManifestException(Exception): - """The required manifest does not exist. - """ + """The required manifest does not exist.""" - def __init__(self, path, reason): - super().__init__(path, reason) - self.path = path - self.reason = reason + def __init__(self, path, reason): + super().__init__(path, reason) + self.path = path + self.reason = reason - def __str__(self): - return self.reason + def __str__(self): + return self.reason class EditorError(Exception): - """Unspecified error from the user's text editor. - """ + """Unspecified error from the user's text editor.""" - def __init__(self, reason): - super().__init__(reason) - self.reason = reason + def __init__(self, reason): + super().__init__(reason) + self.reason = reason - def __str__(self): - return self.reason + def __str__(self): + return self.reason class GitError(Exception): - """Unspecified internal error from git. - """ + """Unspecified internal error from git.""" - def __init__(self, command): - super().__init__(command) - self.command = command + def __init__(self, command): + super().__init__(command) + self.command = command - def __str__(self): - return self.command + def __str__(self): + return self.command class UploadError(Exception): - """A bundle upload to Gerrit did not succeed. - """ + """A bundle upload to Gerrit did not succeed.""" - def __init__(self, reason): - super().__init__(reason) - self.reason = reason + def __init__(self, reason): + super().__init__(reason) + self.reason = reason - def __str__(self): - return self.reason + def __str__(self): + return self.reason class DownloadError(Exception): - """Cannot download a repository. - """ + """Cannot download a repository.""" - def __init__(self, reason): - super().__init__(reason) - self.reason = reason + def __init__(self, reason): + super().__init__(reason) + self.reason = reason - def __str__(self): - return self.reason + def __str__(self): + return self.reason class NoSuchProjectError(Exception): - """A specified project does not exist in the work tree. - """ + """A specified project does not exist in the work tree.""" - def __init__(self, name=None): - super().__init__(name) - self.name = name + def __init__(self, name=None): + super().__init__(name) + self.name = name - def __str__(self): - if self.name is None: - return 'in current directory' - return self.name + def __str__(self): + if self.name is None: + return "in current directory" + return self.name class InvalidProjectGroupsError(Exception): - """A specified project is not suitable for the specified groups - """ + """A specified project is not suitable for the specified groups""" - def __init__(self, name=None): - super().__init__(name) - self.name = name + def __init__(self, name=None): + super().__init__(name) + self.name = name - def __str__(self): - if self.name is None: - return 'in current directory' - return self.name + def __str__(self): + if self.name is None: + return "in current directory" + return self.name class RepoChangedException(Exception): - """Thrown if 'repo sync' results in repo updating its internal - repo or manifest repositories. In this special case we must - use exec to re-execute repo with the new code and manifest. - """ + """Thrown if 'repo sync' results in repo updating its internal + repo or manifest repositories. In this special case we must + use exec to re-execute repo with the new code and manifest. + """ - def __init__(self, extra_args=None): - super().__init__(extra_args) - self.extra_args = extra_args or [] + def __init__(self, extra_args=None): + super().__init__(extra_args) + self.extra_args = extra_args or [] class HookError(Exception): - """Thrown if a 'repo-hook' could not be run. + """Thrown if a 'repo-hook' could not be run. - The common case is that the file wasn't present when we tried to run it. - """ + The common case is that the file wasn't present when we tried to run it. + """ diff --git a/event_log.py b/event_log.py index c77c5648..b1f8bdf9 100644 --- a/event_log.py +++ b/event_log.py @@ -15,161 +15,169 @@ import json import multiprocessing -TASK_COMMAND = 'command' -TASK_SYNC_NETWORK = 'sync-network' -TASK_SYNC_LOCAL = 'sync-local' +TASK_COMMAND = "command" +TASK_SYNC_NETWORK = "sync-network" +TASK_SYNC_LOCAL = "sync-local" class EventLog(object): - """Event log that records events that occurred during a repo invocation. + """Event log that records events that occurred during a repo invocation. - Events are written to the log as a consecutive JSON entries, one per line. - Each entry contains the following keys: - - id: A ('RepoOp', ID) tuple, suitable for storing in a datastore. - The ID is only unique for the invocation of the repo command. - - name: Name of the object being operated upon. - - task_name: The task that was performed. - - start: Timestamp of when the operation started. - - finish: Timestamp of when the operation finished. - - success: Boolean indicating if the operation was successful. - - try_count: A counter indicating the try count of this task. + Events are written to the log as a consecutive JSON entries, one per line. + Each entry contains the following keys: + - id: A ('RepoOp', ID) tuple, suitable for storing in a datastore. + The ID is only unique for the invocation of the repo command. + - name: Name of the object being operated upon. + - task_name: The task that was performed. + - start: Timestamp of when the operation started. + - finish: Timestamp of when the operation finished. + - success: Boolean indicating if the operation was successful. + - try_count: A counter indicating the try count of this task. - Optionally: - - parent: A ('RepoOp', ID) tuple indicating the parent event for nested - events. + Optionally: + - parent: A ('RepoOp', ID) tuple indicating the parent event for nested + events. - Valid task_names include: - - command: The invocation of a subcommand. - - sync-network: The network component of a sync command. - - sync-local: The local component of a sync command. + Valid task_names include: + - command: The invocation of a subcommand. + - sync-network: The network component of a sync command. + - sync-local: The local component of a sync command. - Specific tasks may include additional informational properties. - """ - - def __init__(self): - """Initializes the event log.""" - self._log = [] - self._parent = None - - def Add(self, name, task_name, start, finish=None, success=None, - try_count=1, kind='RepoOp'): - """Add an event to the log. - - Args: - name: Name of the object being operated upon. - task_name: A sub-task that was performed for name. - start: Timestamp of when the operation started. - finish: Timestamp of when the operation finished. - success: Boolean indicating if the operation was successful. - try_count: A counter indicating the try count of this task. - kind: The kind of the object for the unique identifier. - - Returns: - A dictionary of the event added to the log. + Specific tasks may include additional informational properties. """ - event = { - 'id': (kind, _NextEventId()), - 'name': name, - 'task_name': task_name, - 'start_time': start, - 'try': try_count, - } - if self._parent: - event['parent'] = self._parent['id'] + def __init__(self): + """Initializes the event log.""" + self._log = [] + self._parent = None - if success is not None or finish is not None: - self.FinishEvent(event, finish, success) + def Add( + self, + name, + task_name, + start, + finish=None, + success=None, + try_count=1, + kind="RepoOp", + ): + """Add an event to the log. - self._log.append(event) - return event + Args: + name: Name of the object being operated upon. + task_name: A sub-task that was performed for name. + start: Timestamp of when the operation started. + finish: Timestamp of when the operation finished. + success: Boolean indicating if the operation was successful. + try_count: A counter indicating the try count of this task. + kind: The kind of the object for the unique identifier. - def AddSync(self, project, task_name, start, finish, success): - """Add a event to the log for a sync command. + Returns: + A dictionary of the event added to the log. + """ + event = { + "id": (kind, _NextEventId()), + "name": name, + "task_name": task_name, + "start_time": start, + "try": try_count, + } - Args: - project: Project being synced. - task_name: A sub-task that was performed for name. - One of (TASK_SYNC_NETWORK, TASK_SYNC_LOCAL) - start: Timestamp of when the operation started. - finish: Timestamp of when the operation finished. - success: Boolean indicating if the operation was successful. + if self._parent: + event["parent"] = self._parent["id"] - Returns: - A dictionary of the event added to the log. - """ - event = self.Add(project.relpath, task_name, start, finish, success) - if event is not None: - event['project'] = project.name - if project.revisionExpr: - event['revision'] = project.revisionExpr - if project.remote.url: - event['project_url'] = project.remote.url - if project.remote.fetchUrl: - event['remote_url'] = project.remote.fetchUrl - try: - event['git_hash'] = project.GetCommitRevisionId() - except Exception: - pass - return event + if success is not None or finish is not None: + self.FinishEvent(event, finish, success) - def GetStatusString(self, success): - """Converst a boolean success to a status string. + self._log.append(event) + return event - Args: - success: Boolean indicating if the operation was successful. + def AddSync(self, project, task_name, start, finish, success): + """Add a event to the log for a sync command. - Returns: - status string. - """ - return 'pass' if success else 'fail' + Args: + project: Project being synced. + task_name: A sub-task that was performed for name. + One of (TASK_SYNC_NETWORK, TASK_SYNC_LOCAL) + start: Timestamp of when the operation started. + finish: Timestamp of when the operation finished. + success: Boolean indicating if the operation was successful. - def FinishEvent(self, event, finish, success): - """Finishes an incomplete event. + Returns: + A dictionary of the event added to the log. + """ + event = self.Add(project.relpath, task_name, start, finish, success) + if event is not None: + event["project"] = project.name + if project.revisionExpr: + event["revision"] = project.revisionExpr + if project.remote.url: + event["project_url"] = project.remote.url + if project.remote.fetchUrl: + event["remote_url"] = project.remote.fetchUrl + try: + event["git_hash"] = project.GetCommitRevisionId() + except Exception: + pass + return event - Args: - event: An event that has been added to the log. - finish: Timestamp of when the operation finished. - success: Boolean indicating if the operation was successful. + def GetStatusString(self, success): + """Converst a boolean success to a status string. - Returns: - A dictionary of the event added to the log. - """ - event['status'] = self.GetStatusString(success) - event['finish_time'] = finish - return event + Args: + success: Boolean indicating if the operation was successful. - def SetParent(self, event): - """Set a parent event for all new entities. + Returns: + status string. + """ + return "pass" if success else "fail" - Args: - event: The event to use as a parent. - """ - self._parent = event + def FinishEvent(self, event, finish, success): + """Finishes an incomplete event. - def Write(self, filename): - """Writes the log out to a file. + Args: + event: An event that has been added to the log. + finish: Timestamp of when the operation finished. + success: Boolean indicating if the operation was successful. - Args: - filename: The file to write the log to. - """ - with open(filename, 'w+') as f: - for e in self._log: - json.dump(e, f, sort_keys=True) - f.write('\n') + Returns: + A dictionary of the event added to the log. + """ + event["status"] = self.GetStatusString(success) + event["finish_time"] = finish + return event + + def SetParent(self, event): + """Set a parent event for all new entities. + + Args: + event: The event to use as a parent. + """ + self._parent = event + + def Write(self, filename): + """Writes the log out to a file. + + Args: + filename: The file to write the log to. + """ + with open(filename, "w+") as f: + for e in self._log: + json.dump(e, f, sort_keys=True) + f.write("\n") # An integer id that is unique across this invocation of the program. -_EVENT_ID = multiprocessing.Value('i', 1) +_EVENT_ID = multiprocessing.Value("i", 1) def _NextEventId(): - """Helper function for grabbing the next unique id. + """Helper function for grabbing the next unique id. - Returns: - A unique, to this invocation of the program, integer id. - """ - with _EVENT_ID.get_lock(): - val = _EVENT_ID.value - _EVENT_ID.value += 1 - return val + Returns: + A unique, to this invocation of the program, integer id. + """ + with _EVENT_ID.get_lock(): + val = _EVENT_ID.value + _EVENT_ID.value += 1 + return val diff --git a/fetch.py b/fetch.py index c954a9c2..31f8152f 100644 --- a/fetch.py +++ b/fetch.py @@ -21,25 +21,29 @@ from urllib.request import urlopen def fetch_file(url, verbose=False): - """Fetch a file from the specified source using the appropriate protocol. + """Fetch a file from the specified source using the appropriate protocol. - Returns: - The contents of the file as bytes. - """ - scheme = urlparse(url).scheme - if scheme == 'gs': - cmd = ['gsutil', 'cat', url] - try: - result = subprocess.run( - cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, - check=True) - if result.stderr and verbose: - print('warning: non-fatal error running "gsutil": %s' % result.stderr, - file=sys.stderr) - return result.stdout - except subprocess.CalledProcessError as e: - print('fatal: error running "gsutil": %s' % e.stderr, - file=sys.stderr) - sys.exit(1) - with urlopen(url) as f: - return f.read() + Returns: + The contents of the file as bytes. + """ + scheme = urlparse(url).scheme + if scheme == "gs": + cmd = ["gsutil", "cat", url] + try: + result = subprocess.run( + cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True + ) + if result.stderr and verbose: + print( + 'warning: non-fatal error running "gsutil": %s' + % result.stderr, + file=sys.stderr, + ) + return result.stdout + except subprocess.CalledProcessError as e: + print( + 'fatal: error running "gsutil": %s' % e.stderr, file=sys.stderr + ) + sys.exit(1) + with urlopen(url) as f: + return f.read() diff --git a/git_command.py b/git_command.py index d4d4bed4..c7245ade 100644 --- a/git_command.py +++ b/git_command.py @@ -24,7 +24,7 @@ import platform_utils from repo_trace import REPO_TRACE, IsTrace, Trace from wrapper import Wrapper -GIT = 'git' +GIT = "git" # NB: These do not need to be kept in sync with the repo launcher script. # These may be much newer as it allows the repo launcher to roll between # different repo releases while source versions might require a newer git. @@ -36,126 +36,138 @@ GIT = 'git' # git-1.7 is in (EOL) Ubuntu Precise. git-1.9 is in Ubuntu Trusty. MIN_GIT_VERSION_SOFT = (1, 9, 1) MIN_GIT_VERSION_HARD = (1, 7, 2) -GIT_DIR = 'GIT_DIR' +GIT_DIR = "GIT_DIR" LAST_GITDIR = None LAST_CWD = None class _GitCall(object): - @functools.lru_cache(maxsize=None) - def version_tuple(self): - ret = Wrapper().ParseGitVersion() - if ret is None: - print('fatal: unable to detect git version', file=sys.stderr) - sys.exit(1) - return ret + @functools.lru_cache(maxsize=None) + def version_tuple(self): + ret = Wrapper().ParseGitVersion() + if ret is None: + print("fatal: unable to detect git version", file=sys.stderr) + sys.exit(1) + return ret - def __getattr__(self, name): - name = name.replace('_', '-') + def __getattr__(self, name): + name = name.replace("_", "-") - def fun(*cmdv): - command = [name] - command.extend(cmdv) - return GitCommand(None, command).Wait() == 0 - return fun + def fun(*cmdv): + command = [name] + command.extend(cmdv) + return GitCommand(None, command).Wait() == 0 + + return fun git = _GitCall() def RepoSourceVersion(): - """Return the version of the repo.git tree.""" - ver = getattr(RepoSourceVersion, 'version', None) + """Return the version of the repo.git tree.""" + ver = getattr(RepoSourceVersion, "version", None) - # We avoid GitCommand so we don't run into circular deps -- GitCommand needs - # to initialize version info we provide. - if ver is None: - env = GitCommand._GetBasicEnv() + # We avoid GitCommand so we don't run into circular deps -- GitCommand needs + # to initialize version info we provide. + if ver is None: + env = GitCommand._GetBasicEnv() - proj = os.path.dirname(os.path.abspath(__file__)) - env[GIT_DIR] = os.path.join(proj, '.git') - result = subprocess.run([GIT, 'describe', HEAD], stdout=subprocess.PIPE, - stderr=subprocess.DEVNULL, encoding='utf-8', - env=env, check=False) - if result.returncode == 0: - ver = result.stdout.strip() - if ver.startswith('v'): - ver = ver[1:] - else: - ver = 'unknown' - setattr(RepoSourceVersion, 'version', ver) + proj = os.path.dirname(os.path.abspath(__file__)) + env[GIT_DIR] = os.path.join(proj, ".git") + result = subprocess.run( + [GIT, "describe", HEAD], + stdout=subprocess.PIPE, + stderr=subprocess.DEVNULL, + encoding="utf-8", + env=env, + check=False, + ) + if result.returncode == 0: + ver = result.stdout.strip() + if ver.startswith("v"): + ver = ver[1:] + else: + ver = "unknown" + setattr(RepoSourceVersion, "version", ver) - return ver + return ver class UserAgent(object): - """Mange User-Agent settings when talking to external services + """Mange User-Agent settings when talking to external services - We follow the style as documented here: - https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/User-Agent - """ + We follow the style as documented here: + https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/User-Agent + """ - _os = None - _repo_ua = None - _git_ua = None + _os = None + _repo_ua = None + _git_ua = None - @property - def os(self): - """The operating system name.""" - if self._os is None: - os_name = sys.platform - if os_name.lower().startswith('linux'): - os_name = 'Linux' - elif os_name == 'win32': - os_name = 'Win32' - elif os_name == 'cygwin': - os_name = 'Cygwin' - elif os_name == 'darwin': - os_name = 'Darwin' - self._os = os_name + @property + def os(self): + """The operating system name.""" + if self._os is None: + os_name = sys.platform + if os_name.lower().startswith("linux"): + os_name = "Linux" + elif os_name == "win32": + os_name = "Win32" + elif os_name == "cygwin": + os_name = "Cygwin" + elif os_name == "darwin": + os_name = "Darwin" + self._os = os_name - return self._os + return self._os - @property - def repo(self): - """The UA when connecting directly from repo.""" - if self._repo_ua is None: - py_version = sys.version_info - self._repo_ua = 'git-repo/%s (%s) git/%s Python/%d.%d.%d' % ( - RepoSourceVersion(), - self.os, - git.version_tuple().full, - py_version.major, py_version.minor, py_version.micro) + @property + def repo(self): + """The UA when connecting directly from repo.""" + if self._repo_ua is None: + py_version = sys.version_info + self._repo_ua = "git-repo/%s (%s) git/%s Python/%d.%d.%d" % ( + RepoSourceVersion(), + self.os, + git.version_tuple().full, + py_version.major, + py_version.minor, + py_version.micro, + ) - return self._repo_ua + return self._repo_ua - @property - def git(self): - """The UA when running git.""" - if self._git_ua is None: - self._git_ua = 'git/%s (%s) git-repo/%s' % ( - git.version_tuple().full, - self.os, - RepoSourceVersion()) + @property + def git(self): + """The UA when running git.""" + if self._git_ua is None: + self._git_ua = "git/%s (%s) git-repo/%s" % ( + git.version_tuple().full, + self.os, + RepoSourceVersion(), + ) - return self._git_ua + return self._git_ua user_agent = UserAgent() -def git_require(min_version, fail=False, msg=''): - git_version = git.version_tuple() - if min_version <= git_version: - return True - if fail: - need = '.'.join(map(str, min_version)) - if msg: - msg = ' for ' + msg - print('fatal: git %s or later required%s' % (need, msg), file=sys.stderr) - sys.exit(1) - return False +def git_require(min_version, fail=False, msg=""): + git_version = git.version_tuple() + if min_version <= git_version: + return True + if fail: + need = ".".join(map(str, min_version)) + if msg: + msg = " for " + msg + print( + "fatal: git %s or later required%s" % (need, msg), file=sys.stderr + ) + sys.exit(1) + return False def _build_env( @@ -164,175 +176,194 @@ def _build_env( disable_editor: Optional[bool] = False, ssh_proxy: Optional[Any] = None, gitdir: Optional[str] = None, - objdir: Optional[str] = None + objdir: Optional[str] = None, ): - """Constucts an env dict for command execution.""" + """Constucts an env dict for command execution.""" - assert _kwargs_only == (), '_build_env only accepts keyword arguments.' + assert _kwargs_only == (), "_build_env only accepts keyword arguments." - env = GitCommand._GetBasicEnv() + env = GitCommand._GetBasicEnv() - if disable_editor: - env['GIT_EDITOR'] = ':' - if ssh_proxy: - env['REPO_SSH_SOCK'] = ssh_proxy.sock() - env['GIT_SSH'] = ssh_proxy.proxy - env['GIT_SSH_VARIANT'] = 'ssh' - if 'http_proxy' in env and 'darwin' == sys.platform: - s = "'http.proxy=%s'" % (env['http_proxy'],) - p = env.get('GIT_CONFIG_PARAMETERS') - if p is not None: - s = p + ' ' + s - env['GIT_CONFIG_PARAMETERS'] = s - if 'GIT_ALLOW_PROTOCOL' not in env: - env['GIT_ALLOW_PROTOCOL'] = ( - 'file:git:http:https:ssh:persistent-http:persistent-https:sso:rpc') - env['GIT_HTTP_USER_AGENT'] = user_agent.git + if disable_editor: + env["GIT_EDITOR"] = ":" + if ssh_proxy: + env["REPO_SSH_SOCK"] = ssh_proxy.sock() + env["GIT_SSH"] = ssh_proxy.proxy + env["GIT_SSH_VARIANT"] = "ssh" + if "http_proxy" in env and "darwin" == sys.platform: + s = "'http.proxy=%s'" % (env["http_proxy"],) + p = env.get("GIT_CONFIG_PARAMETERS") + if p is not None: + s = p + " " + s + env["GIT_CONFIG_PARAMETERS"] = s + if "GIT_ALLOW_PROTOCOL" not in env: + env[ + "GIT_ALLOW_PROTOCOL" + ] = "file:git:http:https:ssh:persistent-http:persistent-https:sso:rpc" + env["GIT_HTTP_USER_AGENT"] = user_agent.git - if objdir: - # Set to the place we want to save the objects. - env['GIT_OBJECT_DIRECTORY'] = objdir + if objdir: + # Set to the place we want to save the objects. + env["GIT_OBJECT_DIRECTORY"] = objdir - alt_objects = os.path.join(gitdir, 'objects') if gitdir else None - if alt_objects and os.path.realpath(alt_objects) != os.path.realpath(objdir): - # Allow git to search the original place in case of local or unique refs - # that git will attempt to resolve even if we aren't fetching them. - env['GIT_ALTERNATE_OBJECT_DIRECTORIES'] = alt_objects - if bare and gitdir is not None: - env[GIT_DIR] = gitdir + alt_objects = os.path.join(gitdir, "objects") if gitdir else None + if alt_objects and os.path.realpath(alt_objects) != os.path.realpath( + objdir + ): + # Allow git to search the original place in case of local or unique + # refs that git will attempt to resolve even if we aren't fetching + # them. + env["GIT_ALTERNATE_OBJECT_DIRECTORIES"] = alt_objects + if bare and gitdir is not None: + env[GIT_DIR] = gitdir - return env + return env class GitCommand(object): - """Wrapper around a single git invocation.""" + """Wrapper around a single git invocation.""" - def __init__(self, - project, - cmdv, - bare=False, - input=None, - capture_stdout=False, - capture_stderr=False, - merge_output=False, - disable_editor=False, - ssh_proxy=None, - cwd=None, - gitdir=None, - objdir=None): + def __init__( + self, + project, + cmdv, + bare=False, + input=None, + capture_stdout=False, + capture_stderr=False, + merge_output=False, + disable_editor=False, + ssh_proxy=None, + cwd=None, + gitdir=None, + objdir=None, + ): + if project: + if not cwd: + cwd = project.worktree + if not gitdir: + gitdir = project.gitdir - if project: - if not cwd: - cwd = project.worktree - if not gitdir: - gitdir = project.gitdir + # Git on Windows wants its paths only using / for reliability. + if platform_utils.isWindows(): + if objdir: + objdir = objdir.replace("\\", "/") + if gitdir: + gitdir = gitdir.replace("\\", "/") - # Git on Windows wants its paths only using / for reliability. - if platform_utils.isWindows(): - if objdir: - objdir = objdir.replace('\\', '/') - if gitdir: - gitdir = gitdir.replace('\\', '/') + env = _build_env( + disable_editor=disable_editor, + ssh_proxy=ssh_proxy, + objdir=objdir, + gitdir=gitdir, + bare=bare, + ) - env = _build_env( - disable_editor=disable_editor, - ssh_proxy=ssh_proxy, - objdir=objdir, - gitdir=gitdir, - bare=bare, - ) + command = [GIT] + if bare: + cwd = None + command.append(cmdv[0]) + # Need to use the --progress flag for fetch/clone so output will be + # displayed as by default git only does progress output if stderr is a + # TTY. + if sys.stderr.isatty() and cmdv[0] in ("fetch", "clone"): + if "--progress" not in cmdv and "--quiet" not in cmdv: + command.append("--progress") + command.extend(cmdv[1:]) - command = [GIT] - if bare: - cwd = None - command.append(cmdv[0]) - # Need to use the --progress flag for fetch/clone so output will be - # displayed as by default git only does progress output if stderr is a TTY. - if sys.stderr.isatty() and cmdv[0] in ('fetch', 'clone'): - if '--progress' not in cmdv and '--quiet' not in cmdv: - command.append('--progress') - command.extend(cmdv[1:]) + stdin = subprocess.PIPE if input else None + stdout = subprocess.PIPE if capture_stdout else None + stderr = ( + subprocess.STDOUT + if merge_output + else (subprocess.PIPE if capture_stderr else None) + ) - stdin = subprocess.PIPE if input else None - stdout = subprocess.PIPE if capture_stdout else None - stderr = (subprocess.STDOUT if merge_output else - (subprocess.PIPE if capture_stderr else None)) + dbg = "" + if IsTrace(): + global LAST_CWD + global LAST_GITDIR - dbg = '' - if IsTrace(): - global LAST_CWD - global LAST_GITDIR + if cwd and LAST_CWD != cwd: + if LAST_GITDIR or LAST_CWD: + dbg += "\n" + dbg += ": cd %s\n" % cwd + LAST_CWD = cwd - if cwd and LAST_CWD != cwd: - if LAST_GITDIR or LAST_CWD: - dbg += '\n' - dbg += ': cd %s\n' % cwd - LAST_CWD = cwd + if GIT_DIR in env and LAST_GITDIR != env[GIT_DIR]: + if LAST_GITDIR or LAST_CWD: + dbg += "\n" + dbg += ": export GIT_DIR=%s\n" % env[GIT_DIR] + LAST_GITDIR = env[GIT_DIR] - if GIT_DIR in env and LAST_GITDIR != env[GIT_DIR]: - if LAST_GITDIR or LAST_CWD: - dbg += '\n' - dbg += ': export GIT_DIR=%s\n' % env[GIT_DIR] - LAST_GITDIR = env[GIT_DIR] + if "GIT_OBJECT_DIRECTORY" in env: + dbg += ( + ": export GIT_OBJECT_DIRECTORY=%s\n" + % env["GIT_OBJECT_DIRECTORY"] + ) + if "GIT_ALTERNATE_OBJECT_DIRECTORIES" in env: + dbg += ": export GIT_ALTERNATE_OBJECT_DIRECTORIES=%s\n" % ( + env["GIT_ALTERNATE_OBJECT_DIRECTORIES"] + ) - if 'GIT_OBJECT_DIRECTORY' in env: - dbg += ': export GIT_OBJECT_DIRECTORY=%s\n' % env['GIT_OBJECT_DIRECTORY'] - if 'GIT_ALTERNATE_OBJECT_DIRECTORIES' in env: - dbg += ': export GIT_ALTERNATE_OBJECT_DIRECTORIES=%s\n' % ( - env['GIT_ALTERNATE_OBJECT_DIRECTORIES']) + dbg += ": " + dbg += " ".join(command) + if stdin == subprocess.PIPE: + dbg += " 0<|" + if stdout == subprocess.PIPE: + dbg += " 1>|" + if stderr == subprocess.PIPE: + dbg += " 2>|" + elif stderr == subprocess.STDOUT: + dbg += " 2>&1" - dbg += ': ' - dbg += ' '.join(command) - if stdin == subprocess.PIPE: - dbg += ' 0<|' - if stdout == subprocess.PIPE: - dbg += ' 1>|' - if stderr == subprocess.PIPE: - dbg += ' 2>|' - elif stderr == subprocess.STDOUT: - dbg += ' 2>&1' + with Trace( + "git command %s %s with debug: %s", LAST_GITDIR, command, dbg + ): + try: + p = subprocess.Popen( + command, + cwd=cwd, + env=env, + encoding="utf-8", + errors="backslashreplace", + stdin=stdin, + stdout=stdout, + stderr=stderr, + ) + except Exception as e: + raise GitError("%s: %s" % (command[1], e)) - with Trace('git command %s %s with debug: %s', LAST_GITDIR, command, dbg): - try: - p = subprocess.Popen(command, - cwd=cwd, - env=env, - encoding='utf-8', - errors='backslashreplace', - stdin=stdin, - stdout=stdout, - stderr=stderr) - except Exception as e: - raise GitError('%s: %s' % (command[1], e)) + if ssh_proxy: + ssh_proxy.add_client(p) - if ssh_proxy: - ssh_proxy.add_client(p) + self.process = p - self.process = p + try: + self.stdout, self.stderr = p.communicate(input=input) + finally: + if ssh_proxy: + ssh_proxy.remove_client(p) + self.rc = p.wait() - try: - self.stdout, self.stderr = p.communicate(input=input) - finally: - if ssh_proxy: - ssh_proxy.remove_client(p) - self.rc = p.wait() + @staticmethod + def _GetBasicEnv(): + """Return a basic env for running git under. - @staticmethod - def _GetBasicEnv(): - """Return a basic env for running git under. + This is guaranteed to be side-effect free. + """ + env = os.environ.copy() + for key in ( + REPO_TRACE, + GIT_DIR, + "GIT_ALTERNATE_OBJECT_DIRECTORIES", + "GIT_OBJECT_DIRECTORY", + "GIT_WORK_TREE", + "GIT_GRAFT_FILE", + "GIT_INDEX_FILE", + ): + env.pop(key, None) + return env - This is guaranteed to be side-effect free. - """ - env = os.environ.copy() - for key in (REPO_TRACE, - GIT_DIR, - 'GIT_ALTERNATE_OBJECT_DIRECTORIES', - 'GIT_OBJECT_DIRECTORY', - 'GIT_WORK_TREE', - 'GIT_GRAFT_FILE', - 'GIT_INDEX_FILE'): - env.pop(key, None) - return env - - def Wait(self): - return self.rc + def Wait(self): + return self.rc diff --git a/git_config.py b/git_config.py index 9ad979ad..05b3c1ee 100644 --- a/git_config.py +++ b/git_config.py @@ -34,23 +34,23 @@ from git_refs import R_CHANGES, R_HEADS, R_TAGS # Prefix that is prepended to all the keys of SyncAnalysisState's data # that is saved in the config. -SYNC_STATE_PREFIX = 'repo.syncstate.' +SYNC_STATE_PREFIX = "repo.syncstate." -ID_RE = re.compile(r'^[0-9a-f]{40}$') +ID_RE = re.compile(r"^[0-9a-f]{40}$") REVIEW_CACHE = dict() def IsChange(rev): - return rev.startswith(R_CHANGES) + return rev.startswith(R_CHANGES) def IsId(rev): - return ID_RE.match(rev) + return ID_RE.match(rev) def IsTag(rev): - return rev.startswith(R_TAGS) + return rev.startswith(R_TAGS) def IsImmutable(rev): @@ -58,765 +58,785 @@ def IsImmutable(rev): def _key(name): - parts = name.split('.') - if len(parts) < 2: - return name.lower() - parts[0] = parts[0].lower() - parts[-1] = parts[-1].lower() - return '.'.join(parts) + parts = name.split(".") + if len(parts) < 2: + return name.lower() + parts[0] = parts[0].lower() + parts[-1] = parts[-1].lower() + return ".".join(parts) class GitConfig(object): - _ForUser = None + _ForUser = None - _ForSystem = None - _SYSTEM_CONFIG = '/etc/gitconfig' + _ForSystem = None + _SYSTEM_CONFIG = "/etc/gitconfig" - @classmethod - def ForSystem(cls): - if cls._ForSystem is None: - cls._ForSystem = cls(configfile=cls._SYSTEM_CONFIG) - return cls._ForSystem + @classmethod + def ForSystem(cls): + if cls._ForSystem is None: + cls._ForSystem = cls(configfile=cls._SYSTEM_CONFIG) + return cls._ForSystem - @classmethod - def ForUser(cls): - if cls._ForUser is None: - cls._ForUser = cls(configfile=cls._getUserConfig()) - return cls._ForUser + @classmethod + def ForUser(cls): + if cls._ForUser is None: + cls._ForUser = cls(configfile=cls._getUserConfig()) + return cls._ForUser - @staticmethod - def _getUserConfig(): - return os.path.expanduser('~/.gitconfig') + @staticmethod + def _getUserConfig(): + return os.path.expanduser("~/.gitconfig") - @classmethod - def ForRepository(cls, gitdir, defaults=None): - return cls(configfile=os.path.join(gitdir, 'config'), - defaults=defaults) + @classmethod + def ForRepository(cls, gitdir, defaults=None): + return cls(configfile=os.path.join(gitdir, "config"), defaults=defaults) - def __init__(self, configfile, defaults=None, jsonFile=None): - self.file = configfile - self.defaults = defaults - self._cache_dict = None - self._section_dict = None - self._remotes = {} - self._branches = {} + def __init__(self, configfile, defaults=None, jsonFile=None): + self.file = configfile + self.defaults = defaults + self._cache_dict = None + self._section_dict = None + self._remotes = {} + self._branches = {} - self._json = jsonFile - if self._json is None: - self._json = os.path.join( - os.path.dirname(self.file), - '.repo_' + os.path.basename(self.file) + '.json') + self._json = jsonFile + if self._json is None: + self._json = os.path.join( + os.path.dirname(self.file), + ".repo_" + os.path.basename(self.file) + ".json", + ) - def ClearCache(self): - """Clear the in-memory cache of config.""" - self._cache_dict = None + def ClearCache(self): + """Clear the in-memory cache of config.""" + self._cache_dict = None - def Has(self, name, include_defaults=True): - """Return true if this configuration file has the key. - """ - if _key(name) in self._cache: - return True - if include_defaults and self.defaults: - return self.defaults.Has(name, include_defaults=True) - return False + def Has(self, name, include_defaults=True): + """Return true if this configuration file has the key.""" + if _key(name) in self._cache: + return True + if include_defaults and self.defaults: + return self.defaults.Has(name, include_defaults=True) + return False - def GetInt(self, name: str) -> Union[int, None]: - """Returns an integer from the configuration file. + def GetInt(self, name: str) -> Union[int, None]: + """Returns an integer from the configuration file. - This follows the git config syntax. + This follows the git config syntax. - Args: - name: The key to lookup. + Args: + name: The key to lookup. - Returns: - None if the value was not defined, or is not an int. - Otherwise, the number itself. - """ - v = self.GetString(name) - if v is None: - return None - v = v.strip() + Returns: + None if the value was not defined, or is not an int. + Otherwise, the number itself. + """ + v = self.GetString(name) + if v is None: + return None + v = v.strip() - mult = 1 - if v.endswith('k'): - v = v[:-1] - mult = 1024 - elif v.endswith('m'): - v = v[:-1] - mult = 1024 * 1024 - elif v.endswith('g'): - v = v[:-1] - mult = 1024 * 1024 * 1024 + mult = 1 + if v.endswith("k"): + v = v[:-1] + mult = 1024 + elif v.endswith("m"): + v = v[:-1] + mult = 1024 * 1024 + elif v.endswith("g"): + v = v[:-1] + mult = 1024 * 1024 * 1024 - base = 10 - if v.startswith('0x'): - base = 16 + base = 10 + if v.startswith("0x"): + base = 16 - try: - return int(v, base=base) * mult - except ValueError: - print( - f"warning: expected {name} to represent an integer, got {v} instead", - file=sys.stderr) - return None + try: + return int(v, base=base) * mult + except ValueError: + print( + f"warning: expected {name} to represent an integer, got {v} " + "instead", + file=sys.stderr, + ) + return None - def DumpConfigDict(self): - """Returns the current configuration dict. + def DumpConfigDict(self): + """Returns the current configuration dict. - Configuration data is information only (e.g. logging) and - should not be considered a stable data-source. + Configuration data is information only (e.g. logging) and + should not be considered a stable data-source. - Returns: - dict of {, } for git configuration cache. - are strings converted by GetString. - """ - config_dict = {} - for key in self._cache: - config_dict[key] = self.GetString(key) - return config_dict + Returns: + dict of {, } for git configuration cache. + are strings converted by GetString. + """ + config_dict = {} + for key in self._cache: + config_dict[key] = self.GetString(key) + return config_dict - def GetBoolean(self, name: str) -> Union[str, None]: - """Returns a boolean from the configuration file. - None : The value was not defined, or is not a boolean. - True : The value was set to true or yes. - False: The value was set to false or no. - """ - v = self.GetString(name) - if v is None: - return None - v = v.lower() - if v in ('true', 'yes'): - return True - if v in ('false', 'no'): - return False - print(f"warning: expected {name} to represent a boolean, got {v} instead", - file=sys.stderr) - return None + def GetBoolean(self, name: str) -> Union[str, None]: + """Returns a boolean from the configuration file. - def SetBoolean(self, name, value): - """Set the truthy value for a key.""" - if value is not None: - value = 'true' if value else 'false' - self.SetString(name, value) - - def GetString(self, name: str, all_keys: bool = False) -> Union[str, None]: - """Get the first value for a key, or None if it is not defined. - - This configuration file is used first, if the key is not - defined or all_keys = True then the defaults are also searched. - """ - try: - v = self._cache[_key(name)] - except KeyError: - if self.defaults: - return self.defaults.GetString(name, all_keys=all_keys) - v = [] - - if not all_keys: - if v: - return v[0] - return None - - r = [] - r.extend(v) - if self.defaults: - r.extend(self.defaults.GetString(name, all_keys=True)) - return r - - def SetString(self, name, value): - """Set the value(s) for a key. - Only this configuration file is modified. - - The supplied value should be either a string, or a list of strings (to - store multiple values), or None (to delete the key). - """ - key = _key(name) - - try: - old = self._cache[key] - except KeyError: - old = [] - - if value is None: - if old: - del self._cache[key] - self._do('--unset-all', name) - - elif isinstance(value, list): - if len(value) == 0: - self.SetString(name, None) - - elif len(value) == 1: - self.SetString(name, value[0]) - - elif old != value: - self._cache[key] = list(value) - self._do('--replace-all', name, value[0]) - for i in range(1, len(value)): - self._do('--add', name, value[i]) - - elif len(old) != 1 or old[0] != value: - self._cache[key] = [value] - self._do('--replace-all', name, value) - - def GetRemote(self, name): - """Get the remote.$name.* configuration values as an object. - """ - try: - r = self._remotes[name] - except KeyError: - r = Remote(self, name) - self._remotes[r.name] = r - return r - - def GetBranch(self, name): - """Get the branch.$name.* configuration values as an object. - """ - try: - b = self._branches[name] - except KeyError: - b = Branch(self, name) - self._branches[b.name] = b - return b - - def GetSyncAnalysisStateData(self): - """Returns data to be logged for the analysis of sync performance.""" - return {k: v for k, v in self.DumpConfigDict().items() if k.startswith(SYNC_STATE_PREFIX)} - - def UpdateSyncAnalysisState(self, options, superproject_logging_data): - """Update Config's SYNC_STATE_PREFIX* data with the latest sync data. - - Args: - options: Options passed to sync returned from optparse. See _Options(). - superproject_logging_data: A dictionary of superproject data that is to be logged. - - Returns: - SyncAnalysisState object. - """ - return SyncAnalysisState(self, options, superproject_logging_data) - - def GetSubSections(self, section): - """List all subsection names matching $section.*.* - """ - return self._sections.get(section, set()) - - def HasSection(self, section, subsection=''): - """Does at least one key in section.subsection exist? - """ - try: - return subsection in self._sections[section] - except KeyError: - return False - - def UrlInsteadOf(self, url): - """Resolve any url.*.insteadof references. - """ - for new_url in self.GetSubSections('url'): - for old_url in self.GetString('url.%s.insteadof' % new_url, True): - if old_url is not None and url.startswith(old_url): - return new_url + url[len(old_url):] - return url - - @property - def _sections(self): - d = self._section_dict - if d is None: - d = {} - for name in self._cache.keys(): - p = name.split('.') - if 2 == len(p): - section = p[0] - subsect = '' - else: - section = p[0] - subsect = '.'.join(p[1:-1]) - if section not in d: - d[section] = set() - d[section].add(subsect) - self._section_dict = d - return d - - @property - def _cache(self): - if self._cache_dict is None: - self._cache_dict = self._Read() - return self._cache_dict - - def _Read(self): - d = self._ReadJson() - if d is None: - d = self._ReadGit() - self._SaveJson(d) - return d - - def _ReadJson(self): - try: - if os.path.getmtime(self._json) <= os.path.getmtime(self.file): - platform_utils.remove(self._json) + Returns: + None: The value was not defined, or is not a boolean. + True: The value was set to true or yes. + False: The value was set to false or no. + """ + v = self.GetString(name) + if v is None: + return None + v = v.lower() + if v in ("true", "yes"): + return True + if v in ("false", "no"): + return False + print( + f"warning: expected {name} to represent a boolean, got {v} instead", + file=sys.stderr, + ) return None - except OSError: - return None - try: - with Trace(': parsing %s', self.file): - with open(self._json) as fd: - return json.load(fd) - except (IOError, ValueError): - platform_utils.remove(self._json, missing_ok=True) - return None - def _SaveJson(self, cache): - try: - with open(self._json, 'w') as fd: - json.dump(cache, fd, indent=2) - except (IOError, TypeError): - platform_utils.remove(self._json, missing_ok=True) + def SetBoolean(self, name, value): + """Set the truthy value for a key.""" + if value is not None: + value = "true" if value else "false" + self.SetString(name, value) - def _ReadGit(self): - """ - Read configuration data from git. + def GetString(self, name: str, all_keys: bool = False) -> Union[str, None]: + """Get the first value for a key, or None if it is not defined. - This internal method populates the GitConfig cache. + This configuration file is used first, if the key is not + defined or all_keys = True then the defaults are also searched. + """ + try: + v = self._cache[_key(name)] + except KeyError: + if self.defaults: + return self.defaults.GetString(name, all_keys=all_keys) + v = [] - """ - c = {} - if not os.path.exists(self.file): - return c + if not all_keys: + if v: + return v[0] + return None - d = self._do('--null', '--list') - for line in d.rstrip('\0').split('\0'): - if '\n' in line: - key, val = line.split('\n', 1) - else: - key = line - val = None + r = [] + r.extend(v) + if self.defaults: + r.extend(self.defaults.GetString(name, all_keys=True)) + return r - if key in c: - c[key].append(val) - else: - c[key] = [val] + def SetString(self, name, value): + """Set the value(s) for a key. + Only this configuration file is modified. - return c + The supplied value should be either a string, or a list of strings (to + store multiple values), or None (to delete the key). + """ + key = _key(name) - def _do(self, *args): - if self.file == self._SYSTEM_CONFIG: - command = ['config', '--system', '--includes'] - else: - command = ['config', '--file', self.file, '--includes'] - command.extend(args) + try: + old = self._cache[key] + except KeyError: + old = [] - p = GitCommand(None, - command, - capture_stdout=True, - capture_stderr=True) - if p.Wait() == 0: - return p.stdout - else: - raise GitError('git config %s: %s' % (str(args), p.stderr)) + if value is None: + if old: + del self._cache[key] + self._do("--unset-all", name) + + elif isinstance(value, list): + if len(value) == 0: + self.SetString(name, None) + + elif len(value) == 1: + self.SetString(name, value[0]) + + elif old != value: + self._cache[key] = list(value) + self._do("--replace-all", name, value[0]) + for i in range(1, len(value)): + self._do("--add", name, value[i]) + + elif len(old) != 1 or old[0] != value: + self._cache[key] = [value] + self._do("--replace-all", name, value) + + def GetRemote(self, name): + """Get the remote.$name.* configuration values as an object.""" + try: + r = self._remotes[name] + except KeyError: + r = Remote(self, name) + self._remotes[r.name] = r + return r + + def GetBranch(self, name): + """Get the branch.$name.* configuration values as an object.""" + try: + b = self._branches[name] + except KeyError: + b = Branch(self, name) + self._branches[b.name] = b + return b + + def GetSyncAnalysisStateData(self): + """Returns data to be logged for the analysis of sync performance.""" + return { + k: v + for k, v in self.DumpConfigDict().items() + if k.startswith(SYNC_STATE_PREFIX) + } + + def UpdateSyncAnalysisState(self, options, superproject_logging_data): + """Update Config's SYNC_STATE_PREFIX* data with the latest sync data. + + Args: + options: Options passed to sync returned from optparse. See + _Options(). + superproject_logging_data: A dictionary of superproject data that is + to be logged. + + Returns: + SyncAnalysisState object. + """ + return SyncAnalysisState(self, options, superproject_logging_data) + + def GetSubSections(self, section): + """List all subsection names matching $section.*.*""" + return self._sections.get(section, set()) + + def HasSection(self, section, subsection=""): + """Does at least one key in section.subsection exist?""" + try: + return subsection in self._sections[section] + except KeyError: + return False + + def UrlInsteadOf(self, url): + """Resolve any url.*.insteadof references.""" + for new_url in self.GetSubSections("url"): + for old_url in self.GetString("url.%s.insteadof" % new_url, True): + if old_url is not None and url.startswith(old_url): + return new_url + url[len(old_url) :] + return url + + @property + def _sections(self): + d = self._section_dict + if d is None: + d = {} + for name in self._cache.keys(): + p = name.split(".") + if 2 == len(p): + section = p[0] + subsect = "" + else: + section = p[0] + subsect = ".".join(p[1:-1]) + if section not in d: + d[section] = set() + d[section].add(subsect) + self._section_dict = d + return d + + @property + def _cache(self): + if self._cache_dict is None: + self._cache_dict = self._Read() + return self._cache_dict + + def _Read(self): + d = self._ReadJson() + if d is None: + d = self._ReadGit() + self._SaveJson(d) + return d + + def _ReadJson(self): + try: + if os.path.getmtime(self._json) <= os.path.getmtime(self.file): + platform_utils.remove(self._json) + return None + except OSError: + return None + try: + with Trace(": parsing %s", self.file): + with open(self._json) as fd: + return json.load(fd) + except (IOError, ValueError): + platform_utils.remove(self._json, missing_ok=True) + return None + + def _SaveJson(self, cache): + try: + with open(self._json, "w") as fd: + json.dump(cache, fd, indent=2) + except (IOError, TypeError): + platform_utils.remove(self._json, missing_ok=True) + + def _ReadGit(self): + """ + Read configuration data from git. + + This internal method populates the GitConfig cache. + + """ + c = {} + if not os.path.exists(self.file): + return c + + d = self._do("--null", "--list") + for line in d.rstrip("\0").split("\0"): + if "\n" in line: + key, val = line.split("\n", 1) + else: + key = line + val = None + + if key in c: + c[key].append(val) + else: + c[key] = [val] + + return c + + def _do(self, *args): + if self.file == self._SYSTEM_CONFIG: + command = ["config", "--system", "--includes"] + else: + command = ["config", "--file", self.file, "--includes"] + command.extend(args) + + p = GitCommand(None, command, capture_stdout=True, capture_stderr=True) + if p.Wait() == 0: + return p.stdout + else: + raise GitError("git config %s: %s" % (str(args), p.stderr)) class RepoConfig(GitConfig): - """User settings for repo itself.""" + """User settings for repo itself.""" - @staticmethod - def _getUserConfig(): - repo_config_dir = os.getenv('REPO_CONFIG_DIR', os.path.expanduser('~')) - return os.path.join(repo_config_dir, '.repoconfig/config') + @staticmethod + def _getUserConfig(): + repo_config_dir = os.getenv("REPO_CONFIG_DIR", os.path.expanduser("~")) + return os.path.join(repo_config_dir, ".repoconfig/config") class RefSpec(object): - """A Git refspec line, split into its components: + """A Git refspec line, split into its components: - forced: True if the line starts with '+' - src: Left side of the line - dst: Right side of the line - """ + forced: True if the line starts with '+' + src: Left side of the line + dst: Right side of the line + """ - @classmethod - def FromString(cls, rs): - lhs, rhs = rs.split(':', 2) - if lhs.startswith('+'): - lhs = lhs[1:] - forced = True - else: - forced = False - return cls(forced, lhs, rhs) + @classmethod + def FromString(cls, rs): + lhs, rhs = rs.split(":", 2) + if lhs.startswith("+"): + lhs = lhs[1:] + forced = True + else: + forced = False + return cls(forced, lhs, rhs) - def __init__(self, forced, lhs, rhs): - self.forced = forced - self.src = lhs - self.dst = rhs + def __init__(self, forced, lhs, rhs): + self.forced = forced + self.src = lhs + self.dst = rhs - def SourceMatches(self, rev): - if self.src: - if rev == self.src: - return True - if self.src.endswith('/*') and rev.startswith(self.src[:-1]): - return True - return False + def SourceMatches(self, rev): + if self.src: + if rev == self.src: + return True + if self.src.endswith("/*") and rev.startswith(self.src[:-1]): + return True + return False - def DestMatches(self, ref): - if self.dst: - if ref == self.dst: - return True - if self.dst.endswith('/*') and ref.startswith(self.dst[:-1]): - return True - return False + def DestMatches(self, ref): + if self.dst: + if ref == self.dst: + return True + if self.dst.endswith("/*") and ref.startswith(self.dst[:-1]): + return True + return False - def MapSource(self, rev): - if self.src.endswith('/*'): - return self.dst[:-1] + rev[len(self.src) - 1:] - return self.dst + def MapSource(self, rev): + if self.src.endswith("/*"): + return self.dst[:-1] + rev[len(self.src) - 1 :] + return self.dst - def __str__(self): - s = '' - if self.forced: - s += '+' - if self.src: - s += self.src - if self.dst: - s += ':' - s += self.dst - return s + def __str__(self): + s = "" + if self.forced: + s += "+" + if self.src: + s += self.src + if self.dst: + s += ":" + s += self.dst + return s -URI_ALL = re.compile(r'^([a-z][a-z+-]*)://([^@/]*@?[^/]*)/') +URI_ALL = re.compile(r"^([a-z][a-z+-]*)://([^@/]*@?[^/]*)/") def GetSchemeFromUrl(url): - m = URI_ALL.match(url) - if m: - return m.group(1) - return None + m = URI_ALL.match(url) + if m: + return m.group(1) + return None @contextlib.contextmanager def GetUrlCookieFile(url, quiet): - if url.startswith('persistent-'): - try: - p = subprocess.Popen( - ['git-remote-persistent-https', '-print_config', url], - stdin=subprocess.PIPE, stdout=subprocess.PIPE, - stderr=subprocess.PIPE) - try: - cookieprefix = 'http.cookiefile=' - proxyprefix = 'http.proxy=' - cookiefile = None - proxy = None - for line in p.stdout: - line = line.strip().decode('utf-8') - if line.startswith(cookieprefix): - cookiefile = os.path.expanduser(line[len(cookieprefix):]) - if line.startswith(proxyprefix): - proxy = line[len(proxyprefix):] - # Leave subprocess open, as cookie file may be transient. - if cookiefile or proxy: - yield cookiefile, proxy - return - finally: - p.stdin.close() - if p.wait(): - err_msg = p.stderr.read().decode('utf-8') - if ' -print_config' in err_msg: - pass # Persistent proxy doesn't support -print_config. - elif not quiet: - print(err_msg, file=sys.stderr) - except OSError as e: - if e.errno == errno.ENOENT: - pass # No persistent proxy. - raise - cookiefile = GitConfig.ForUser().GetString('http.cookiefile') - if cookiefile: - cookiefile = os.path.expanduser(cookiefile) - yield cookiefile, None + if url.startswith("persistent-"): + try: + p = subprocess.Popen( + ["git-remote-persistent-https", "-print_config", url], + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + try: + cookieprefix = "http.cookiefile=" + proxyprefix = "http.proxy=" + cookiefile = None + proxy = None + for line in p.stdout: + line = line.strip().decode("utf-8") + if line.startswith(cookieprefix): + cookiefile = os.path.expanduser( + line[len(cookieprefix) :] + ) + if line.startswith(proxyprefix): + proxy = line[len(proxyprefix) :] + # Leave subprocess open, as cookie file may be transient. + if cookiefile or proxy: + yield cookiefile, proxy + return + finally: + p.stdin.close() + if p.wait(): + err_msg = p.stderr.read().decode("utf-8") + if " -print_config" in err_msg: + pass # Persistent proxy doesn't support -print_config. + elif not quiet: + print(err_msg, file=sys.stderr) + except OSError as e: + if e.errno == errno.ENOENT: + pass # No persistent proxy. + raise + cookiefile = GitConfig.ForUser().GetString("http.cookiefile") + if cookiefile: + cookiefile = os.path.expanduser(cookiefile) + yield cookiefile, None class Remote(object): - """Configuration options related to a remote. - """ + """Configuration options related to a remote.""" - def __init__(self, config, name): - self._config = config - self.name = name - self.url = self._Get('url') - self.pushUrl = self._Get('pushurl') - self.review = self._Get('review') - self.projectname = self._Get('projectname') - self.fetch = list(map(RefSpec.FromString, - self._Get('fetch', all_keys=True))) - self._review_url = None + def __init__(self, config, name): + self._config = config + self.name = name + self.url = self._Get("url") + self.pushUrl = self._Get("pushurl") + self.review = self._Get("review") + self.projectname = self._Get("projectname") + self.fetch = list( + map(RefSpec.FromString, self._Get("fetch", all_keys=True)) + ) + self._review_url = None - def _InsteadOf(self): - globCfg = GitConfig.ForUser() - urlList = globCfg.GetSubSections('url') - longest = "" - longestUrl = "" + def _InsteadOf(self): + globCfg = GitConfig.ForUser() + urlList = globCfg.GetSubSections("url") + longest = "" + longestUrl = "" - for url in urlList: - key = "url." + url + ".insteadOf" - insteadOfList = globCfg.GetString(key, all_keys=True) + for url in urlList: + key = "url." + url + ".insteadOf" + insteadOfList = globCfg.GetString(key, all_keys=True) - for insteadOf in insteadOfList: - if (self.url.startswith(insteadOf) - and len(insteadOf) > len(longest)): - longest = insteadOf - longestUrl = url + for insteadOf in insteadOfList: + if self.url.startswith(insteadOf) and len(insteadOf) > len( + longest + ): + longest = insteadOf + longestUrl = url - if len(longest) == 0: - return self.url + if len(longest) == 0: + return self.url - return self.url.replace(longest, longestUrl, 1) + return self.url.replace(longest, longestUrl, 1) - def PreConnectFetch(self, ssh_proxy): - """Run any setup for this remote before we connect to it. + def PreConnectFetch(self, ssh_proxy): + """Run any setup for this remote before we connect to it. - In practice, if the remote is using SSH, we'll attempt to create a new - SSH master session to it for reuse across projects. + In practice, if the remote is using SSH, we'll attempt to create a new + SSH master session to it for reuse across projects. - Args: - ssh_proxy: The SSH settings for managing master sessions. + Args: + ssh_proxy: The SSH settings for managing master sessions. - Returns: - Whether the preconnect phase for this remote was successful. - """ - if not ssh_proxy: - return True + Returns: + Whether the preconnect phase for this remote was successful. + """ + if not ssh_proxy: + return True - connectionUrl = self._InsteadOf() - return ssh_proxy.preconnect(connectionUrl) + connectionUrl = self._InsteadOf() + return ssh_proxy.preconnect(connectionUrl) - def ReviewUrl(self, userEmail, validate_certs): - if self._review_url is None: - if self.review is None: - return None + def ReviewUrl(self, userEmail, validate_certs): + if self._review_url is None: + if self.review is None: + return None - u = self.review - if u.startswith('persistent-'): - u = u[len('persistent-'):] - if u.split(':')[0] not in ('http', 'https', 'sso', 'ssh'): - u = 'http://%s' % u - if u.endswith('/Gerrit'): - u = u[:len(u) - len('/Gerrit')] - if u.endswith('/ssh_info'): - u = u[:len(u) - len('/ssh_info')] - if not u.endswith('/'): - u += '/' - http_url = u + u = self.review + if u.startswith("persistent-"): + u = u[len("persistent-") :] + if u.split(":")[0] not in ("http", "https", "sso", "ssh"): + u = "http://%s" % u + if u.endswith("/Gerrit"): + u = u[: len(u) - len("/Gerrit")] + if u.endswith("/ssh_info"): + u = u[: len(u) - len("/ssh_info")] + if not u.endswith("/"): + u += "/" + http_url = u - if u in REVIEW_CACHE: - self._review_url = REVIEW_CACHE[u] - elif 'REPO_HOST_PORT_INFO' in os.environ: - host, port = os.environ['REPO_HOST_PORT_INFO'].split() - self._review_url = self._SshReviewUrl(userEmail, host, port) - REVIEW_CACHE[u] = self._review_url - elif u.startswith('sso:') or u.startswith('ssh:'): - self._review_url = u # Assume it's right - REVIEW_CACHE[u] = self._review_url - elif 'REPO_IGNORE_SSH_INFO' in os.environ: - self._review_url = http_url - REVIEW_CACHE[u] = self._review_url - else: - try: - info_url = u + 'ssh_info' - if not validate_certs: - context = ssl._create_unverified_context() - info = urllib.request.urlopen(info_url, context=context).read() - else: - info = urllib.request.urlopen(info_url).read() - if info == b'NOT_AVAILABLE' or b'<' in info: - # If `info` contains '<', we assume the server gave us some sort - # of HTML response back, like maybe a login page. - # - # Assume HTTP if SSH is not enabled or ssh_info doesn't look right. - self._review_url = http_url - else: - info = info.decode('utf-8') - host, port = info.split() - self._review_url = self._SshReviewUrl(userEmail, host, port) - except urllib.error.HTTPError as e: - raise UploadError('%s: %s' % (self.review, str(e))) - except urllib.error.URLError as e: - raise UploadError('%s: %s' % (self.review, str(e))) - except HTTPException as e: - raise UploadError('%s: %s' % (self.review, e.__class__.__name__)) + if u in REVIEW_CACHE: + self._review_url = REVIEW_CACHE[u] + elif "REPO_HOST_PORT_INFO" in os.environ: + host, port = os.environ["REPO_HOST_PORT_INFO"].split() + self._review_url = self._SshReviewUrl(userEmail, host, port) + REVIEW_CACHE[u] = self._review_url + elif u.startswith("sso:") or u.startswith("ssh:"): + self._review_url = u # Assume it's right + REVIEW_CACHE[u] = self._review_url + elif "REPO_IGNORE_SSH_INFO" in os.environ: + self._review_url = http_url + REVIEW_CACHE[u] = self._review_url + else: + try: + info_url = u + "ssh_info" + if not validate_certs: + context = ssl._create_unverified_context() + info = urllib.request.urlopen( + info_url, context=context + ).read() + else: + info = urllib.request.urlopen(info_url).read() + if info == b"NOT_AVAILABLE" or b"<" in info: + # If `info` contains '<', we assume the server gave us + # some sort of HTML response back, like maybe a login + # page. + # + # Assume HTTP if SSH is not enabled or ssh_info doesn't + # look right. + self._review_url = http_url + else: + info = info.decode("utf-8") + host, port = info.split() + self._review_url = self._SshReviewUrl( + userEmail, host, port + ) + except urllib.error.HTTPError as e: + raise UploadError("%s: %s" % (self.review, str(e))) + except urllib.error.URLError as e: + raise UploadError("%s: %s" % (self.review, str(e))) + except HTTPException as e: + raise UploadError( + "%s: %s" % (self.review, e.__class__.__name__) + ) - REVIEW_CACHE[u] = self._review_url - return self._review_url + self.projectname + REVIEW_CACHE[u] = self._review_url + return self._review_url + self.projectname - def _SshReviewUrl(self, userEmail, host, port): - username = self._config.GetString('review.%s.username' % self.review) - if username is None: - username = userEmail.split('@')[0] - return 'ssh://%s@%s:%s/' % (username, host, port) + def _SshReviewUrl(self, userEmail, host, port): + username = self._config.GetString("review.%s.username" % self.review) + if username is None: + username = userEmail.split("@")[0] + return "ssh://%s@%s:%s/" % (username, host, port) - def ToLocal(self, rev): - """Convert a remote revision string to something we have locally. - """ - if self.name == '.' or IsId(rev): - return rev + def ToLocal(self, rev): + """Convert a remote revision string to something we have locally.""" + if self.name == "." or IsId(rev): + return rev - if not rev.startswith('refs/'): - rev = R_HEADS + rev + if not rev.startswith("refs/"): + rev = R_HEADS + rev - for spec in self.fetch: - if spec.SourceMatches(rev): - return spec.MapSource(rev) + for spec in self.fetch: + if spec.SourceMatches(rev): + return spec.MapSource(rev) - if not rev.startswith(R_HEADS): - return rev + if not rev.startswith(R_HEADS): + return rev - raise GitError('%s: remote %s does not have %s' % - (self.projectname, self.name, rev)) + raise GitError( + "%s: remote %s does not have %s" + % (self.projectname, self.name, rev) + ) - def WritesTo(self, ref): - """True if the remote stores to the tracking ref. - """ - for spec in self.fetch: - if spec.DestMatches(ref): - return True - return False + def WritesTo(self, ref): + """True if the remote stores to the tracking ref.""" + for spec in self.fetch: + if spec.DestMatches(ref): + return True + return False - def ResetFetch(self, mirror=False): - """Set the fetch refspec to its default value. - """ - if mirror: - dst = 'refs/heads/*' - else: - dst = 'refs/remotes/%s/*' % self.name - self.fetch = [RefSpec(True, 'refs/heads/*', dst)] + def ResetFetch(self, mirror=False): + """Set the fetch refspec to its default value.""" + if mirror: + dst = "refs/heads/*" + else: + dst = "refs/remotes/%s/*" % self.name + self.fetch = [RefSpec(True, "refs/heads/*", dst)] - def Save(self): - """Save this remote to the configuration. - """ - self._Set('url', self.url) - if self.pushUrl is not None: - self._Set('pushurl', self.pushUrl + '/' + self.projectname) - else: - self._Set('pushurl', self.pushUrl) - self._Set('review', self.review) - self._Set('projectname', self.projectname) - self._Set('fetch', list(map(str, self.fetch))) + def Save(self): + """Save this remote to the configuration.""" + self._Set("url", self.url) + if self.pushUrl is not None: + self._Set("pushurl", self.pushUrl + "/" + self.projectname) + else: + self._Set("pushurl", self.pushUrl) + self._Set("review", self.review) + self._Set("projectname", self.projectname) + self._Set("fetch", list(map(str, self.fetch))) - def _Set(self, key, value): - key = 'remote.%s.%s' % (self.name, key) - return self._config.SetString(key, value) + def _Set(self, key, value): + key = "remote.%s.%s" % (self.name, key) + return self._config.SetString(key, value) - def _Get(self, key, all_keys=False): - key = 'remote.%s.%s' % (self.name, key) - return self._config.GetString(key, all_keys=all_keys) + def _Get(self, key, all_keys=False): + key = "remote.%s.%s" % (self.name, key) + return self._config.GetString(key, all_keys=all_keys) class Branch(object): - """Configuration options related to a single branch. - """ + """Configuration options related to a single branch.""" - def __init__(self, config, name): - self._config = config - self.name = name - self.merge = self._Get('merge') + def __init__(self, config, name): + self._config = config + self.name = name + self.merge = self._Get("merge") - r = self._Get('remote') - if r: - self.remote = self._config.GetRemote(r) - else: - self.remote = None + r = self._Get("remote") + if r: + self.remote = self._config.GetRemote(r) + else: + self.remote = None - @property - def LocalMerge(self): - """Convert the merge spec to a local name. - """ - if self.remote and self.merge: - return self.remote.ToLocal(self.merge) - return None + @property + def LocalMerge(self): + """Convert the merge spec to a local name.""" + if self.remote and self.merge: + return self.remote.ToLocal(self.merge) + return None - def Save(self): - """Save this branch back into the configuration. - """ - if self._config.HasSection('branch', self.name): - if self.remote: - self._Set('remote', self.remote.name) - else: - self._Set('remote', None) - self._Set('merge', self.merge) + def Save(self): + """Save this branch back into the configuration.""" + if self._config.HasSection("branch", self.name): + if self.remote: + self._Set("remote", self.remote.name) + else: + self._Set("remote", None) + self._Set("merge", self.merge) - else: - with open(self._config.file, 'a') as fd: - fd.write('[branch "%s"]\n' % self.name) - if self.remote: - fd.write('\tremote = %s\n' % self.remote.name) - if self.merge: - fd.write('\tmerge = %s\n' % self.merge) + else: + with open(self._config.file, "a") as fd: + fd.write('[branch "%s"]\n' % self.name) + if self.remote: + fd.write("\tremote = %s\n" % self.remote.name) + if self.merge: + fd.write("\tmerge = %s\n" % self.merge) - def _Set(self, key, value): - key = 'branch.%s.%s' % (self.name, key) - return self._config.SetString(key, value) + def _Set(self, key, value): + key = "branch.%s.%s" % (self.name, key) + return self._config.SetString(key, value) - def _Get(self, key, all_keys=False): - key = 'branch.%s.%s' % (self.name, key) - return self._config.GetString(key, all_keys=all_keys) + def _Get(self, key, all_keys=False): + key = "branch.%s.%s" % (self.name, key) + return self._config.GetString(key, all_keys=all_keys) class SyncAnalysisState: - """Configuration options related to logging of sync state for analysis. + """Configuration options related to logging of sync state for analysis. - This object is versioned. - """ - def __init__(self, config, options, superproject_logging_data): - """Initializes SyncAnalysisState. - - Saves the following data into the |config| object. - - sys.argv, options, superproject's logging data. - - repo.*, branch.* and remote.* parameters from config object. - - Current time as synctime. - - Version number of the object. - - All the keys saved by this object are prepended with SYNC_STATE_PREFIX. - - Args: - config: GitConfig object to store all options. - options: Options passed to sync returned from optparse. See _Options(). - superproject_logging_data: A dictionary of superproject data that is to be logged. + This object is versioned. """ - self._config = config - now = datetime.datetime.utcnow() - self._Set('main.synctime', now.isoformat() + 'Z') - self._Set('main.version', '1') - self._Set('sys.argv', sys.argv) - for key, value in superproject_logging_data.items(): - self._Set(f'superproject.{key}', value) - for key, value in options.__dict__.items(): - self._Set(f'options.{key}', value) - config_items = config.DumpConfigDict().items() - EXTRACT_NAMESPACES = {'repo', 'branch', 'remote'} - self._SetDictionary({k: v for k, v in config_items - if not k.startswith(SYNC_STATE_PREFIX) and - k.split('.', 1)[0] in EXTRACT_NAMESPACES}) - def _SetDictionary(self, data): - """Save all key/value pairs of |data| dictionary. + def __init__(self, config, options, superproject_logging_data): + """Initializes SyncAnalysisState. - Args: - data: A dictionary whose key/value are to be saved. - """ - for key, value in data.items(): - self._Set(key, value) + Saves the following data into the |config| object. + - sys.argv, options, superproject's logging data. + - repo.*, branch.* and remote.* parameters from config object. + - Current time as synctime. + - Version number of the object. - def _Set(self, key, value): - """Set the |value| for a |key| in the |_config| member. + All the keys saved by this object are prepended with SYNC_STATE_PREFIX. - |key| is prepended with the value of SYNC_STATE_PREFIX constant. + Args: + config: GitConfig object to store all options. + options: Options passed to sync returned from optparse. See + _Options(). + superproject_logging_data: A dictionary of superproject data that is + to be logged. + """ + self._config = config + now = datetime.datetime.utcnow() + self._Set("main.synctime", now.isoformat() + "Z") + self._Set("main.version", "1") + self._Set("sys.argv", sys.argv) + for key, value in superproject_logging_data.items(): + self._Set(f"superproject.{key}", value) + for key, value in options.__dict__.items(): + self._Set(f"options.{key}", value) + config_items = config.DumpConfigDict().items() + EXTRACT_NAMESPACES = {"repo", "branch", "remote"} + self._SetDictionary( + { + k: v + for k, v in config_items + if not k.startswith(SYNC_STATE_PREFIX) + and k.split(".", 1)[0] in EXTRACT_NAMESPACES + } + ) - Args: - key: Name of the key. - value: |value| could be of any type. If it is 'bool', it will be saved - as a Boolean and for all other types, it will be saved as a String. - """ - if value is None: - return - sync_key = f'{SYNC_STATE_PREFIX}{key}' - sync_key = sync_key.replace('_', '') - if isinstance(value, str): - self._config.SetString(sync_key, value) - elif isinstance(value, bool): - self._config.SetBoolean(sync_key, value) - else: - self._config.SetString(sync_key, str(value)) + def _SetDictionary(self, data): + """Save all key/value pairs of |data| dictionary. + + Args: + data: A dictionary whose key/value are to be saved. + """ + for key, value in data.items(): + self._Set(key, value) + + def _Set(self, key, value): + """Set the |value| for a |key| in the |_config| member. + + |key| is prepended with the value of SYNC_STATE_PREFIX constant. + + Args: + key: Name of the key. + value: |value| could be of any type. If it is 'bool', it will be + saved as a Boolean and for all other types, it will be saved as + a String. + """ + if value is None: + return + sync_key = f"{SYNC_STATE_PREFIX}{key}" + sync_key = sync_key.replace("_", "") + if isinstance(value, str): + self._config.SetString(sync_key, value) + elif isinstance(value, bool): + self._config.SetBoolean(sync_key, value) + else: + self._config.SetString(sync_key, str(value)) diff --git a/git_refs.py b/git_refs.py index 300d2b30..aca1f90d 100644 --- a/git_refs.py +++ b/git_refs.py @@ -16,149 +16,150 @@ import os from repo_trace import Trace import platform_utils -HEAD = 'HEAD' -R_CHANGES = 'refs/changes/' -R_HEADS = 'refs/heads/' -R_TAGS = 'refs/tags/' -R_PUB = 'refs/published/' -R_WORKTREE = 'refs/worktree/' -R_WORKTREE_M = R_WORKTREE + 'm/' -R_M = 'refs/remotes/m/' +HEAD = "HEAD" +R_CHANGES = "refs/changes/" +R_HEADS = "refs/heads/" +R_TAGS = "refs/tags/" +R_PUB = "refs/published/" +R_WORKTREE = "refs/worktree/" +R_WORKTREE_M = R_WORKTREE + "m/" +R_M = "refs/remotes/m/" class GitRefs(object): - def __init__(self, gitdir): - self._gitdir = gitdir - self._phyref = None - self._symref = None - self._mtime = {} + def __init__(self, gitdir): + self._gitdir = gitdir + self._phyref = None + self._symref = None + self._mtime = {} - @property - def all(self): - self._EnsureLoaded() - return self._phyref + @property + def all(self): + self._EnsureLoaded() + return self._phyref - def get(self, name): - try: - return self.all[name] - except KeyError: - return '' - - def deleted(self, name): - if self._phyref is not None: - if name in self._phyref: - del self._phyref[name] - - if name in self._symref: - del self._symref[name] - - if name in self._mtime: - del self._mtime[name] - - def symref(self, name): - try: - self._EnsureLoaded() - return self._symref[name] - except KeyError: - return '' - - def _EnsureLoaded(self): - if self._phyref is None or self._NeedUpdate(): - self._LoadAll() - - def _NeedUpdate(self): - with Trace(': scan refs %s', self._gitdir): - for name, mtime in self._mtime.items(): + def get(self, name): try: - if mtime != os.path.getmtime(os.path.join(self._gitdir, name)): - return True + return self.all[name] + except KeyError: + return "" + + def deleted(self, name): + if self._phyref is not None: + if name in self._phyref: + del self._phyref[name] + + if name in self._symref: + del self._symref[name] + + if name in self._mtime: + del self._mtime[name] + + def symref(self, name): + try: + self._EnsureLoaded() + return self._symref[name] + except KeyError: + return "" + + def _EnsureLoaded(self): + if self._phyref is None or self._NeedUpdate(): + self._LoadAll() + + def _NeedUpdate(self): + with Trace(": scan refs %s", self._gitdir): + for name, mtime in self._mtime.items(): + try: + if mtime != os.path.getmtime( + os.path.join(self._gitdir, name) + ): + return True + except OSError: + return True + return False + + def _LoadAll(self): + with Trace(": load refs %s", self._gitdir): + self._phyref = {} + self._symref = {} + self._mtime = {} + + self._ReadPackedRefs() + self._ReadLoose("refs/") + self._ReadLoose1(os.path.join(self._gitdir, HEAD), HEAD) + + scan = self._symref + attempts = 0 + while scan and attempts < 5: + scan_next = {} + for name, dest in scan.items(): + if dest in self._phyref: + self._phyref[name] = self._phyref[dest] + else: + scan_next[name] = dest + scan = scan_next + attempts += 1 + + def _ReadPackedRefs(self): + path = os.path.join(self._gitdir, "packed-refs") + try: + fd = open(path, "r") + mtime = os.path.getmtime(path) + except IOError: + return except OSError: - return True - return False + return + try: + for line in fd: + line = str(line) + if line[0] == "#": + continue + if line[0] == "^": + continue - def _LoadAll(self): - with Trace(': load refs %s', self._gitdir): + line = line[:-1] + p = line.split(" ") + ref_id = p[0] + name = p[1] - self._phyref = {} - self._symref = {} - self._mtime = {} + self._phyref[name] = ref_id + finally: + fd.close() + self._mtime["packed-refs"] = mtime - self._ReadPackedRefs() - self._ReadLoose('refs/') - self._ReadLoose1(os.path.join(self._gitdir, HEAD), HEAD) + def _ReadLoose(self, prefix): + base = os.path.join(self._gitdir, prefix) + for name in platform_utils.listdir(base): + p = os.path.join(base, name) + # We don't implement the full ref validation algorithm, just the + # simple rules that would show up in local filesystems. + # https://git-scm.com/docs/git-check-ref-format + if name.startswith(".") or name.endswith(".lock"): + pass + elif platform_utils.isdir(p): + self._mtime[prefix] = os.path.getmtime(base) + self._ReadLoose(prefix + name + "/") + else: + self._ReadLoose1(p, prefix + name) - scan = self._symref - attempts = 0 - while scan and attempts < 5: - scan_next = {} - for name, dest in scan.items(): - if dest in self._phyref: - self._phyref[name] = self._phyref[dest] - else: - scan_next[name] = dest - scan = scan_next - attempts += 1 + def _ReadLoose1(self, path, name): + try: + with open(path) as fd: + mtime = os.path.getmtime(path) + ref_id = fd.readline() + except (OSError, UnicodeError): + return - def _ReadPackedRefs(self): - path = os.path.join(self._gitdir, 'packed-refs') - try: - fd = open(path, 'r') - mtime = os.path.getmtime(path) - except IOError: - return - except OSError: - return - try: - for line in fd: - line = str(line) - if line[0] == '#': - continue - if line[0] == '^': - continue + try: + ref_id = ref_id.decode() + except AttributeError: + pass + if not ref_id: + return + ref_id = ref_id[:-1] - line = line[:-1] - p = line.split(' ') - ref_id = p[0] - name = p[1] - - self._phyref[name] = ref_id - finally: - fd.close() - self._mtime['packed-refs'] = mtime - - def _ReadLoose(self, prefix): - base = os.path.join(self._gitdir, prefix) - for name in platform_utils.listdir(base): - p = os.path.join(base, name) - # We don't implement the full ref validation algorithm, just the simple - # rules that would show up in local filesystems. - # https://git-scm.com/docs/git-check-ref-format - if name.startswith('.') or name.endswith('.lock'): - pass - elif platform_utils.isdir(p): - self._mtime[prefix] = os.path.getmtime(base) - self._ReadLoose(prefix + name + '/') - else: - self._ReadLoose1(p, prefix + name) - - def _ReadLoose1(self, path, name): - try: - with open(path) as fd: - mtime = os.path.getmtime(path) - ref_id = fd.readline() - except (OSError, UnicodeError): - return - - try: - ref_id = ref_id.decode() - except AttributeError: - pass - if not ref_id: - return - ref_id = ref_id[:-1] - - if ref_id.startswith('ref: '): - self._symref[name] = ref_id[5:] - else: - self._phyref[name] = ref_id - self._mtime[name] = mtime + if ref_id.startswith("ref: "): + self._symref[name] = ref_id[5:] + else: + self._phyref[name] = ref_id + self._mtime[name] = mtime diff --git a/git_superproject.py b/git_superproject.py index 69a4d1fe..f1b4f231 100644 --- a/git_superproject.py +++ b/git_superproject.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Provide functionality to get all projects and their commit ids from Superproject. +"""Provide functionality to get projects and their commit ids from Superproject. For more information on superproject, check out: https://en.wikibooks.org/wiki/Git/Submodules_and_Superprojects @@ -33,434 +33,524 @@ from git_command import git_require, GitCommand from git_config import RepoConfig from git_refs import GitRefs -_SUPERPROJECT_GIT_NAME = 'superproject.git' -_SUPERPROJECT_MANIFEST_NAME = 'superproject_override.xml' +_SUPERPROJECT_GIT_NAME = "superproject.git" +_SUPERPROJECT_MANIFEST_NAME = "superproject_override.xml" class SyncResult(NamedTuple): - """Return the status of sync and whether caller should exit.""" + """Return the status of sync and whether caller should exit.""" - # Whether the superproject sync was successful. - success: bool - # Whether the caller should exit. - fatal: bool + # Whether the superproject sync was successful. + success: bool + # Whether the caller should exit. + fatal: bool class CommitIdsResult(NamedTuple): - """Return the commit ids and whether caller should exit.""" + """Return the commit ids and whether caller should exit.""" - # A dictionary with the projects/commit ids on success, otherwise None. - commit_ids: dict - # Whether the caller should exit. - fatal: bool + # A dictionary with the projects/commit ids on success, otherwise None. + commit_ids: dict + # Whether the caller should exit. + fatal: bool class UpdateProjectsResult(NamedTuple): - """Return the overriding manifest file and whether caller should exit.""" + """Return the overriding manifest file and whether caller should exit.""" - # Path name of the overriding manifest file if successful, otherwise None. - manifest_path: str - # Whether the caller should exit. - fatal: bool + # Path name of the overriding manifest file if successful, otherwise None. + manifest_path: str + # Whether the caller should exit. + fatal: bool class Superproject(object): - """Get commit ids from superproject. + """Get commit ids from superproject. - Initializes a local copy of a superproject for the manifest. This allows - lookup of commit ids for all projects. It contains _project_commit_ids which - is a dictionary with project/commit id entries. - """ - def __init__(self, manifest, name, remote, revision, - superproject_dir='exp-superproject'): - """Initializes superproject. - - Args: - manifest: A Manifest object that is to be written to a file. - name: The unique name of the superproject - remote: The RemoteSpec for the remote. - revision: The name of the git branch to track. - superproject_dir: Relative path under |manifest.subdir| to checkout - superproject. + Initializes a local copy of a superproject for the manifest. This allows + lookup of commit ids for all projects. It contains _project_commit_ids which + is a dictionary with project/commit id entries. """ - self._project_commit_ids = None - self._manifest = manifest - self.name = name - self.remote = remote - self.revision = self._branch = revision - self._repodir = manifest.repodir - self._superproject_dir = superproject_dir - self._superproject_path = manifest.SubmanifestInfoDir(manifest.path_prefix, - superproject_dir) - self._manifest_path = os.path.join(self._superproject_path, - _SUPERPROJECT_MANIFEST_NAME) - git_name = hashlib.md5(remote.name.encode('utf8')).hexdigest() + '-' - self._remote_url = remote.url - self._work_git_name = git_name + _SUPERPROJECT_GIT_NAME - self._work_git = os.path.join(self._superproject_path, self._work_git_name) - # The following are command arguemnts, rather than superproject attributes, - # and were included here originally. They should eventually become - # arguments that are passed down from the public methods, instead of being - # treated as attributes. - self._git_event_log = None - self._quiet = False - self._print_messages = False + def __init__( + self, + manifest, + name, + remote, + revision, + superproject_dir="exp-superproject", + ): + """Initializes superproject. - def SetQuiet(self, value): - """Set the _quiet attribute.""" - self._quiet = value + Args: + manifest: A Manifest object that is to be written to a file. + name: The unique name of the superproject + remote: The RemoteSpec for the remote. + revision: The name of the git branch to track. + superproject_dir: Relative path under |manifest.subdir| to checkout + superproject. + """ + self._project_commit_ids = None + self._manifest = manifest + self.name = name + self.remote = remote + self.revision = self._branch = revision + self._repodir = manifest.repodir + self._superproject_dir = superproject_dir + self._superproject_path = manifest.SubmanifestInfoDir( + manifest.path_prefix, superproject_dir + ) + self._manifest_path = os.path.join( + self._superproject_path, _SUPERPROJECT_MANIFEST_NAME + ) + git_name = hashlib.md5(remote.name.encode("utf8")).hexdigest() + "-" + self._remote_url = remote.url + self._work_git_name = git_name + _SUPERPROJECT_GIT_NAME + self._work_git = os.path.join( + self._superproject_path, self._work_git_name + ) - def SetPrintMessages(self, value): - """Set the _print_messages attribute.""" - self._print_messages = value + # The following are command arguemnts, rather than superproject + # attributes, and were included here originally. They should eventually + # become arguments that are passed down from the public methods, instead + # of being treated as attributes. + self._git_event_log = None + self._quiet = False + self._print_messages = False - @property - def project_commit_ids(self): - """Returns a dictionary of projects and their commit ids.""" - return self._project_commit_ids + def SetQuiet(self, value): + """Set the _quiet attribute.""" + self._quiet = value - @property - def manifest_path(self): - """Returns the manifest path if the path exists or None.""" - return self._manifest_path if os.path.exists(self._manifest_path) else None + def SetPrintMessages(self, value): + """Set the _print_messages attribute.""" + self._print_messages = value - def _LogMessage(self, fmt, *inputs): - """Logs message to stderr and _git_event_log.""" - message = f'{self._LogMessagePrefix()} {fmt.format(*inputs)}' - if self._print_messages: - print(message, file=sys.stderr) - self._git_event_log.ErrorEvent(message, fmt) + @property + def project_commit_ids(self): + """Returns a dictionary of projects and their commit ids.""" + return self._project_commit_ids - def _LogMessagePrefix(self): - """Returns the prefix string to be logged in each log message""" - return f'repo superproject branch: {self._branch} url: {self._remote_url}' + @property + def manifest_path(self): + """Returns the manifest path if the path exists or None.""" + return ( + self._manifest_path if os.path.exists(self._manifest_path) else None + ) - def _LogError(self, fmt, *inputs): - """Logs error message to stderr and _git_event_log.""" - self._LogMessage(f'error: {fmt}', *inputs) + def _LogMessage(self, fmt, *inputs): + """Logs message to stderr and _git_event_log.""" + message = f"{self._LogMessagePrefix()} {fmt.format(*inputs)}" + if self._print_messages: + print(message, file=sys.stderr) + self._git_event_log.ErrorEvent(message, fmt) - def _LogWarning(self, fmt, *inputs): - """Logs warning message to stderr and _git_event_log.""" - self._LogMessage(f'warning: {fmt}', *inputs) + def _LogMessagePrefix(self): + """Returns the prefix string to be logged in each log message""" + return ( + f"repo superproject branch: {self._branch} url: {self._remote_url}" + ) - def _Init(self): - """Sets up a local Git repository to get a copy of a superproject. + def _LogError(self, fmt, *inputs): + """Logs error message to stderr and _git_event_log.""" + self._LogMessage(f"error: {fmt}", *inputs) - Returns: - True if initialization is successful, or False. - """ - if not os.path.exists(self._superproject_path): - os.mkdir(self._superproject_path) - if not self._quiet and not os.path.exists(self._work_git): - print('%s: Performing initial setup for superproject; this might take ' - 'several minutes.' % self._work_git) - cmd = ['init', '--bare', self._work_git_name] - p = GitCommand(None, - cmd, - cwd=self._superproject_path, - capture_stdout=True, - capture_stderr=True) - retval = p.Wait() - if retval: - self._LogWarning('git init call failed, command: git {}, ' - 'return code: {}, stderr: {}', cmd, retval, p.stderr) - return False - return True + def _LogWarning(self, fmt, *inputs): + """Logs warning message to stderr and _git_event_log.""" + self._LogMessage(f"warning: {fmt}", *inputs) - def _Fetch(self): - """Fetches a local copy of a superproject for the manifest based on |_remote_url|. + def _Init(self): + """Sets up a local Git repository to get a copy of a superproject. - Returns: - True if fetch is successful, or False. - """ - if not os.path.exists(self._work_git): - self._LogWarning('git fetch missing directory: {}', self._work_git) - return False - if not git_require((2, 28, 0)): - self._LogWarning('superproject requires a git version 2.28 or later') - return False - cmd = ['fetch', self._remote_url, '--depth', '1', '--force', '--no-tags', - '--filter', 'blob:none'] + Returns: + True if initialization is successful, or False. + """ + if not os.path.exists(self._superproject_path): + os.mkdir(self._superproject_path) + if not self._quiet and not os.path.exists(self._work_git): + print( + "%s: Performing initial setup for superproject; this might " + "take several minutes." % self._work_git + ) + cmd = ["init", "--bare", self._work_git_name] + p = GitCommand( + None, + cmd, + cwd=self._superproject_path, + capture_stdout=True, + capture_stderr=True, + ) + retval = p.Wait() + if retval: + self._LogWarning( + "git init call failed, command: git {}, " + "return code: {}, stderr: {}", + cmd, + retval, + p.stderr, + ) + return False + return True - # Check if there is a local ref that we can pass to --negotiation-tip. - # If this is the first fetch, it does not exist yet. - # We use --negotiation-tip to speed up the fetch. Superproject branches do - # not share commits. So this lets git know it only needs to send commits - # reachable from the specified local refs. - rev_commit = GitRefs(self._work_git).get(f'refs/heads/{self.revision}') - if rev_commit: - cmd.extend(['--negotiation-tip', rev_commit]) + def _Fetch(self): + """Fetches a superproject for the manifest based on |_remote_url|. - if self._branch: - cmd += [self._branch + ':' + self._branch] - p = GitCommand(None, - cmd, - cwd=self._work_git, - capture_stdout=True, - capture_stderr=True) - retval = p.Wait() - if retval: - self._LogWarning('git fetch call failed, command: git {}, ' - 'return code: {}, stderr: {}', cmd, retval, p.stderr) - return False - return True + This runs git fetch which stores a local copy the superproject. - def _LsTree(self): - """Gets the commit ids for all projects. + Returns: + True if fetch is successful, or False. + """ + if not os.path.exists(self._work_git): + self._LogWarning("git fetch missing directory: {}", self._work_git) + return False + if not git_require((2, 28, 0)): + self._LogWarning( + "superproject requires a git version 2.28 or later" + ) + return False + cmd = [ + "fetch", + self._remote_url, + "--depth", + "1", + "--force", + "--no-tags", + "--filter", + "blob:none", + ] - Works only in git repositories. + # Check if there is a local ref that we can pass to --negotiation-tip. + # If this is the first fetch, it does not exist yet. + # We use --negotiation-tip to speed up the fetch. Superproject branches + # do not share commits. So this lets git know it only needs to send + # commits reachable from the specified local refs. + rev_commit = GitRefs(self._work_git).get(f"refs/heads/{self.revision}") + if rev_commit: + cmd.extend(["--negotiation-tip", rev_commit]) - Returns: - data: data returned from 'git ls-tree ...' instead of None. - """ - if not os.path.exists(self._work_git): - self._LogWarning('git ls-tree missing directory: {}', self._work_git) - return None - data = None - branch = 'HEAD' if not self._branch else self._branch - cmd = ['ls-tree', '-z', '-r', branch] + if self._branch: + cmd += [self._branch + ":" + self._branch] + p = GitCommand( + None, + cmd, + cwd=self._work_git, + capture_stdout=True, + capture_stderr=True, + ) + retval = p.Wait() + if retval: + self._LogWarning( + "git fetch call failed, command: git {}, " + "return code: {}, stderr: {}", + cmd, + retval, + p.stderr, + ) + return False + return True - p = GitCommand(None, - cmd, - cwd=self._work_git, - capture_stdout=True, - capture_stderr=True) - retval = p.Wait() - if retval == 0: - data = p.stdout - else: - self._LogWarning('git ls-tree call failed, command: git {}, ' - 'return code: {}, stderr: {}', cmd, retval, p.stderr) - return data + def _LsTree(self): + """Gets the commit ids for all projects. - def Sync(self, git_event_log): - """Gets a local copy of a superproject for the manifest. + Works only in git repositories. - Args: - git_event_log: an EventLog, for git tracing. + Returns: + data: data returned from 'git ls-tree ...' instead of None. + """ + if not os.path.exists(self._work_git): + self._LogWarning( + "git ls-tree missing directory: {}", self._work_git + ) + return None + data = None + branch = "HEAD" if not self._branch else self._branch + cmd = ["ls-tree", "-z", "-r", branch] - Returns: - SyncResult - """ - self._git_event_log = git_event_log - if not self._manifest.superproject: - self._LogWarning('superproject tag is not defined in manifest: {}', - self._manifest.manifestFile) - return SyncResult(False, False) + p = GitCommand( + None, + cmd, + cwd=self._work_git, + capture_stdout=True, + capture_stderr=True, + ) + retval = p.Wait() + if retval == 0: + data = p.stdout + else: + self._LogWarning( + "git ls-tree call failed, command: git {}, " + "return code: {}, stderr: {}", + cmd, + retval, + p.stderr, + ) + return data - _PrintBetaNotice() + def Sync(self, git_event_log): + """Gets a local copy of a superproject for the manifest. - should_exit = True - if not self._remote_url: - self._LogWarning('superproject URL is not defined in manifest: {}', - self._manifest.manifestFile) - return SyncResult(False, should_exit) + Args: + git_event_log: an EventLog, for git tracing. - if not self._Init(): - return SyncResult(False, should_exit) - if not self._Fetch(): - return SyncResult(False, should_exit) - if not self._quiet: - print('%s: Initial setup for superproject completed.' % self._work_git) - return SyncResult(True, False) + Returns: + SyncResult + """ + self._git_event_log = git_event_log + if not self._manifest.superproject: + self._LogWarning( + "superproject tag is not defined in manifest: {}", + self._manifest.manifestFile, + ) + return SyncResult(False, False) - def _GetAllProjectsCommitIds(self): - """Get commit ids for all projects from superproject and save them in _project_commit_ids. + _PrintBetaNotice() - Returns: - CommitIdsResult - """ - sync_result = self.Sync(self._git_event_log) - if not sync_result.success: - return CommitIdsResult(None, sync_result.fatal) + should_exit = True + if not self._remote_url: + self._LogWarning( + "superproject URL is not defined in manifest: {}", + self._manifest.manifestFile, + ) + return SyncResult(False, should_exit) - data = self._LsTree() - if not data: - self._LogWarning('git ls-tree failed to return data for manifest: {}', - self._manifest.manifestFile) - return CommitIdsResult(None, True) + if not self._Init(): + return SyncResult(False, should_exit) + if not self._Fetch(): + return SyncResult(False, should_exit) + if not self._quiet: + print( + "%s: Initial setup for superproject completed." % self._work_git + ) + return SyncResult(True, False) - # Parse lines like the following to select lines starting with '160000' and - # build a dictionary with project path (last element) and its commit id (3rd element). - # - # 160000 commit 2c2724cb36cd5a9cec6c852c681efc3b7c6b86ea\tart\x00 - # 120000 blob acc2cbdf438f9d2141f0ae424cec1d8fc4b5d97f\tbootstrap.bash\x00 - commit_ids = {} - for line in data.split('\x00'): - ls_data = line.split(None, 3) - if not ls_data: - break - if ls_data[0] == '160000': - commit_ids[ls_data[3]] = ls_data[2] + def _GetAllProjectsCommitIds(self): + """Get commit ids for all projects from superproject and save them. - self._project_commit_ids = commit_ids - return CommitIdsResult(commit_ids, False) + Commit ids are saved in _project_commit_ids. - def _WriteManifestFile(self): - """Writes manifest to a file. + Returns: + CommitIdsResult + """ + sync_result = self.Sync(self._git_event_log) + if not sync_result.success: + return CommitIdsResult(None, sync_result.fatal) - Returns: - manifest_path: Path name of the file into which manifest is written instead of None. - """ - if not os.path.exists(self._superproject_path): - self._LogWarning('missing superproject directory: {}', self._superproject_path) - return None - manifest_str = self._manifest.ToXml(groups=self._manifest.GetGroupsStr(), - omit_local=True).toxml() - manifest_path = self._manifest_path - try: - with open(manifest_path, 'w', encoding='utf-8') as fp: - fp.write(manifest_str) - except IOError as e: - self._LogError('cannot write manifest to : {} {}', - manifest_path, e) - return None - return manifest_path + data = self._LsTree() + if not data: + self._LogWarning( + "git ls-tree failed to return data for manifest: {}", + self._manifest.manifestFile, + ) + return CommitIdsResult(None, True) - def _SkipUpdatingProjectRevisionId(self, project): - """Checks if a project's revision id needs to be updated or not. + # Parse lines like the following to select lines starting with '160000' + # and build a dictionary with project path (last element) and its commit + # id (3rd element). + # + # 160000 commit 2c2724cb36cd5a9cec6c852c681efc3b7c6b86ea\tart\x00 + # 120000 blob acc2cbdf438f9d2141f0ae424cec1d8fc4b5d97f\tbootstrap.bash\x00 # noqa: E501 + commit_ids = {} + for line in data.split("\x00"): + ls_data = line.split(None, 3) + if not ls_data: + break + if ls_data[0] == "160000": + commit_ids[ls_data[3]] = ls_data[2] - Revision id for projects from local manifest will not be updated. + self._project_commit_ids = commit_ids + return CommitIdsResult(commit_ids, False) - Args: - project: project whose revision id is being updated. + def _WriteManifestFile(self): + """Writes manifest to a file. - Returns: - True if a project's revision id should not be updated, or False, - """ - path = project.relpath - if not path: - return True - # Skip the project with revisionId. - if project.revisionId: - return True - # Skip the project if it comes from the local manifest. - return project.manifest.IsFromLocalManifest(project) + Returns: + manifest_path: Path name of the file into which manifest is written + instead of None. + """ + if not os.path.exists(self._superproject_path): + self._LogWarning( + "missing superproject directory: {}", self._superproject_path + ) + return None + manifest_str = self._manifest.ToXml( + groups=self._manifest.GetGroupsStr(), omit_local=True + ).toxml() + manifest_path = self._manifest_path + try: + with open(manifest_path, "w", encoding="utf-8") as fp: + fp.write(manifest_str) + except IOError as e: + self._LogError("cannot write manifest to : {} {}", manifest_path, e) + return None + return manifest_path - def UpdateProjectsRevisionId(self, projects, git_event_log): - """Update revisionId of every project in projects with the commit id. + def _SkipUpdatingProjectRevisionId(self, project): + """Checks if a project's revision id needs to be updated or not. - Args: - projects: a list of projects whose revisionId needs to be updated. - git_event_log: an EventLog, for git tracing. + Revision id for projects from local manifest will not be updated. - Returns: - UpdateProjectsResult - """ - self._git_event_log = git_event_log - commit_ids_result = self._GetAllProjectsCommitIds() - commit_ids = commit_ids_result.commit_ids - if not commit_ids: - return UpdateProjectsResult(None, commit_ids_result.fatal) + Args: + project: project whose revision id is being updated. - projects_missing_commit_ids = [] - for project in projects: - if self._SkipUpdatingProjectRevisionId(project): - continue - path = project.relpath - commit_id = commit_ids.get(path) - if not commit_id: - projects_missing_commit_ids.append(path) + Returns: + True if a project's revision id should not be updated, or False, + """ + path = project.relpath + if not path: + return True + # Skip the project with revisionId. + if project.revisionId: + return True + # Skip the project if it comes from the local manifest. + return project.manifest.IsFromLocalManifest(project) - # If superproject doesn't have a commit id for a project, then report an - # error event and continue as if do not use superproject is specified. - if projects_missing_commit_ids: - self._LogWarning('please file a bug using {} to report missing ' - 'commit_ids for: {}', self._manifest.contactinfo.bugurl, - projects_missing_commit_ids) - return UpdateProjectsResult(None, False) + def UpdateProjectsRevisionId(self, projects, git_event_log): + """Update revisionId of every project in projects with the commit id. - for project in projects: - if not self._SkipUpdatingProjectRevisionId(project): - project.SetRevisionId(commit_ids.get(project.relpath)) + Args: + projects: a list of projects whose revisionId needs to be updated. + git_event_log: an EventLog, for git tracing. - manifest_path = self._WriteManifestFile() - return UpdateProjectsResult(manifest_path, False) + Returns: + UpdateProjectsResult + """ + self._git_event_log = git_event_log + commit_ids_result = self._GetAllProjectsCommitIds() + commit_ids = commit_ids_result.commit_ids + if not commit_ids: + return UpdateProjectsResult(None, commit_ids_result.fatal) + + projects_missing_commit_ids = [] + for project in projects: + if self._SkipUpdatingProjectRevisionId(project): + continue + path = project.relpath + commit_id = commit_ids.get(path) + if not commit_id: + projects_missing_commit_ids.append(path) + + # If superproject doesn't have a commit id for a project, then report an + # error event and continue as if do not use superproject is specified. + if projects_missing_commit_ids: + self._LogWarning( + "please file a bug using {} to report missing " + "commit_ids for: {}", + self._manifest.contactinfo.bugurl, + projects_missing_commit_ids, + ) + return UpdateProjectsResult(None, False) + + for project in projects: + if not self._SkipUpdatingProjectRevisionId(project): + project.SetRevisionId(commit_ids.get(project.relpath)) + + manifest_path = self._WriteManifestFile() + return UpdateProjectsResult(manifest_path, False) @functools.lru_cache(maxsize=10) def _PrintBetaNotice(): - """Print the notice of beta status.""" - print('NOTICE: --use-superproject is in beta; report any issues to the ' - 'address described in `repo version`', file=sys.stderr) + """Print the notice of beta status.""" + print( + "NOTICE: --use-superproject is in beta; report any issues to the " + "address described in `repo version`", + file=sys.stderr, + ) @functools.lru_cache(maxsize=None) def _UseSuperprojectFromConfiguration(): - """Returns the user choice of whether to use superproject.""" - user_cfg = RepoConfig.ForUser() - time_now = int(time.time()) + """Returns the user choice of whether to use superproject.""" + user_cfg = RepoConfig.ForUser() + time_now = int(time.time()) - user_value = user_cfg.GetBoolean('repo.superprojectChoice') - if user_value is not None: - user_expiration = user_cfg.GetInt('repo.superprojectChoiceExpire') - if user_expiration is None or user_expiration <= 0 or user_expiration >= time_now: - # TODO(b/190688390) - Remove prompt when we are comfortable with the new - # default value. - if user_value: - print(('You are currently enrolled in Git submodules experiment ' - '(go/android-submodules-quickstart). Use --no-use-superproject ' - 'to override.\n'), file=sys.stderr) - else: - print(('You are not currently enrolled in Git submodules experiment ' - '(go/android-submodules-quickstart). Use --use-superproject ' - 'to override.\n'), file=sys.stderr) - return user_value + user_value = user_cfg.GetBoolean("repo.superprojectChoice") + if user_value is not None: + user_expiration = user_cfg.GetInt("repo.superprojectChoiceExpire") + if ( + user_expiration is None + or user_expiration <= 0 + or user_expiration >= time_now + ): + # TODO(b/190688390) - Remove prompt when we are comfortable with the + # new default value. + if user_value: + print( + ( + "You are currently enrolled in Git submodules " + "experiment (go/android-submodules-quickstart). Use " + "--no-use-superproject to override.\n" + ), + file=sys.stderr, + ) + else: + print( + ( + "You are not currently enrolled in Git submodules " + "experiment (go/android-submodules-quickstart). Use " + "--use-superproject to override.\n" + ), + file=sys.stderr, + ) + return user_value - # We don't have an unexpired choice, ask for one. - system_cfg = RepoConfig.ForSystem() - system_value = system_cfg.GetBoolean('repo.superprojectChoice') - if system_value: - # The system configuration is proposing that we should enable the - # use of superproject. Treat the user as enrolled for two weeks. - # - # TODO(b/190688390) - Remove prompt when we are comfortable with the new - # default value. - userchoice = True - time_choiceexpire = time_now + (86400 * 14) - user_cfg.SetString('repo.superprojectChoiceExpire', str(time_choiceexpire)) - user_cfg.SetBoolean('repo.superprojectChoice', userchoice) - print('You are automatically enrolled in Git submodules experiment ' - '(go/android-submodules-quickstart) for another two weeks.\n', - file=sys.stderr) - return True + # We don't have an unexpired choice, ask for one. + system_cfg = RepoConfig.ForSystem() + system_value = system_cfg.GetBoolean("repo.superprojectChoice") + if system_value: + # The system configuration is proposing that we should enable the + # use of superproject. Treat the user as enrolled for two weeks. + # + # TODO(b/190688390) - Remove prompt when we are comfortable with the new + # default value. + userchoice = True + time_choiceexpire = time_now + (86400 * 14) + user_cfg.SetString( + "repo.superprojectChoiceExpire", str(time_choiceexpire) + ) + user_cfg.SetBoolean("repo.superprojectChoice", userchoice) + print( + "You are automatically enrolled in Git submodules experiment " + "(go/android-submodules-quickstart) for another two weeks.\n", + file=sys.stderr, + ) + return True - # For all other cases, we would not use superproject by default. - return False + # For all other cases, we would not use superproject by default. + return False def PrintMessages(use_superproject, manifest): - """Returns a boolean if error/warning messages are to be printed. + """Returns a boolean if error/warning messages are to be printed. - Args: - use_superproject: option value from optparse. - manifest: manifest to use. - """ - return use_superproject is not None or bool(manifest.superproject) + Args: + use_superproject: option value from optparse. + manifest: manifest to use. + """ + return use_superproject is not None or bool(manifest.superproject) def UseSuperproject(use_superproject, manifest): - """Returns a boolean if use-superproject option is enabled. + """Returns a boolean if use-superproject option is enabled. - Args: - use_superproject: option value from optparse. - manifest: manifest to use. + Args: + use_superproject: option value from optparse. + manifest: manifest to use. - Returns: - Whether the superproject should be used. - """ + Returns: + Whether the superproject should be used. + """ - if not manifest.superproject: - # This (sub) manifest does not have a superproject definition. - return False - elif use_superproject is not None: - return use_superproject - else: - client_value = manifest.manifestProject.use_superproject - if client_value is not None: - return client_value - elif manifest.superproject: - return _UseSuperprojectFromConfiguration() + if not manifest.superproject: + # This (sub) manifest does not have a superproject definition. + return False + elif use_superproject is not None: + return use_superproject else: - return False + client_value = manifest.manifestProject.use_superproject + if client_value is not None: + return client_value + elif manifest.superproject: + return _UseSuperprojectFromConfiguration() + else: + return False diff --git a/git_trace2_event_log.py b/git_trace2_event_log.py index 2edab0e1..d90e9039 100644 --- a/git_trace2_event_log.py +++ b/git_trace2_event_log.py @@ -41,291 +41,330 @@ from git_command import GitCommand, RepoSourceVersion class EventLog(object): - """Event log that records events that occurred during a repo invocation. + """Event log that records events that occurred during a repo invocation. - Events are written to the log as a consecutive JSON entries, one per line. - Entries follow the git trace2 EVENT format. + Events are written to the log as a consecutive JSON entries, one per line. + Entries follow the git trace2 EVENT format. - Each entry contains the following common keys: - - event: The event name - - sid: session-id - Unique string to allow process instance to be identified. - - thread: The thread name. - - time: is the UTC time of the event. + Each entry contains the following common keys: + - event: The event name + - sid: session-id - Unique string to allow process instance to be + identified. + - thread: The thread name. + - time: is the UTC time of the event. - Valid 'event' names and event specific fields are documented here: - https://git-scm.com/docs/api-trace2#_event_format - """ - - def __init__(self, env=None): - """Initializes the event log.""" - self._log = [] - # Try to get session-id (sid) from environment (setup in repo launcher). - KEY = 'GIT_TRACE2_PARENT_SID' - if env is None: - env = os.environ - - now = datetime.datetime.utcnow() - - # Save both our sid component and the complete sid. - # We use our sid component (self._sid) as the unique filename prefix and - # the full sid (self._full_sid) in the log itself. - self._sid = 'repo-%s-P%08x' % (now.strftime('%Y%m%dT%H%M%SZ'), os.getpid()) - parent_sid = env.get(KEY) - # Append our sid component to the parent sid (if it exists). - if parent_sid is not None: - self._full_sid = parent_sid + '/' + self._sid - else: - self._full_sid = self._sid - - # Set/update the environment variable. - # Environment handling across systems is messy. - try: - env[KEY] = self._full_sid - except UnicodeEncodeError: - env[KEY] = self._full_sid.encode() - - # Add a version event to front of the log. - self._AddVersionEvent() - - @property - def full_sid(self): - return self._full_sid - - def _AddVersionEvent(self): - """Adds a 'version' event at the beginning of current log.""" - version_event = self._CreateEventDict('version') - version_event['evt'] = "2" - version_event['exe'] = RepoSourceVersion() - self._log.insert(0, version_event) - - def _CreateEventDict(self, event_name): - """Returns a dictionary with the common keys/values for git trace2 events. - - Args: - event_name: The event name. - - Returns: - Dictionary with the common event fields populated. - """ - return { - 'event': event_name, - 'sid': self._full_sid, - 'thread': threading.current_thread().name, - 'time': datetime.datetime.utcnow().isoformat() + 'Z', - } - - def StartEvent(self): - """Append a 'start' event to the current log.""" - start_event = self._CreateEventDict('start') - start_event['argv'] = sys.argv - self._log.append(start_event) - - def ExitEvent(self, result): - """Append an 'exit' event to the current log. - - Args: - result: Exit code of the event - """ - exit_event = self._CreateEventDict('exit') - - # Consider 'None' success (consistent with event_log result handling). - if result is None: - result = 0 - exit_event['code'] = result - self._log.append(exit_event) - - def CommandEvent(self, name, subcommands): - """Append a 'command' event to the current log. - - Args: - name: Name of the primary command (ex: repo, git) - subcommands: List of the sub-commands (ex: version, init, sync) - """ - command_event = self._CreateEventDict('command') - command_event['name'] = name - command_event['subcommands'] = subcommands - self._log.append(command_event) - - def LogConfigEvents(self, config, event_dict_name): - """Append a |event_dict_name| event for each config key in |config|. - - Args: - config: Configuration dictionary. - event_dict_name: Name of the event dictionary for items to be logged under. - """ - for param, value in config.items(): - event = self._CreateEventDict(event_dict_name) - event['param'] = param - event['value'] = value - self._log.append(event) - - def DefParamRepoEvents(self, config): - """Append a 'def_param' event for each repo.* config key to the current log. - - Args: - config: Repo configuration dictionary - """ - # Only output the repo.* config parameters. - repo_config = {k: v for k, v in config.items() if k.startswith('repo.')} - self.LogConfigEvents(repo_config, 'def_param') - - def GetDataEventName(self, value): - """Returns 'data-json' if the value is an array else returns 'data'.""" - return 'data-json' if value[0] == '[' and value[-1] == ']' else 'data' - - def LogDataConfigEvents(self, config, prefix): - """Append a 'data' event for each config key/value in |config| to the current log. - - For each keyX and valueX of the config, "key" field of the event is '|prefix|/keyX' - and the "value" of the "key" field is valueX. - - Args: - config: Configuration dictionary. - prefix: Prefix for each key that is logged. - """ - for key, value in config.items(): - event = self._CreateEventDict(self.GetDataEventName(value)) - event['key'] = f'{prefix}/{key}' - event['value'] = value - self._log.append(event) - - def ErrorEvent(self, msg, fmt): - """Append a 'error' event to the current log.""" - error_event = self._CreateEventDict('error') - error_event['msg'] = msg - error_event['fmt'] = fmt - self._log.append(error_event) - - def _GetEventTargetPath(self): - """Get the 'trace2.eventtarget' path from git configuration. - - Returns: - path: git config's 'trace2.eventtarget' path if it exists, or None - """ - path = None - cmd = ['config', '--get', 'trace2.eventtarget'] - # TODO(https://crbug.com/gerrit/13706): Use GitConfig when it supports - # system git config variables. - p = GitCommand(None, cmd, capture_stdout=True, capture_stderr=True, - bare=True) - retval = p.Wait() - if retval == 0: - # Strip trailing carriage-return in path. - path = p.stdout.rstrip('\n') - elif retval != 1: - # `git config --get` is documented to produce an exit status of `1` if - # the requested variable is not present in the configuration. Report any - # other return value as an error. - print("repo: error: 'git config --get' call failed with return code: %r, stderr: %r" % ( - retval, p.stderr), file=sys.stderr) - return path - - def _WriteLog(self, write_fn): - """Writes the log out using a provided writer function. - - Generate compact JSON output for each item in the log, and write it using - write_fn. - - Args: - write_fn: A function that accepts byts and writes them to a destination. + Valid 'event' names and event specific fields are documented here: + https://git-scm.com/docs/api-trace2#_event_format """ - for e in self._log: - # Dump in compact encoding mode. - # See 'Compact encoding' in Python docs: - # https://docs.python.org/3/library/json.html#module-json - write_fn(json.dumps(e, indent=None, separators=(',', ':')).encode('utf-8') + b'\n') + def __init__(self, env=None): + """Initializes the event log.""" + self._log = [] + # Try to get session-id (sid) from environment (setup in repo launcher). + KEY = "GIT_TRACE2_PARENT_SID" + if env is None: + env = os.environ - def Write(self, path=None): - """Writes the log out to a file or socket. + now = datetime.datetime.utcnow() - Log is only written if 'path' or 'git config --get trace2.eventtarget' - provide a valid path (or socket) to write logs to. + # Save both our sid component and the complete sid. + # We use our sid component (self._sid) as the unique filename prefix and + # the full sid (self._full_sid) in the log itself. + self._sid = "repo-%s-P%08x" % ( + now.strftime("%Y%m%dT%H%M%SZ"), + os.getpid(), + ) + parent_sid = env.get(KEY) + # Append our sid component to the parent sid (if it exists). + if parent_sid is not None: + self._full_sid = parent_sid + "/" + self._sid + else: + self._full_sid = self._sid - Logging filename format follows the git trace2 style of being a unique - (exclusive writable) file. - - Args: - path: Path to where logs should be written. The path may have a prefix of - the form "af_unix:[{stream|dgram}:]", in which case the path is - treated as a Unix domain socket. See - https://git-scm.com/docs/api-trace2#_enabling_a_target for details. - - Returns: - log_path: Path to the log file or socket if log is written, otherwise None - """ - log_path = None - # If no logging path is specified, get the path from 'trace2.eventtarget'. - if path is None: - path = self._GetEventTargetPath() - - # If no logging path is specified, exit. - if path is None: - return None - - path_is_socket = False - socket_type = None - if isinstance(path, str): - parts = path.split(':', 1) - if parts[0] == 'af_unix' and len(parts) == 2: - path_is_socket = True - path = parts[1] - parts = path.split(':', 1) - if parts[0] == 'stream' and len(parts) == 2: - socket_type = socket.SOCK_STREAM - path = parts[1] - elif parts[0] == 'dgram' and len(parts) == 2: - socket_type = socket.SOCK_DGRAM - path = parts[1] - else: - # Get absolute path. - path = os.path.abspath(os.path.expanduser(path)) - else: - raise TypeError('path: str required but got %s.' % type(path)) - - # Git trace2 requires a directory to write log to. - - # TODO(https://crbug.com/gerrit/13706): Support file (append) mode also. - if not (path_is_socket or os.path.isdir(path)): - return None - - if path_is_socket: - if socket_type == socket.SOCK_STREAM or socket_type is None: + # Set/update the environment variable. + # Environment handling across systems is messy. try: - with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as sock: - sock.connect(path) - self._WriteLog(sock.sendall) - return f'af_unix:stream:{path}' - except OSError as err: - # If we tried to connect to a DGRAM socket using STREAM, ignore the - # attempt and continue to DGRAM below. Otherwise, issue a warning. - if err.errno != errno.EPROTOTYPE: - print(f'repo: warning: git trace2 logging failed: {err}', file=sys.stderr) + env[KEY] = self._full_sid + except UnicodeEncodeError: + env[KEY] = self._full_sid.encode() + + # Add a version event to front of the log. + self._AddVersionEvent() + + @property + def full_sid(self): + return self._full_sid + + def _AddVersionEvent(self): + """Adds a 'version' event at the beginning of current log.""" + version_event = self._CreateEventDict("version") + version_event["evt"] = "2" + version_event["exe"] = RepoSourceVersion() + self._log.insert(0, version_event) + + def _CreateEventDict(self, event_name): + """Returns a dictionary with common keys/values for git trace2 events. + + Args: + event_name: The event name. + + Returns: + Dictionary with the common event fields populated. + """ + return { + "event": event_name, + "sid": self._full_sid, + "thread": threading.current_thread().name, + "time": datetime.datetime.utcnow().isoformat() + "Z", + } + + def StartEvent(self): + """Append a 'start' event to the current log.""" + start_event = self._CreateEventDict("start") + start_event["argv"] = sys.argv + self._log.append(start_event) + + def ExitEvent(self, result): + """Append an 'exit' event to the current log. + + Args: + result: Exit code of the event + """ + exit_event = self._CreateEventDict("exit") + + # Consider 'None' success (consistent with event_log result handling). + if result is None: + result = 0 + exit_event["code"] = result + self._log.append(exit_event) + + def CommandEvent(self, name, subcommands): + """Append a 'command' event to the current log. + + Args: + name: Name of the primary command (ex: repo, git) + subcommands: List of the sub-commands (ex: version, init, sync) + """ + command_event = self._CreateEventDict("command") + command_event["name"] = name + command_event["subcommands"] = subcommands + self._log.append(command_event) + + def LogConfigEvents(self, config, event_dict_name): + """Append a |event_dict_name| event for each config key in |config|. + + Args: + config: Configuration dictionary. + event_dict_name: Name of the event dictionary for items to be logged + under. + """ + for param, value in config.items(): + event = self._CreateEventDict(event_dict_name) + event["param"] = param + event["value"] = value + self._log.append(event) + + def DefParamRepoEvents(self, config): + """Append 'def_param' events for repo config keys to the current log. + + This appends one event for each repo.* config key. + + Args: + config: Repo configuration dictionary + """ + # Only output the repo.* config parameters. + repo_config = {k: v for k, v in config.items() if k.startswith("repo.")} + self.LogConfigEvents(repo_config, "def_param") + + def GetDataEventName(self, value): + """Returns 'data-json' if the value is an array else returns 'data'.""" + return "data-json" if value[0] == "[" and value[-1] == "]" else "data" + + def LogDataConfigEvents(self, config, prefix): + """Append a 'data' event for each entry in |config| to the current log. + + For each keyX and valueX of the config, "key" field of the event is + '|prefix|/keyX' and the "value" of the "key" field is valueX. + + Args: + config: Configuration dictionary. + prefix: Prefix for each key that is logged. + """ + for key, value in config.items(): + event = self._CreateEventDict(self.GetDataEventName(value)) + event["key"] = f"{prefix}/{key}" + event["value"] = value + self._log.append(event) + + def ErrorEvent(self, msg, fmt): + """Append a 'error' event to the current log.""" + error_event = self._CreateEventDict("error") + error_event["msg"] = msg + error_event["fmt"] = fmt + self._log.append(error_event) + + def _GetEventTargetPath(self): + """Get the 'trace2.eventtarget' path from git configuration. + + Returns: + path: git config's 'trace2.eventtarget' path if it exists, or None + """ + path = None + cmd = ["config", "--get", "trace2.eventtarget"] + # TODO(https://crbug.com/gerrit/13706): Use GitConfig when it supports + # system git config variables. + p = GitCommand( + None, cmd, capture_stdout=True, capture_stderr=True, bare=True + ) + retval = p.Wait() + if retval == 0: + # Strip trailing carriage-return in path. + path = p.stdout.rstrip("\n") + elif retval != 1: + # `git config --get` is documented to produce an exit status of `1` + # if the requested variable is not present in the configuration. + # Report any other return value as an error. + print( + "repo: error: 'git config --get' call failed with return code: " + "%r, stderr: %r" % (retval, p.stderr), + file=sys.stderr, + ) + return path + + def _WriteLog(self, write_fn): + """Writes the log out using a provided writer function. + + Generate compact JSON output for each item in the log, and write it + using write_fn. + + Args: + write_fn: A function that accepts byts and writes them to a + destination. + """ + + for e in self._log: + # Dump in compact encoding mode. + # See 'Compact encoding' in Python docs: + # https://docs.python.org/3/library/json.html#module-json + write_fn( + json.dumps(e, indent=None, separators=(",", ":")).encode( + "utf-8" + ) + + b"\n" + ) + + def Write(self, path=None): + """Writes the log out to a file or socket. + + Log is only written if 'path' or 'git config --get trace2.eventtarget' + provide a valid path (or socket) to write logs to. + + Logging filename format follows the git trace2 style of being a unique + (exclusive writable) file. + + Args: + path: Path to where logs should be written. The path may have a + prefix of the form "af_unix:[{stream|dgram}:]", in which case + the path is treated as a Unix domain socket. See + https://git-scm.com/docs/api-trace2#_enabling_a_target for + details. + + Returns: + log_path: Path to the log file or socket if log is written, + otherwise None + """ + log_path = None + # If no logging path is specified, get the path from + # 'trace2.eventtarget'. + if path is None: + path = self._GetEventTargetPath() + + # If no logging path is specified, exit. + if path is None: return None - if socket_type == socket.SOCK_DGRAM or socket_type is None: - try: - with socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM) as sock: - self._WriteLog(lambda bs: sock.sendto(bs, path)) - return f'af_unix:dgram:{path}' - except OSError as err: - print(f'repo: warning: git trace2 logging failed: {err}', file=sys.stderr) - return None - # Tried to open a socket but couldn't connect (SOCK_STREAM) or write - # (SOCK_DGRAM). - print('repo: warning: git trace2 logging failed: could not write to socket', file=sys.stderr) - return None - # Path is an absolute path - # Use NamedTemporaryFile to generate a unique filename as required by git trace2. - try: - with tempfile.NamedTemporaryFile(mode='xb', prefix=self._sid, dir=path, - delete=False) as f: - # TODO(https://crbug.com/gerrit/13706): Support writing events as they - # occur. - self._WriteLog(f.write) - log_path = f.name - except FileExistsError as err: - print('repo: warning: git trace2 logging failed: %r' % err, - file=sys.stderr) - return None - return log_path + path_is_socket = False + socket_type = None + if isinstance(path, str): + parts = path.split(":", 1) + if parts[0] == "af_unix" and len(parts) == 2: + path_is_socket = True + path = parts[1] + parts = path.split(":", 1) + if parts[0] == "stream" and len(parts) == 2: + socket_type = socket.SOCK_STREAM + path = parts[1] + elif parts[0] == "dgram" and len(parts) == 2: + socket_type = socket.SOCK_DGRAM + path = parts[1] + else: + # Get absolute path. + path = os.path.abspath(os.path.expanduser(path)) + else: + raise TypeError("path: str required but got %s." % type(path)) + + # Git trace2 requires a directory to write log to. + + # TODO(https://crbug.com/gerrit/13706): Support file (append) mode also. + if not (path_is_socket or os.path.isdir(path)): + return None + + if path_is_socket: + if socket_type == socket.SOCK_STREAM or socket_type is None: + try: + with socket.socket( + socket.AF_UNIX, socket.SOCK_STREAM + ) as sock: + sock.connect(path) + self._WriteLog(sock.sendall) + return f"af_unix:stream:{path}" + except OSError as err: + # If we tried to connect to a DGRAM socket using STREAM, + # ignore the attempt and continue to DGRAM below. Otherwise, + # issue a warning. + if err.errno != errno.EPROTOTYPE: + print( + f"repo: warning: git trace2 logging failed: {err}", + file=sys.stderr, + ) + return None + if socket_type == socket.SOCK_DGRAM or socket_type is None: + try: + with socket.socket( + socket.AF_UNIX, socket.SOCK_DGRAM + ) as sock: + self._WriteLog(lambda bs: sock.sendto(bs, path)) + return f"af_unix:dgram:{path}" + except OSError as err: + print( + f"repo: warning: git trace2 logging failed: {err}", + file=sys.stderr, + ) + return None + # Tried to open a socket but couldn't connect (SOCK_STREAM) or write + # (SOCK_DGRAM). + print( + "repo: warning: git trace2 logging failed: could not write to " + "socket", + file=sys.stderr, + ) + return None + + # Path is an absolute path + # Use NamedTemporaryFile to generate a unique filename as required by + # git trace2. + try: + with tempfile.NamedTemporaryFile( + mode="xb", prefix=self._sid, dir=path, delete=False + ) as f: + # TODO(https://crbug.com/gerrit/13706): Support writing events + # as they occur. + self._WriteLog(f.write) + log_path = f.name + except FileExistsError as err: + print( + "repo: warning: git trace2 logging failed: %r" % err, + file=sys.stderr, + ) + return None + return log_path diff --git a/gitc_utils.py b/gitc_utils.py index dfcfd2a4..7b72048f 100644 --- a/gitc_utils.py +++ b/gitc_utils.py @@ -28,128 +28,139 @@ NUM_BATCH_RETRIEVE_REVISIONID = 32 def get_gitc_manifest_dir(): - return wrapper.Wrapper().get_gitc_manifest_dir() + return wrapper.Wrapper().get_gitc_manifest_dir() def parse_clientdir(gitc_fs_path): - return wrapper.Wrapper().gitc_parse_clientdir(gitc_fs_path) + return wrapper.Wrapper().gitc_parse_clientdir(gitc_fs_path) def _get_project_revision(args): - """Worker for _set_project_revisions to lookup one project remote.""" - (i, url, expr) = args - gitcmd = git_command.GitCommand( - None, ['ls-remote', url, expr], capture_stdout=True, cwd='/tmp') - rc = gitcmd.Wait() - return (i, rc, gitcmd.stdout.split('\t', 1)[0]) + """Worker for _set_project_revisions to lookup one project remote.""" + (i, url, expr) = args + gitcmd = git_command.GitCommand( + None, ["ls-remote", url, expr], capture_stdout=True, cwd="/tmp" + ) + rc = gitcmd.Wait() + return (i, rc, gitcmd.stdout.split("\t", 1)[0]) def _set_project_revisions(projects): - """Sets the revisionExpr for a list of projects. + """Sets the revisionExpr for a list of projects. - Because of the limit of open file descriptors allowed, length of projects - should not be overly large. Recommend calling this function multiple times - with each call not exceeding NUM_BATCH_RETRIEVE_REVISIONID projects. + Because of the limit of open file descriptors allowed, length of projects + should not be overly large. Recommend calling this function multiple times + with each call not exceeding NUM_BATCH_RETRIEVE_REVISIONID projects. - Args: - projects: List of project objects to set the revionExpr for. - """ - # Retrieve the commit id for each project based off of it's current - # revisionExpr and it is not already a commit id. - with multiprocessing.Pool(NUM_BATCH_RETRIEVE_REVISIONID) as pool: - results_iter = pool.imap_unordered( - _get_project_revision, - ((i, project.remote.url, project.revisionExpr) - for i, project in enumerate(projects) - if not git_config.IsId(project.revisionExpr)), - chunksize=8) - for (i, rc, revisionExpr) in results_iter: - project = projects[i] - if rc: - print('FATAL: Failed to retrieve revisionExpr for %s' % project.name) - pool.terminate() - sys.exit(1) - if not revisionExpr: - pool.terminate() - raise ManifestParseError('Invalid SHA-1 revision project %s (%s)' % - (project.remote.url, project.revisionExpr)) - project.revisionExpr = revisionExpr + Args: + projects: List of project objects to set the revionExpr for. + """ + # Retrieve the commit id for each project based off of its current + # revisionExpr and it is not already a commit id. + with multiprocessing.Pool(NUM_BATCH_RETRIEVE_REVISIONID) as pool: + results_iter = pool.imap_unordered( + _get_project_revision, + ( + (i, project.remote.url, project.revisionExpr) + for i, project in enumerate(projects) + if not git_config.IsId(project.revisionExpr) + ), + chunksize=8, + ) + for i, rc, revisionExpr in results_iter: + project = projects[i] + if rc: + print( + "FATAL: Failed to retrieve revisionExpr for %s" + % project.name + ) + pool.terminate() + sys.exit(1) + if not revisionExpr: + pool.terminate() + raise ManifestParseError( + "Invalid SHA-1 revision project %s (%s)" + % (project.remote.url, project.revisionExpr) + ) + project.revisionExpr = revisionExpr def generate_gitc_manifest(gitc_manifest, manifest, paths=None): - """Generate a manifest for shafsd to use for this GITC client. + """Generate a manifest for shafsd to use for this GITC client. - Args: - gitc_manifest: Current gitc manifest, or None if there isn't one yet. - manifest: A GitcManifest object loaded with the current repo manifest. - paths: List of project paths we want to update. - """ + Args: + gitc_manifest: Current gitc manifest, or None if there isn't one yet. + manifest: A GitcManifest object loaded with the current repo manifest. + paths: List of project paths we want to update. + """ - print('Generating GITC Manifest by fetching revision SHAs for each ' - 'project.') - if paths is None: - paths = list(manifest.paths.keys()) + print( + "Generating GITC Manifest by fetching revision SHAs for each " + "project." + ) + if paths is None: + paths = list(manifest.paths.keys()) - groups = [x for x in re.split(r'[,\s]+', manifest.GetGroupsStr()) if x] + groups = [x for x in re.split(r"[,\s]+", manifest.GetGroupsStr()) if x] - # Convert the paths to projects, and filter them to the matched groups. - projects = [manifest.paths[p] for p in paths] - projects = [p for p in projects if p.MatchesGroups(groups)] + # Convert the paths to projects, and filter them to the matched groups. + projects = [manifest.paths[p] for p in paths] + projects = [p for p in projects if p.MatchesGroups(groups)] - if gitc_manifest is not None: - for path, proj in manifest.paths.items(): - if not proj.MatchesGroups(groups): - continue + if gitc_manifest is not None: + for path, proj in manifest.paths.items(): + if not proj.MatchesGroups(groups): + continue - if not proj.upstream and not git_config.IsId(proj.revisionExpr): - proj.upstream = proj.revisionExpr + if not proj.upstream and not git_config.IsId(proj.revisionExpr): + proj.upstream = proj.revisionExpr - if path not in gitc_manifest.paths: - # Any new projects need their first revision, even if we weren't asked - # for them. - projects.append(proj) - elif path not in paths: - # And copy revisions from the previous manifest if we're not updating - # them now. - gitc_proj = gitc_manifest.paths[path] - if gitc_proj.old_revision: - proj.revisionExpr = None - proj.old_revision = gitc_proj.old_revision - else: - proj.revisionExpr = gitc_proj.revisionExpr + if path not in gitc_manifest.paths: + # Any new projects need their first revision, even if we weren't + # asked for them. + projects.append(proj) + elif path not in paths: + # And copy revisions from the previous manifest if we're not + # updating them now. + gitc_proj = gitc_manifest.paths[path] + if gitc_proj.old_revision: + proj.revisionExpr = None + proj.old_revision = gitc_proj.old_revision + else: + proj.revisionExpr = gitc_proj.revisionExpr - _set_project_revisions(projects) + _set_project_revisions(projects) - if gitc_manifest is not None: - for path, proj in gitc_manifest.paths.items(): - if proj.old_revision and path in paths: - # If we updated a project that has been started, keep the old-revision - # updated. - repo_proj = manifest.paths[path] - repo_proj.old_revision = repo_proj.revisionExpr - repo_proj.revisionExpr = None + if gitc_manifest is not None: + for path, proj in gitc_manifest.paths.items(): + if proj.old_revision and path in paths: + # If we updated a project that has been started, keep the + # old-revision updated. + repo_proj = manifest.paths[path] + repo_proj.old_revision = repo_proj.revisionExpr + repo_proj.revisionExpr = None - # Convert URLs from relative to absolute. - for _name, remote in manifest.remotes.items(): - remote.fetchUrl = remote.resolvedFetchUrl + # Convert URLs from relative to absolute. + for _name, remote in manifest.remotes.items(): + remote.fetchUrl = remote.resolvedFetchUrl - # Save the manifest. - save_manifest(manifest) + # Save the manifest. + save_manifest(manifest) def save_manifest(manifest, client_dir=None): - """Save the manifest file in the client_dir. + """Save the manifest file in the client_dir. - Args: - manifest: Manifest object to save. - client_dir: Client directory to save the manifest in. - """ - if not client_dir: - manifest_file = manifest.manifestFile - else: - manifest_file = os.path.join(client_dir, '.manifest') - with open(manifest_file, 'w') as f: - manifest.Save(f, groups=manifest.GetGroupsStr()) - # TODO(sbasi/jorg): Come up with a solution to remove the sleep below. - # Give the GITC filesystem time to register the manifest changes. - time.sleep(3) + Args: + manifest: Manifest object to save. + client_dir: Client directory to save the manifest in. + """ + if not client_dir: + manifest_file = manifest.manifestFile + else: + manifest_file = os.path.join(client_dir, ".manifest") + with open(manifest_file, "w") as f: + manifest.Save(f, groups=manifest.GetGroupsStr()) + # TODO(sbasi/jorg): Come up with a solution to remove the sleep below. + # Give the GITC filesystem time to register the manifest changes. + time.sleep(3) diff --git a/hooks.py b/hooks.py index 67c21a25..decf0699 100644 --- a/hooks.py +++ b/hooks.py @@ -26,271 +26,293 @@ from git_refs import HEAD class RepoHook(object): - """A RepoHook contains information about a script to run as a hook. + """A RepoHook contains information about a script to run as a hook. - Hooks are used to run a python script before running an upload (for instance, - to run presubmit checks). Eventually, we may have hooks for other actions. + Hooks are used to run a python script before running an upload (for + instance, to run presubmit checks). Eventually, we may have hooks for other + actions. - This shouldn't be confused with files in the 'repo/hooks' directory. Those - files are copied into each '.git/hooks' folder for each project. Repo-level - hooks are associated instead with repo actions. + This shouldn't be confused with files in the 'repo/hooks' directory. Those + files are copied into each '.git/hooks' folder for each project. Repo-level + hooks are associated instead with repo actions. - Hooks are always python. When a hook is run, we will load the hook into the - interpreter and execute its main() function. + Hooks are always python. When a hook is run, we will load the hook into the + interpreter and execute its main() function. - Combinations of hook option flags: - - no-verify=False, verify=False (DEFAULT): - If stdout is a tty, can prompt about running hooks if needed. - If user denies running hooks, the action is cancelled. If stdout is - not a tty and we would need to prompt about hooks, action is - cancelled. - - no-verify=False, verify=True: - Always run hooks with no prompt. - - no-verify=True, verify=False: - Never run hooks, but run action anyway (AKA bypass hooks). - - no-verify=True, verify=True: - Invalid - """ - - def __init__(self, - hook_type, - hooks_project, - repo_topdir, - manifest_url, - bypass_hooks=False, - allow_all_hooks=False, - ignore_hooks=False, - abort_if_user_denies=False): - """RepoHook constructor. - - Params: - hook_type: A string representing the type of hook. This is also used - to figure out the name of the file containing the hook. For - example: 'pre-upload'. - hooks_project: The project containing the repo hooks. - If you have a manifest, this is manifest.repo_hooks_project. - OK if this is None, which will make the hook a no-op. - repo_topdir: The top directory of the repo client checkout. - This is the one containing the .repo directory. Scripts will - run with CWD as this directory. - If you have a manifest, this is manifest.topdir. - manifest_url: The URL to the manifest git repo. - bypass_hooks: If True, then 'Do not run the hook'. - allow_all_hooks: If True, then 'Run the hook without prompting'. - ignore_hooks: If True, then 'Do not abort action if hooks fail'. - abort_if_user_denies: If True, we'll abort running the hook if the user - doesn't allow us to run the hook. + Combinations of hook option flags: + - no-verify=False, verify=False (DEFAULT): + If stdout is a tty, can prompt about running hooks if needed. + If user denies running hooks, the action is cancelled. If stdout is + not a tty and we would need to prompt about hooks, action is + cancelled. + - no-verify=False, verify=True: + Always run hooks with no prompt. + - no-verify=True, verify=False: + Never run hooks, but run action anyway (AKA bypass hooks). + - no-verify=True, verify=True: + Invalid """ - self._hook_type = hook_type - self._hooks_project = hooks_project - self._repo_topdir = repo_topdir - self._manifest_url = manifest_url - self._bypass_hooks = bypass_hooks - self._allow_all_hooks = allow_all_hooks - self._ignore_hooks = ignore_hooks - self._abort_if_user_denies = abort_if_user_denies - # Store the full path to the script for convenience. - if self._hooks_project: - self._script_fullpath = os.path.join(self._hooks_project.worktree, - self._hook_type + '.py') - else: - self._script_fullpath = None + def __init__( + self, + hook_type, + hooks_project, + repo_topdir, + manifest_url, + bypass_hooks=False, + allow_all_hooks=False, + ignore_hooks=False, + abort_if_user_denies=False, + ): + """RepoHook constructor. - def _GetHash(self): - """Return a hash of the contents of the hooks directory. + Params: + hook_type: A string representing the type of hook. This is also used + to figure out the name of the file containing the hook. For + example: 'pre-upload'. + hooks_project: The project containing the repo hooks. + If you have a manifest, this is manifest.repo_hooks_project. + OK if this is None, which will make the hook a no-op. + repo_topdir: The top directory of the repo client checkout. + This is the one containing the .repo directory. Scripts will + run with CWD as this directory. + If you have a manifest, this is manifest.topdir. + manifest_url: The URL to the manifest git repo. + bypass_hooks: If True, then 'Do not run the hook'. + allow_all_hooks: If True, then 'Run the hook without prompting'. + ignore_hooks: If True, then 'Do not abort action if hooks fail'. + abort_if_user_denies: If True, we'll abort running the hook if the + user doesn't allow us to run the hook. + """ + self._hook_type = hook_type + self._hooks_project = hooks_project + self._repo_topdir = repo_topdir + self._manifest_url = manifest_url + self._bypass_hooks = bypass_hooks + self._allow_all_hooks = allow_all_hooks + self._ignore_hooks = ignore_hooks + self._abort_if_user_denies = abort_if_user_denies - We'll just use git to do this. This hash has the property that if anything - changes in the directory we will return a different has. + # Store the full path to the script for convenience. + if self._hooks_project: + self._script_fullpath = os.path.join( + self._hooks_project.worktree, self._hook_type + ".py" + ) + else: + self._script_fullpath = None - SECURITY CONSIDERATION: - This hash only represents the contents of files in the hook directory, not - any other files imported or called by hooks. Changes to imported files - can change the script behavior without affecting the hash. + def _GetHash(self): + """Return a hash of the contents of the hooks directory. - Returns: - A string representing the hash. This will always be ASCII so that it can - be printed to the user easily. - """ - assert self._hooks_project, "Must have hooks to calculate their hash." + We'll just use git to do this. This hash has the property that if + anything changes in the directory we will return a different has. - # We will use the work_git object rather than just calling GetRevisionId(). - # That gives us a hash of the latest checked in version of the files that - # the user will actually be executing. Specifically, GetRevisionId() - # doesn't appear to change even if a user checks out a different version - # of the hooks repo (via git checkout) nor if a user commits their own revs. - # - # NOTE: Local (non-committed) changes will not be factored into this hash. - # I think this is OK, since we're really only worried about warning the user - # about upstream changes. - return self._hooks_project.work_git.rev_parse(HEAD) + SECURITY CONSIDERATION: + This hash only represents the contents of files in the hook + directory, not any other files imported or called by hooks. Changes + to imported files can change the script behavior without affecting + the hash. - def _GetMustVerb(self): - """Return 'must' if the hook is required; 'should' if not.""" - if self._abort_if_user_denies: - return 'must' - else: - return 'should' + Returns: + A string representing the hash. This will always be ASCII so that + it can be printed to the user easily. + """ + assert self._hooks_project, "Must have hooks to calculate their hash." - def _CheckForHookApproval(self): - """Check to see whether this hook has been approved. + # We will use the work_git object rather than just calling + # GetRevisionId(). That gives us a hash of the latest checked in version + # of the files that the user will actually be executing. Specifically, + # GetRevisionId() doesn't appear to change even if a user checks out a + # different version of the hooks repo (via git checkout) nor if a user + # commits their own revs. + # + # NOTE: Local (non-committed) changes will not be factored into this + # hash. I think this is OK, since we're really only worried about + # warning the user about upstream changes. + return self._hooks_project.work_git.rev_parse(HEAD) - We'll accept approval of manifest URLs if they're using secure transports. - This way the user can say they trust the manifest hoster. For insecure - hosts, we fall back to checking the hash of the hooks repo. + def _GetMustVerb(self): + """Return 'must' if the hook is required; 'should' if not.""" + if self._abort_if_user_denies: + return "must" + else: + return "should" - Note that we ask permission for each individual hook even though we use - the hash of all hooks when detecting changes. We'd like the user to be - able to approve / deny each hook individually. We only use the hash of all - hooks because there is no other easy way to detect changes to local imports. + def _CheckForHookApproval(self): + """Check to see whether this hook has been approved. - Returns: - True if this hook is approved to run; False otherwise. + We'll accept approval of manifest URLs if they're using secure + transports. This way the user can say they trust the manifest hoster. + For insecure hosts, we fall back to checking the hash of the hooks repo. - Raises: - HookError: Raised if the user doesn't approve and abort_if_user_denies - was passed to the consturctor. - """ - if self._ManifestUrlHasSecureScheme(): - return self._CheckForHookApprovalManifest() - else: - return self._CheckForHookApprovalHash() + Note that we ask permission for each individual hook even though we use + the hash of all hooks when detecting changes. We'd like the user to be + able to approve / deny each hook individually. We only use the hash of + all hooks because there is no other easy way to detect changes to local + imports. - def _CheckForHookApprovalHelper(self, subkey, new_val, main_prompt, - changed_prompt): - """Check for approval for a particular attribute and hook. + Returns: + True if this hook is approved to run; False otherwise. - Args: - subkey: The git config key under [repo.hooks.] to store the - last approved string. - new_val: The new value to compare against the last approved one. - main_prompt: Message to display to the user to ask for approval. - changed_prompt: Message explaining why we're re-asking for approval. + Raises: + HookError: Raised if the user doesn't approve and + abort_if_user_denies was passed to the consturctor. + """ + if self._ManifestUrlHasSecureScheme(): + return self._CheckForHookApprovalManifest() + else: + return self._CheckForHookApprovalHash() - Returns: - True if this hook is approved to run; False otherwise. + def _CheckForHookApprovalHelper( + self, subkey, new_val, main_prompt, changed_prompt + ): + """Check for approval for a particular attribute and hook. - Raises: - HookError: Raised if the user doesn't approve and abort_if_user_denies - was passed to the consturctor. - """ - hooks_config = self._hooks_project.config - git_approval_key = 'repo.hooks.%s.%s' % (self._hook_type, subkey) + Args: + subkey: The git config key under [repo.hooks.] to store + the last approved string. + new_val: The new value to compare against the last approved one. + main_prompt: Message to display to the user to ask for approval. + changed_prompt: Message explaining why we're re-asking for approval. - # Get the last value that the user approved for this hook; may be None. - old_val = hooks_config.GetString(git_approval_key) + Returns: + True if this hook is approved to run; False otherwise. - if old_val is not None: - # User previously approved hook and asked not to be prompted again. - if new_val == old_val: - # Approval matched. We're done. - return True - else: - # Give the user a reason why we're prompting, since they last told - # us to "never ask again". - prompt = 'WARNING: %s\n\n' % (changed_prompt,) - else: - prompt = '' + Raises: + HookError: Raised if the user doesn't approve and + abort_if_user_denies was passed to the consturctor. + """ + hooks_config = self._hooks_project.config + git_approval_key = "repo.hooks.%s.%s" % (self._hook_type, subkey) - # Prompt the user if we're not on a tty; on a tty we'll assume "no". - if sys.stdout.isatty(): - prompt += main_prompt + ' (yes/always/NO)? ' - response = input(prompt).lower() - print() + # Get the last value that the user approved for this hook; may be None. + old_val = hooks_config.GetString(git_approval_key) - # User is doing a one-time approval. - if response in ('y', 'yes'): - return True - elif response == 'always': - hooks_config.SetString(git_approval_key, new_val) - return True + if old_val is not None: + # User previously approved hook and asked not to be prompted again. + if new_val == old_val: + # Approval matched. We're done. + return True + else: + # Give the user a reason why we're prompting, since they last + # told us to "never ask again". + prompt = "WARNING: %s\n\n" % (changed_prompt,) + else: + prompt = "" - # For anything else, we'll assume no approval. - if self._abort_if_user_denies: - raise HookError('You must allow the %s hook or use --no-verify.' % - self._hook_type) + # Prompt the user if we're not on a tty; on a tty we'll assume "no". + if sys.stdout.isatty(): + prompt += main_prompt + " (yes/always/NO)? " + response = input(prompt).lower() + print() - return False + # User is doing a one-time approval. + if response in ("y", "yes"): + return True + elif response == "always": + hooks_config.SetString(git_approval_key, new_val) + return True - def _ManifestUrlHasSecureScheme(self): - """Check if the URI for the manifest is a secure transport.""" - secure_schemes = ('file', 'https', 'ssh', 'persistent-https', 'sso', 'rpc') - parse_results = urllib.parse.urlparse(self._manifest_url) - return parse_results.scheme in secure_schemes + # For anything else, we'll assume no approval. + if self._abort_if_user_denies: + raise HookError( + "You must allow the %s hook or use --no-verify." + % self._hook_type + ) - def _CheckForHookApprovalManifest(self): - """Check whether the user has approved this manifest host. + return False - Returns: - True if this hook is approved to run; False otherwise. - """ - return self._CheckForHookApprovalHelper( - 'approvedmanifest', - self._manifest_url, - 'Run hook scripts from %s' % (self._manifest_url,), - 'Manifest URL has changed since %s was allowed.' % (self._hook_type,)) + def _ManifestUrlHasSecureScheme(self): + """Check if the URI for the manifest is a secure transport.""" + secure_schemes = ( + "file", + "https", + "ssh", + "persistent-https", + "sso", + "rpc", + ) + parse_results = urllib.parse.urlparse(self._manifest_url) + return parse_results.scheme in secure_schemes - def _CheckForHookApprovalHash(self): - """Check whether the user has approved the hooks repo. + def _CheckForHookApprovalManifest(self): + """Check whether the user has approved this manifest host. - Returns: - True if this hook is approved to run; False otherwise. - """ - prompt = ('Repo %s run the script:\n' - ' %s\n' - '\n' - 'Do you want to allow this script to run') - return self._CheckForHookApprovalHelper( - 'approvedhash', - self._GetHash(), - prompt % (self._GetMustVerb(), self._script_fullpath), - 'Scripts have changed since %s was allowed.' % (self._hook_type,)) + Returns: + True if this hook is approved to run; False otherwise. + """ + return self._CheckForHookApprovalHelper( + "approvedmanifest", + self._manifest_url, + "Run hook scripts from %s" % (self._manifest_url,), + "Manifest URL has changed since %s was allowed." + % (self._hook_type,), + ) - @staticmethod - def _ExtractInterpFromShebang(data): - """Extract the interpreter used in the shebang. + def _CheckForHookApprovalHash(self): + """Check whether the user has approved the hooks repo. - Try to locate the interpreter the script is using (ignoring `env`). + Returns: + True if this hook is approved to run; False otherwise. + """ + prompt = ( + "Repo %s run the script:\n" + " %s\n" + "\n" + "Do you want to allow this script to run" + ) + return self._CheckForHookApprovalHelper( + "approvedhash", + self._GetHash(), + prompt % (self._GetMustVerb(), self._script_fullpath), + "Scripts have changed since %s was allowed." % (self._hook_type,), + ) - Args: - data: The file content of the script. + @staticmethod + def _ExtractInterpFromShebang(data): + """Extract the interpreter used in the shebang. - Returns: - The basename of the main script interpreter, or None if a shebang is not - used or could not be parsed out. - """ - firstline = data.splitlines()[:1] - if not firstline: - return None + Try to locate the interpreter the script is using (ignoring `env`). - # The format here can be tricky. - shebang = firstline[0].strip() - m = re.match(r'^#!\s*([^\s]+)(?:\s+([^\s]+))?', shebang) - if not m: - return None + Args: + data: The file content of the script. - # If the using `env`, find the target program. - interp = m.group(1) - if os.path.basename(interp) == 'env': - interp = m.group(2) + Returns: + The basename of the main script interpreter, or None if a shebang is + not used or could not be parsed out. + """ + firstline = data.splitlines()[:1] + if not firstline: + return None - return interp + # The format here can be tricky. + shebang = firstline[0].strip() + m = re.match(r"^#!\s*([^\s]+)(?:\s+([^\s]+))?", shebang) + if not m: + return None - def _ExecuteHookViaReexec(self, interp, context, **kwargs): - """Execute the hook script through |interp|. + # If the using `env`, find the target program. + interp = m.group(1) + if os.path.basename(interp) == "env": + interp = m.group(2) - Note: Support for this feature should be dropped ~Jun 2021. + return interp - Args: - interp: The Python program to run. - context: Basic Python context to execute the hook inside. - kwargs: Arbitrary arguments to pass to the hook script. + def _ExecuteHookViaReexec(self, interp, context, **kwargs): + """Execute the hook script through |interp|. - Raises: - HookError: When the hooks failed for any reason. - """ - # This logic needs to be kept in sync with _ExecuteHookViaImport below. - script = """ + Note: Support for this feature should be dropped ~Jun 2021. + + Args: + interp: The Python program to run. + context: Basic Python context to execute the hook inside. + kwargs: Arbitrary arguments to pass to the hook script. + + Raises: + HookError: When the hooks failed for any reason. + """ + # This logic needs to be kept in sync with _ExecuteHookViaImport below. + script = """ import json, os, sys path = '''%(path)s''' kwargs = json.loads('''%(kwargs)s''') @@ -300,210 +322,240 @@ data = open(path).read() exec(compile(data, path, 'exec'), context) context['main'](**kwargs) """ % { - 'path': self._script_fullpath, - 'kwargs': json.dumps(kwargs), - 'context': json.dumps(context), - } + "path": self._script_fullpath, + "kwargs": json.dumps(kwargs), + "context": json.dumps(context), + } - # We pass the script via stdin to avoid OS argv limits. It also makes - # unhandled exception tracebacks less verbose/confusing for users. - cmd = [interp, '-c', 'import sys; exec(sys.stdin.read())'] - proc = subprocess.Popen(cmd, stdin=subprocess.PIPE) - proc.communicate(input=script.encode('utf-8')) - if proc.returncode: - raise HookError('Failed to run %s hook.' % (self._hook_type,)) + # We pass the script via stdin to avoid OS argv limits. It also makes + # unhandled exception tracebacks less verbose/confusing for users. + cmd = [interp, "-c", "import sys; exec(sys.stdin.read())"] + proc = subprocess.Popen(cmd, stdin=subprocess.PIPE) + proc.communicate(input=script.encode("utf-8")) + if proc.returncode: + raise HookError("Failed to run %s hook." % (self._hook_type,)) - def _ExecuteHookViaImport(self, data, context, **kwargs): - """Execute the hook code in |data| directly. + def _ExecuteHookViaImport(self, data, context, **kwargs): + """Execute the hook code in |data| directly. - Args: - data: The code of the hook to execute. - context: Basic Python context to execute the hook inside. - kwargs: Arbitrary arguments to pass to the hook script. + Args: + data: The code of the hook to execute. + context: Basic Python context to execute the hook inside. + kwargs: Arbitrary arguments to pass to the hook script. - Raises: - HookError: When the hooks failed for any reason. - """ - # Exec, storing global context in the context dict. We catch exceptions - # and convert to a HookError w/ just the failing traceback. - try: - exec(compile(data, self._script_fullpath, 'exec'), context) - except Exception: - raise HookError('%s\nFailed to import %s hook; see traceback above.' % - (traceback.format_exc(), self._hook_type)) - - # Running the script should have defined a main() function. - if 'main' not in context: - raise HookError('Missing main() in: "%s"' % self._script_fullpath) - - # Call the main function in the hook. If the hook should cause the - # build to fail, it will raise an Exception. We'll catch that convert - # to a HookError w/ just the failing traceback. - try: - context['main'](**kwargs) - except Exception: - raise HookError('%s\nFailed to run main() for %s hook; see traceback ' - 'above.' % (traceback.format_exc(), self._hook_type)) - - def _ExecuteHook(self, **kwargs): - """Actually execute the given hook. - - This will run the hook's 'main' function in our python interpreter. - - Args: - kwargs: Keyword arguments to pass to the hook. These are often specific - to the hook type. For instance, pre-upload hooks will contain - a project_list. - """ - # Keep sys.path and CWD stashed away so that we can always restore them - # upon function exit. - orig_path = os.getcwd() - orig_syspath = sys.path - - try: - # Always run hooks with CWD as topdir. - os.chdir(self._repo_topdir) - - # Put the hook dir as the first item of sys.path so hooks can do - # relative imports. We want to replace the repo dir as [0] so - # hooks can't import repo files. - sys.path = [os.path.dirname(self._script_fullpath)] + sys.path[1:] - - # Initial global context for the hook to run within. - context = {'__file__': self._script_fullpath} - - # Add 'hook_should_take_kwargs' to the arguments to be passed to main. - # We don't actually want hooks to define their main with this argument-- - # it's there to remind them that their hook should always take **kwargs. - # For instance, a pre-upload hook should be defined like: - # def main(project_list, **kwargs): - # - # This allows us to later expand the API without breaking old hooks. - kwargs = kwargs.copy() - kwargs['hook_should_take_kwargs'] = True - - # See what version of python the hook has been written against. - data = open(self._script_fullpath).read() - interp = self._ExtractInterpFromShebang(data) - reexec = False - if interp: - prog = os.path.basename(interp) - if prog.startswith('python2') and sys.version_info.major != 2: - reexec = True - elif prog.startswith('python3') and sys.version_info.major == 2: - reexec = True - - # Attempt to execute the hooks through the requested version of Python. - if reexec: + Raises: + HookError: When the hooks failed for any reason. + """ + # Exec, storing global context in the context dict. We catch exceptions + # and convert to a HookError w/ just the failing traceback. try: - self._ExecuteHookViaReexec(interp, context, **kwargs) - except OSError as e: - if e.errno == errno.ENOENT: - # We couldn't find the interpreter, so fallback to importing. + exec(compile(data, self._script_fullpath, "exec"), context) + except Exception: + raise HookError( + "%s\nFailed to import %s hook; see traceback above." + % (traceback.format_exc(), self._hook_type) + ) + + # Running the script should have defined a main() function. + if "main" not in context: + raise HookError('Missing main() in: "%s"' % self._script_fullpath) + + # Call the main function in the hook. If the hook should cause the + # build to fail, it will raise an Exception. We'll catch that convert + # to a HookError w/ just the failing traceback. + try: + context["main"](**kwargs) + except Exception: + raise HookError( + "%s\nFailed to run main() for %s hook; see traceback " + "above." % (traceback.format_exc(), self._hook_type) + ) + + def _ExecuteHook(self, **kwargs): + """Actually execute the given hook. + + This will run the hook's 'main' function in our python interpreter. + + Args: + kwargs: Keyword arguments to pass to the hook. These are often + specific to the hook type. For instance, pre-upload hooks will + contain a project_list. + """ + # Keep sys.path and CWD stashed away so that we can always restore them + # upon function exit. + orig_path = os.getcwd() + orig_syspath = sys.path + + try: + # Always run hooks with CWD as topdir. + os.chdir(self._repo_topdir) + + # Put the hook dir as the first item of sys.path so hooks can do + # relative imports. We want to replace the repo dir as [0] so + # hooks can't import repo files. + sys.path = [os.path.dirname(self._script_fullpath)] + sys.path[1:] + + # Initial global context for the hook to run within. + context = {"__file__": self._script_fullpath} + + # Add 'hook_should_take_kwargs' to the arguments to be passed to + # main. We don't actually want hooks to define their main with this + # argument--it's there to remind them that their hook should always + # take **kwargs. + # For instance, a pre-upload hook should be defined like: + # def main(project_list, **kwargs): + # + # This allows us to later expand the API without breaking old hooks. + kwargs = kwargs.copy() + kwargs["hook_should_take_kwargs"] = True + + # See what version of python the hook has been written against. + data = open(self._script_fullpath).read() + interp = self._ExtractInterpFromShebang(data) reexec = False - else: - raise + if interp: + prog = os.path.basename(interp) + if prog.startswith("python2") and sys.version_info.major != 2: + reexec = True + elif prog.startswith("python3") and sys.version_info.major == 2: + reexec = True - # Run the hook by importing directly. - if not reexec: - self._ExecuteHookViaImport(data, context, **kwargs) - finally: - # Restore sys.path and CWD. - sys.path = orig_syspath - os.chdir(orig_path) + # Attempt to execute the hooks through the requested version of + # Python. + if reexec: + try: + self._ExecuteHookViaReexec(interp, context, **kwargs) + except OSError as e: + if e.errno == errno.ENOENT: + # We couldn't find the interpreter, so fallback to + # importing. + reexec = False + else: + raise - def _CheckHook(self): - # Bail with a nice error if we can't find the hook. - if not os.path.isfile(self._script_fullpath): - raise HookError('Couldn\'t find repo hook: %s' % self._script_fullpath) + # Run the hook by importing directly. + if not reexec: + self._ExecuteHookViaImport(data, context, **kwargs) + finally: + # Restore sys.path and CWD. + sys.path = orig_syspath + os.chdir(orig_path) - def Run(self, **kwargs): - """Run the hook. + def _CheckHook(self): + # Bail with a nice error if we can't find the hook. + if not os.path.isfile(self._script_fullpath): + raise HookError( + "Couldn't find repo hook: %s" % self._script_fullpath + ) - If the hook doesn't exist (because there is no hooks project or because - this particular hook is not enabled), this is a no-op. + def Run(self, **kwargs): + """Run the hook. - Args: - user_allows_all_hooks: If True, we will never prompt about running the - hook--we'll just assume it's OK to run it. - kwargs: Keyword arguments to pass to the hook. These are often specific - to the hook type. For instance, pre-upload hooks will contain - a project_list. + If the hook doesn't exist (because there is no hooks project or because + this particular hook is not enabled), this is a no-op. - Returns: - True: On success or ignore hooks by user-request - False: The hook failed. The caller should respond with aborting the action. - Some examples in which False is returned: - * Finding the hook failed while it was enabled, or - * the user declined to run a required hook (from _CheckForHookApproval) - In all these cases the user did not pass the proper arguments to - ignore the result through the option combinations as listed in - AddHookOptionGroup(). - """ - # Do not do anything in case bypass_hooks is set, or - # no-op if there is no hooks project or if hook is disabled. - if (self._bypass_hooks or - not self._hooks_project or - self._hook_type not in self._hooks_project.enabled_repo_hooks): - return True + Args: + user_allows_all_hooks: If True, we will never prompt about running + the hook--we'll just assume it's OK to run it. + kwargs: Keyword arguments to pass to the hook. These are often + specific to the hook type. For instance, pre-upload hooks will + contain a project_list. - passed = True - try: - self._CheckHook() + Returns: + True: On success or ignore hooks by user-request + False: The hook failed. The caller should respond with aborting the + action. Some examples in which False is returned: + * Finding the hook failed while it was enabled, or + * the user declined to run a required hook (from + _CheckForHookApproval) + In all these cases the user did not pass the proper arguments to + ignore the result through the option combinations as listed in + AddHookOptionGroup(). + """ + # Do not do anything in case bypass_hooks is set, or + # no-op if there is no hooks project or if hook is disabled. + if ( + self._bypass_hooks + or not self._hooks_project + or self._hook_type not in self._hooks_project.enabled_repo_hooks + ): + return True - # Make sure the user is OK with running the hook. - if self._allow_all_hooks or self._CheckForHookApproval(): - # Run the hook with the same version of python we're using. - self._ExecuteHook(**kwargs) - except SystemExit as e: - passed = False - print('ERROR: %s hooks exited with exit code: %s' % (self._hook_type, str(e)), - file=sys.stderr) - except HookError as e: - passed = False - print('ERROR: %s' % str(e), file=sys.stderr) + passed = True + try: + self._CheckHook() - if not passed and self._ignore_hooks: - print('\nWARNING: %s hooks failed, but continuing anyways.' % self._hook_type, - file=sys.stderr) - passed = True + # Make sure the user is OK with running the hook. + if self._allow_all_hooks or self._CheckForHookApproval(): + # Run the hook with the same version of python we're using. + self._ExecuteHook(**kwargs) + except SystemExit as e: + passed = False + print( + "ERROR: %s hooks exited with exit code: %s" + % (self._hook_type, str(e)), + file=sys.stderr, + ) + except HookError as e: + passed = False + print("ERROR: %s" % str(e), file=sys.stderr) - return passed + if not passed and self._ignore_hooks: + print( + "\nWARNING: %s hooks failed, but continuing anyways." + % self._hook_type, + file=sys.stderr, + ) + passed = True - @classmethod - def FromSubcmd(cls, manifest, opt, *args, **kwargs): - """Method to construct the repo hook class + return passed - Args: - manifest: The current active manifest for this command from which we - extract a couple of fields. - opt: Contains the commandline options for the action of this hook. - It should contain the options added by AddHookOptionGroup() in which - we are interested in RepoHook execution. - """ - for key in ('bypass_hooks', 'allow_all_hooks', 'ignore_hooks'): - kwargs.setdefault(key, getattr(opt, key)) - kwargs.update({ - 'hooks_project': manifest.repo_hooks_project, - 'repo_topdir': manifest.topdir, - 'manifest_url': manifest.manifestProject.GetRemote('origin').url, - }) - return cls(*args, **kwargs) + @classmethod + def FromSubcmd(cls, manifest, opt, *args, **kwargs): + """Method to construct the repo hook class - @staticmethod - def AddOptionGroup(parser, name): - """Help options relating to the various hooks.""" + Args: + manifest: The current active manifest for this command from which we + extract a couple of fields. + opt: Contains the commandline options for the action of this hook. + It should contain the options added by AddHookOptionGroup() in + which we are interested in RepoHook execution. + """ + for key in ("bypass_hooks", "allow_all_hooks", "ignore_hooks"): + kwargs.setdefault(key, getattr(opt, key)) + kwargs.update( + { + "hooks_project": manifest.repo_hooks_project, + "repo_topdir": manifest.topdir, + "manifest_url": manifest.manifestProject.GetRemote( + "origin" + ).url, + } + ) + return cls(*args, **kwargs) - # Note that verify and no-verify are NOT opposites of each other, which - # is why they store to different locations. We are using them to match - # 'git commit' syntax. - group = parser.add_option_group(name + ' hooks') - group.add_option('--no-verify', - dest='bypass_hooks', action='store_true', - help='Do not run the %s hook.' % name) - group.add_option('--verify', - dest='allow_all_hooks', action='store_true', - help='Run the %s hook without prompting.' % name) - group.add_option('--ignore-hooks', - action='store_true', - help='Do not abort if %s hooks fail.' % name) + @staticmethod + def AddOptionGroup(parser, name): + """Help options relating to the various hooks.""" + + # Note that verify and no-verify are NOT opposites of each other, which + # is why they store to different locations. We are using them to match + # 'git commit' syntax. + group = parser.add_option_group(name + " hooks") + group.add_option( + "--no-verify", + dest="bypass_hooks", + action="store_true", + help="Do not run the %s hook." % name, + ) + group.add_option( + "--verify", + dest="allow_all_hooks", + action="store_true", + help="Run the %s hook without prompting." % name, + ) + group.add_option( + "--ignore-hooks", + action="store_true", + help="Do not abort if %s hooks fail." % name, + ) diff --git a/main.py b/main.py index f4b6e7ac..6dcb66f6 100755 --- a/main.py +++ b/main.py @@ -31,9 +31,9 @@ import time import urllib.request try: - import kerberos + import kerberos except ImportError: - kerberos = None + kerberos = None from color import SetDefaultColoring import event_log @@ -74,347 +74,442 @@ MIN_PYTHON_VERSION_SOFT = (3, 6) MIN_PYTHON_VERSION_HARD = (3, 6) if sys.version_info.major < 3: - print('repo: error: Python 2 is no longer supported; ' - 'Please upgrade to Python {}.{}+.'.format(*MIN_PYTHON_VERSION_SOFT), - file=sys.stderr) - sys.exit(1) -else: - if sys.version_info < MIN_PYTHON_VERSION_HARD: - print('repo: error: Python 3 version is too old; ' - 'Please upgrade to Python {}.{}+.'.format(*MIN_PYTHON_VERSION_SOFT), - file=sys.stderr) + print( + "repo: error: Python 2 is no longer supported; " + "Please upgrade to Python {}.{}+.".format(*MIN_PYTHON_VERSION_SOFT), + file=sys.stderr, + ) sys.exit(1) - elif sys.version_info < MIN_PYTHON_VERSION_SOFT: - print('repo: warning: your Python 3 version is no longer supported; ' - 'Please upgrade to Python {}.{}+.'.format(*MIN_PYTHON_VERSION_SOFT), - file=sys.stderr) +else: + if sys.version_info < MIN_PYTHON_VERSION_HARD: + print( + "repo: error: Python 3 version is too old; " + "Please upgrade to Python {}.{}+.".format(*MIN_PYTHON_VERSION_SOFT), + file=sys.stderr, + ) + sys.exit(1) + elif sys.version_info < MIN_PYTHON_VERSION_SOFT: + print( + "repo: warning: your Python 3 version is no longer supported; " + "Please upgrade to Python {}.{}+.".format(*MIN_PYTHON_VERSION_SOFT), + file=sys.stderr, + ) global_options = optparse.OptionParser( - usage='repo [-p|--paginate|--no-pager] COMMAND [ARGS]', - add_help_option=False) -global_options.add_option('-h', '--help', action='store_true', - help='show this help message and exit') -global_options.add_option('--help-all', action='store_true', - help='show this help message with all subcommands and exit') -global_options.add_option('-p', '--paginate', - dest='pager', action='store_true', - help='display command output in the pager') -global_options.add_option('--no-pager', - dest='pager', action='store_false', - help='disable the pager') -global_options.add_option('--color', - choices=('auto', 'always', 'never'), default=None, - help='control color usage: auto, always, never') -global_options.add_option('--trace', - dest='trace', action='store_true', - help='trace git command execution (REPO_TRACE=1)') -global_options.add_option('--trace-to-stderr', - dest='trace_to_stderr', action='store_true', - help='trace outputs go to stderr in addition to .repo/TRACE_FILE') -global_options.add_option('--trace-python', - dest='trace_python', action='store_true', - help='trace python command execution') -global_options.add_option('--time', - dest='time', action='store_true', - help='time repo command execution') -global_options.add_option('--version', - dest='show_version', action='store_true', - help='display this version of repo') -global_options.add_option('--show-toplevel', - action='store_true', - help='display the path of the top-level directory of ' - 'the repo client checkout') -global_options.add_option('--event-log', - dest='event_log', action='store', - help='filename of event log to append timeline to') -global_options.add_option('--git-trace2-event-log', action='store', - help='directory to write git trace2 event log to') -global_options.add_option('--submanifest-path', action='store', - metavar='REL_PATH', help='submanifest path') + usage="repo [-p|--paginate|--no-pager] COMMAND [ARGS]", + add_help_option=False, +) +global_options.add_option( + "-h", "--help", action="store_true", help="show this help message and exit" +) +global_options.add_option( + "--help-all", + action="store_true", + help="show this help message with all subcommands and exit", +) +global_options.add_option( + "-p", + "--paginate", + dest="pager", + action="store_true", + help="display command output in the pager", +) +global_options.add_option( + "--no-pager", dest="pager", action="store_false", help="disable the pager" +) +global_options.add_option( + "--color", + choices=("auto", "always", "never"), + default=None, + help="control color usage: auto, always, never", +) +global_options.add_option( + "--trace", + dest="trace", + action="store_true", + help="trace git command execution (REPO_TRACE=1)", +) +global_options.add_option( + "--trace-to-stderr", + dest="trace_to_stderr", + action="store_true", + help="trace outputs go to stderr in addition to .repo/TRACE_FILE", +) +global_options.add_option( + "--trace-python", + dest="trace_python", + action="store_true", + help="trace python command execution", +) +global_options.add_option( + "--time", + dest="time", + action="store_true", + help="time repo command execution", +) +global_options.add_option( + "--version", + dest="show_version", + action="store_true", + help="display this version of repo", +) +global_options.add_option( + "--show-toplevel", + action="store_true", + help="display the path of the top-level directory of " + "the repo client checkout", +) +global_options.add_option( + "--event-log", + dest="event_log", + action="store", + help="filename of event log to append timeline to", +) +global_options.add_option( + "--git-trace2-event-log", + action="store", + help="directory to write git trace2 event log to", +) +global_options.add_option( + "--submanifest-path", + action="store", + metavar="REL_PATH", + help="submanifest path", +) class _Repo(object): - def __init__(self, repodir): - self.repodir = repodir - self.commands = all_commands + def __init__(self, repodir): + self.repodir = repodir + self.commands = all_commands - def _PrintHelp(self, short: bool = False, all_commands: bool = False): - """Show --help screen.""" - global_options.print_help() - print() - if short: - commands = ' '.join(sorted(self.commands)) - wrapped_commands = textwrap.wrap(commands, width=77) - print('Available commands:\n %s' % ('\n '.join(wrapped_commands),)) - print('\nRun `repo help ` for command-specific details.') - print('Bug reports:', Wrapper().BUG_URL) - else: - cmd = self.commands['help']() - if all_commands: - cmd.PrintAllCommandsBody() - else: - cmd.PrintCommonCommandsBody() - - def _ParseArgs(self, argv): - """Parse the main `repo` command line options.""" - for i, arg in enumerate(argv): - if not arg.startswith('-'): - name = arg - glob = argv[:i] - argv = argv[i + 1:] - break - else: - name = None - glob = argv - argv = [] - gopts, _gargs = global_options.parse_args(glob) - - if name: - name, alias_args = self._ExpandAlias(name) - argv = alias_args + argv - - return (name, gopts, argv) - - def _ExpandAlias(self, name): - """Look up user registered aliases.""" - # We don't resolve aliases for existing subcommands. This matches git. - if name in self.commands: - return name, [] - - key = 'alias.%s' % (name,) - alias = RepoConfig.ForRepository(self.repodir).GetString(key) - if alias is None: - alias = RepoConfig.ForUser().GetString(key) - if alias is None: - return name, [] - - args = alias.strip().split(' ', 1) - name = args[0] - if len(args) == 2: - args = shlex.split(args[1]) - else: - args = [] - return name, args - - def _Run(self, name, gopts, argv): - """Execute the requested subcommand.""" - result = 0 - - # Handle options that terminate quickly first. - if gopts.help or gopts.help_all: - self._PrintHelp(short=False, all_commands=gopts.help_all) - return 0 - elif gopts.show_version: - # Always allow global --version regardless of subcommand validity. - name = 'version' - elif gopts.show_toplevel: - print(os.path.dirname(self.repodir)) - return 0 - elif not name: - # No subcommand specified, so show the help/subcommand. - self._PrintHelp(short=True) - return 1 - - run = lambda: self._RunLong(name, gopts, argv) or 0 - with Trace('starting new command: %s', ', '.join([name] + argv), - first_trace=True): - if gopts.trace_python: - import trace - tracer = trace.Trace(count=False, trace=True, timing=True, - ignoredirs=set(sys.path[1:])) - result = tracer.runfunc(run) - else: - result = run() - return result - - def _RunLong(self, name, gopts, argv): - """Execute the (longer running) requested subcommand.""" - result = 0 - SetDefaultColoring(gopts.color) - - git_trace2_event_log = EventLog() - outer_client = RepoClient(self.repodir) - repo_client = outer_client - if gopts.submanifest_path: - repo_client = RepoClient(self.repodir, - submanifest_path=gopts.submanifest_path, - outer_client=outer_client) - gitc_manifest = None - gitc_client_name = gitc_utils.parse_clientdir(os.getcwd()) - if gitc_client_name: - gitc_manifest = GitcClient(self.repodir, gitc_client_name) - repo_client.isGitcClient = True - - try: - cmd = self.commands[name]( - repodir=self.repodir, - client=repo_client, - manifest=repo_client.manifest, - outer_client=outer_client, - outer_manifest=outer_client.manifest, - gitc_manifest=gitc_manifest, - git_event_log=git_trace2_event_log) - except KeyError: - print("repo: '%s' is not a repo command. See 'repo help'." % name, - file=sys.stderr) - return 1 - - Editor.globalConfig = cmd.client.globalConfig - - if not isinstance(cmd, MirrorSafeCommand) and cmd.manifest.IsMirror: - print("fatal: '%s' requires a working directory" % name, - file=sys.stderr) - return 1 - - if isinstance(cmd, GitcAvailableCommand) and not gitc_utils.get_gitc_manifest_dir(): - print("fatal: '%s' requires GITC to be available" % name, - file=sys.stderr) - return 1 - - if isinstance(cmd, GitcClientCommand) and not gitc_client_name: - print("fatal: '%s' requires a GITC client" % name, - file=sys.stderr) - return 1 - - try: - copts, cargs = cmd.OptionParser.parse_args(argv) - copts = cmd.ReadEnvironmentOptions(copts) - except NoManifestException as e: - print('error: in `%s`: %s' % (' '.join([name] + argv), str(e)), - file=sys.stderr) - print('error: manifest missing or unreadable -- please run init', - file=sys.stderr) - return 1 - - if gopts.pager is not False and not isinstance(cmd, InteractiveCommand): - config = cmd.client.globalConfig - if gopts.pager: - use_pager = True - else: - use_pager = config.GetBoolean('pager.%s' % name) - if use_pager is None: - use_pager = cmd.WantPager(copts) - if use_pager: - RunPager(config) - - start = time.time() - cmd_event = cmd.event_log.Add(name, event_log.TASK_COMMAND, start) - cmd.event_log.SetParent(cmd_event) - git_trace2_event_log.StartEvent() - git_trace2_event_log.CommandEvent(name='repo', subcommands=[name]) - - try: - cmd.CommonValidateOptions(copts, cargs) - cmd.ValidateOptions(copts, cargs) - - this_manifest_only = copts.this_manifest_only - outer_manifest = copts.outer_manifest - if cmd.MULTI_MANIFEST_SUPPORT or this_manifest_only: - result = cmd.Execute(copts, cargs) - elif outer_manifest and repo_client.manifest.is_submanifest: - # The command does not support multi-manifest, we are using a - # submanifest, and the command line is for the outermost manifest. - # Re-run using the outermost manifest, which will recurse through the - # submanifests. - gopts.submanifest_path = '' - result = self._Run(name, gopts, argv) - else: - # No multi-manifest support. Run the command in the current - # (sub)manifest, and then any child submanifests. - result = cmd.Execute(copts, cargs) - for submanifest in repo_client.manifest.submanifests.values(): - spec = submanifest.ToSubmanifestSpec() - gopts.submanifest_path = submanifest.repo_client.path_prefix - child_argv = argv[:] - child_argv.append('--no-outer-manifest') - # Not all subcommands support the 3 manifest options, so only add them - # if the original command includes them. - if hasattr(copts, 'manifest_url'): - child_argv.extend(['--manifest-url', spec.manifestUrl]) - if hasattr(copts, 'manifest_name'): - child_argv.extend(['--manifest-name', spec.manifestName]) - if hasattr(copts, 'manifest_branch'): - child_argv.extend(['--manifest-branch', spec.revision]) - result = self._Run(name, gopts, child_argv) or result - except (DownloadError, ManifestInvalidRevisionError, - NoManifestException) as e: - print('error: in `%s`: %s' % (' '.join([name] + argv), str(e)), - file=sys.stderr) - if isinstance(e, NoManifestException): - print('error: manifest missing or unreadable -- please run init', - file=sys.stderr) - result = 1 - except NoSuchProjectError as e: - if e.name: - print('error: project %s not found' % e.name, file=sys.stderr) - else: - print('error: no project in current directory', file=sys.stderr) - result = 1 - except InvalidProjectGroupsError as e: - if e.name: - print('error: project group must be enabled for project %s' % e.name, file=sys.stderr) - else: - print('error: project group must be enabled for the project in the current directory', - file=sys.stderr) - result = 1 - except SystemExit as e: - if e.code: - result = e.code - raise - finally: - finish = time.time() - elapsed = finish - start - hours, remainder = divmod(elapsed, 3600) - minutes, seconds = divmod(remainder, 60) - if gopts.time: - if hours == 0: - print('real\t%dm%.3fs' % (minutes, seconds), file=sys.stderr) + def _PrintHelp(self, short: bool = False, all_commands: bool = False): + """Show --help screen.""" + global_options.print_help() + print() + if short: + commands = " ".join(sorted(self.commands)) + wrapped_commands = textwrap.wrap(commands, width=77) + print( + "Available commands:\n %s" % ("\n ".join(wrapped_commands),) + ) + print("\nRun `repo help ` for command-specific details.") + print("Bug reports:", Wrapper().BUG_URL) else: - print('real\t%dh%dm%.3fs' % (hours, minutes, seconds), - file=sys.stderr) + cmd = self.commands["help"]() + if all_commands: + cmd.PrintAllCommandsBody() + else: + cmd.PrintCommonCommandsBody() - cmd.event_log.FinishEvent(cmd_event, finish, - result is None or result == 0) - git_trace2_event_log.DefParamRepoEvents( - cmd.manifest.manifestProject.config.DumpConfigDict()) - git_trace2_event_log.ExitEvent(result) + def _ParseArgs(self, argv): + """Parse the main `repo` command line options.""" + for i, arg in enumerate(argv): + if not arg.startswith("-"): + name = arg + glob = argv[:i] + argv = argv[i + 1 :] + break + else: + name = None + glob = argv + argv = [] + gopts, _gargs = global_options.parse_args(glob) - if gopts.event_log: - cmd.event_log.Write(os.path.abspath( - os.path.expanduser(gopts.event_log))) + if name: + name, alias_args = self._ExpandAlias(name) + argv = alias_args + argv - git_trace2_event_log.Write(gopts.git_trace2_event_log) - return result + return (name, gopts, argv) + + def _ExpandAlias(self, name): + """Look up user registered aliases.""" + # We don't resolve aliases for existing subcommands. This matches git. + if name in self.commands: + return name, [] + + key = "alias.%s" % (name,) + alias = RepoConfig.ForRepository(self.repodir).GetString(key) + if alias is None: + alias = RepoConfig.ForUser().GetString(key) + if alias is None: + return name, [] + + args = alias.strip().split(" ", 1) + name = args[0] + if len(args) == 2: + args = shlex.split(args[1]) + else: + args = [] + return name, args + + def _Run(self, name, gopts, argv): + """Execute the requested subcommand.""" + result = 0 + + # Handle options that terminate quickly first. + if gopts.help or gopts.help_all: + self._PrintHelp(short=False, all_commands=gopts.help_all) + return 0 + elif gopts.show_version: + # Always allow global --version regardless of subcommand validity. + name = "version" + elif gopts.show_toplevel: + print(os.path.dirname(self.repodir)) + return 0 + elif not name: + # No subcommand specified, so show the help/subcommand. + self._PrintHelp(short=True) + return 1 + + run = lambda: self._RunLong(name, gopts, argv) or 0 + with Trace( + "starting new command: %s", + ", ".join([name] + argv), + first_trace=True, + ): + if gopts.trace_python: + import trace + + tracer = trace.Trace( + count=False, + trace=True, + timing=True, + ignoredirs=set(sys.path[1:]), + ) + result = tracer.runfunc(run) + else: + result = run() + return result + + def _RunLong(self, name, gopts, argv): + """Execute the (longer running) requested subcommand.""" + result = 0 + SetDefaultColoring(gopts.color) + + git_trace2_event_log = EventLog() + outer_client = RepoClient(self.repodir) + repo_client = outer_client + if gopts.submanifest_path: + repo_client = RepoClient( + self.repodir, + submanifest_path=gopts.submanifest_path, + outer_client=outer_client, + ) + gitc_manifest = None + gitc_client_name = gitc_utils.parse_clientdir(os.getcwd()) + if gitc_client_name: + gitc_manifest = GitcClient(self.repodir, gitc_client_name) + repo_client.isGitcClient = True + + try: + cmd = self.commands[name]( + repodir=self.repodir, + client=repo_client, + manifest=repo_client.manifest, + outer_client=outer_client, + outer_manifest=outer_client.manifest, + gitc_manifest=gitc_manifest, + git_event_log=git_trace2_event_log, + ) + except KeyError: + print( + "repo: '%s' is not a repo command. See 'repo help'." % name, + file=sys.stderr, + ) + return 1 + + Editor.globalConfig = cmd.client.globalConfig + + if not isinstance(cmd, MirrorSafeCommand) and cmd.manifest.IsMirror: + print( + "fatal: '%s' requires a working directory" % name, + file=sys.stderr, + ) + return 1 + + if ( + isinstance(cmd, GitcAvailableCommand) + and not gitc_utils.get_gitc_manifest_dir() + ): + print( + "fatal: '%s' requires GITC to be available" % name, + file=sys.stderr, + ) + return 1 + + if isinstance(cmd, GitcClientCommand) and not gitc_client_name: + print("fatal: '%s' requires a GITC client" % name, file=sys.stderr) + return 1 + + try: + copts, cargs = cmd.OptionParser.parse_args(argv) + copts = cmd.ReadEnvironmentOptions(copts) + except NoManifestException as e: + print( + "error: in `%s`: %s" % (" ".join([name] + argv), str(e)), + file=sys.stderr, + ) + print( + "error: manifest missing or unreadable -- please run init", + file=sys.stderr, + ) + return 1 + + if gopts.pager is not False and not isinstance(cmd, InteractiveCommand): + config = cmd.client.globalConfig + if gopts.pager: + use_pager = True + else: + use_pager = config.GetBoolean("pager.%s" % name) + if use_pager is None: + use_pager = cmd.WantPager(copts) + if use_pager: + RunPager(config) + + start = time.time() + cmd_event = cmd.event_log.Add(name, event_log.TASK_COMMAND, start) + cmd.event_log.SetParent(cmd_event) + git_trace2_event_log.StartEvent() + git_trace2_event_log.CommandEvent(name="repo", subcommands=[name]) + + try: + cmd.CommonValidateOptions(copts, cargs) + cmd.ValidateOptions(copts, cargs) + + this_manifest_only = copts.this_manifest_only + outer_manifest = copts.outer_manifest + if cmd.MULTI_MANIFEST_SUPPORT or this_manifest_only: + result = cmd.Execute(copts, cargs) + elif outer_manifest and repo_client.manifest.is_submanifest: + # The command does not support multi-manifest, we are using a + # submanifest, and the command line is for the outermost + # manifest. Re-run using the outermost manifest, which will + # recurse through the submanifests. + gopts.submanifest_path = "" + result = self._Run(name, gopts, argv) + else: + # No multi-manifest support. Run the command in the current + # (sub)manifest, and then any child submanifests. + result = cmd.Execute(copts, cargs) + for submanifest in repo_client.manifest.submanifests.values(): + spec = submanifest.ToSubmanifestSpec() + gopts.submanifest_path = submanifest.repo_client.path_prefix + child_argv = argv[:] + child_argv.append("--no-outer-manifest") + # Not all subcommands support the 3 manifest options, so + # only add them if the original command includes them. + if hasattr(copts, "manifest_url"): + child_argv.extend(["--manifest-url", spec.manifestUrl]) + if hasattr(copts, "manifest_name"): + child_argv.extend( + ["--manifest-name", spec.manifestName] + ) + if hasattr(copts, "manifest_branch"): + child_argv.extend(["--manifest-branch", spec.revision]) + result = self._Run(name, gopts, child_argv) or result + except ( + DownloadError, + ManifestInvalidRevisionError, + NoManifestException, + ) as e: + print( + "error: in `%s`: %s" % (" ".join([name] + argv), str(e)), + file=sys.stderr, + ) + if isinstance(e, NoManifestException): + print( + "error: manifest missing or unreadable -- please run init", + file=sys.stderr, + ) + result = 1 + except NoSuchProjectError as e: + if e.name: + print("error: project %s not found" % e.name, file=sys.stderr) + else: + print("error: no project in current directory", file=sys.stderr) + result = 1 + except InvalidProjectGroupsError as e: + if e.name: + print( + "error: project group must be enabled for project %s" + % e.name, + file=sys.stderr, + ) + else: + print( + "error: project group must be enabled for the project in " + "the current directory", + file=sys.stderr, + ) + result = 1 + except SystemExit as e: + if e.code: + result = e.code + raise + finally: + finish = time.time() + elapsed = finish - start + hours, remainder = divmod(elapsed, 3600) + minutes, seconds = divmod(remainder, 60) + if gopts.time: + if hours == 0: + print( + "real\t%dm%.3fs" % (minutes, seconds), file=sys.stderr + ) + else: + print( + "real\t%dh%dm%.3fs" % (hours, minutes, seconds), + file=sys.stderr, + ) + + cmd.event_log.FinishEvent( + cmd_event, finish, result is None or result == 0 + ) + git_trace2_event_log.DefParamRepoEvents( + cmd.manifest.manifestProject.config.DumpConfigDict() + ) + git_trace2_event_log.ExitEvent(result) + + if gopts.event_log: + cmd.event_log.Write( + os.path.abspath(os.path.expanduser(gopts.event_log)) + ) + + git_trace2_event_log.Write(gopts.git_trace2_event_log) + return result def _CheckWrapperVersion(ver_str, repo_path): - """Verify the repo launcher is new enough for this checkout. + """Verify the repo launcher is new enough for this checkout. - Args: - ver_str: The version string passed from the repo launcher when it ran us. - repo_path: The path to the repo launcher that loaded us. - """ - # Refuse to work with really old wrapper versions. We don't test these, - # so might as well require a somewhat recent sane version. - # v1.15 of the repo launcher was released in ~Mar 2012. - MIN_REPO_VERSION = (1, 15) - min_str = '.'.join(str(x) for x in MIN_REPO_VERSION) + Args: + ver_str: The version string passed from the repo launcher when it ran + us. + repo_path: The path to the repo launcher that loaded us. + """ + # Refuse to work with really old wrapper versions. We don't test these, + # so might as well require a somewhat recent sane version. + # v1.15 of the repo launcher was released in ~Mar 2012. + MIN_REPO_VERSION = (1, 15) + min_str = ".".join(str(x) for x in MIN_REPO_VERSION) - if not repo_path: - repo_path = '~/bin/repo' + if not repo_path: + repo_path = "~/bin/repo" - if not ver_str: - print('no --wrapper-version argument', file=sys.stderr) - sys.exit(1) + if not ver_str: + print("no --wrapper-version argument", file=sys.stderr) + sys.exit(1) - # Pull out the version of the repo launcher we know about to compare. - exp = Wrapper().VERSION - ver = tuple(map(int, ver_str.split('.'))) + # Pull out the version of the repo launcher we know about to compare. + exp = Wrapper().VERSION + ver = tuple(map(int, ver_str.split("."))) - exp_str = '.'.join(map(str, exp)) - if ver < MIN_REPO_VERSION: - print(""" + exp_str = ".".join(map(str, exp)) + if ver < MIN_REPO_VERSION: + print( + """ repo: error: !!! Your version of repo %s is too old. !!! We need at least version %s. @@ -422,284 +517,321 @@ repo: error: !!! You must upgrade before you can continue: cp %s %s -""" % (ver_str, min_str, exp_str, WrapperPath(), repo_path), file=sys.stderr) - sys.exit(1) +""" + % (ver_str, min_str, exp_str, WrapperPath(), repo_path), + file=sys.stderr, + ) + sys.exit(1) - if exp > ver: - print('\n... A new version of repo (%s) is available.' % (exp_str,), - file=sys.stderr) - if os.access(repo_path, os.W_OK): - print("""\ + if exp > ver: + print( + "\n... A new version of repo (%s) is available." % (exp_str,), + file=sys.stderr, + ) + if os.access(repo_path, os.W_OK): + print( + """\ ... You should upgrade soon: cp %s %s -""" % (WrapperPath(), repo_path), file=sys.stderr) - else: - print("""\ +""" + % (WrapperPath(), repo_path), + file=sys.stderr, + ) + else: + print( + """\ ... New version is available at: %s ... The launcher is run from: %s !!! The launcher is not writable. Please talk to your sysadmin or distro !!! to get an update installed. -""" % (WrapperPath(), repo_path), file=sys.stderr) +""" + % (WrapperPath(), repo_path), + file=sys.stderr, + ) def _CheckRepoDir(repo_dir): - if not repo_dir: - print('no --repo-dir argument', file=sys.stderr) - sys.exit(1) + if not repo_dir: + print("no --repo-dir argument", file=sys.stderr) + sys.exit(1) def _PruneOptions(argv, opt): - i = 0 - while i < len(argv): - a = argv[i] - if a == '--': - break - if a.startswith('--'): - eq = a.find('=') - if eq > 0: - a = a[0:eq] - if not opt.has_option(a): - del argv[i] - continue - i += 1 + i = 0 + while i < len(argv): + a = argv[i] + if a == "--": + break + if a.startswith("--"): + eq = a.find("=") + if eq > 0: + a = a[0:eq] + if not opt.has_option(a): + del argv[i] + continue + i += 1 class _UserAgentHandler(urllib.request.BaseHandler): - def http_request(self, req): - req.add_header('User-Agent', user_agent.repo) - return req + def http_request(self, req): + req.add_header("User-Agent", user_agent.repo) + return req - def https_request(self, req): - req.add_header('User-Agent', user_agent.repo) - return req + def https_request(self, req): + req.add_header("User-Agent", user_agent.repo) + return req def _AddPasswordFromUserInput(handler, msg, req): - # If repo could not find auth info from netrc, try to get it from user input - url = req.get_full_url() - user, password = handler.passwd.find_user_password(None, url) - if user is None: - print(msg) - try: - user = input('User: ') - password = getpass.getpass() - except KeyboardInterrupt: - return - handler.passwd.add_password(None, url, user, password) + # If repo could not find auth info from netrc, try to get it from user input + url = req.get_full_url() + user, password = handler.passwd.find_user_password(None, url) + if user is None: + print(msg) + try: + user = input("User: ") + password = getpass.getpass() + except KeyboardInterrupt: + return + handler.passwd.add_password(None, url, user, password) class _BasicAuthHandler(urllib.request.HTTPBasicAuthHandler): - def http_error_401(self, req, fp, code, msg, headers): - _AddPasswordFromUserInput(self, msg, req) - return urllib.request.HTTPBasicAuthHandler.http_error_401( - self, req, fp, code, msg, headers) + def http_error_401(self, req, fp, code, msg, headers): + _AddPasswordFromUserInput(self, msg, req) + return urllib.request.HTTPBasicAuthHandler.http_error_401( + self, req, fp, code, msg, headers + ) - def http_error_auth_reqed(self, authreq, host, req, headers): - try: - old_add_header = req.add_header + def http_error_auth_reqed(self, authreq, host, req, headers): + try: + old_add_header = req.add_header - def _add_header(name, val): - val = val.replace('\n', '') - old_add_header(name, val) - req.add_header = _add_header - return urllib.request.AbstractBasicAuthHandler.http_error_auth_reqed( - self, authreq, host, req, headers) - except Exception: - reset = getattr(self, 'reset_retry_count', None) - if reset is not None: - reset() - elif getattr(self, 'retried', None): - self.retried = 0 - raise + def _add_header(name, val): + val = val.replace("\n", "") + old_add_header(name, val) + + req.add_header = _add_header + return ( + urllib.request.AbstractBasicAuthHandler.http_error_auth_reqed( + self, authreq, host, req, headers + ) + ) + except Exception: + reset = getattr(self, "reset_retry_count", None) + if reset is not None: + reset() + elif getattr(self, "retried", None): + self.retried = 0 + raise class _DigestAuthHandler(urllib.request.HTTPDigestAuthHandler): - def http_error_401(self, req, fp, code, msg, headers): - _AddPasswordFromUserInput(self, msg, req) - return urllib.request.HTTPDigestAuthHandler.http_error_401( - self, req, fp, code, msg, headers) + def http_error_401(self, req, fp, code, msg, headers): + _AddPasswordFromUserInput(self, msg, req) + return urllib.request.HTTPDigestAuthHandler.http_error_401( + self, req, fp, code, msg, headers + ) - def http_error_auth_reqed(self, auth_header, host, req, headers): - try: - old_add_header = req.add_header + def http_error_auth_reqed(self, auth_header, host, req, headers): + try: + old_add_header = req.add_header - def _add_header(name, val): - val = val.replace('\n', '') - old_add_header(name, val) - req.add_header = _add_header - return urllib.request.AbstractDigestAuthHandler.http_error_auth_reqed( - self, auth_header, host, req, headers) - except Exception: - reset = getattr(self, 'reset_retry_count', None) - if reset is not None: - reset() - elif getattr(self, 'retried', None): - self.retried = 0 - raise + def _add_header(name, val): + val = val.replace("\n", "") + old_add_header(name, val) + + req.add_header = _add_header + return ( + urllib.request.AbstractDigestAuthHandler.http_error_auth_reqed( + self, auth_header, host, req, headers + ) + ) + except Exception: + reset = getattr(self, "reset_retry_count", None) + if reset is not None: + reset() + elif getattr(self, "retried", None): + self.retried = 0 + raise class _KerberosAuthHandler(urllib.request.BaseHandler): - def __init__(self): - self.retried = 0 - self.context = None - self.handler_order = urllib.request.BaseHandler.handler_order - 50 + def __init__(self): + self.retried = 0 + self.context = None + self.handler_order = urllib.request.BaseHandler.handler_order - 50 - def http_error_401(self, req, fp, code, msg, headers): - host = req.get_host() - retry = self.http_error_auth_reqed('www-authenticate', host, req, headers) - return retry + def http_error_401(self, req, fp, code, msg, headers): + host = req.get_host() + retry = self.http_error_auth_reqed( + "www-authenticate", host, req, headers + ) + return retry - def http_error_auth_reqed(self, auth_header, host, req, headers): - try: - spn = "HTTP@%s" % host - authdata = self._negotiate_get_authdata(auth_header, headers) + def http_error_auth_reqed(self, auth_header, host, req, headers): + try: + spn = "HTTP@%s" % host + authdata = self._negotiate_get_authdata(auth_header, headers) - if self.retried > 3: - raise urllib.request.HTTPError(req.get_full_url(), 401, - "Negotiate auth failed", headers, None) - else: - self.retried += 1 + if self.retried > 3: + raise urllib.request.HTTPError( + req.get_full_url(), + 401, + "Negotiate auth failed", + headers, + None, + ) + else: + self.retried += 1 - neghdr = self._negotiate_get_svctk(spn, authdata) - if neghdr is None: + neghdr = self._negotiate_get_svctk(spn, authdata) + if neghdr is None: + return None + + req.add_unredirected_header("Authorization", neghdr) + response = self.parent.open(req) + + srvauth = self._negotiate_get_authdata(auth_header, response.info()) + if self._validate_response(srvauth): + return response + except kerberos.GSSError: + return None + except Exception: + self.reset_retry_count() + raise + finally: + self._clean_context() + + def reset_retry_count(self): + self.retried = 0 + + def _negotiate_get_authdata(self, auth_header, headers): + authhdr = headers.get(auth_header, None) + if authhdr is not None: + for mech_tuple in authhdr.split(","): + mech, __, authdata = mech_tuple.strip().partition(" ") + if mech.lower() == "negotiate": + return authdata.strip() return None - req.add_unredirected_header('Authorization', neghdr) - response = self.parent.open(req) + def _negotiate_get_svctk(self, spn, authdata): + if authdata is None: + return None - srvauth = self._negotiate_get_authdata(auth_header, response.info()) - if self._validate_response(srvauth): - return response - except kerberos.GSSError: - return None - except Exception: - self.reset_retry_count() - raise - finally: - self._clean_context() + result, self.context = kerberos.authGSSClientInit(spn) + if result < kerberos.AUTH_GSS_COMPLETE: + return None - def reset_retry_count(self): - self.retried = 0 + result = kerberos.authGSSClientStep(self.context, authdata) + if result < kerberos.AUTH_GSS_CONTINUE: + return None - def _negotiate_get_authdata(self, auth_header, headers): - authhdr = headers.get(auth_header, None) - if authhdr is not None: - for mech_tuple in authhdr.split(","): - mech, __, authdata = mech_tuple.strip().partition(" ") - if mech.lower() == "negotiate": - return authdata.strip() - return None + response = kerberos.authGSSClientResponse(self.context) + return "Negotiate %s" % response - def _negotiate_get_svctk(self, spn, authdata): - if authdata is None: - return None + def _validate_response(self, authdata): + if authdata is None: + return None + result = kerberos.authGSSClientStep(self.context, authdata) + if result == kerberos.AUTH_GSS_COMPLETE: + return True + return None - result, self.context = kerberos.authGSSClientInit(spn) - if result < kerberos.AUTH_GSS_COMPLETE: - return None - - result = kerberos.authGSSClientStep(self.context, authdata) - if result < kerberos.AUTH_GSS_CONTINUE: - return None - - response = kerberos.authGSSClientResponse(self.context) - return "Negotiate %s" % response - - def _validate_response(self, authdata): - if authdata is None: - return None - result = kerberos.authGSSClientStep(self.context, authdata) - if result == kerberos.AUTH_GSS_COMPLETE: - return True - return None - - def _clean_context(self): - if self.context is not None: - kerberos.authGSSClientClean(self.context) - self.context = None + def _clean_context(self): + if self.context is not None: + kerberos.authGSSClientClean(self.context) + self.context = None def init_http(): - handlers = [_UserAgentHandler()] + handlers = [_UserAgentHandler()] - mgr = urllib.request.HTTPPasswordMgrWithDefaultRealm() - try: - n = netrc.netrc() - for host in n.hosts: - p = n.hosts[host] - mgr.add_password(p[1], 'http://%s/' % host, p[0], p[2]) - mgr.add_password(p[1], 'https://%s/' % host, p[0], p[2]) - except netrc.NetrcParseError: - pass - except IOError: - pass - handlers.append(_BasicAuthHandler(mgr)) - handlers.append(_DigestAuthHandler(mgr)) - if kerberos: - handlers.append(_KerberosAuthHandler()) + mgr = urllib.request.HTTPPasswordMgrWithDefaultRealm() + try: + n = netrc.netrc() + for host in n.hosts: + p = n.hosts[host] + mgr.add_password(p[1], "http://%s/" % host, p[0], p[2]) + mgr.add_password(p[1], "https://%s/" % host, p[0], p[2]) + except netrc.NetrcParseError: + pass + except IOError: + pass + handlers.append(_BasicAuthHandler(mgr)) + handlers.append(_DigestAuthHandler(mgr)) + if kerberos: + handlers.append(_KerberosAuthHandler()) - if 'http_proxy' in os.environ: - url = os.environ['http_proxy'] - handlers.append(urllib.request.ProxyHandler({'http': url, 'https': url})) - if 'REPO_CURL_VERBOSE' in os.environ: - handlers.append(urllib.request.HTTPHandler(debuglevel=1)) - handlers.append(urllib.request.HTTPSHandler(debuglevel=1)) - urllib.request.install_opener(urllib.request.build_opener(*handlers)) + if "http_proxy" in os.environ: + url = os.environ["http_proxy"] + handlers.append( + urllib.request.ProxyHandler({"http": url, "https": url}) + ) + if "REPO_CURL_VERBOSE" in os.environ: + handlers.append(urllib.request.HTTPHandler(debuglevel=1)) + handlers.append(urllib.request.HTTPSHandler(debuglevel=1)) + urllib.request.install_opener(urllib.request.build_opener(*handlers)) def _Main(argv): - result = 0 + result = 0 - opt = optparse.OptionParser(usage="repo wrapperinfo -- ...") - opt.add_option("--repo-dir", dest="repodir", - help="path to .repo/") - opt.add_option("--wrapper-version", dest="wrapper_version", - help="version of the wrapper script") - opt.add_option("--wrapper-path", dest="wrapper_path", - help="location of the wrapper script") - _PruneOptions(argv, opt) - opt, argv = opt.parse_args(argv) + opt = optparse.OptionParser(usage="repo wrapperinfo -- ...") + opt.add_option("--repo-dir", dest="repodir", help="path to .repo/") + opt.add_option( + "--wrapper-version", + dest="wrapper_version", + help="version of the wrapper script", + ) + opt.add_option( + "--wrapper-path", + dest="wrapper_path", + help="location of the wrapper script", + ) + _PruneOptions(argv, opt) + opt, argv = opt.parse_args(argv) - _CheckWrapperVersion(opt.wrapper_version, opt.wrapper_path) - _CheckRepoDir(opt.repodir) + _CheckWrapperVersion(opt.wrapper_version, opt.wrapper_path) + _CheckRepoDir(opt.repodir) - Version.wrapper_version = opt.wrapper_version - Version.wrapper_path = opt.wrapper_path + Version.wrapper_version = opt.wrapper_version + Version.wrapper_path = opt.wrapper_path - repo = _Repo(opt.repodir) + repo = _Repo(opt.repodir) - try: - init_http() - name, gopts, argv = repo._ParseArgs(argv) - - if gopts.trace: - SetTrace() - - if gopts.trace_to_stderr: - SetTraceToStderr() - - result = repo._Run(name, gopts, argv) or 0 - except KeyboardInterrupt: - print('aborted by user', file=sys.stderr) - result = 1 - except ManifestParseError as mpe: - print('fatal: %s' % mpe, file=sys.stderr) - result = 1 - except RepoChangedException as rce: - # If repo changed, re-exec ourselves. - # - argv = list(sys.argv) - argv.extend(rce.extra_args) try: - os.execv(sys.executable, [__file__] + argv) - except OSError as e: - print('fatal: cannot restart repo after upgrade', file=sys.stderr) - print('fatal: %s' % e, file=sys.stderr) - result = 128 + init_http() + name, gopts, argv = repo._ParseArgs(argv) - TerminatePager() - sys.exit(result) + if gopts.trace: + SetTrace() + + if gopts.trace_to_stderr: + SetTraceToStderr() + + result = repo._Run(name, gopts, argv) or 0 + except KeyboardInterrupt: + print("aborted by user", file=sys.stderr) + result = 1 + except ManifestParseError as mpe: + print("fatal: %s" % mpe, file=sys.stderr) + result = 1 + except RepoChangedException as rce: + # If repo changed, re-exec ourselves. + # + argv = list(sys.argv) + argv.extend(rce.extra_args) + try: + os.execv(sys.executable, [__file__] + argv) + except OSError as e: + print("fatal: cannot restart repo after upgrade", file=sys.stderr) + print("fatal: %s" % e, file=sys.stderr) + result = 128 + + TerminatePager() + sys.exit(result) -if __name__ == '__main__': - _Main(sys.argv[1:]) +if __name__ == "__main__": + _Main(sys.argv[1:]) diff --git a/manifest_xml.py b/manifest_xml.py index 5b83f368..9603906f 100644 --- a/manifest_xml.py +++ b/manifest_xml.py @@ -26,415 +26,452 @@ from git_config import GitConfig from git_refs import R_HEADS, HEAD from git_superproject import Superproject import platform_utils -from project import (Annotation, RemoteSpec, Project, RepoProject, - ManifestProject) -from error import (ManifestParseError, ManifestInvalidPathError, - ManifestInvalidRevisionError) +from project import ( + Annotation, + RemoteSpec, + Project, + RepoProject, + ManifestProject, +) +from error import ( + ManifestParseError, + ManifestInvalidPathError, + ManifestInvalidRevisionError, +) from wrapper import Wrapper -MANIFEST_FILE_NAME = 'manifest.xml' -LOCAL_MANIFEST_NAME = 'local_manifest.xml' -LOCAL_MANIFESTS_DIR_NAME = 'local_manifests' -SUBMANIFEST_DIR = 'submanifests' +MANIFEST_FILE_NAME = "manifest.xml" +LOCAL_MANIFEST_NAME = "local_manifest.xml" +LOCAL_MANIFESTS_DIR_NAME = "local_manifests" +SUBMANIFEST_DIR = "submanifests" # Limit submanifests to an arbitrary depth for loop detection. MAX_SUBMANIFEST_DEPTH = 8 # Add all projects from sub manifest into a group. -SUBMANIFEST_GROUP_PREFIX = 'submanifest:' +SUBMANIFEST_GROUP_PREFIX = "submanifest:" # Add all projects from local manifest into a group. -LOCAL_MANIFEST_GROUP_PREFIX = 'local:' +LOCAL_MANIFEST_GROUP_PREFIX = "local:" # ContactInfo has the self-registered bug url, supplied by the manifest authors. -ContactInfo = collections.namedtuple('ContactInfo', 'bugurl') +ContactInfo = collections.namedtuple("ContactInfo", "bugurl") # urljoin gets confused if the scheme is not known. -urllib.parse.uses_relative.extend([ - 'ssh', - 'git', - 'persistent-https', - 'sso', - 'rpc']) -urllib.parse.uses_netloc.extend([ - 'ssh', - 'git', - 'persistent-https', - 'sso', - 'rpc']) +urllib.parse.uses_relative.extend( + ["ssh", "git", "persistent-https", "sso", "rpc"] +) +urllib.parse.uses_netloc.extend( + ["ssh", "git", "persistent-https", "sso", "rpc"] +) def XmlBool(node, attr, default=None): - """Determine boolean value of |node|'s |attr|. + """Determine boolean value of |node|'s |attr|. - Invalid values will issue a non-fatal warning. + Invalid values will issue a non-fatal warning. - Args: - node: XML node whose attributes we access. - attr: The attribute to access. - default: If the attribute is not set (value is empty), then use this. + Args: + node: XML node whose attributes we access. + attr: The attribute to access. + default: If the attribute is not set (value is empty), then use this. - Returns: - True if the attribute is a valid string representing true. - False if the attribute is a valid string representing false. - |default| otherwise. - """ - value = node.getAttribute(attr) - s = value.lower() - if s == '': - return default - elif s in {'yes', 'true', '1'}: - return True - elif s in {'no', 'false', '0'}: - return False - else: - print('warning: manifest: %s="%s": ignoring invalid XML boolean' % - (attr, value), file=sys.stderr) - return default + Returns: + True if the attribute is a valid string representing true. + False if the attribute is a valid string representing false. + |default| otherwise. + """ + value = node.getAttribute(attr) + s = value.lower() + if s == "": + return default + elif s in {"yes", "true", "1"}: + return True + elif s in {"no", "false", "0"}: + return False + else: + print( + 'warning: manifest: %s="%s": ignoring invalid XML boolean' + % (attr, value), + file=sys.stderr, + ) + return default def XmlInt(node, attr, default=None): - """Determine integer value of |node|'s |attr|. + """Determine integer value of |node|'s |attr|. - Args: - node: XML node whose attributes we access. - attr: The attribute to access. - default: If the attribute is not set (value is empty), then use this. + Args: + node: XML node whose attributes we access. + attr: The attribute to access. + default: If the attribute is not set (value is empty), then use this. - Returns: - The number if the attribute is a valid number. + Returns: + The number if the attribute is a valid number. - Raises: - ManifestParseError: The number is invalid. - """ - value = node.getAttribute(attr) - if not value: - return default + Raises: + ManifestParseError: The number is invalid. + """ + value = node.getAttribute(attr) + if not value: + return default - try: - return int(value) - except ValueError: - raise ManifestParseError('manifest: invalid %s="%s" integer' % - (attr, value)) + try: + return int(value) + except ValueError: + raise ManifestParseError( + 'manifest: invalid %s="%s" integer' % (attr, value) + ) class _Default(object): - """Project defaults within the manifest.""" + """Project defaults within the manifest.""" - revisionExpr = None - destBranchExpr = None - upstreamExpr = None - remote = None - sync_j = None - sync_c = False - sync_s = False - sync_tags = True + revisionExpr = None + destBranchExpr = None + upstreamExpr = None + remote = None + sync_j = None + sync_c = False + sync_s = False + sync_tags = True - def __eq__(self, other): - if not isinstance(other, _Default): - return False - return self.__dict__ == other.__dict__ + def __eq__(self, other): + if not isinstance(other, _Default): + return False + return self.__dict__ == other.__dict__ - def __ne__(self, other): - if not isinstance(other, _Default): - return True - return self.__dict__ != other.__dict__ + def __ne__(self, other): + if not isinstance(other, _Default): + return True + return self.__dict__ != other.__dict__ class _XmlRemote(object): - def __init__(self, - name, - alias=None, - fetch=None, - pushUrl=None, - manifestUrl=None, - review=None, - revision=None): - self.name = name - self.fetchUrl = fetch - self.pushUrl = pushUrl - self.manifestUrl = manifestUrl - self.remoteAlias = alias - self.reviewUrl = review - self.revision = revision - self.resolvedFetchUrl = self._resolveFetchUrl() - self.annotations = [] + def __init__( + self, + name, + alias=None, + fetch=None, + pushUrl=None, + manifestUrl=None, + review=None, + revision=None, + ): + self.name = name + self.fetchUrl = fetch + self.pushUrl = pushUrl + self.manifestUrl = manifestUrl + self.remoteAlias = alias + self.reviewUrl = review + self.revision = revision + self.resolvedFetchUrl = self._resolveFetchUrl() + self.annotations = [] - def __eq__(self, other): - if not isinstance(other, _XmlRemote): - return False - return (sorted(self.annotations) == sorted(other.annotations) and - self.name == other.name and self.fetchUrl == other.fetchUrl and - self.pushUrl == other.pushUrl and self.remoteAlias == other.remoteAlias - and self.reviewUrl == other.reviewUrl and self.revision == other.revision) + def __eq__(self, other): + if not isinstance(other, _XmlRemote): + return False + return ( + sorted(self.annotations) == sorted(other.annotations) + and self.name == other.name + and self.fetchUrl == other.fetchUrl + and self.pushUrl == other.pushUrl + and self.remoteAlias == other.remoteAlias + and self.reviewUrl == other.reviewUrl + and self.revision == other.revision + ) - def __ne__(self, other): - return not self.__eq__(other) + def __ne__(self, other): + return not self.__eq__(other) - def _resolveFetchUrl(self): - if self.fetchUrl is None: - return '' - url = self.fetchUrl.rstrip('/') - manifestUrl = self.manifestUrl.rstrip('/') - # urljoin will gets confused over quite a few things. The ones we care - # about here are: - # * no scheme in the base url, like - # We handle no scheme by replacing it with an obscure protocol, gopher - # and then replacing it with the original when we are done. + def _resolveFetchUrl(self): + if self.fetchUrl is None: + return "" + url = self.fetchUrl.rstrip("/") + manifestUrl = self.manifestUrl.rstrip("/") + # urljoin will gets confused over quite a few things. The ones we care + # about here are: + # * no scheme in the base url, like + # We handle no scheme by replacing it with an obscure protocol, gopher + # and then replacing it with the original when we are done. - if manifestUrl.find(':') != manifestUrl.find('/') - 1: - url = urllib.parse.urljoin('gopher://' + manifestUrl, url) - url = re.sub(r'^gopher://', '', url) - else: - url = urllib.parse.urljoin(manifestUrl, url) - return url + if manifestUrl.find(":") != manifestUrl.find("/") - 1: + url = urllib.parse.urljoin("gopher://" + manifestUrl, url) + url = re.sub(r"^gopher://", "", url) + else: + url = urllib.parse.urljoin(manifestUrl, url) + return url - def ToRemoteSpec(self, projectName): - fetchUrl = self.resolvedFetchUrl.rstrip('/') - url = fetchUrl + '/' + projectName - remoteName = self.name - if self.remoteAlias: - remoteName = self.remoteAlias - return RemoteSpec(remoteName, - url=url, - pushUrl=self.pushUrl, - review=self.reviewUrl, - orig_name=self.name, - fetchUrl=self.fetchUrl) + def ToRemoteSpec(self, projectName): + fetchUrl = self.resolvedFetchUrl.rstrip("/") + url = fetchUrl + "/" + projectName + remoteName = self.name + if self.remoteAlias: + remoteName = self.remoteAlias + return RemoteSpec( + remoteName, + url=url, + pushUrl=self.pushUrl, + review=self.reviewUrl, + orig_name=self.name, + fetchUrl=self.fetchUrl, + ) - def AddAnnotation(self, name, value, keep): - self.annotations.append(Annotation(name, value, keep)) + def AddAnnotation(self, name, value, keep): + self.annotations.append(Annotation(name, value, keep)) class _XmlSubmanifest: - """Manage the element specified in the manifest. + """Manage the element specified in the manifest. - Attributes: - name: a string, the name for this submanifest. - remote: a string, the remote.name for this submanifest. - project: a string, the name of the manifest project. - revision: a string, the commitish. - manifestName: a string, the submanifest file name. - groups: a list of strings, the groups to add to all projects in the submanifest. - default_groups: a list of strings, the default groups to sync. - path: a string, the relative path for the submanifest checkout. - parent: an XmlManifest, the parent manifest. - annotations: (derived) a list of annotations. - present: (derived) a boolean, whether the sub manifest file is present. - """ - def __init__(self, - name, - remote=None, - project=None, - revision=None, - manifestName=None, - groups=None, - default_groups=None, - path=None, - parent=None): - self.name = name - self.remote = remote - self.project = project - self.revision = revision - self.manifestName = manifestName - self.groups = groups - self.default_groups = default_groups - self.path = path - self.parent = parent - self.annotations = [] - outer_client = parent._outer_client or parent - if self.remote and not self.project: - raise ManifestParseError( - f'Submanifest {name}: must specify project when remote is given.') - # Construct the absolute path to the manifest file using the parent's - # method, so that we can correctly create our repo_client. - manifestFile = parent.SubmanifestInfoDir( - os.path.join(parent.path_prefix, self.relpath), - os.path.join('manifests', manifestName or 'default.xml')) - linkFile = parent.SubmanifestInfoDir( - os.path.join(parent.path_prefix, self.relpath), MANIFEST_FILE_NAME) - rc = self.repo_client = RepoClient( - parent.repodir, linkFile, parent_groups=','.join(groups) or '', - submanifest_path=self.relpath, outer_client=outer_client, - default_groups=default_groups) + Attributes: + name: a string, the name for this submanifest. + remote: a string, the remote.name for this submanifest. + project: a string, the name of the manifest project. + revision: a string, the commitish. + manifestName: a string, the submanifest file name. + groups: a list of strings, the groups to add to all projects in the + submanifest. + default_groups: a list of strings, the default groups to sync. + path: a string, the relative path for the submanifest checkout. + parent: an XmlManifest, the parent manifest. + annotations: (derived) a list of annotations. + present: (derived) a boolean, whether the sub manifest file is present. + """ - self.present = os.path.exists(manifestFile) + def __init__( + self, + name, + remote=None, + project=None, + revision=None, + manifestName=None, + groups=None, + default_groups=None, + path=None, + parent=None, + ): + self.name = name + self.remote = remote + self.project = project + self.revision = revision + self.manifestName = manifestName + self.groups = groups + self.default_groups = default_groups + self.path = path + self.parent = parent + self.annotations = [] + outer_client = parent._outer_client or parent + if self.remote and not self.project: + raise ManifestParseError( + f"Submanifest {name}: must specify project when remote is " + "given." + ) + # Construct the absolute path to the manifest file using the parent's + # method, so that we can correctly create our repo_client. + manifestFile = parent.SubmanifestInfoDir( + os.path.join(parent.path_prefix, self.relpath), + os.path.join("manifests", manifestName or "default.xml"), + ) + linkFile = parent.SubmanifestInfoDir( + os.path.join(parent.path_prefix, self.relpath), MANIFEST_FILE_NAME + ) + self.repo_client = RepoClient( + parent.repodir, + linkFile, + parent_groups=",".join(groups) or "", + submanifest_path=self.relpath, + outer_client=outer_client, + default_groups=default_groups, + ) - def __eq__(self, other): - if not isinstance(other, _XmlSubmanifest): - return False - return ( - self.name == other.name and - self.remote == other.remote and - self.project == other.project and - self.revision == other.revision and - self.manifestName == other.manifestName and - self.groups == other.groups and - self.default_groups == other.default_groups and - self.path == other.path and - sorted(self.annotations) == sorted(other.annotations)) + self.present = os.path.exists(manifestFile) - def __ne__(self, other): - return not self.__eq__(other) + def __eq__(self, other): + if not isinstance(other, _XmlSubmanifest): + return False + return ( + self.name == other.name + and self.remote == other.remote + and self.project == other.project + and self.revision == other.revision + and self.manifestName == other.manifestName + and self.groups == other.groups + and self.default_groups == other.default_groups + and self.path == other.path + and sorted(self.annotations) == sorted(other.annotations) + ) - def ToSubmanifestSpec(self): - """Return a SubmanifestSpec object, populating attributes""" - mp = self.parent.manifestProject - remote = self.parent.remotes[self.remote or self.parent.default.remote.name] - # If a project was given, generate the url from the remote and project. - # If not, use this manifestProject's url. - if self.project: - manifestUrl = remote.ToRemoteSpec(self.project).url - else: - manifestUrl = mp.GetRemote().url - manifestName = self.manifestName or 'default.xml' - revision = self.revision or self.name - path = self.path or revision.split('/')[-1] - groups = self.groups or [] - default_groups = self.default_groups or [] + def __ne__(self, other): + return not self.__eq__(other) - return SubmanifestSpec(self.name, manifestUrl, manifestName, revision, path, - groups) + def ToSubmanifestSpec(self): + """Return a SubmanifestSpec object, populating attributes""" + mp = self.parent.manifestProject + remote = self.parent.remotes[ + self.remote or self.parent.default.remote.name + ] + # If a project was given, generate the url from the remote and project. + # If not, use this manifestProject's url. + if self.project: + manifestUrl = remote.ToRemoteSpec(self.project).url + else: + manifestUrl = mp.GetRemote().url + manifestName = self.manifestName or "default.xml" + revision = self.revision or self.name + path = self.path or revision.split("/")[-1] + groups = self.groups or [] - @property - def relpath(self): - """The path of this submanifest relative to the parent manifest.""" - revision = self.revision or self.name - return self.path or revision.split('/')[-1] + return SubmanifestSpec( + self.name, manifestUrl, manifestName, revision, path, groups + ) - def GetGroupsStr(self): - """Returns the `groups` given for this submanifest.""" - if self.groups: - return ','.join(self.groups) - return '' + @property + def relpath(self): + """The path of this submanifest relative to the parent manifest.""" + revision = self.revision or self.name + return self.path or revision.split("/")[-1] - def GetDefaultGroupsStr(self): - """Returns the `default-groups` given for this submanifest.""" - return ','.join(self.default_groups or []) + def GetGroupsStr(self): + """Returns the `groups` given for this submanifest.""" + if self.groups: + return ",".join(self.groups) + return "" - def AddAnnotation(self, name, value, keep): - """Add annotations to the submanifest.""" - self.annotations.append(Annotation(name, value, keep)) + def GetDefaultGroupsStr(self): + """Returns the `default-groups` given for this submanifest.""" + return ",".join(self.default_groups or []) + + def AddAnnotation(self, name, value, keep): + """Add annotations to the submanifest.""" + self.annotations.append(Annotation(name, value, keep)) class SubmanifestSpec: - """The submanifest element, with all fields expanded.""" + """The submanifest element, with all fields expanded.""" - def __init__(self, - name, - manifestUrl, - manifestName, - revision, - path, - groups): - self.name = name - self.manifestUrl = manifestUrl - self.manifestName = manifestName - self.revision = revision - self.path = path - self.groups = groups or [] + def __init__(self, name, manifestUrl, manifestName, revision, path, groups): + self.name = name + self.manifestUrl = manifestUrl + self.manifestName = manifestName + self.revision = revision + self.path = path + self.groups = groups or [] class XmlManifest(object): - """manages the repo configuration file""" + """manages the repo configuration file""" - def __init__(self, repodir, manifest_file, local_manifests=None, - outer_client=None, parent_groups='', submanifest_path='', - default_groups=None): - """Initialize. + def __init__( + self, + repodir, + manifest_file, + local_manifests=None, + outer_client=None, + parent_groups="", + submanifest_path="", + default_groups=None, + ): + """Initialize. - Args: - repodir: Path to the .repo/ dir for holding all internal checkout state. - It must be in the top directory of the repo client checkout. - manifest_file: Full path to the manifest file to parse. This will usually - be |repodir|/|MANIFEST_FILE_NAME|. - local_manifests: Full path to the directory of local override manifests. - This will usually be |repodir|/|LOCAL_MANIFESTS_DIR_NAME|. - outer_client: RepoClient of the outer manifest. - parent_groups: a string, the groups to apply to this projects. - submanifest_path: The submanifest root relative to the repo root. - default_groups: a string, the default manifest groups to use. - """ - # TODO(vapier): Move this out of this class. - self.globalConfig = GitConfig.ForUser() + Args: + repodir: Path to the .repo/ dir for holding all internal checkout + state. It must be in the top directory of the repo client + checkout. + manifest_file: Full path to the manifest file to parse. This will + usually be |repodir|/|MANIFEST_FILE_NAME|. + local_manifests: Full path to the directory of local override + manifests. This will usually be + |repodir|/|LOCAL_MANIFESTS_DIR_NAME|. + outer_client: RepoClient of the outer manifest. + parent_groups: a string, the groups to apply to this projects. + submanifest_path: The submanifest root relative to the repo root. + default_groups: a string, the default manifest groups to use. + """ + # TODO(vapier): Move this out of this class. + self.globalConfig = GitConfig.ForUser() - self.repodir = os.path.abspath(repodir) - self._CheckLocalPath(submanifest_path) - self.topdir = os.path.dirname(self.repodir) - if submanifest_path: - # This avoids a trailing os.path.sep when submanifest_path is empty. - self.topdir = os.path.join(self.topdir, submanifest_path) - if manifest_file != os.path.abspath(manifest_file): - raise ManifestParseError('manifest_file must be abspath') - self.manifestFile = manifest_file - if not outer_client or outer_client == self: - # manifestFileOverrides only exists in the outer_client's manifest, since - # that is the only instance left when Unload() is called on the outer - # manifest. - self.manifestFileOverrides = {} - self.local_manifests = local_manifests - self._load_local_manifests = True - self.parent_groups = parent_groups - self.default_groups = default_groups + self.repodir = os.path.abspath(repodir) + self._CheckLocalPath(submanifest_path) + self.topdir = os.path.dirname(self.repodir) + if submanifest_path: + # This avoids a trailing os.path.sep when submanifest_path is empty. + self.topdir = os.path.join(self.topdir, submanifest_path) + if manifest_file != os.path.abspath(manifest_file): + raise ManifestParseError("manifest_file must be abspath") + self.manifestFile = manifest_file + if not outer_client or outer_client == self: + # manifestFileOverrides only exists in the outer_client's manifest, + # since that is the only instance left when Unload() is called on + # the outer manifest. + self.manifestFileOverrides = {} + self.local_manifests = local_manifests + self._load_local_manifests = True + self.parent_groups = parent_groups + self.default_groups = default_groups - if outer_client and self.isGitcClient: - raise ManifestParseError('Multi-manifest is incompatible with `gitc-init`') + if outer_client and self.isGitcClient: + raise ManifestParseError( + "Multi-manifest is incompatible with `gitc-init`" + ) - if submanifest_path and not outer_client: - # If passing a submanifest_path, there must be an outer_client. - raise ManifestParseError(f'Bad call to {self.__class__.__name__}') + if submanifest_path and not outer_client: + # If passing a submanifest_path, there must be an outer_client. + raise ManifestParseError(f"Bad call to {self.__class__.__name__}") - # If self._outer_client is None, this is not a checkout that supports - # multi-tree. - self._outer_client = outer_client or self + # If self._outer_client is None, this is not a checkout that supports + # multi-tree. + self._outer_client = outer_client or self - self.repoProject = RepoProject(self, 'repo', - gitdir=os.path.join(repodir, 'repo/.git'), - worktree=os.path.join(repodir, 'repo')) + self.repoProject = RepoProject( + self, + "repo", + gitdir=os.path.join(repodir, "repo/.git"), + worktree=os.path.join(repodir, "repo"), + ) - mp = self.SubmanifestProject(self.path_prefix) - self.manifestProject = mp + mp = self.SubmanifestProject(self.path_prefix) + self.manifestProject = mp - # This is a bit hacky, but we're in a chicken & egg situation: all the - # normal repo settings live in the manifestProject which we just setup - # above, so we couldn't easily query before that. We assume Project() - # init doesn't care if this changes afterwards. - if os.path.exists(mp.gitdir) and mp.use_worktree: - mp.use_git_worktrees = True + # This is a bit hacky, but we're in a chicken & egg situation: all the + # normal repo settings live in the manifestProject which we just setup + # above, so we couldn't easily query before that. We assume Project() + # init doesn't care if this changes afterwards. + if os.path.exists(mp.gitdir) and mp.use_worktree: + mp.use_git_worktrees = True - self.Unload() + self.Unload() - def Override(self, name, load_local_manifests=True): - """Use a different manifest, just for the current instantiation. - """ - path = None + def Override(self, name, load_local_manifests=True): + """Use a different manifest, just for the current instantiation.""" + path = None - # Look for a manifest by path in the filesystem (including the cwd). - if not load_local_manifests: - local_path = os.path.abspath(name) - if os.path.isfile(local_path): - path = local_path + # Look for a manifest by path in the filesystem (including the cwd). + if not load_local_manifests: + local_path = os.path.abspath(name) + if os.path.isfile(local_path): + path = local_path - # Look for manifests by name from the manifests repo. - if path is None: - path = os.path.join(self.manifestProject.worktree, name) - if not os.path.isfile(path): - raise ManifestParseError('manifest %s not found' % name) + # Look for manifests by name from the manifests repo. + if path is None: + path = os.path.join(self.manifestProject.worktree, name) + if not os.path.isfile(path): + raise ManifestParseError("manifest %s not found" % name) - self._load_local_manifests = load_local_manifests - self._outer_client.manifestFileOverrides[self.path_prefix] = path - self.Unload() - self._Load() + self._load_local_manifests = load_local_manifests + self._outer_client.manifestFileOverrides[self.path_prefix] = path + self.Unload() + self._Load() - def Link(self, name): - """Update the repo metadata to use a different manifest. - """ - self.Override(name) + def Link(self, name): + """Update the repo metadata to use a different manifest.""" + self.Override(name) - # Old versions of repo would generate symlinks we need to clean up. - platform_utils.remove(self.manifestFile, missing_ok=True) - # This file is interpreted as if it existed inside the manifest repo. - # That allows us to use with the relative file name. - with open(self.manifestFile, 'w') as fp: - fp.write(""" + # Old versions of repo would generate symlinks we need to clean up. + platform_utils.remove(self.manifestFile, missing_ok=True) + # This file is interpreted as if it existed inside the manifest repo. + # That allows us to use with the relative file name. + with open(self.manifestFile, "w") as fp: + fp.write( + """ XYZ - return refs.get(last_pub).split('/')[-2] - except (AttributeError, IndexError): - return "" - - def _UploadAndReport(self, opt, todo, original_people): - have_errors = False - for branch in todo: - try: - people = copy.deepcopy(original_people) - self._AppendAutoList(branch, people) - - # Check if there are local changes that may have been forgotten - changes = branch.project.UncommitedFiles() - if opt.ignore_untracked_files: - untracked = set(branch.project.UntrackedFiles()) - changes = [x for x in changes if x not in untracked] - - if changes: - key = 'review.%s.autoupload' % branch.project.remote.review - answer = branch.project.config.GetBoolean(key) - - # if they want to auto upload, let's not ask because it could be automated - if answer is None: - print() - print('Uncommitted changes in %s (did you forget to amend?):' - % branch.project.name) - print('\n'.join(changes)) - print('Continue uploading? (y/N) ', end='', flush=True) + print("to %s (y/N)? " % remote.review, end="", flush=True) if opt.yes: - print('<--yes>') - a = 'yes' + print("<--yes>") + answer = True else: - a = sys.stdin.readline().strip().lower() - if a not in ('y', 'yes', 't', 'true', 'on'): - print("skipping upload", file=sys.stderr) - branch.uploaded = False - branch.error = 'User aborted' - continue + answer = sys.stdin.readline().strip().lower() + answer = answer in ("y", "yes", "1", "true", "t") + if not answer: + _die("upload aborted by user") - # Check if topic branches should be sent to the server during upload - if opt.auto_topic is not True: - key = 'review.%s.uploadtopic' % branch.project.remote.review - opt.auto_topic = branch.project.config.GetBoolean(key) + # Perform some basic safety checks prior to uploading. + if not opt.yes and not _VerifyPendingCommits([branch]): + _die("upload aborted by user") - def _ExpandCommaList(value): - """Split |value| up into comma delimited entries.""" - if not value: - return - for ret in value.split(','): - ret = ret.strip() - if ret: - yield ret + self._UploadAndReport(opt, [branch], people) - # Check if hashtags should be included. - key = 'review.%s.uploadhashtags' % branch.project.remote.review - hashtags = set(_ExpandCommaList(branch.project.config.GetString(key))) - for tag in opt.hashtags: - hashtags.update(_ExpandCommaList(tag)) - if opt.hashtag_branch: - hashtags.add(branch.name) + def _MultipleBranches(self, opt, pending, people): + projects = {} + branches = {} - # Check if labels should be included. - key = 'review.%s.uploadlabels' % branch.project.remote.review - labels = set(_ExpandCommaList(branch.project.config.GetString(key))) - for label in opt.labels: - labels.update(_ExpandCommaList(label)) + script = [] + script.append("# Uncomment the branches to upload:") + for project, avail in pending: + project_path = project.RelPath(local=opt.this_manifest_only) + script.append("#") + script.append(f"# project {project_path}/:") - # Handle e-mail notifications. - if opt.notify is False: - notify = 'NONE' + b = {} + for branch in avail: + if branch is None: + continue + name = branch.name + date = branch.date + commit_list = branch.commits + + if b: + script.append("#") + destination = ( + opt.dest_branch + or project.dest_branch + or project.revisionExpr + ) + script.append( + "# branch %s (%2d commit%s, %s) to remote branch %s:" + % ( + name, + len(commit_list), + len(commit_list) != 1 and "s" or "", + date, + destination, + ) + ) + for commit in commit_list: + script.append("# %s" % commit) + b[name] = branch + + projects[project_path] = project + branches[project_path] = b + script.append("") + + script = Editor.EditString("\n".join(script)).split("\n") + + project_re = re.compile(r"^#?\s*project\s*([^\s]+)/:$") + branch_re = re.compile(r"^\s*branch\s*([^\s(]+)\s*\(.*") + + project = None + todo = [] + + for line in script: + m = project_re.match(line) + if m: + name = m.group(1) + project = projects.get(name) + if not project: + _die("project %s not available for upload", name) + continue + + m = branch_re.match(line) + if m: + name = m.group(1) + if not project: + _die("project for branch %s not in script", name) + project_path = project.RelPath(local=opt.this_manifest_only) + branch = branches[project_path].get(name) + if not branch: + _die("branch %s not in %s", name, project_path) + todo.append(branch) + if not todo: + _die("nothing uncommented for upload") + + # Perform some basic safety checks prior to uploading. + if not opt.yes and not _VerifyPendingCommits(todo): + _die("upload aborted by user") + + self._UploadAndReport(opt, todo, people) + + def _AppendAutoList(self, branch, people): + """ + Appends the list of reviewers in the git project's config. + Appends the list of users in the CC list in the git project's config if + a non-empty reviewer list was found. + """ + name = branch.name + project = branch.project + + key = "review.%s.autoreviewer" % project.GetBranch(name).remote.review + raw_list = project.config.GetString(key) + if raw_list is not None: + people[0].extend([entry.strip() for entry in raw_list.split(",")]) + + key = "review.%s.autocopy" % project.GetBranch(name).remote.review + raw_list = project.config.GetString(key) + if raw_list is not None and len(people[0]) > 0: + people[1].extend([entry.strip() for entry in raw_list.split(",")]) + + def _FindGerritChange(self, branch): + last_pub = branch.project.WasPublished(branch.name) + if last_pub is None: + return "" + + refs = branch.GetPublishedRefs() + try: + # refs/changes/XYZ/N --> XYZ + return refs.get(last_pub).split("/")[-2] + except (AttributeError, IndexError): + return "" + + def _UploadAndReport(self, opt, todo, original_people): + have_errors = False + for branch in todo: + try: + people = copy.deepcopy(original_people) + self._AppendAutoList(branch, people) + + # Check if there are local changes that may have been forgotten. + changes = branch.project.UncommitedFiles() + if opt.ignore_untracked_files: + untracked = set(branch.project.UntrackedFiles()) + changes = [x for x in changes if x not in untracked] + + if changes: + key = "review.%s.autoupload" % branch.project.remote.review + answer = branch.project.config.GetBoolean(key) + + # If they want to auto upload, let's not ask because it + # could be automated. + if answer is None: + print() + print( + "Uncommitted changes in %s (did you forget to " + "amend?):" % branch.project.name + ) + print("\n".join(changes)) + print("Continue uploading? (y/N) ", end="", flush=True) + if opt.yes: + print("<--yes>") + a = "yes" + else: + a = sys.stdin.readline().strip().lower() + if a not in ("y", "yes", "t", "true", "on"): + print("skipping upload", file=sys.stderr) + branch.uploaded = False + branch.error = "User aborted" + continue + + # Check if topic branches should be sent to the server during + # upload. + if opt.auto_topic is not True: + key = "review.%s.uploadtopic" % branch.project.remote.review + opt.auto_topic = branch.project.config.GetBoolean(key) + + def _ExpandCommaList(value): + """Split |value| up into comma delimited entries.""" + if not value: + return + for ret in value.split(","): + ret = ret.strip() + if ret: + yield ret + + # Check if hashtags should be included. + key = "review.%s.uploadhashtags" % branch.project.remote.review + hashtags = set( + _ExpandCommaList(branch.project.config.GetString(key)) + ) + for tag in opt.hashtags: + hashtags.update(_ExpandCommaList(tag)) + if opt.hashtag_branch: + hashtags.add(branch.name) + + # Check if labels should be included. + key = "review.%s.uploadlabels" % branch.project.remote.review + labels = set( + _ExpandCommaList(branch.project.config.GetString(key)) + ) + for label in opt.labels: + labels.update(_ExpandCommaList(label)) + + # Handle e-mail notifications. + if opt.notify is False: + notify = "NONE" + else: + key = ( + "review.%s.uploadnotify" % branch.project.remote.review + ) + notify = branch.project.config.GetString(key) + + destination = opt.dest_branch or branch.project.dest_branch + + if branch.project.dest_branch and not opt.dest_branch: + merge_branch = self._GetMergeBranch( + branch.project, local_branch=branch.name + ) + + full_dest = destination + if not full_dest.startswith(R_HEADS): + full_dest = R_HEADS + full_dest + + # If the merge branch of the local branch is different from + # the project's revision AND destination, this might not be + # intentional. + if ( + merge_branch + and merge_branch != branch.project.revisionExpr + and merge_branch != full_dest + ): + print( + f"For local branch {branch.name}: merge branch " + f"{merge_branch} does not match destination branch " + f"{destination}" + ) + print("skipping upload.") + print( + f"Please use `--destination {destination}` if this " + "is intentional" + ) + branch.uploaded = False + continue + + branch.UploadForReview( + people, + dryrun=opt.dryrun, + auto_topic=opt.auto_topic, + hashtags=hashtags, + labels=labels, + private=opt.private, + notify=notify, + wip=opt.wip, + ready=opt.ready, + dest_branch=destination, + validate_certs=opt.validate_certs, + push_options=opt.push_options, + ) + + branch.uploaded = True + except UploadError as e: + branch.error = e + branch.uploaded = False + have_errors = True + + print(file=sys.stderr) + print("-" * 70, file=sys.stderr) + + if have_errors: + for branch in todo: + if not branch.uploaded: + if len(str(branch.error)) <= 30: + fmt = " (%s)" + else: + fmt = "\n (%s)" + print( + ("[FAILED] %-15s %-15s" + fmt) + % ( + branch.project.RelPath(local=opt.this_manifest_only) + + "/", + branch.name, + str(branch.error), + ), + file=sys.stderr, + ) + print() + + for branch in todo: + if branch.uploaded: + print( + "[OK ] %-15s %s" + % ( + branch.project.RelPath(local=opt.this_manifest_only) + + "/", + branch.name, + ), + file=sys.stderr, + ) + + if have_errors: + sys.exit(1) + + def _GetMergeBranch(self, project, local_branch=None): + if local_branch is None: + p = GitCommand( + project, + ["rev-parse", "--abbrev-ref", "HEAD"], + capture_stdout=True, + capture_stderr=True, + ) + p.Wait() + local_branch = p.stdout.strip() + p = GitCommand( + project, + ["config", "--get", "branch.%s.merge" % local_branch], + capture_stdout=True, + capture_stderr=True, + ) + p.Wait() + merge_branch = p.stdout.strip() + return merge_branch + + @staticmethod + def _GatherOne(opt, project): + """Figure out the upload status for |project|.""" + if opt.current_branch: + cbr = project.CurrentBranch + up_branch = project.GetUploadableBranch(cbr) + avail = [up_branch] if up_branch else None else: - key = 'review.%s.uploadnotify' % branch.project.remote.review - notify = branch.project.config.GetString(key) + avail = project.GetUploadableBranches(opt.branch) + return (project, avail) - destination = opt.dest_branch or branch.project.dest_branch + def Execute(self, opt, args): + projects = self.GetProjects( + args, all_manifests=not opt.this_manifest_only + ) - if branch.project.dest_branch and not opt.dest_branch: + def _ProcessResults(_pool, _out, results): + pending = [] + for result in results: + project, avail = result + if avail is None: + print( + 'repo: error: %s: Unable to upload branch "%s". ' + "You might be able to fix the branch by running:\n" + " git branch --set-upstream-to m/%s" + % ( + project.RelPath(local=opt.this_manifest_only), + project.CurrentBranch, + project.manifest.branch, + ), + file=sys.stderr, + ) + elif avail: + pending.append(result) + return pending - merge_branch = self._GetMergeBranch( - branch.project, local_branch=branch.name) + pending = self.ExecuteInParallel( + opt.jobs, + functools.partial(self._GatherOne, opt), + projects, + callback=_ProcessResults, + ) - full_dest = destination - if not full_dest.startswith(R_HEADS): - full_dest = R_HEADS + full_dest + if not pending: + if opt.branch is None: + print( + "repo: error: no branches ready for upload", file=sys.stderr + ) + else: + print( + 'repo: error: no branches named "%s" ready for upload' + % (opt.branch,), + file=sys.stderr, + ) + return 1 - # If the merge branch of the local branch is different from the - # project's revision AND destination, this might not be intentional. - if (merge_branch and merge_branch != branch.project.revisionExpr - and merge_branch != full_dest): - print(f'For local branch {branch.name}: merge branch ' - f'{merge_branch} does not match destination branch ' - f'{destination}') - print('skipping upload.') - print(f'Please use `--destination {destination}` if this is intentional') - branch.uploaded = False - continue + manifests = { + project.manifest.topdir: project.manifest + for (project, available) in pending + } + ret = 0 + for manifest in manifests.values(): + pending_proj_names = [ + project.name + for (project, available) in pending + if project.manifest.topdir == manifest.topdir + ] + pending_worktrees = [ + project.worktree + for (project, available) in pending + if project.manifest.topdir == manifest.topdir + ] + hook = RepoHook.FromSubcmd( + hook_type="pre-upload", + manifest=manifest, + opt=opt, + abort_if_user_denies=True, + ) + if not hook.Run( + project_list=pending_proj_names, worktree_list=pending_worktrees + ): + ret = 1 + if ret: + return ret - branch.UploadForReview(people, - dryrun=opt.dryrun, - auto_topic=opt.auto_topic, - hashtags=hashtags, - labels=labels, - private=opt.private, - notify=notify, - wip=opt.wip, - ready=opt.ready, - dest_branch=destination, - validate_certs=opt.validate_certs, - push_options=opt.push_options) + reviewers = _SplitEmails(opt.reviewers) if opt.reviewers else [] + cc = _SplitEmails(opt.cc) if opt.cc else [] + people = (reviewers, cc) - branch.uploaded = True - except UploadError as e: - branch.error = e - branch.uploaded = False - have_errors = True - - print(file=sys.stderr) - print('----------------------------------------------------------------------', file=sys.stderr) - - if have_errors: - for branch in todo: - if not branch.uploaded: - if len(str(branch.error)) <= 30: - fmt = ' (%s)' - else: - fmt = '\n (%s)' - print(('[FAILED] %-15s %-15s' + fmt) % ( - branch.project.RelPath(local=opt.this_manifest_only) + '/', - branch.name, - str(branch.error)), - file=sys.stderr) - print() - - for branch in todo: - if branch.uploaded: - print('[OK ] %-15s %s' % ( - branch.project.RelPath(local=opt.this_manifest_only) + '/', - branch.name), - file=sys.stderr) - - if have_errors: - sys.exit(1) - - def _GetMergeBranch(self, project, local_branch=None): - if local_branch is None: - p = GitCommand(project, - ['rev-parse', '--abbrev-ref', 'HEAD'], - capture_stdout=True, - capture_stderr=True) - p.Wait() - local_branch = p.stdout.strip() - p = GitCommand(project, - ['config', '--get', 'branch.%s.merge' % local_branch], - capture_stdout=True, - capture_stderr=True) - p.Wait() - merge_branch = p.stdout.strip() - return merge_branch - - @staticmethod - def _GatherOne(opt, project): - """Figure out the upload status for |project|.""" - if opt.current_branch: - cbr = project.CurrentBranch - up_branch = project.GetUploadableBranch(cbr) - avail = [up_branch] if up_branch else None - else: - avail = project.GetUploadableBranches(opt.branch) - return (project, avail) - - def Execute(self, opt, args): - projects = self.GetProjects(args, all_manifests=not opt.this_manifest_only) - - def _ProcessResults(_pool, _out, results): - pending = [] - for result in results: - project, avail = result - if avail is None: - print('repo: error: %s: Unable to upload branch "%s". ' - 'You might be able to fix the branch by running:\n' - ' git branch --set-upstream-to m/%s' % - (project.RelPath(local=opt.this_manifest_only), project.CurrentBranch, - project.manifest.branch), - file=sys.stderr) - elif avail: - pending.append(result) - return pending - - pending = self.ExecuteInParallel( - opt.jobs, - functools.partial(self._GatherOne, opt), - projects, - callback=_ProcessResults) - - if not pending: - if opt.branch is None: - print('repo: error: no branches ready for upload', file=sys.stderr) - else: - print('repo: error: no branches named "%s" ready for upload' % - (opt.branch,), file=sys.stderr) - return 1 - - manifests = {project.manifest.topdir: project.manifest - for (project, available) in pending} - ret = 0 - for manifest in manifests.values(): - pending_proj_names = [project.name for (project, available) in pending - if project.manifest.topdir == manifest.topdir] - pending_worktrees = [project.worktree for (project, available) in pending - if project.manifest.topdir == manifest.topdir] - hook = RepoHook.FromSubcmd( - hook_type='pre-upload', manifest=manifest, - opt=opt, abort_if_user_denies=True) - if not hook.Run(project_list=pending_proj_names, - worktree_list=pending_worktrees): - ret = 1 - if ret: - return ret - - reviewers = _SplitEmails(opt.reviewers) if opt.reviewers else [] - cc = _SplitEmails(opt.cc) if opt.cc else [] - people = (reviewers, cc) - - if len(pending) == 1 and len(pending[0][1]) == 1: - self._SingleBranch(opt, pending[0][1][0], people) - else: - self._MultipleBranches(opt, pending, people) + if len(pending) == 1 and len(pending[0][1]) == 1: + self._SingleBranch(opt, pending[0][1][0], people) + else: + self._MultipleBranches(opt, pending, people) diff --git a/subcmds/version.py b/subcmds/version.py index c68cb0af..c539db63 100644 --- a/subcmds/version.py +++ b/subcmds/version.py @@ -22,45 +22,52 @@ from wrapper import Wrapper class Version(Command, MirrorSafeCommand): - wrapper_version = None - wrapper_path = None + wrapper_version = None + wrapper_path = None - COMMON = False - helpSummary = "Display the version of repo" - helpUsage = """ + COMMON = False + helpSummary = "Display the version of repo" + helpUsage = """ %prog """ - def Execute(self, opt, args): - rp = self.manifest.repoProject - rem = rp.GetRemote() - branch = rp.GetBranch('default') + def Execute(self, opt, args): + rp = self.manifest.repoProject + rem = rp.GetRemote() + branch = rp.GetBranch("default") - # These might not be the same. Report them both. - src_ver = RepoSourceVersion() - rp_ver = rp.bare_git.describe(HEAD) - print('repo version %s' % rp_ver) - print(' (from %s)' % rem.url) - print(' (tracking %s)' % branch.merge) - print(' (%s)' % rp.bare_git.log('-1', '--format=%cD', HEAD)) + # These might not be the same. Report them both. + src_ver = RepoSourceVersion() + rp_ver = rp.bare_git.describe(HEAD) + print("repo version %s" % rp_ver) + print(" (from %s)" % rem.url) + print(" (tracking %s)" % branch.merge) + print(" (%s)" % rp.bare_git.log("-1", "--format=%cD", HEAD)) - if self.wrapper_path is not None: - print('repo launcher version %s' % self.wrapper_version) - print(' (from %s)' % self.wrapper_path) + if self.wrapper_path is not None: + print("repo launcher version %s" % self.wrapper_version) + print(" (from %s)" % self.wrapper_path) - if src_ver != rp_ver: - print(' (currently at %s)' % src_ver) + if src_ver != rp_ver: + print(" (currently at %s)" % src_ver) - print('repo User-Agent %s' % user_agent.repo) - print('git %s' % git.version_tuple().full) - print('git User-Agent %s' % user_agent.git) - print('Python %s' % sys.version) - uname = platform.uname() - if sys.version_info.major < 3: - # Python 3 returns a named tuple, but Python 2 is simpler. - print(uname) - else: - print('OS %s %s (%s)' % (uname.system, uname.release, uname.version)) - print('CPU %s (%s)' % - (uname.machine, uname.processor if uname.processor else 'unknown')) - print('Bug reports:', Wrapper().BUG_URL) + print("repo User-Agent %s" % user_agent.repo) + print("git %s" % git.version_tuple().full) + print("git User-Agent %s" % user_agent.git) + print("Python %s" % sys.version) + uname = platform.uname() + if sys.version_info.major < 3: + # Python 3 returns a named tuple, but Python 2 is simpler. + print(uname) + else: + print( + "OS %s %s (%s)" % (uname.system, uname.release, uname.version) + ) + print( + "CPU %s (%s)" + % ( + uname.machine, + uname.processor if uname.processor else "unknown", + ) + ) + print("Bug reports:", Wrapper().BUG_URL) diff --git a/tests/conftest.py b/tests/conftest.py index 3e43f6d3..e1a2292a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -21,5 +21,5 @@ import repo_trace @pytest.fixture(autouse=True) def disable_repo_trace(tmp_path): - """Set an environment marker to relax certain strict checks for test code.""" - repo_trace._TRACE_FILE = str(tmp_path / 'TRACE_FILE_from_test') + """Set an environment marker to relax certain strict checks for test code.""" # noqa: E501 + repo_trace._TRACE_FILE = str(tmp_path / "TRACE_FILE_from_test") diff --git a/tests/test_editor.py b/tests/test_editor.py index cfd4f5ed..8f5d160e 100644 --- a/tests/test_editor.py +++ b/tests/test_editor.py @@ -20,37 +20,37 @@ from editor import Editor class EditorTestCase(unittest.TestCase): - """Take care of resetting Editor state across tests.""" + """Take care of resetting Editor state across tests.""" - def setUp(self): - self.setEditor(None) + def setUp(self): + self.setEditor(None) - def tearDown(self): - self.setEditor(None) + def tearDown(self): + self.setEditor(None) - @staticmethod - def setEditor(editor): - Editor._editor = editor + @staticmethod + def setEditor(editor): + Editor._editor = editor class GetEditor(EditorTestCase): - """Check GetEditor behavior.""" + """Check GetEditor behavior.""" - def test_basic(self): - """Basic checking of _GetEditor.""" - self.setEditor(':') - self.assertEqual(':', Editor._GetEditor()) + def test_basic(self): + """Basic checking of _GetEditor.""" + self.setEditor(":") + self.assertEqual(":", Editor._GetEditor()) class EditString(EditorTestCase): - """Check EditString behavior.""" + """Check EditString behavior.""" - def test_no_editor(self): - """Check behavior when no editor is available.""" - self.setEditor(':') - self.assertEqual('foo', Editor.EditString('foo')) + def test_no_editor(self): + """Check behavior when no editor is available.""" + self.setEditor(":") + self.assertEqual("foo", Editor.EditString("foo")) - def test_cat_editor(self): - """Check behavior when editor is `cat`.""" - self.setEditor('cat') - self.assertEqual('foo', Editor.EditString('foo')) + def test_cat_editor(self): + """Check behavior when editor is `cat`.""" + self.setEditor("cat") + self.assertEqual("foo", Editor.EditString("foo")) diff --git a/tests/test_error.py b/tests/test_error.py index 82b00c24..784e2d57 100644 --- a/tests/test_error.py +++ b/tests/test_error.py @@ -22,32 +22,34 @@ import error class PickleTests(unittest.TestCase): - """Make sure all our custom exceptions can be pickled.""" + """Make sure all our custom exceptions can be pickled.""" - def getExceptions(self): - """Return all our custom exceptions.""" - for name in dir(error): - cls = getattr(error, name) - if isinstance(cls, type) and issubclass(cls, Exception): - yield cls + def getExceptions(self): + """Return all our custom exceptions.""" + for name in dir(error): + cls = getattr(error, name) + if isinstance(cls, type) and issubclass(cls, Exception): + yield cls - def testExceptionLookup(self): - """Make sure our introspection logic works.""" - classes = list(self.getExceptions()) - self.assertIn(error.HookError, classes) - # Don't assert the exact number to avoid being a change-detector test. - self.assertGreater(len(classes), 10) + def testExceptionLookup(self): + """Make sure our introspection logic works.""" + classes = list(self.getExceptions()) + self.assertIn(error.HookError, classes) + # Don't assert the exact number to avoid being a change-detector test. + self.assertGreater(len(classes), 10) - def testPickle(self): - """Try to pickle all the exceptions.""" - for cls in self.getExceptions(): - args = inspect.getfullargspec(cls.__init__).args[1:] - obj = cls(*args) - p = pickle.dumps(obj) - try: - newobj = pickle.loads(p) - except Exception as e: # pylint: disable=broad-except - self.fail('Class %s is unable to be pickled: %s\n' - 'Incomplete super().__init__(...) call?' % (cls, e)) - self.assertIsInstance(newobj, cls) - self.assertEqual(str(obj), str(newobj)) + def testPickle(self): + """Try to pickle all the exceptions.""" + for cls in self.getExceptions(): + args = inspect.getfullargspec(cls.__init__).args[1:] + obj = cls(*args) + p = pickle.dumps(obj) + try: + newobj = pickle.loads(p) + except Exception as e: # pylint: disable=broad-except + self.fail( + "Class %s is unable to be pickled: %s\n" + "Incomplete super().__init__(...) call?" % (cls, e) + ) + self.assertIsInstance(newobj, cls) + self.assertEqual(str(obj), str(newobj)) diff --git a/tests/test_git_command.py b/tests/test_git_command.py index 96408a23..c4c3a4c5 100644 --- a/tests/test_git_command.py +++ b/tests/test_git_command.py @@ -19,138 +19,146 @@ import os import unittest try: - from unittest import mock + from unittest import mock except ImportError: - import mock + import mock import git_command import wrapper class GitCommandTest(unittest.TestCase): - """Tests the GitCommand class (via git_command.git).""" + """Tests the GitCommand class (via git_command.git).""" - def setUp(self): + def setUp(self): + def realpath_mock(val): + return val - def realpath_mock(val): - return val + mock.patch.object( + os.path, "realpath", side_effect=realpath_mock + ).start() - mock.patch.object(os.path, 'realpath', side_effect=realpath_mock).start() + def tearDown(self): + mock.patch.stopall() - def tearDown(self): - mock.patch.stopall() + def test_alternative_setting_when_matching(self): + r = git_command._build_env( + objdir=os.path.join("zap", "objects"), gitdir="zap" + ) - def test_alternative_setting_when_matching(self): - r = git_command._build_env( - objdir = os.path.join('zap', 'objects'), - gitdir = 'zap' - ) + self.assertIsNone(r.get("GIT_ALTERNATE_OBJECT_DIRECTORIES")) + self.assertEqual( + r.get("GIT_OBJECT_DIRECTORY"), os.path.join("zap", "objects") + ) - self.assertIsNone(r.get('GIT_ALTERNATE_OBJECT_DIRECTORIES')) - self.assertEqual(r.get('GIT_OBJECT_DIRECTORY'), os.path.join('zap', 'objects')) + def test_alternative_setting_when_different(self): + r = git_command._build_env( + objdir=os.path.join("wow", "objects"), gitdir="zap" + ) - def test_alternative_setting_when_different(self): - r = git_command._build_env( - objdir = os.path.join('wow', 'objects'), - gitdir = 'zap' - ) - - self.assertEqual(r.get('GIT_ALTERNATE_OBJECT_DIRECTORIES'), os.path.join('zap', 'objects')) - self.assertEqual(r.get('GIT_OBJECT_DIRECTORY'), os.path.join('wow', 'objects')) + self.assertEqual( + r.get("GIT_ALTERNATE_OBJECT_DIRECTORIES"), + os.path.join("zap", "objects"), + ) + self.assertEqual( + r.get("GIT_OBJECT_DIRECTORY"), os.path.join("wow", "objects") + ) class GitCallUnitTest(unittest.TestCase): - """Tests the _GitCall class (via git_command.git).""" + """Tests the _GitCall class (via git_command.git).""" - def test_version_tuple(self): - """Check git.version_tuple() handling.""" - ver = git_command.git.version_tuple() - self.assertIsNotNone(ver) + def test_version_tuple(self): + """Check git.version_tuple() handling.""" + ver = git_command.git.version_tuple() + self.assertIsNotNone(ver) - # We don't dive too deep into the values here to avoid having to update - # whenever git versions change. We do check relative to this min version - # as this is what `repo` itself requires via MIN_GIT_VERSION. - MIN_GIT_VERSION = (2, 10, 2) - self.assertTrue(isinstance(ver.major, int)) - self.assertTrue(isinstance(ver.minor, int)) - self.assertTrue(isinstance(ver.micro, int)) + # We don't dive too deep into the values here to avoid having to update + # whenever git versions change. We do check relative to this min + # version as this is what `repo` itself requires via MIN_GIT_VERSION. + MIN_GIT_VERSION = (2, 10, 2) + self.assertTrue(isinstance(ver.major, int)) + self.assertTrue(isinstance(ver.minor, int)) + self.assertTrue(isinstance(ver.micro, int)) - self.assertGreater(ver.major, MIN_GIT_VERSION[0] - 1) - self.assertGreaterEqual(ver.micro, 0) - self.assertGreaterEqual(ver.major, 0) + self.assertGreater(ver.major, MIN_GIT_VERSION[0] - 1) + self.assertGreaterEqual(ver.micro, 0) + self.assertGreaterEqual(ver.major, 0) - self.assertGreaterEqual(ver, MIN_GIT_VERSION) - self.assertLess(ver, (9999, 9999, 9999)) + self.assertGreaterEqual(ver, MIN_GIT_VERSION) + self.assertLess(ver, (9999, 9999, 9999)) - self.assertNotEqual('', ver.full) + self.assertNotEqual("", ver.full) class UserAgentUnitTest(unittest.TestCase): - """Tests the UserAgent function.""" + """Tests the UserAgent function.""" - def test_smoke_os(self): - """Make sure UA OS setting returns something useful.""" - os_name = git_command.user_agent.os - # We can't dive too deep because of OS/tool differences, but we can check - # the general form. - m = re.match(r'^[^ ]+$', os_name) - self.assertIsNotNone(m) + def test_smoke_os(self): + """Make sure UA OS setting returns something useful.""" + os_name = git_command.user_agent.os + # We can't dive too deep because of OS/tool differences, but we can + # check the general form. + m = re.match(r"^[^ ]+$", os_name) + self.assertIsNotNone(m) - def test_smoke_repo(self): - """Make sure repo UA returns something useful.""" - ua = git_command.user_agent.repo - # We can't dive too deep because of OS/tool differences, but we can check - # the general form. - m = re.match(r'^git-repo/[^ ]+ ([^ ]+) git/[^ ]+ Python/[0-9.]+', ua) - self.assertIsNotNone(m) + def test_smoke_repo(self): + """Make sure repo UA returns something useful.""" + ua = git_command.user_agent.repo + # We can't dive too deep because of OS/tool differences, but we can + # check the general form. + m = re.match(r"^git-repo/[^ ]+ ([^ ]+) git/[^ ]+ Python/[0-9.]+", ua) + self.assertIsNotNone(m) - def test_smoke_git(self): - """Make sure git UA returns something useful.""" - ua = git_command.user_agent.git - # We can't dive too deep because of OS/tool differences, but we can check - # the general form. - m = re.match(r'^git/[^ ]+ ([^ ]+) git-repo/[^ ]+', ua) - self.assertIsNotNone(m) + def test_smoke_git(self): + """Make sure git UA returns something useful.""" + ua = git_command.user_agent.git + # We can't dive too deep because of OS/tool differences, but we can + # check the general form. + m = re.match(r"^git/[^ ]+ ([^ ]+) git-repo/[^ ]+", ua) + self.assertIsNotNone(m) class GitRequireTests(unittest.TestCase): - """Test the git_require helper.""" + """Test the git_require helper.""" - def setUp(self): - self.wrapper = wrapper.Wrapper() - ver = self.wrapper.GitVersion(1, 2, 3, 4) - mock.patch.object(git_command.git, 'version_tuple', return_value=ver).start() + def setUp(self): + self.wrapper = wrapper.Wrapper() + ver = self.wrapper.GitVersion(1, 2, 3, 4) + mock.patch.object( + git_command.git, "version_tuple", return_value=ver + ).start() - def tearDown(self): - mock.patch.stopall() + def tearDown(self): + mock.patch.stopall() - def test_older_nonfatal(self): - """Test non-fatal require calls with old versions.""" - self.assertFalse(git_command.git_require((2,))) - self.assertFalse(git_command.git_require((1, 3))) - self.assertFalse(git_command.git_require((1, 2, 4))) - self.assertFalse(git_command.git_require((1, 2, 3, 5))) + def test_older_nonfatal(self): + """Test non-fatal require calls with old versions.""" + self.assertFalse(git_command.git_require((2,))) + self.assertFalse(git_command.git_require((1, 3))) + self.assertFalse(git_command.git_require((1, 2, 4))) + self.assertFalse(git_command.git_require((1, 2, 3, 5))) - def test_newer_nonfatal(self): - """Test non-fatal require calls with newer versions.""" - self.assertTrue(git_command.git_require((0,))) - self.assertTrue(git_command.git_require((1, 0))) - self.assertTrue(git_command.git_require((1, 2, 0))) - self.assertTrue(git_command.git_require((1, 2, 3, 0))) + def test_newer_nonfatal(self): + """Test non-fatal require calls with newer versions.""" + self.assertTrue(git_command.git_require((0,))) + self.assertTrue(git_command.git_require((1, 0))) + self.assertTrue(git_command.git_require((1, 2, 0))) + self.assertTrue(git_command.git_require((1, 2, 3, 0))) - def test_equal_nonfatal(self): - """Test require calls with equal values.""" - self.assertTrue(git_command.git_require((1, 2, 3, 4), fail=False)) - self.assertTrue(git_command.git_require((1, 2, 3, 4), fail=True)) + def test_equal_nonfatal(self): + """Test require calls with equal values.""" + self.assertTrue(git_command.git_require((1, 2, 3, 4), fail=False)) + self.assertTrue(git_command.git_require((1, 2, 3, 4), fail=True)) - def test_older_fatal(self): - """Test fatal require calls with old versions.""" - with self.assertRaises(SystemExit) as e: - git_command.git_require((2,), fail=True) - self.assertNotEqual(0, e.code) + def test_older_fatal(self): + """Test fatal require calls with old versions.""" + with self.assertRaises(SystemExit) as e: + git_command.git_require((2,), fail=True) + self.assertNotEqual(0, e.code) - def test_older_fatal_msg(self): - """Test fatal require calls with old versions and message.""" - with self.assertRaises(SystemExit) as e: - git_command.git_require((2,), fail=True, msg='so sad') - self.assertNotEqual(0, e.code) + def test_older_fatal_msg(self): + """Test fatal require calls with old versions and message.""" + with self.assertRaises(SystemExit) as e: + git_command.git_require((2,), fail=True, msg="so sad") + self.assertNotEqual(0, e.code) diff --git a/tests/test_git_config.py b/tests/test_git_config.py index 3b0aa8b4..a44dca0f 100644 --- a/tests/test_git_config.py +++ b/tests/test_git_config.py @@ -22,167 +22,169 @@ import git_config def fixture(*paths): - """Return a path relative to test/fixtures. - """ - return os.path.join(os.path.dirname(__file__), 'fixtures', *paths) + """Return a path relative to test/fixtures.""" + return os.path.join(os.path.dirname(__file__), "fixtures", *paths) class GitConfigReadOnlyTests(unittest.TestCase): - """Read-only tests of the GitConfig class.""" + """Read-only tests of the GitConfig class.""" - def setUp(self): - """Create a GitConfig object using the test.gitconfig fixture. - """ - config_fixture = fixture('test.gitconfig') - self.config = git_config.GitConfig(config_fixture) + def setUp(self): + """Create a GitConfig object using the test.gitconfig fixture.""" + config_fixture = fixture("test.gitconfig") + self.config = git_config.GitConfig(config_fixture) - def test_GetString_with_empty_config_values(self): - """ - Test config entries with no value. + def test_GetString_with_empty_config_values(self): + """ + Test config entries with no value. - [section] - empty + [section] + empty - """ - val = self.config.GetString('section.empty') - self.assertEqual(val, None) + """ + val = self.config.GetString("section.empty") + self.assertEqual(val, None) - def test_GetString_with_true_value(self): - """ - Test config entries with a string value. + def test_GetString_with_true_value(self): + """ + Test config entries with a string value. - [section] - nonempty = true + [section] + nonempty = true - """ - val = self.config.GetString('section.nonempty') - self.assertEqual(val, 'true') + """ + val = self.config.GetString("section.nonempty") + self.assertEqual(val, "true") - def test_GetString_from_missing_file(self): - """ - Test missing config file - """ - config_fixture = fixture('not.present.gitconfig') - config = git_config.GitConfig(config_fixture) - val = config.GetString('empty') - self.assertEqual(val, None) + def test_GetString_from_missing_file(self): + """ + Test missing config file + """ + config_fixture = fixture("not.present.gitconfig") + config = git_config.GitConfig(config_fixture) + val = config.GetString("empty") + self.assertEqual(val, None) - def test_GetBoolean_undefined(self): - """Test GetBoolean on key that doesn't exist.""" - self.assertIsNone(self.config.GetBoolean('section.missing')) + def test_GetBoolean_undefined(self): + """Test GetBoolean on key that doesn't exist.""" + self.assertIsNone(self.config.GetBoolean("section.missing")) - def test_GetBoolean_invalid(self): - """Test GetBoolean on invalid boolean value.""" - self.assertIsNone(self.config.GetBoolean('section.boolinvalid')) + def test_GetBoolean_invalid(self): + """Test GetBoolean on invalid boolean value.""" + self.assertIsNone(self.config.GetBoolean("section.boolinvalid")) - def test_GetBoolean_true(self): - """Test GetBoolean on valid true boolean.""" - self.assertTrue(self.config.GetBoolean('section.booltrue')) + def test_GetBoolean_true(self): + """Test GetBoolean on valid true boolean.""" + self.assertTrue(self.config.GetBoolean("section.booltrue")) - def test_GetBoolean_false(self): - """Test GetBoolean on valid false boolean.""" - self.assertFalse(self.config.GetBoolean('section.boolfalse')) + def test_GetBoolean_false(self): + """Test GetBoolean on valid false boolean.""" + self.assertFalse(self.config.GetBoolean("section.boolfalse")) - def test_GetInt_undefined(self): - """Test GetInt on key that doesn't exist.""" - self.assertIsNone(self.config.GetInt('section.missing')) + def test_GetInt_undefined(self): + """Test GetInt on key that doesn't exist.""" + self.assertIsNone(self.config.GetInt("section.missing")) - def test_GetInt_invalid(self): - """Test GetInt on invalid integer value.""" - self.assertIsNone(self.config.GetBoolean('section.intinvalid')) + def test_GetInt_invalid(self): + """Test GetInt on invalid integer value.""" + self.assertIsNone(self.config.GetBoolean("section.intinvalid")) - def test_GetInt_valid(self): - """Test GetInt on valid integers.""" - TESTS = ( - ('inthex', 16), - ('inthexk', 16384), - ('int', 10), - ('intk', 10240), - ('intm', 10485760), - ('intg', 10737418240), - ) - for key, value in TESTS: - self.assertEqual(value, self.config.GetInt('section.%s' % (key,))) + def test_GetInt_valid(self): + """Test GetInt on valid integers.""" + TESTS = ( + ("inthex", 16), + ("inthexk", 16384), + ("int", 10), + ("intk", 10240), + ("intm", 10485760), + ("intg", 10737418240), + ) + for key, value in TESTS: + self.assertEqual(value, self.config.GetInt("section.%s" % (key,))) class GitConfigReadWriteTests(unittest.TestCase): - """Read/write tests of the GitConfig class.""" + """Read/write tests of the GitConfig class.""" - def setUp(self): - self.tmpfile = tempfile.NamedTemporaryFile() - self.config = self.get_config() + def setUp(self): + self.tmpfile = tempfile.NamedTemporaryFile() + self.config = self.get_config() - def get_config(self): - """Get a new GitConfig instance.""" - return git_config.GitConfig(self.tmpfile.name) + def get_config(self): + """Get a new GitConfig instance.""" + return git_config.GitConfig(self.tmpfile.name) - def test_SetString(self): - """Test SetString behavior.""" - # Set a value. - self.assertIsNone(self.config.GetString('foo.bar')) - self.config.SetString('foo.bar', 'val') - self.assertEqual('val', self.config.GetString('foo.bar')) + def test_SetString(self): + """Test SetString behavior.""" + # Set a value. + self.assertIsNone(self.config.GetString("foo.bar")) + self.config.SetString("foo.bar", "val") + self.assertEqual("val", self.config.GetString("foo.bar")) - # Make sure the value was actually written out. - config = self.get_config() - self.assertEqual('val', config.GetString('foo.bar')) + # Make sure the value was actually written out. + config = self.get_config() + self.assertEqual("val", config.GetString("foo.bar")) - # Update the value. - self.config.SetString('foo.bar', 'valll') - self.assertEqual('valll', self.config.GetString('foo.bar')) - config = self.get_config() - self.assertEqual('valll', config.GetString('foo.bar')) + # Update the value. + self.config.SetString("foo.bar", "valll") + self.assertEqual("valll", self.config.GetString("foo.bar")) + config = self.get_config() + self.assertEqual("valll", config.GetString("foo.bar")) - # Delete the value. - self.config.SetString('foo.bar', None) - self.assertIsNone(self.config.GetString('foo.bar')) - config = self.get_config() - self.assertIsNone(config.GetString('foo.bar')) + # Delete the value. + self.config.SetString("foo.bar", None) + self.assertIsNone(self.config.GetString("foo.bar")) + config = self.get_config() + self.assertIsNone(config.GetString("foo.bar")) - def test_SetBoolean(self): - """Test SetBoolean behavior.""" - # Set a true value. - self.assertIsNone(self.config.GetBoolean('foo.bar')) - for val in (True, 1): - self.config.SetBoolean('foo.bar', val) - self.assertTrue(self.config.GetBoolean('foo.bar')) + def test_SetBoolean(self): + """Test SetBoolean behavior.""" + # Set a true value. + self.assertIsNone(self.config.GetBoolean("foo.bar")) + for val in (True, 1): + self.config.SetBoolean("foo.bar", val) + self.assertTrue(self.config.GetBoolean("foo.bar")) - # Make sure the value was actually written out. - config = self.get_config() - self.assertTrue(config.GetBoolean('foo.bar')) - self.assertEqual('true', config.GetString('foo.bar')) + # Make sure the value was actually written out. + config = self.get_config() + self.assertTrue(config.GetBoolean("foo.bar")) + self.assertEqual("true", config.GetString("foo.bar")) - # Set a false value. - for val in (False, 0): - self.config.SetBoolean('foo.bar', val) - self.assertFalse(self.config.GetBoolean('foo.bar')) + # Set a false value. + for val in (False, 0): + self.config.SetBoolean("foo.bar", val) + self.assertFalse(self.config.GetBoolean("foo.bar")) - # Make sure the value was actually written out. - config = self.get_config() - self.assertFalse(config.GetBoolean('foo.bar')) - self.assertEqual('false', config.GetString('foo.bar')) + # Make sure the value was actually written out. + config = self.get_config() + self.assertFalse(config.GetBoolean("foo.bar")) + self.assertEqual("false", config.GetString("foo.bar")) - # Delete the value. - self.config.SetBoolean('foo.bar', None) - self.assertIsNone(self.config.GetBoolean('foo.bar')) - config = self.get_config() - self.assertIsNone(config.GetBoolean('foo.bar')) + # Delete the value. + self.config.SetBoolean("foo.bar", None) + self.assertIsNone(self.config.GetBoolean("foo.bar")) + config = self.get_config() + self.assertIsNone(config.GetBoolean("foo.bar")) - def test_GetSyncAnalysisStateData(self): - """Test config entries with a sync state analysis data.""" - superproject_logging_data = {} - superproject_logging_data['test'] = False - options = type('options', (object,), {})() - options.verbose = 'true' - options.mp_update = 'false' - TESTS = ( - ('superproject.test', 'false'), - ('options.verbose', 'true'), - ('options.mpupdate', 'false'), - ('main.version', '1'), - ) - self.config.UpdateSyncAnalysisState(options, superproject_logging_data) - sync_data = self.config.GetSyncAnalysisStateData() - for key, value in TESTS: - self.assertEqual(sync_data[f'{git_config.SYNC_STATE_PREFIX}{key}'], value) - self.assertTrue(sync_data[f'{git_config.SYNC_STATE_PREFIX}main.synctime']) + def test_GetSyncAnalysisStateData(self): + """Test config entries with a sync state analysis data.""" + superproject_logging_data = {} + superproject_logging_data["test"] = False + options = type("options", (object,), {})() + options.verbose = "true" + options.mp_update = "false" + TESTS = ( + ("superproject.test", "false"), + ("options.verbose", "true"), + ("options.mpupdate", "false"), + ("main.version", "1"), + ) + self.config.UpdateSyncAnalysisState(options, superproject_logging_data) + sync_data = self.config.GetSyncAnalysisStateData() + for key, value in TESTS: + self.assertEqual( + sync_data[f"{git_config.SYNC_STATE_PREFIX}{key}"], value + ) + self.assertTrue( + sync_data[f"{git_config.SYNC_STATE_PREFIX}main.synctime"] + ) diff --git a/tests/test_git_superproject.py b/tests/test_git_superproject.py index b9b597a6..eb542c60 100644 --- a/tests/test_git_superproject.py +++ b/tests/test_git_superproject.py @@ -28,297 +28,369 @@ from test_manifest_xml import sort_attributes class SuperprojectTestCase(unittest.TestCase): - """TestCase for the Superproject module.""" + """TestCase for the Superproject module.""" - PARENT_SID_KEY = 'GIT_TRACE2_PARENT_SID' - PARENT_SID_VALUE = 'parent_sid' - SELF_SID_REGEX = r'repo-\d+T\d+Z-.*' - FULL_SID_REGEX = r'^%s/%s' % (PARENT_SID_VALUE, SELF_SID_REGEX) + PARENT_SID_KEY = "GIT_TRACE2_PARENT_SID" + PARENT_SID_VALUE = "parent_sid" + SELF_SID_REGEX = r"repo-\d+T\d+Z-.*" + FULL_SID_REGEX = r"^%s/%s" % (PARENT_SID_VALUE, SELF_SID_REGEX) - def setUp(self): - """Set up superproject every time.""" - self.tempdirobj = tempfile.TemporaryDirectory(prefix='repo_tests') - self.tempdir = self.tempdirobj.name - self.repodir = os.path.join(self.tempdir, '.repo') - self.manifest_file = os.path.join( - self.repodir, manifest_xml.MANIFEST_FILE_NAME) - os.mkdir(self.repodir) - self.platform = platform.system().lower() + def setUp(self): + """Set up superproject every time.""" + self.tempdirobj = tempfile.TemporaryDirectory(prefix="repo_tests") + self.tempdir = self.tempdirobj.name + self.repodir = os.path.join(self.tempdir, ".repo") + self.manifest_file = os.path.join( + self.repodir, manifest_xml.MANIFEST_FILE_NAME + ) + os.mkdir(self.repodir) + self.platform = platform.system().lower() - # By default we initialize with the expected case where - # repo launches us (so GIT_TRACE2_PARENT_SID is set). - env = { - self.PARENT_SID_KEY: self.PARENT_SID_VALUE, - } - self.git_event_log = git_trace2_event_log.EventLog(env=env) + # By default we initialize with the expected case where + # repo launches us (so GIT_TRACE2_PARENT_SID is set). + env = { + self.PARENT_SID_KEY: self.PARENT_SID_VALUE, + } + self.git_event_log = git_trace2_event_log.EventLog(env=env) - # The manifest parsing really wants a git repo currently. - gitdir = os.path.join(self.repodir, 'manifests.git') - os.mkdir(gitdir) - with open(os.path.join(gitdir, 'config'), 'w') as fp: - fp.write("""[remote "origin"] + # The manifest parsing really wants a git repo currently. + gitdir = os.path.join(self.repodir, "manifests.git") + os.mkdir(gitdir) + with open(os.path.join(gitdir, "config"), "w") as fp: + fp.write( + """[remote "origin"] url = https://localhost:0/manifest -""") +""" + ) - manifest = self.getXmlManifest(""" + manifest = self.getXmlManifest( + """ - -""") - self._superproject = git_superproject.Superproject( - manifest, name='superproject', - remote=manifest.remotes.get('default-remote').ToRemoteSpec('superproject'), - revision='refs/heads/main') +""" + ) + self._superproject = git_superproject.Superproject( + manifest, + name="superproject", + remote=manifest.remotes.get("default-remote").ToRemoteSpec( + "superproject" + ), + revision="refs/heads/main", + ) - def tearDown(self): - """Tear down superproject every time.""" - self.tempdirobj.cleanup() + def tearDown(self): + """Tear down superproject every time.""" + self.tempdirobj.cleanup() - def getXmlManifest(self, data): - """Helper to initialize a manifest for testing.""" - with open(self.manifest_file, 'w') as fp: - fp.write(data) - return manifest_xml.XmlManifest(self.repodir, self.manifest_file) + def getXmlManifest(self, data): + """Helper to initialize a manifest for testing.""" + with open(self.manifest_file, "w") as fp: + fp.write(data) + return manifest_xml.XmlManifest(self.repodir, self.manifest_file) - def verifyCommonKeys(self, log_entry, expected_event_name, full_sid=True): - """Helper function to verify common event log keys.""" - self.assertIn('event', log_entry) - self.assertIn('sid', log_entry) - self.assertIn('thread', log_entry) - self.assertIn('time', log_entry) + def verifyCommonKeys(self, log_entry, expected_event_name, full_sid=True): + """Helper function to verify common event log keys.""" + self.assertIn("event", log_entry) + self.assertIn("sid", log_entry) + self.assertIn("thread", log_entry) + self.assertIn("time", log_entry) - # Do basic data format validation. - self.assertEqual(expected_event_name, log_entry['event']) - if full_sid: - self.assertRegex(log_entry['sid'], self.FULL_SID_REGEX) - else: - self.assertRegex(log_entry['sid'], self.SELF_SID_REGEX) - self.assertRegex(log_entry['time'], r'^\d+-\d+-\d+T\d+:\d+:\d+\.\d+Z$') + # Do basic data format validation. + self.assertEqual(expected_event_name, log_entry["event"]) + if full_sid: + self.assertRegex(log_entry["sid"], self.FULL_SID_REGEX) + else: + self.assertRegex(log_entry["sid"], self.SELF_SID_REGEX) + self.assertRegex(log_entry["time"], r"^\d+-\d+-\d+T\d+:\d+:\d+\.\d+Z$") - def readLog(self, log_path): - """Helper function to read log data into a list.""" - log_data = [] - with open(log_path, mode='rb') as f: - for line in f: - log_data.append(json.loads(line)) - return log_data + def readLog(self, log_path): + """Helper function to read log data into a list.""" + log_data = [] + with open(log_path, mode="rb") as f: + for line in f: + log_data.append(json.loads(line)) + return log_data - def verifyErrorEvent(self): - """Helper to verify that error event is written.""" + def verifyErrorEvent(self): + """Helper to verify that error event is written.""" - with tempfile.TemporaryDirectory(prefix='event_log_tests') as tempdir: - log_path = self.git_event_log.Write(path=tempdir) - self.log_data = self.readLog(log_path) + with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir: + log_path = self.git_event_log.Write(path=tempdir) + self.log_data = self.readLog(log_path) - self.assertEqual(len(self.log_data), 2) - error_event = self.log_data[1] - self.verifyCommonKeys(self.log_data[0], expected_event_name='version') - self.verifyCommonKeys(error_event, expected_event_name='error') - # Check for 'error' event specific fields. - self.assertIn('msg', error_event) - self.assertIn('fmt', error_event) + self.assertEqual(len(self.log_data), 2) + error_event = self.log_data[1] + self.verifyCommonKeys(self.log_data[0], expected_event_name="version") + self.verifyCommonKeys(error_event, expected_event_name="error") + # Check for 'error' event specific fields. + self.assertIn("msg", error_event) + self.assertIn("fmt", error_event) - def test_superproject_get_superproject_no_superproject(self): - """Test with no url.""" - manifest = self.getXmlManifest(""" + def test_superproject_get_superproject_no_superproject(self): + """Test with no url.""" + manifest = self.getXmlManifest( + """ -""") - self.assertIsNone(manifest.superproject) +""" + ) + self.assertIsNone(manifest.superproject) - def test_superproject_get_superproject_invalid_url(self): - """Test with an invalid url.""" - manifest = self.getXmlManifest(""" + def test_superproject_get_superproject_invalid_url(self): + """Test with an invalid url.""" + manifest = self.getXmlManifest( + """ -""") - superproject = git_superproject.Superproject( - manifest, name='superproject', - remote=manifest.remotes.get('test-remote').ToRemoteSpec('superproject'), - revision='refs/heads/main') - sync_result = superproject.Sync(self.git_event_log) - self.assertFalse(sync_result.success) - self.assertTrue(sync_result.fatal) - - def test_superproject_get_superproject_invalid_branch(self): - """Test with an invalid branch.""" - manifest = self.getXmlManifest(""" - - - - - -""") - self._superproject = git_superproject.Superproject( - manifest, name='superproject', - remote=manifest.remotes.get('test-remote').ToRemoteSpec('superproject'), - revision='refs/heads/main') - with mock.patch.object(self._superproject, '_branch', 'junk'): - sync_result = self._superproject.Sync(self.git_event_log) - self.assertFalse(sync_result.success) - self.assertTrue(sync_result.fatal) - self.verifyErrorEvent() - - def test_superproject_get_superproject_mock_init(self): - """Test with _Init failing.""" - with mock.patch.object(self._superproject, '_Init', return_value=False): - sync_result = self._superproject.Sync(self.git_event_log) - self.assertFalse(sync_result.success) - self.assertTrue(sync_result.fatal) - - def test_superproject_get_superproject_mock_fetch(self): - """Test with _Fetch failing.""" - with mock.patch.object(self._superproject, '_Init', return_value=True): - os.mkdir(self._superproject._superproject_path) - with mock.patch.object(self._superproject, '_Fetch', return_value=False): - sync_result = self._superproject.Sync(self.git_event_log) +""" + ) + superproject = git_superproject.Superproject( + manifest, + name="superproject", + remote=manifest.remotes.get("test-remote").ToRemoteSpec( + "superproject" + ), + revision="refs/heads/main", + ) + sync_result = superproject.Sync(self.git_event_log) self.assertFalse(sync_result.success) self.assertTrue(sync_result.fatal) - def test_superproject_get_all_project_commit_ids_mock_ls_tree(self): - """Test with LsTree being a mock.""" - data = ('120000 blob 158258bdf146f159218e2b90f8b699c4d85b5804\tAndroid.bp\x00' - '160000 commit 2c2724cb36cd5a9cec6c852c681efc3b7c6b86ea\tart\x00' - '160000 commit e9d25da64d8d365dbba7c8ee00fe8c4473fe9a06\tbootable/recovery\x00' - '120000 blob acc2cbdf438f9d2141f0ae424cec1d8fc4b5d97f\tbootstrap.bash\x00' - '160000 commit ade9b7a0d874e25fff4bf2552488825c6f111928\tbuild/bazel\x00') - with mock.patch.object(self._superproject, '_Init', return_value=True): - with mock.patch.object(self._superproject, '_Fetch', return_value=True): - with mock.patch.object(self._superproject, '_LsTree', return_value=data): - commit_ids_result = self._superproject._GetAllProjectsCommitIds() - self.assertEqual(commit_ids_result.commit_ids, { - 'art': '2c2724cb36cd5a9cec6c852c681efc3b7c6b86ea', - 'bootable/recovery': 'e9d25da64d8d365dbba7c8ee00fe8c4473fe9a06', - 'build/bazel': 'ade9b7a0d874e25fff4bf2552488825c6f111928' - }) - self.assertFalse(commit_ids_result.fatal) + def test_superproject_get_superproject_invalid_branch(self): + """Test with an invalid branch.""" + manifest = self.getXmlManifest( + """ + + + + + +""" + ) + self._superproject = git_superproject.Superproject( + manifest, + name="superproject", + remote=manifest.remotes.get("test-remote").ToRemoteSpec( + "superproject" + ), + revision="refs/heads/main", + ) + with mock.patch.object(self._superproject, "_branch", "junk"): + sync_result = self._superproject.Sync(self.git_event_log) + self.assertFalse(sync_result.success) + self.assertTrue(sync_result.fatal) + self.verifyErrorEvent() - def test_superproject_write_manifest_file(self): - """Test with writing manifest to a file after setting revisionId.""" - self.assertEqual(len(self._superproject._manifest.projects), 1) - project = self._superproject._manifest.projects[0] - project.SetRevisionId('ABCDEF') - # Create temporary directory so that it can write the file. - os.mkdir(self._superproject._superproject_path) - manifest_path = self._superproject._WriteManifestFile() - self.assertIsNotNone(manifest_path) - with open(manifest_path, 'r') as fp: - manifest_xml_data = fp.read() - self.assertEqual( - sort_attributes(manifest_xml_data), - '' - '' - '' - '' - '' - '') + def test_superproject_get_superproject_mock_init(self): + """Test with _Init failing.""" + with mock.patch.object(self._superproject, "_Init", return_value=False): + sync_result = self._superproject.Sync(self.git_event_log) + self.assertFalse(sync_result.success) + self.assertTrue(sync_result.fatal) - def test_superproject_update_project_revision_id(self): - """Test with LsTree being a mock.""" - self.assertEqual(len(self._superproject._manifest.projects), 1) - projects = self._superproject._manifest.projects - data = ('160000 commit 2c2724cb36cd5a9cec6c852c681efc3b7c6b86ea\tart\x00' - '160000 commit e9d25da64d8d365dbba7c8ee00fe8c4473fe9a06\tbootable/recovery\x00') - with mock.patch.object(self._superproject, '_Init', return_value=True): - with mock.patch.object(self._superproject, '_Fetch', return_value=True): - with mock.patch.object(self._superproject, - '_LsTree', - return_value=data): - # Create temporary directory so that it can write the file. - os.mkdir(self._superproject._superproject_path) - update_result = self._superproject.UpdateProjectsRevisionId(projects, self.git_event_log) - self.assertIsNotNone(update_result.manifest_path) - self.assertFalse(update_result.fatal) - with open(update_result.manifest_path, 'r') as fp: + def test_superproject_get_superproject_mock_fetch(self): + """Test with _Fetch failing.""" + with mock.patch.object(self._superproject, "_Init", return_value=True): + os.mkdir(self._superproject._superproject_path) + with mock.patch.object( + self._superproject, "_Fetch", return_value=False + ): + sync_result = self._superproject.Sync(self.git_event_log) + self.assertFalse(sync_result.success) + self.assertTrue(sync_result.fatal) + + def test_superproject_get_all_project_commit_ids_mock_ls_tree(self): + """Test with LsTree being a mock.""" + data = ( + "120000 blob 158258bdf146f159218e2b90f8b699c4d85b5804\tAndroid.bp\x00" + "160000 commit 2c2724cb36cd5a9cec6c852c681efc3b7c6b86ea\tart\x00" + "160000 commit e9d25da64d8d365dbba7c8ee00fe8c4473fe9a06\tbootable/recovery\x00" + "120000 blob acc2cbdf438f9d2141f0ae424cec1d8fc4b5d97f\tbootstrap.bash\x00" + "160000 commit ade9b7a0d874e25fff4bf2552488825c6f111928\tbuild/bazel\x00" + ) + with mock.patch.object(self._superproject, "_Init", return_value=True): + with mock.patch.object( + self._superproject, "_Fetch", return_value=True + ): + with mock.patch.object( + self._superproject, "_LsTree", return_value=data + ): + commit_ids_result = ( + self._superproject._GetAllProjectsCommitIds() + ) + self.assertEqual( + commit_ids_result.commit_ids, + { + "art": "2c2724cb36cd5a9cec6c852c681efc3b7c6b86ea", + "bootable/recovery": "e9d25da64d8d365dbba7c8ee00fe8c4473fe9a06", + "build/bazel": "ade9b7a0d874e25fff4bf2552488825c6f111928", + }, + ) + self.assertFalse(commit_ids_result.fatal) + + def test_superproject_write_manifest_file(self): + """Test with writing manifest to a file after setting revisionId.""" + self.assertEqual(len(self._superproject._manifest.projects), 1) + project = self._superproject._manifest.projects[0] + project.SetRevisionId("ABCDEF") + # Create temporary directory so that it can write the file. + os.mkdir(self._superproject._superproject_path) + manifest_path = self._superproject._WriteManifestFile() + self.assertIsNotNone(manifest_path) + with open(manifest_path, "r") as fp: manifest_xml_data = fp.read() - self.assertEqual( - sort_attributes(manifest_xml_data), - '' - '' - '' - '' - '' - '') + self.assertEqual( + sort_attributes(manifest_xml_data), + '' + '' + '' + '' + '' + "", + ) - def test_superproject_update_project_revision_id_no_superproject_tag(self): - """Test update of commit ids of a manifest without superproject tag.""" - manifest = self.getXmlManifest(""" + def test_superproject_update_project_revision_id(self): + """Test with LsTree being a mock.""" + self.assertEqual(len(self._superproject._manifest.projects), 1) + projects = self._superproject._manifest.projects + data = ( + "160000 commit 2c2724cb36cd5a9cec6c852c681efc3b7c6b86ea\tart\x00" + "160000 commit e9d25da64d8d365dbba7c8ee00fe8c4473fe9a06\tbootable/recovery\x00" + ) + with mock.patch.object(self._superproject, "_Init", return_value=True): + with mock.patch.object( + self._superproject, "_Fetch", return_value=True + ): + with mock.patch.object( + self._superproject, "_LsTree", return_value=data + ): + # Create temporary directory so that it can write the file. + os.mkdir(self._superproject._superproject_path) + update_result = self._superproject.UpdateProjectsRevisionId( + projects, self.git_event_log + ) + self.assertIsNotNone(update_result.manifest_path) + self.assertFalse(update_result.fatal) + with open(update_result.manifest_path, "r") as fp: + manifest_xml_data = fp.read() + self.assertEqual( + sort_attributes(manifest_xml_data), + '' + '' + '' + '' + '' + "", + ) + + def test_superproject_update_project_revision_id_no_superproject_tag(self): + """Test update of commit ids of a manifest without superproject tag.""" + manifest = self.getXmlManifest( + """ -""") - self.maxDiff = None - self.assertIsNone(manifest.superproject) - self.assertEqual( - sort_attributes(manifest.ToXml().toxml()), - '' - '' - '' - '' - '') +""" + ) + self.maxDiff = None + self.assertIsNone(manifest.superproject) + self.assertEqual( + sort_attributes(manifest.ToXml().toxml()), + '' + '' + '' + '' + "", + ) - def test_superproject_update_project_revision_id_from_local_manifest_group(self): - """Test update of commit ids of a manifest that have local manifest no superproject group.""" - local_group = manifest_xml.LOCAL_MANIFEST_GROUP_PREFIX + ':local' - manifest = self.getXmlManifest(""" + def test_superproject_update_project_revision_id_from_local_manifest_group( + self, + ): + """Test update of commit ids of a manifest that have local manifest no superproject group.""" + local_group = manifest_xml.LOCAL_MANIFEST_GROUP_PREFIX + ":local" + manifest = self.getXmlManifest( + """ - -""") - self.maxDiff = None - self._superproject = git_superproject.Superproject( - manifest, name='superproject', - remote=manifest.remotes.get('default-remote').ToRemoteSpec('superproject'), - revision='refs/heads/main') - self.assertEqual(len(self._superproject._manifest.projects), 2) - projects = self._superproject._manifest.projects - data = ('160000 commit 2c2724cb36cd5a9cec6c852c681efc3b7c6b86ea\tart\x00') - with mock.patch.object(self._superproject, '_Init', return_value=True): - with mock.patch.object(self._superproject, '_Fetch', return_value=True): - with mock.patch.object(self._superproject, - '_LsTree', - return_value=data): - # Create temporary directory so that it can write the file. - os.mkdir(self._superproject._superproject_path) - update_result = self._superproject.UpdateProjectsRevisionId(projects, self.git_event_log) - self.assertIsNotNone(update_result.manifest_path) - self.assertFalse(update_result.fatal) - with open(update_result.manifest_path, 'r') as fp: - manifest_xml_data = fp.read() - # Verify platform/vendor/x's project revision hasn't changed. - self.assertEqual( - sort_attributes(manifest_xml_data), - '' - '' - '' - '' - '' - '' - '') +""" + ) + self.maxDiff = None + self._superproject = git_superproject.Superproject( + manifest, + name="superproject", + remote=manifest.remotes.get("default-remote").ToRemoteSpec( + "superproject" + ), + revision="refs/heads/main", + ) + self.assertEqual(len(self._superproject._manifest.projects), 2) + projects = self._superproject._manifest.projects + data = "160000 commit 2c2724cb36cd5a9cec6c852c681efc3b7c6b86ea\tart\x00" + with mock.patch.object(self._superproject, "_Init", return_value=True): + with mock.patch.object( + self._superproject, "_Fetch", return_value=True + ): + with mock.patch.object( + self._superproject, "_LsTree", return_value=data + ): + # Create temporary directory so that it can write the file. + os.mkdir(self._superproject._superproject_path) + update_result = self._superproject.UpdateProjectsRevisionId( + projects, self.git_event_log + ) + self.assertIsNotNone(update_result.manifest_path) + self.assertFalse(update_result.fatal) + with open(update_result.manifest_path, "r") as fp: + manifest_xml_data = fp.read() + # Verify platform/vendor/x's project revision hasn't + # changed. + self.assertEqual( + sort_attributes(manifest_xml_data), + '' + '' + '' + '' + '' + '' + "", + ) - def test_superproject_update_project_revision_id_with_pinned_manifest(self): - """Test update of commit ids of a pinned manifest.""" - manifest = self.getXmlManifest(""" + def test_superproject_update_project_revision_id_with_pinned_manifest(self): + """Test update of commit ids of a pinned manifest.""" + manifest = self.getXmlManifest( + """ @@ -326,80 +398,132 @@ class SuperprojectTestCase(unittest.TestCase): - -""") - self.maxDiff = None - self._superproject = git_superproject.Superproject( - manifest, name='superproject', - remote=manifest.remotes.get('default-remote').ToRemoteSpec('superproject'), - revision='refs/heads/main') - self.assertEqual(len(self._superproject._manifest.projects), 3) - projects = self._superproject._manifest.projects - data = ('160000 commit 2c2724cb36cd5a9cec6c852c681efc3b7c6b86ea\tart\x00' - '160000 commit e9d25da64d8d365dbba7c8ee00fe8c4473fe9a06\tvendor/x\x00') - with mock.patch.object(self._superproject, '_Init', return_value=True): - with mock.patch.object(self._superproject, '_Fetch', return_value=True): - with mock.patch.object(self._superproject, - '_LsTree', - return_value=data): - # Create temporary directory so that it can write the file. - os.mkdir(self._superproject._superproject_path) - update_result = self._superproject.UpdateProjectsRevisionId(projects, self.git_event_log) - self.assertIsNotNone(update_result.manifest_path) - self.assertFalse(update_result.fatal) - with open(update_result.manifest_path, 'r') as fp: - manifest_xml_data = fp.read() - # Verify platform/vendor/x's project revision hasn't changed. - self.assertEqual( - sort_attributes(manifest_xml_data), - '' - '' - '' - '' - '' - '' - '' - '') +""" + ) + self.maxDiff = None + self._superproject = git_superproject.Superproject( + manifest, + name="superproject", + remote=manifest.remotes.get("default-remote").ToRemoteSpec( + "superproject" + ), + revision="refs/heads/main", + ) + self.assertEqual(len(self._superproject._manifest.projects), 3) + projects = self._superproject._manifest.projects + data = ( + "160000 commit 2c2724cb36cd5a9cec6c852c681efc3b7c6b86ea\tart\x00" + "160000 commit e9d25da64d8d365dbba7c8ee00fe8c4473fe9a06\tvendor/x\x00" + ) + with mock.patch.object(self._superproject, "_Init", return_value=True): + with mock.patch.object( + self._superproject, "_Fetch", return_value=True + ): + with mock.patch.object( + self._superproject, "_LsTree", return_value=data + ): + # Create temporary directory so that it can write the file. + os.mkdir(self._superproject._superproject_path) + update_result = self._superproject.UpdateProjectsRevisionId( + projects, self.git_event_log + ) + self.assertIsNotNone(update_result.manifest_path) + self.assertFalse(update_result.fatal) + with open(update_result.manifest_path, "r") as fp: + manifest_xml_data = fp.read() + # Verify platform/vendor/x's project revision hasn't + # changed. + self.assertEqual( + sort_attributes(manifest_xml_data), + '' + '' + '' + '' + '' + '' + '' + "", + ) - def test_Fetch(self): - manifest = self.getXmlManifest(""" + def test_Fetch(self): + manifest = self.getXmlManifest( + """ " /> -""") - self.maxDiff = None - self._superproject = git_superproject.Superproject( - manifest, name='superproject', - remote=manifest.remotes.get('default-remote').ToRemoteSpec('superproject'), - revision='refs/heads/main') - os.mkdir(self._superproject._superproject_path) - os.mkdir(self._superproject._work_git) - with mock.patch.object(self._superproject, '_Init', return_value=True): - with mock.patch('git_superproject.GitCommand', autospec=True) as mock_git_command: - with mock.patch('git_superproject.GitRefs.get', autospec=True) as mock_git_refs: - instance = mock_git_command.return_value - instance.Wait.return_value = 0 - mock_git_refs.side_effect = ['', '1234'] +""" + ) + self.maxDiff = None + self._superproject = git_superproject.Superproject( + manifest, + name="superproject", + remote=manifest.remotes.get("default-remote").ToRemoteSpec( + "superproject" + ), + revision="refs/heads/main", + ) + os.mkdir(self._superproject._superproject_path) + os.mkdir(self._superproject._work_git) + with mock.patch.object(self._superproject, "_Init", return_value=True): + with mock.patch( + "git_superproject.GitCommand", autospec=True + ) as mock_git_command: + with mock.patch( + "git_superproject.GitRefs.get", autospec=True + ) as mock_git_refs: + instance = mock_git_command.return_value + instance.Wait.return_value = 0 + mock_git_refs.side_effect = ["", "1234"] - self.assertTrue(self._superproject._Fetch()) - self.assertEqual(mock_git_command.call_args.args,(None, [ - 'fetch', 'http://localhost/superproject', '--depth', '1', - '--force', '--no-tags', '--filter', 'blob:none', - 'refs/heads/main:refs/heads/main' - ])) + self.assertTrue(self._superproject._Fetch()) + self.assertEqual( + mock_git_command.call_args.args, + ( + None, + [ + "fetch", + "http://localhost/superproject", + "--depth", + "1", + "--force", + "--no-tags", + "--filter", + "blob:none", + "refs/heads/main:refs/heads/main", + ], + ), + ) - # If branch for revision exists, set as --negotiation-tip. - self.assertTrue(self._superproject._Fetch()) - self.assertEqual(mock_git_command.call_args.args,(None, [ - 'fetch', 'http://localhost/superproject', '--depth', '1', - '--force', '--no-tags', '--filter', 'blob:none', - '--negotiation-tip', '1234', - 'refs/heads/main:refs/heads/main' - ])) + # If branch for revision exists, set as --negotiation-tip. + self.assertTrue(self._superproject._Fetch()) + self.assertEqual( + mock_git_command.call_args.args, + ( + None, + [ + "fetch", + "http://localhost/superproject", + "--depth", + "1", + "--force", + "--no-tags", + "--filter", + "blob:none", + "--negotiation-tip", + "1234", + "refs/heads/main:refs/heads/main", + ], + ), + ) diff --git a/tests/test_git_trace2_event_log.py b/tests/test_git_trace2_event_log.py index 7e7dfb7a..a6078d38 100644 --- a/tests/test_git_trace2_event_log.py +++ b/tests/test_git_trace2_event_log.py @@ -27,361 +27,382 @@ import platform_utils def serverLoggingThread(socket_path, server_ready, received_traces): - """Helper function to receive logs over a Unix domain socket. + """Helper function to receive logs over a Unix domain socket. - Appends received messages on the provided socket and appends to received_traces. + Appends received messages on the provided socket and appends to + received_traces. - Args: - socket_path: path to a Unix domain socket on which to listen for traces - server_ready: a threading.Condition used to signal to the caller that this thread is ready to - accept connections - received_traces: a list to which received traces will be appended (after decoding to a utf-8 - string). - """ - platform_utils.remove(socket_path, missing_ok=True) - data = b'' - with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as sock: - sock.bind(socket_path) - sock.listen(0) - with server_ready: - server_ready.notify() - with sock.accept()[0] as conn: - while True: - recved = conn.recv(4096) - if not recved: - break - data += recved - received_traces.extend(data.decode('utf-8').splitlines()) + Args: + socket_path: path to a Unix domain socket on which to listen for traces + server_ready: a threading.Condition used to signal to the caller that + this thread is ready to accept connections + received_traces: a list to which received traces will be appended (after + decoding to a utf-8 string). + """ + platform_utils.remove(socket_path, missing_ok=True) + data = b"" + with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as sock: + sock.bind(socket_path) + sock.listen(0) + with server_ready: + server_ready.notify() + with sock.accept()[0] as conn: + while True: + recved = conn.recv(4096) + if not recved: + break + data += recved + received_traces.extend(data.decode("utf-8").splitlines()) class EventLogTestCase(unittest.TestCase): - """TestCase for the EventLog module.""" + """TestCase for the EventLog module.""" - PARENT_SID_KEY = 'GIT_TRACE2_PARENT_SID' - PARENT_SID_VALUE = 'parent_sid' - SELF_SID_REGEX = r'repo-\d+T\d+Z-.*' - FULL_SID_REGEX = r'^%s/%s' % (PARENT_SID_VALUE, SELF_SID_REGEX) + PARENT_SID_KEY = "GIT_TRACE2_PARENT_SID" + PARENT_SID_VALUE = "parent_sid" + SELF_SID_REGEX = r"repo-\d+T\d+Z-.*" + FULL_SID_REGEX = r"^%s/%s" % (PARENT_SID_VALUE, SELF_SID_REGEX) - def setUp(self): - """Load the event_log module every time.""" - self._event_log_module = None - # By default we initialize with the expected case where - # repo launches us (so GIT_TRACE2_PARENT_SID is set). - env = { - self.PARENT_SID_KEY: self.PARENT_SID_VALUE, - } - self._event_log_module = git_trace2_event_log.EventLog(env=env) - self._log_data = None + def setUp(self): + """Load the event_log module every time.""" + self._event_log_module = None + # By default we initialize with the expected case where + # repo launches us (so GIT_TRACE2_PARENT_SID is set). + env = { + self.PARENT_SID_KEY: self.PARENT_SID_VALUE, + } + self._event_log_module = git_trace2_event_log.EventLog(env=env) + self._log_data = None - def verifyCommonKeys(self, log_entry, expected_event_name=None, full_sid=True): - """Helper function to verify common event log keys.""" - self.assertIn('event', log_entry) - self.assertIn('sid', log_entry) - self.assertIn('thread', log_entry) - self.assertIn('time', log_entry) + def verifyCommonKeys( + self, log_entry, expected_event_name=None, full_sid=True + ): + """Helper function to verify common event log keys.""" + self.assertIn("event", log_entry) + self.assertIn("sid", log_entry) + self.assertIn("thread", log_entry) + self.assertIn("time", log_entry) - # Do basic data format validation. - if expected_event_name: - self.assertEqual(expected_event_name, log_entry['event']) - if full_sid: - self.assertRegex(log_entry['sid'], self.FULL_SID_REGEX) - else: - self.assertRegex(log_entry['sid'], self.SELF_SID_REGEX) - self.assertRegex(log_entry['time'], r'^\d+-\d+-\d+T\d+:\d+:\d+\.\d+Z$') + # Do basic data format validation. + if expected_event_name: + self.assertEqual(expected_event_name, log_entry["event"]) + if full_sid: + self.assertRegex(log_entry["sid"], self.FULL_SID_REGEX) + else: + self.assertRegex(log_entry["sid"], self.SELF_SID_REGEX) + self.assertRegex(log_entry["time"], r"^\d+-\d+-\d+T\d+:\d+:\d+\.\d+Z$") - def readLog(self, log_path): - """Helper function to read log data into a list.""" - log_data = [] - with open(log_path, mode='rb') as f: - for line in f: - log_data.append(json.loads(line)) - return log_data + def readLog(self, log_path): + """Helper function to read log data into a list.""" + log_data = [] + with open(log_path, mode="rb") as f: + for line in f: + log_data.append(json.loads(line)) + return log_data - def remove_prefix(self, s, prefix): - """Return a copy string after removing |prefix| from |s|, if present or the original string.""" - if s.startswith(prefix): - return s[len(prefix):] - else: - return s + def remove_prefix(self, s, prefix): + """Return a copy string after removing |prefix| from |s|, if present or + the original string.""" + if s.startswith(prefix): + return s[len(prefix) :] + else: + return s - def test_initial_state_with_parent_sid(self): - """Test initial state when 'GIT_TRACE2_PARENT_SID' is set by parent.""" - self.assertRegex(self._event_log_module.full_sid, self.FULL_SID_REGEX) + def test_initial_state_with_parent_sid(self): + """Test initial state when 'GIT_TRACE2_PARENT_SID' is set by parent.""" + self.assertRegex(self._event_log_module.full_sid, self.FULL_SID_REGEX) - def test_initial_state_no_parent_sid(self): - """Test initial state when 'GIT_TRACE2_PARENT_SID' is not set.""" - # Setup an empty environment dict (no parent sid). - self._event_log_module = git_trace2_event_log.EventLog(env={}) - self.assertRegex(self._event_log_module.full_sid, self.SELF_SID_REGEX) + def test_initial_state_no_parent_sid(self): + """Test initial state when 'GIT_TRACE2_PARENT_SID' is not set.""" + # Setup an empty environment dict (no parent sid). + self._event_log_module = git_trace2_event_log.EventLog(env={}) + self.assertRegex(self._event_log_module.full_sid, self.SELF_SID_REGEX) - def test_version_event(self): - """Test 'version' event data is valid. + def test_version_event(self): + """Test 'version' event data is valid. - Verify that the 'version' event is written even when no other - events are addded. + Verify that the 'version' event is written even when no other + events are addded. - Expected event log: - - """ - with tempfile.TemporaryDirectory(prefix='event_log_tests') as tempdir: - log_path = self._event_log_module.Write(path=tempdir) - self._log_data = self.readLog(log_path) + Expected event log: + + """ + with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir: + log_path = self._event_log_module.Write(path=tempdir) + self._log_data = self.readLog(log_path) - # A log with no added events should only have the version entry. - self.assertEqual(len(self._log_data), 1) - version_event = self._log_data[0] - self.verifyCommonKeys(version_event, expected_event_name='version') - # Check for 'version' event specific fields. - self.assertIn('evt', version_event) - self.assertIn('exe', version_event) - # Verify "evt" version field is a string. - self.assertIsInstance(version_event['evt'], str) + # A log with no added events should only have the version entry. + self.assertEqual(len(self._log_data), 1) + version_event = self._log_data[0] + self.verifyCommonKeys(version_event, expected_event_name="version") + # Check for 'version' event specific fields. + self.assertIn("evt", version_event) + self.assertIn("exe", version_event) + # Verify "evt" version field is a string. + self.assertIsInstance(version_event["evt"], str) - def test_start_event(self): - """Test and validate 'start' event data is valid. - - Expected event log: - - - """ - self._event_log_module.StartEvent() - with tempfile.TemporaryDirectory(prefix='event_log_tests') as tempdir: - log_path = self._event_log_module.Write(path=tempdir) - self._log_data = self.readLog(log_path) - - self.assertEqual(len(self._log_data), 2) - start_event = self._log_data[1] - self.verifyCommonKeys(self._log_data[0], expected_event_name='version') - self.verifyCommonKeys(start_event, expected_event_name='start') - # Check for 'start' event specific fields. - self.assertIn('argv', start_event) - self.assertTrue(isinstance(start_event['argv'], list)) - - def test_exit_event_result_none(self): - """Test 'exit' event data is valid when result is None. - - We expect None result to be converted to 0 in the exit event data. - - Expected event log: - - - """ - self._event_log_module.ExitEvent(None) - with tempfile.TemporaryDirectory(prefix='event_log_tests') as tempdir: - log_path = self._event_log_module.Write(path=tempdir) - self._log_data = self.readLog(log_path) - - self.assertEqual(len(self._log_data), 2) - exit_event = self._log_data[1] - self.verifyCommonKeys(self._log_data[0], expected_event_name='version') - self.verifyCommonKeys(exit_event, expected_event_name='exit') - # Check for 'exit' event specific fields. - self.assertIn('code', exit_event) - # 'None' result should convert to 0 (successful) return code. - self.assertEqual(exit_event['code'], 0) - - def test_exit_event_result_integer(self): - """Test 'exit' event data is valid when result is an integer. - - Expected event log: - - - """ - self._event_log_module.ExitEvent(2) - with tempfile.TemporaryDirectory(prefix='event_log_tests') as tempdir: - log_path = self._event_log_module.Write(path=tempdir) - self._log_data = self.readLog(log_path) - - self.assertEqual(len(self._log_data), 2) - exit_event = self._log_data[1] - self.verifyCommonKeys(self._log_data[0], expected_event_name='version') - self.verifyCommonKeys(exit_event, expected_event_name='exit') - # Check for 'exit' event specific fields. - self.assertIn('code', exit_event) - self.assertEqual(exit_event['code'], 2) - - def test_command_event(self): - """Test and validate 'command' event data is valid. - - Expected event log: - - - """ - name = 'repo' - subcommands = ['init' 'this'] - self._event_log_module.CommandEvent(name='repo', subcommands=subcommands) - with tempfile.TemporaryDirectory(prefix='event_log_tests') as tempdir: - log_path = self._event_log_module.Write(path=tempdir) - self._log_data = self.readLog(log_path) - - self.assertEqual(len(self._log_data), 2) - command_event = self._log_data[1] - self.verifyCommonKeys(self._log_data[0], expected_event_name='version') - self.verifyCommonKeys(command_event, expected_event_name='command') - # Check for 'command' event specific fields. - self.assertIn('name', command_event) - self.assertIn('subcommands', command_event) - self.assertEqual(command_event['name'], name) - self.assertEqual(command_event['subcommands'], subcommands) - - def test_def_params_event_repo_config(self): - """Test 'def_params' event data outputs only repo config keys. - - Expected event log: - - - - """ - config = { - 'git.foo': 'bar', - 'repo.partialclone': 'true', - 'repo.partialclonefilter': 'blob:none', - } - self._event_log_module.DefParamRepoEvents(config) - - with tempfile.TemporaryDirectory(prefix='event_log_tests') as tempdir: - log_path = self._event_log_module.Write(path=tempdir) - self._log_data = self.readLog(log_path) - - self.assertEqual(len(self._log_data), 3) - def_param_events = self._log_data[1:] - self.verifyCommonKeys(self._log_data[0], expected_event_name='version') - - for event in def_param_events: - self.verifyCommonKeys(event, expected_event_name='def_param') - # Check for 'def_param' event specific fields. - self.assertIn('param', event) - self.assertIn('value', event) - self.assertTrue(event['param'].startswith('repo.')) - - def test_def_params_event_no_repo_config(self): - """Test 'def_params' event data won't output non-repo config keys. - - Expected event log: - - """ - config = { - 'git.foo': 'bar', - 'git.core.foo2': 'baz', - } - self._event_log_module.DefParamRepoEvents(config) - - with tempfile.TemporaryDirectory(prefix='event_log_tests') as tempdir: - log_path = self._event_log_module.Write(path=tempdir) - self._log_data = self.readLog(log_path) - - self.assertEqual(len(self._log_data), 1) - self.verifyCommonKeys(self._log_data[0], expected_event_name='version') - - def test_data_event_config(self): - """Test 'data' event data outputs all config keys. - - Expected event log: - - - - """ - config = { - 'git.foo': 'bar', - 'repo.partialclone': 'false', - 'repo.syncstate.superproject.hassuperprojecttag': 'true', - 'repo.syncstate.superproject.sys.argv': ['--', 'sync', 'protobuf'], - } - prefix_value = 'prefix' - self._event_log_module.LogDataConfigEvents(config, prefix_value) - - with tempfile.TemporaryDirectory(prefix='event_log_tests') as tempdir: - log_path = self._event_log_module.Write(path=tempdir) - self._log_data = self.readLog(log_path) - - self.assertEqual(len(self._log_data), 5) - data_events = self._log_data[1:] - self.verifyCommonKeys(self._log_data[0], expected_event_name='version') - - for event in data_events: - self.verifyCommonKeys(event) - # Check for 'data' event specific fields. - self.assertIn('key', event) - self.assertIn('value', event) - key = event['key'] - key = self.remove_prefix(key, f'{prefix_value}/') - value = event['value'] - self.assertEqual(self._event_log_module.GetDataEventName(value), event['event']) - self.assertTrue(key in config and value == config[key]) - - def test_error_event(self): - """Test and validate 'error' event data is valid. - - Expected event log: - - - """ - msg = 'invalid option: --cahced' - fmt = 'invalid option: %s' - self._event_log_module.ErrorEvent(msg, fmt) - with tempfile.TemporaryDirectory(prefix='event_log_tests') as tempdir: - log_path = self._event_log_module.Write(path=tempdir) - self._log_data = self.readLog(log_path) - - self.assertEqual(len(self._log_data), 2) - error_event = self._log_data[1] - self.verifyCommonKeys(self._log_data[0], expected_event_name='version') - self.verifyCommonKeys(error_event, expected_event_name='error') - # Check for 'error' event specific fields. - self.assertIn('msg', error_event) - self.assertIn('fmt', error_event) - self.assertEqual(error_event['msg'], msg) - self.assertEqual(error_event['fmt'], fmt) - - def test_write_with_filename(self): - """Test Write() with a path to a file exits with None.""" - self.assertIsNone(self._event_log_module.Write(path='path/to/file')) - - def test_write_with_git_config(self): - """Test Write() uses the git config path when 'git config' call succeeds.""" - with tempfile.TemporaryDirectory(prefix='event_log_tests') as tempdir: - with mock.patch.object(self._event_log_module, - '_GetEventTargetPath', return_value=tempdir): - self.assertEqual(os.path.dirname(self._event_log_module.Write()), tempdir) - - def test_write_no_git_config(self): - """Test Write() with no git config variable present exits with None.""" - with mock.patch.object(self._event_log_module, - '_GetEventTargetPath', return_value=None): - self.assertIsNone(self._event_log_module.Write()) - - def test_write_non_string(self): - """Test Write() with non-string type for |path| throws TypeError.""" - with self.assertRaises(TypeError): - self._event_log_module.Write(path=1234) - - def test_write_socket(self): - """Test Write() with Unix domain socket for |path| and validate received traces.""" - received_traces = [] - with tempfile.TemporaryDirectory(prefix='test_server_sockets') as tempdir: - socket_path = os.path.join(tempdir, "server.sock") - server_ready = threading.Condition() - # Start "server" listening on Unix domain socket at socket_path. - try: - server_thread = threading.Thread( - target=serverLoggingThread, - args=(socket_path, server_ready, received_traces)) - server_thread.start() - - with server_ready: - server_ready.wait(timeout=120) + def test_start_event(self): + """Test and validate 'start' event data is valid. + Expected event log: + + + """ self._event_log_module.StartEvent() - path = self._event_log_module.Write(path=f'af_unix:{socket_path}') - finally: - server_thread.join(timeout=5) + with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir: + log_path = self._event_log_module.Write(path=tempdir) + self._log_data = self.readLog(log_path) - self.assertEqual(path, f'af_unix:stream:{socket_path}') - self.assertEqual(len(received_traces), 2) - version_event = json.loads(received_traces[0]) - start_event = json.loads(received_traces[1]) - self.verifyCommonKeys(version_event, expected_event_name='version') - self.verifyCommonKeys(start_event, expected_event_name='start') - # Check for 'start' event specific fields. - self.assertIn('argv', start_event) - self.assertIsInstance(start_event['argv'], list) + self.assertEqual(len(self._log_data), 2) + start_event = self._log_data[1] + self.verifyCommonKeys(self._log_data[0], expected_event_name="version") + self.verifyCommonKeys(start_event, expected_event_name="start") + # Check for 'start' event specific fields. + self.assertIn("argv", start_event) + self.assertTrue(isinstance(start_event["argv"], list)) + + def test_exit_event_result_none(self): + """Test 'exit' event data is valid when result is None. + + We expect None result to be converted to 0 in the exit event data. + + Expected event log: + + + """ + self._event_log_module.ExitEvent(None) + with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir: + log_path = self._event_log_module.Write(path=tempdir) + self._log_data = self.readLog(log_path) + + self.assertEqual(len(self._log_data), 2) + exit_event = self._log_data[1] + self.verifyCommonKeys(self._log_data[0], expected_event_name="version") + self.verifyCommonKeys(exit_event, expected_event_name="exit") + # Check for 'exit' event specific fields. + self.assertIn("code", exit_event) + # 'None' result should convert to 0 (successful) return code. + self.assertEqual(exit_event["code"], 0) + + def test_exit_event_result_integer(self): + """Test 'exit' event data is valid when result is an integer. + + Expected event log: + + + """ + self._event_log_module.ExitEvent(2) + with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir: + log_path = self._event_log_module.Write(path=tempdir) + self._log_data = self.readLog(log_path) + + self.assertEqual(len(self._log_data), 2) + exit_event = self._log_data[1] + self.verifyCommonKeys(self._log_data[0], expected_event_name="version") + self.verifyCommonKeys(exit_event, expected_event_name="exit") + # Check for 'exit' event specific fields. + self.assertIn("code", exit_event) + self.assertEqual(exit_event["code"], 2) + + def test_command_event(self): + """Test and validate 'command' event data is valid. + + Expected event log: + + + """ + name = "repo" + subcommands = ["init" "this"] + self._event_log_module.CommandEvent( + name="repo", subcommands=subcommands + ) + with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir: + log_path = self._event_log_module.Write(path=tempdir) + self._log_data = self.readLog(log_path) + + self.assertEqual(len(self._log_data), 2) + command_event = self._log_data[1] + self.verifyCommonKeys(self._log_data[0], expected_event_name="version") + self.verifyCommonKeys(command_event, expected_event_name="command") + # Check for 'command' event specific fields. + self.assertIn("name", command_event) + self.assertIn("subcommands", command_event) + self.assertEqual(command_event["name"], name) + self.assertEqual(command_event["subcommands"], subcommands) + + def test_def_params_event_repo_config(self): + """Test 'def_params' event data outputs only repo config keys. + + Expected event log: + + + + """ + config = { + "git.foo": "bar", + "repo.partialclone": "true", + "repo.partialclonefilter": "blob:none", + } + self._event_log_module.DefParamRepoEvents(config) + + with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir: + log_path = self._event_log_module.Write(path=tempdir) + self._log_data = self.readLog(log_path) + + self.assertEqual(len(self._log_data), 3) + def_param_events = self._log_data[1:] + self.verifyCommonKeys(self._log_data[0], expected_event_name="version") + + for event in def_param_events: + self.verifyCommonKeys(event, expected_event_name="def_param") + # Check for 'def_param' event specific fields. + self.assertIn("param", event) + self.assertIn("value", event) + self.assertTrue(event["param"].startswith("repo.")) + + def test_def_params_event_no_repo_config(self): + """Test 'def_params' event data won't output non-repo config keys. + + Expected event log: + + """ + config = { + "git.foo": "bar", + "git.core.foo2": "baz", + } + self._event_log_module.DefParamRepoEvents(config) + + with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir: + log_path = self._event_log_module.Write(path=tempdir) + self._log_data = self.readLog(log_path) + + self.assertEqual(len(self._log_data), 1) + self.verifyCommonKeys(self._log_data[0], expected_event_name="version") + + def test_data_event_config(self): + """Test 'data' event data outputs all config keys. + + Expected event log: + + + + """ + config = { + "git.foo": "bar", + "repo.partialclone": "false", + "repo.syncstate.superproject.hassuperprojecttag": "true", + "repo.syncstate.superproject.sys.argv": ["--", "sync", "protobuf"], + } + prefix_value = "prefix" + self._event_log_module.LogDataConfigEvents(config, prefix_value) + + with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir: + log_path = self._event_log_module.Write(path=tempdir) + self._log_data = self.readLog(log_path) + + self.assertEqual(len(self._log_data), 5) + data_events = self._log_data[1:] + self.verifyCommonKeys(self._log_data[0], expected_event_name="version") + + for event in data_events: + self.verifyCommonKeys(event) + # Check for 'data' event specific fields. + self.assertIn("key", event) + self.assertIn("value", event) + key = event["key"] + key = self.remove_prefix(key, f"{prefix_value}/") + value = event["value"] + self.assertEqual( + self._event_log_module.GetDataEventName(value), event["event"] + ) + self.assertTrue(key in config and value == config[key]) + + def test_error_event(self): + """Test and validate 'error' event data is valid. + + Expected event log: + + + """ + msg = "invalid option: --cahced" + fmt = "invalid option: %s" + self._event_log_module.ErrorEvent(msg, fmt) + with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir: + log_path = self._event_log_module.Write(path=tempdir) + self._log_data = self.readLog(log_path) + + self.assertEqual(len(self._log_data), 2) + error_event = self._log_data[1] + self.verifyCommonKeys(self._log_data[0], expected_event_name="version") + self.verifyCommonKeys(error_event, expected_event_name="error") + # Check for 'error' event specific fields. + self.assertIn("msg", error_event) + self.assertIn("fmt", error_event) + self.assertEqual(error_event["msg"], msg) + self.assertEqual(error_event["fmt"], fmt) + + def test_write_with_filename(self): + """Test Write() with a path to a file exits with None.""" + self.assertIsNone(self._event_log_module.Write(path="path/to/file")) + + def test_write_with_git_config(self): + """Test Write() uses the git config path when 'git config' call + succeeds.""" + with tempfile.TemporaryDirectory(prefix="event_log_tests") as tempdir: + with mock.patch.object( + self._event_log_module, + "_GetEventTargetPath", + return_value=tempdir, + ): + self.assertEqual( + os.path.dirname(self._event_log_module.Write()), tempdir + ) + + def test_write_no_git_config(self): + """Test Write() with no git config variable present exits with None.""" + with mock.patch.object( + self._event_log_module, "_GetEventTargetPath", return_value=None + ): + self.assertIsNone(self._event_log_module.Write()) + + def test_write_non_string(self): + """Test Write() with non-string type for |path| throws TypeError.""" + with self.assertRaises(TypeError): + self._event_log_module.Write(path=1234) + + def test_write_socket(self): + """Test Write() with Unix domain socket for |path| and validate received + traces.""" + received_traces = [] + with tempfile.TemporaryDirectory( + prefix="test_server_sockets" + ) as tempdir: + socket_path = os.path.join(tempdir, "server.sock") + server_ready = threading.Condition() + # Start "server" listening on Unix domain socket at socket_path. + try: + server_thread = threading.Thread( + target=serverLoggingThread, + args=(socket_path, server_ready, received_traces), + ) + server_thread.start() + + with server_ready: + server_ready.wait(timeout=120) + + self._event_log_module.StartEvent() + path = self._event_log_module.Write( + path=f"af_unix:{socket_path}" + ) + finally: + server_thread.join(timeout=5) + + self.assertEqual(path, f"af_unix:stream:{socket_path}") + self.assertEqual(len(received_traces), 2) + version_event = json.loads(received_traces[0]) + start_event = json.loads(received_traces[1]) + self.verifyCommonKeys(version_event, expected_event_name="version") + self.verifyCommonKeys(start_event, expected_event_name="start") + # Check for 'start' event specific fields. + self.assertIn("argv", start_event) + self.assertIsInstance(start_event["argv"], list) diff --git a/tests/test_hooks.py b/tests/test_hooks.py index 6632b3e5..78277128 100644 --- a/tests/test_hooks.py +++ b/tests/test_hooks.py @@ -17,39 +17,38 @@ import hooks import unittest + class RepoHookShebang(unittest.TestCase): - """Check shebang parsing in RepoHook.""" + """Check shebang parsing in RepoHook.""" - def test_no_shebang(self): - """Lines w/out shebangs should be rejected.""" - DATA = ( - '', - '#\n# foo\n', - '# Bad shebang in script\n#!/foo\n' - ) - for data in DATA: - self.assertIsNone(hooks.RepoHook._ExtractInterpFromShebang(data)) + def test_no_shebang(self): + """Lines w/out shebangs should be rejected.""" + DATA = ("", "#\n# foo\n", "# Bad shebang in script\n#!/foo\n") + for data in DATA: + self.assertIsNone(hooks.RepoHook._ExtractInterpFromShebang(data)) - def test_direct_interp(self): - """Lines whose shebang points directly to the interpreter.""" - DATA = ( - ('#!/foo', '/foo'), - ('#! /foo', '/foo'), - ('#!/bin/foo ', '/bin/foo'), - ('#! /usr/foo ', '/usr/foo'), - ('#! /usr/foo -args', '/usr/foo'), - ) - for shebang, interp in DATA: - self.assertEqual(hooks.RepoHook._ExtractInterpFromShebang(shebang), - interp) + def test_direct_interp(self): + """Lines whose shebang points directly to the interpreter.""" + DATA = ( + ("#!/foo", "/foo"), + ("#! /foo", "/foo"), + ("#!/bin/foo ", "/bin/foo"), + ("#! /usr/foo ", "/usr/foo"), + ("#! /usr/foo -args", "/usr/foo"), + ) + for shebang, interp in DATA: + self.assertEqual( + hooks.RepoHook._ExtractInterpFromShebang(shebang), interp + ) - def test_env_interp(self): - """Lines whose shebang launches through `env`.""" - DATA = ( - ('#!/usr/bin/env foo', 'foo'), - ('#!/bin/env foo', 'foo'), - ('#! /bin/env /bin/foo ', '/bin/foo'), - ) - for shebang, interp in DATA: - self.assertEqual(hooks.RepoHook._ExtractInterpFromShebang(shebang), - interp) + def test_env_interp(self): + """Lines whose shebang launches through `env`.""" + DATA = ( + ("#!/usr/bin/env foo", "foo"), + ("#!/bin/env foo", "foo"), + ("#! /bin/env /bin/foo ", "/bin/foo"), + ) + for shebang, interp in DATA: + self.assertEqual( + hooks.RepoHook._ExtractInterpFromShebang(shebang), interp + ) diff --git a/tests/test_manifest_xml.py b/tests/test_manifest_xml.py index 3634701f..648acde8 100644 --- a/tests/test_manifest_xml.py +++ b/tests/test_manifest_xml.py @@ -27,291 +27,318 @@ import manifest_xml # Invalid paths that we don't want in the filesystem. INVALID_FS_PATHS = ( - '', - '.', - '..', - '../', - './', - './/', - 'foo/', - './foo', - '../foo', - 'foo/./bar', - 'foo/../../bar', - '/foo', - './../foo', - '.git/foo', + "", + ".", + "..", + "../", + "./", + ".//", + "foo/", + "./foo", + "../foo", + "foo/./bar", + "foo/../../bar", + "/foo", + "./../foo", + ".git/foo", # Check case folding. - '.GIT/foo', - 'blah/.git/foo', - '.repo/foo', - '.repoconfig', + ".GIT/foo", + "blah/.git/foo", + ".repo/foo", + ".repoconfig", # Block ~ due to 8.3 filenames on Windows filesystems. - '~', - 'foo~', - 'blah/foo~', + "~", + "foo~", + "blah/foo~", # Block Unicode characters that get normalized out by filesystems. - u'foo\u200Cbar', + "foo\u200Cbar", # Block newlines. - 'f\n/bar', - 'f\r/bar', + "f\n/bar", + "f\r/bar", ) # Make sure platforms that use path separators (e.g. Windows) are also # rejected properly. -if os.path.sep != '/': - INVALID_FS_PATHS += tuple(x.replace('/', os.path.sep) for x in INVALID_FS_PATHS) +if os.path.sep != "/": + INVALID_FS_PATHS += tuple( + x.replace("/", os.path.sep) for x in INVALID_FS_PATHS + ) def sort_attributes(manifest): - """Sort the attributes of all elements alphabetically. + """Sort the attributes of all elements alphabetically. - This is needed because different versions of the toxml() function from - xml.dom.minidom outputs the attributes of elements in different orders. - Before Python 3.8 they were output alphabetically, later versions preserve - the order specified by the user. + This is needed because different versions of the toxml() function from + xml.dom.minidom outputs the attributes of elements in different orders. + Before Python 3.8 they were output alphabetically, later versions preserve + the order specified by the user. - Args: - manifest: String containing an XML manifest. + Args: + manifest: String containing an XML manifest. - Returns: - The XML manifest with the attributes of all elements sorted alphabetically. - """ - new_manifest = '' - # This will find every element in the XML manifest, whether they have - # attributes or not. This simplifies recreating the manifest below. - matches = re.findall(r'(<[/?]?[a-z-]+\s*)((?:\S+?="[^"]+"\s*?)*)(\s*[/?]?>)', manifest) - for head, attrs, tail in matches: - m = re.findall(r'\S+?="[^"]+"', attrs) - new_manifest += head + ' '.join(sorted(m)) + tail - return new_manifest + Returns: + The XML manifest with the attributes of all elements sorted + alphabetically. + """ + new_manifest = "" + # This will find every element in the XML manifest, whether they have + # attributes or not. This simplifies recreating the manifest below. + matches = re.findall( + r'(<[/?]?[a-z-]+\s*)((?:\S+?="[^"]+"\s*?)*)(\s*[/?]?>)', manifest + ) + for head, attrs, tail in matches: + m = re.findall(r'\S+?="[^"]+"', attrs) + new_manifest += head + " ".join(sorted(m)) + tail + return new_manifest class ManifestParseTestCase(unittest.TestCase): - """TestCase for parsing manifests.""" + """TestCase for parsing manifests.""" - def setUp(self): - self.tempdirobj = tempfile.TemporaryDirectory(prefix='repo_tests') - self.tempdir = self.tempdirobj.name - self.repodir = os.path.join(self.tempdir, '.repo') - self.manifest_dir = os.path.join(self.repodir, 'manifests') - self.manifest_file = os.path.join( - self.repodir, manifest_xml.MANIFEST_FILE_NAME) - self.local_manifest_dir = os.path.join( - self.repodir, manifest_xml.LOCAL_MANIFESTS_DIR_NAME) - os.mkdir(self.repodir) - os.mkdir(self.manifest_dir) + def setUp(self): + self.tempdirobj = tempfile.TemporaryDirectory(prefix="repo_tests") + self.tempdir = self.tempdirobj.name + self.repodir = os.path.join(self.tempdir, ".repo") + self.manifest_dir = os.path.join(self.repodir, "manifests") + self.manifest_file = os.path.join( + self.repodir, manifest_xml.MANIFEST_FILE_NAME + ) + self.local_manifest_dir = os.path.join( + self.repodir, manifest_xml.LOCAL_MANIFESTS_DIR_NAME + ) + os.mkdir(self.repodir) + os.mkdir(self.manifest_dir) - # The manifest parsing really wants a git repo currently. - gitdir = os.path.join(self.repodir, 'manifests.git') - os.mkdir(gitdir) - with open(os.path.join(gitdir, 'config'), 'w') as fp: - fp.write("""[remote "origin"] + # The manifest parsing really wants a git repo currently. + gitdir = os.path.join(self.repodir, "manifests.git") + os.mkdir(gitdir) + with open(os.path.join(gitdir, "config"), "w") as fp: + fp.write( + """[remote "origin"] url = https://localhost:0/manifest -""") +""" + ) - def tearDown(self): - self.tempdirobj.cleanup() + def tearDown(self): + self.tempdirobj.cleanup() - def getXmlManifest(self, data): - """Helper to initialize a manifest for testing.""" - with open(self.manifest_file, 'w', encoding="utf-8") as fp: - fp.write(data) - return manifest_xml.XmlManifest(self.repodir, self.manifest_file) + def getXmlManifest(self, data): + """Helper to initialize a manifest for testing.""" + with open(self.manifest_file, "w", encoding="utf-8") as fp: + fp.write(data) + return manifest_xml.XmlManifest(self.repodir, self.manifest_file) - @staticmethod - def encodeXmlAttr(attr): - """Encode |attr| using XML escape rules.""" - return attr.replace('\r', ' ').replace('\n', ' ') + @staticmethod + def encodeXmlAttr(attr): + """Encode |attr| using XML escape rules.""" + return attr.replace("\r", " ").replace("\n", " ") class ManifestValidateFilePaths(unittest.TestCase): - """Check _ValidateFilePaths helper. + """Check _ValidateFilePaths helper. - This doesn't access a real filesystem. - """ + This doesn't access a real filesystem. + """ - def check_both(self, *args): - manifest_xml.XmlManifest._ValidateFilePaths('copyfile', *args) - manifest_xml.XmlManifest._ValidateFilePaths('linkfile', *args) + def check_both(self, *args): + manifest_xml.XmlManifest._ValidateFilePaths("copyfile", *args) + manifest_xml.XmlManifest._ValidateFilePaths("linkfile", *args) - def test_normal_path(self): - """Make sure good paths are accepted.""" - self.check_both('foo', 'bar') - self.check_both('foo/bar', 'bar') - self.check_both('foo', 'bar/bar') - self.check_both('foo/bar', 'bar/bar') + def test_normal_path(self): + """Make sure good paths are accepted.""" + self.check_both("foo", "bar") + self.check_both("foo/bar", "bar") + self.check_both("foo", "bar/bar") + self.check_both("foo/bar", "bar/bar") - def test_symlink_targets(self): - """Some extra checks for symlinks.""" - def check(*args): - manifest_xml.XmlManifest._ValidateFilePaths('linkfile', *args) + def test_symlink_targets(self): + """Some extra checks for symlinks.""" - # We allow symlinks to end in a slash since we allow them to point to dirs - # in general. Technically the slash isn't necessary. - check('foo/', 'bar') - # We allow a single '.' to get a reference to the project itself. - check('.', 'bar') + def check(*args): + manifest_xml.XmlManifest._ValidateFilePaths("linkfile", *args) - def test_bad_paths(self): - """Make sure bad paths (src & dest) are rejected.""" - for path in INVALID_FS_PATHS: - self.assertRaises( - error.ManifestInvalidPathError, self.check_both, path, 'a') - self.assertRaises( - error.ManifestInvalidPathError, self.check_both, 'a', path) + # We allow symlinks to end in a slash since we allow them to point to + # dirs in general. Technically the slash isn't necessary. + check("foo/", "bar") + # We allow a single '.' to get a reference to the project itself. + check(".", "bar") + + def test_bad_paths(self): + """Make sure bad paths (src & dest) are rejected.""" + for path in INVALID_FS_PATHS: + self.assertRaises( + error.ManifestInvalidPathError, self.check_both, path, "a" + ) + self.assertRaises( + error.ManifestInvalidPathError, self.check_both, "a", path + ) class ValueTests(unittest.TestCase): - """Check utility parsing code.""" + """Check utility parsing code.""" - def _get_node(self, text): - return xml.dom.minidom.parseString(text).firstChild + def _get_node(self, text): + return xml.dom.minidom.parseString(text).firstChild - def test_bool_default(self): - """Check XmlBool default handling.""" - node = self._get_node('') - self.assertIsNone(manifest_xml.XmlBool(node, 'a')) - self.assertIsNone(manifest_xml.XmlBool(node, 'a', None)) - self.assertEqual(123, manifest_xml.XmlBool(node, 'a', 123)) + def test_bool_default(self): + """Check XmlBool default handling.""" + node = self._get_node("") + self.assertIsNone(manifest_xml.XmlBool(node, "a")) + self.assertIsNone(manifest_xml.XmlBool(node, "a", None)) + self.assertEqual(123, manifest_xml.XmlBool(node, "a", 123)) - node = self._get_node('') - self.assertIsNone(manifest_xml.XmlBool(node, 'a')) + node = self._get_node('') + self.assertIsNone(manifest_xml.XmlBool(node, "a")) - def test_bool_invalid(self): - """Check XmlBool invalid handling.""" - node = self._get_node('') - self.assertEqual(123, manifest_xml.XmlBool(node, 'a', 123)) + def test_bool_invalid(self): + """Check XmlBool invalid handling.""" + node = self._get_node('') + self.assertEqual(123, manifest_xml.XmlBool(node, "a", 123)) - def test_bool_true(self): - """Check XmlBool true values.""" - for value in ('yes', 'true', '1'): - node = self._get_node('' % (value,)) - self.assertTrue(manifest_xml.XmlBool(node, 'a')) + def test_bool_true(self): + """Check XmlBool true values.""" + for value in ("yes", "true", "1"): + node = self._get_node('' % (value,)) + self.assertTrue(manifest_xml.XmlBool(node, "a")) - def test_bool_false(self): - """Check XmlBool false values.""" - for value in ('no', 'false', '0'): - node = self._get_node('' % (value,)) - self.assertFalse(manifest_xml.XmlBool(node, 'a')) + def test_bool_false(self): + """Check XmlBool false values.""" + for value in ("no", "false", "0"): + node = self._get_node('' % (value,)) + self.assertFalse(manifest_xml.XmlBool(node, "a")) - def test_int_default(self): - """Check XmlInt default handling.""" - node = self._get_node('') - self.assertIsNone(manifest_xml.XmlInt(node, 'a')) - self.assertIsNone(manifest_xml.XmlInt(node, 'a', None)) - self.assertEqual(123, manifest_xml.XmlInt(node, 'a', 123)) + def test_int_default(self): + """Check XmlInt default handling.""" + node = self._get_node("") + self.assertIsNone(manifest_xml.XmlInt(node, "a")) + self.assertIsNone(manifest_xml.XmlInt(node, "a", None)) + self.assertEqual(123, manifest_xml.XmlInt(node, "a", 123)) - node = self._get_node('') - self.assertIsNone(manifest_xml.XmlInt(node, 'a')) + node = self._get_node('') + self.assertIsNone(manifest_xml.XmlInt(node, "a")) - def test_int_good(self): - """Check XmlInt numeric handling.""" - for value in (-1, 0, 1, 50000): - node = self._get_node('' % (value,)) - self.assertEqual(value, manifest_xml.XmlInt(node, 'a')) + def test_int_good(self): + """Check XmlInt numeric handling.""" + for value in (-1, 0, 1, 50000): + node = self._get_node('' % (value,)) + self.assertEqual(value, manifest_xml.XmlInt(node, "a")) - def test_int_invalid(self): - """Check XmlInt invalid handling.""" - with self.assertRaises(error.ManifestParseError): - node = self._get_node('') - manifest_xml.XmlInt(node, 'a') + def test_int_invalid(self): + """Check XmlInt invalid handling.""" + with self.assertRaises(error.ManifestParseError): + node = self._get_node('') + manifest_xml.XmlInt(node, "a") class XmlManifestTests(ManifestParseTestCase): - """Check manifest processing.""" + """Check manifest processing.""" - def test_empty(self): - """Parse an 'empty' manifest file.""" - manifest = self.getXmlManifest( - '' - '') - self.assertEqual(manifest.remotes, {}) - self.assertEqual(manifest.projects, []) + def test_empty(self): + """Parse an 'empty' manifest file.""" + manifest = self.getXmlManifest( + '' "" + ) + self.assertEqual(manifest.remotes, {}) + self.assertEqual(manifest.projects, []) - def test_link(self): - """Verify Link handling with new names.""" - manifest = manifest_xml.XmlManifest(self.repodir, self.manifest_file) - with open(os.path.join(self.manifest_dir, 'foo.xml'), 'w') as fp: - fp.write('') - manifest.Link('foo.xml') - with open(self.manifest_file) as fp: - self.assertIn('', fp.read()) + def test_link(self): + """Verify Link handling with new names.""" + manifest = manifest_xml.XmlManifest(self.repodir, self.manifest_file) + with open(os.path.join(self.manifest_dir, "foo.xml"), "w") as fp: + fp.write("") + manifest.Link("foo.xml") + with open(self.manifest_file) as fp: + self.assertIn('', fp.read()) - def test_toxml_empty(self): - """Verify the ToXml() helper.""" - manifest = self.getXmlManifest( - '' - '') - self.assertEqual(manifest.ToXml().toxml(), '') + def test_toxml_empty(self): + """Verify the ToXml() helper.""" + manifest = self.getXmlManifest( + '' "" + ) + self.assertEqual( + manifest.ToXml().toxml(), '' + ) - def test_todict_empty(self): - """Verify the ToDict() helper.""" - manifest = self.getXmlManifest( - '' - '') - self.assertEqual(manifest.ToDict(), {}) + def test_todict_empty(self): + """Verify the ToDict() helper.""" + manifest = self.getXmlManifest( + '' "" + ) + self.assertEqual(manifest.ToDict(), {}) - def test_toxml_omit_local(self): - """Does not include local_manifests projects when omit_local=True.""" - manifest = self.getXmlManifest( - '' - '' - '' - '' - '' - '') - self.assertEqual( - sort_attributes(manifest.ToXml(omit_local=True).toxml()), - '' - '' - '') + def test_toxml_omit_local(self): + """Does not include local_manifests projects when omit_local=True.""" + manifest = self.getXmlManifest( + '' + '' + '' + '' + '' + "" + ) + self.assertEqual( + sort_attributes(manifest.ToXml(omit_local=True).toxml()), + '' + '' + '', + ) - def test_toxml_with_local(self): - """Does include local_manifests projects when omit_local=False.""" - manifest = self.getXmlManifest( - '' - '' - '' - '' - '' - '') - self.assertEqual( - sort_attributes(manifest.ToXml(omit_local=False).toxml()), - '' - '' - '' - '') + def test_toxml_with_local(self): + """Does include local_manifests projects when omit_local=False.""" + manifest = self.getXmlManifest( + '' + '' + '' + '' + '' + "" + ) + self.assertEqual( + sort_attributes(manifest.ToXml(omit_local=False).toxml()), + '' + '' + '' + '', + ) - def test_repo_hooks(self): - """Check repo-hooks settings.""" - manifest = self.getXmlManifest(""" + def test_repo_hooks(self): + """Check repo-hooks settings.""" + manifest = self.getXmlManifest( + """ -""") - self.assertEqual(manifest.repo_hooks_project.name, 'repohooks') - self.assertEqual(manifest.repo_hooks_project.enabled_repo_hooks, ['a', 'b']) +""" + ) + self.assertEqual(manifest.repo_hooks_project.name, "repohooks") + self.assertEqual( + manifest.repo_hooks_project.enabled_repo_hooks, ["a", "b"] + ) - def test_repo_hooks_unordered(self): - """Check repo-hooks settings work even if the project def comes second.""" - manifest = self.getXmlManifest(""" + def test_repo_hooks_unordered(self): + """Check repo-hooks settings work even if the project def comes second.""" # noqa: E501 + manifest = self.getXmlManifest( + """ -""") - self.assertEqual(manifest.repo_hooks_project.name, 'repohooks') - self.assertEqual(manifest.repo_hooks_project.enabled_repo_hooks, ['a', 'b']) +""" + ) + self.assertEqual(manifest.repo_hooks_project.name, "repohooks") + self.assertEqual( + manifest.repo_hooks_project.enabled_repo_hooks, ["a", "b"] + ) - def test_unknown_tags(self): - """Check superproject settings.""" - manifest = self.getXmlManifest(""" + def test_unknown_tags(self): + """Check superproject settings.""" + manifest = self.getXmlManifest( + """ @@ -319,44 +346,54 @@ class XmlManifestTests(ManifestParseTestCase): X tags are always ignored -""") - self.assertEqual(manifest.superproject.name, 'superproject') - self.assertEqual(manifest.superproject.remote.name, 'test-remote') - self.assertEqual( - sort_attributes(manifest.ToXml().toxml()), - '' - '' - '' - '' - '') +""" + ) + self.assertEqual(manifest.superproject.name, "superproject") + self.assertEqual(manifest.superproject.remote.name, "test-remote") + self.assertEqual( + sort_attributes(manifest.ToXml().toxml()), + '' + '' + '' + '' + "", + ) - def test_remote_annotations(self): - """Check remote settings.""" - manifest = self.getXmlManifest(""" + def test_remote_annotations(self): + """Check remote settings.""" + manifest = self.getXmlManifest( + """ -""") - self.assertEqual(manifest.remotes['test-remote'].annotations[0].name, 'foo') - self.assertEqual(manifest.remotes['test-remote'].annotations[0].value, 'bar') - self.assertEqual( - sort_attributes(manifest.ToXml().toxml()), - '' - '' - '' - '' - '') +""" + ) + self.assertEqual( + manifest.remotes["test-remote"].annotations[0].name, "foo" + ) + self.assertEqual( + manifest.remotes["test-remote"].annotations[0].value, "bar" + ) + self.assertEqual( + sort_attributes(manifest.ToXml().toxml()), + '' + '' + '' + "" + "", + ) class IncludeElementTests(ManifestParseTestCase): - """Tests for .""" + """Tests for .""" - def test_group_levels(self): - root_m = os.path.join(self.manifest_dir, 'root.xml') - with open(root_m, 'w') as fp: - fp.write(""" + def test_group_levels(self): + root_m = os.path.join(self.manifest_dir, "root.xml") + with open(root_m, "w") as fp: + fp.write( + """ @@ -364,438 +401,524 @@ class IncludeElementTests(ManifestParseTestCase): -""") - with open(os.path.join(self.manifest_dir, 'level1.xml'), 'w') as fp: - fp.write(""" +""" + ) + with open(os.path.join(self.manifest_dir, "level1.xml"), "w") as fp: + fp.write( + """ -""") - with open(os.path.join(self.manifest_dir, 'level2.xml'), 'w') as fp: - fp.write(""" +""" + ) + with open(os.path.join(self.manifest_dir, "level2.xml"), "w") as fp: + fp.write( + """ -""") - include_m = manifest_xml.XmlManifest(self.repodir, root_m) - for proj in include_m.projects: - if proj.name == 'root-name1': - # Check include group not set on root level proj. - self.assertNotIn('level1-group', proj.groups) - if proj.name == 'root-name2': - # Check root proj group not removed. - self.assertIn('r2g1', proj.groups) - if proj.name == 'level1-name1': - # Check level1 proj has inherited group level 1. - self.assertIn('level1-group', proj.groups) - if proj.name == 'level2-name1': - # Check level2 proj has inherited group levels 1 and 2. - self.assertIn('level1-group', proj.groups) - self.assertIn('level2-group', proj.groups) - # Check level2 proj group not removed. - self.assertIn('l2g1', proj.groups) +""" + ) + include_m = manifest_xml.XmlManifest(self.repodir, root_m) + for proj in include_m.projects: + if proj.name == "root-name1": + # Check include group not set on root level proj. + self.assertNotIn("level1-group", proj.groups) + if proj.name == "root-name2": + # Check root proj group not removed. + self.assertIn("r2g1", proj.groups) + if proj.name == "level1-name1": + # Check level1 proj has inherited group level 1. + self.assertIn("level1-group", proj.groups) + if proj.name == "level2-name1": + # Check level2 proj has inherited group levels 1 and 2. + self.assertIn("level1-group", proj.groups) + self.assertIn("level2-group", proj.groups) + # Check level2 proj group not removed. + self.assertIn("l2g1", proj.groups) - def test_allow_bad_name_from_user(self): - """Check handling of bad name attribute from the user's input.""" - def parse(name): - name = self.encodeXmlAttr(name) - manifest = self.getXmlManifest(f""" + def test_allow_bad_name_from_user(self): + """Check handling of bad name attribute from the user's input.""" + + def parse(name): + name = self.encodeXmlAttr(name) + manifest = self.getXmlManifest( + f""" -""") - # Force the manifest to be parsed. - manifest.ToXml() +""" + ) + # Force the manifest to be parsed. + manifest.ToXml() - # Setup target of the include. - target = os.path.join(self.tempdir, 'target.xml') - with open(target, 'w') as fp: - fp.write('') + # Setup target of the include. + target = os.path.join(self.tempdir, "target.xml") + with open(target, "w") as fp: + fp.write("") - # Include with absolute path. - parse(os.path.abspath(target)) + # Include with absolute path. + parse(os.path.abspath(target)) - # Include with relative path. - parse(os.path.relpath(target, self.manifest_dir)) + # Include with relative path. + parse(os.path.relpath(target, self.manifest_dir)) - def test_bad_name_checks(self): - """Check handling of bad name attribute.""" - def parse(name): - name = self.encodeXmlAttr(name) - # Setup target of the include. - with open(os.path.join(self.manifest_dir, 'target.xml'), 'w', encoding="utf-8") as fp: - fp.write(f'') + def test_bad_name_checks(self): + """Check handling of bad name attribute.""" - manifest = self.getXmlManifest(""" + def parse(name): + name = self.encodeXmlAttr(name) + # Setup target of the include. + with open( + os.path.join(self.manifest_dir, "target.xml"), + "w", + encoding="utf-8", + ) as fp: + fp.write(f'') + + manifest = self.getXmlManifest( + """ -""") - # Force the manifest to be parsed. - manifest.ToXml() +""" + ) + # Force the manifest to be parsed. + manifest.ToXml() - # Handle empty name explicitly because a different codepath rejects it. - with self.assertRaises(error.ManifestParseError): - parse('') + # Handle empty name explicitly because a different codepath rejects it. + with self.assertRaises(error.ManifestParseError): + parse("") - for path in INVALID_FS_PATHS: - if not path: - continue + for path in INVALID_FS_PATHS: + if not path: + continue - with self.assertRaises(error.ManifestInvalidPathError): - parse(path) + with self.assertRaises(error.ManifestInvalidPathError): + parse(path) class ProjectElementTests(ManifestParseTestCase): - """Tests for .""" + """Tests for .""" - def test_group(self): - """Check project group settings.""" - manifest = self.getXmlManifest(""" + def test_group(self): + """Check project group settings.""" + manifest = self.getXmlManifest( + """ -""") - self.assertEqual(len(manifest.projects), 2) - # Ordering isn't guaranteed. - result = { - manifest.projects[0].name: manifest.projects[0].groups, - manifest.projects[1].name: manifest.projects[1].groups, - } - project = manifest.projects[0] - self.assertCountEqual( - result['test-name'], - ['name:test-name', 'all', 'path:test-path']) - self.assertCountEqual( - result['extras'], - ['g1', 'g2', 'g1', 'name:extras', 'all', 'path:path']) - groupstr = 'default,platform-' + platform.system().lower() - self.assertEqual(groupstr, manifest.GetGroupsStr()) - groupstr = 'g1,g2,g1' - manifest.manifestProject.config.SetString('manifest.groups', groupstr) - self.assertEqual(groupstr, manifest.GetGroupsStr()) +""" + ) + self.assertEqual(len(manifest.projects), 2) + # Ordering isn't guaranteed. + result = { + manifest.projects[0].name: manifest.projects[0].groups, + manifest.projects[1].name: manifest.projects[1].groups, + } + self.assertCountEqual( + result["test-name"], ["name:test-name", "all", "path:test-path"] + ) + self.assertCountEqual( + result["extras"], + ["g1", "g2", "g1", "name:extras", "all", "path:path"], + ) + groupstr = "default,platform-" + platform.system().lower() + self.assertEqual(groupstr, manifest.GetGroupsStr()) + groupstr = "g1,g2,g1" + manifest.manifestProject.config.SetString("manifest.groups", groupstr) + self.assertEqual(groupstr, manifest.GetGroupsStr()) - def test_set_revision_id(self): - """Check setting of project's revisionId.""" - manifest = self.getXmlManifest(""" + def test_set_revision_id(self): + """Check setting of project's revisionId.""" + manifest = self.getXmlManifest( + """ -""") - self.assertEqual(len(manifest.projects), 1) - project = manifest.projects[0] - project.SetRevisionId('ABCDEF') - self.assertEqual( - sort_attributes(manifest.ToXml().toxml()), - '' - '' - '' - '' - '') +""" + ) + self.assertEqual(len(manifest.projects), 1) + project = manifest.projects[0] + project.SetRevisionId("ABCDEF") + self.assertEqual( + sort_attributes(manifest.ToXml().toxml()), + '' + '' + '' + '' # noqa: E501 + "", + ) - def test_trailing_slash(self): - """Check handling of trailing slashes in attributes.""" - def parse(name, path): - name = self.encodeXmlAttr(name) - path = self.encodeXmlAttr(path) - return self.getXmlManifest(f""" + def test_trailing_slash(self): + """Check handling of trailing slashes in attributes.""" + + def parse(name, path): + name = self.encodeXmlAttr(name) + path = self.encodeXmlAttr(path) + return self.getXmlManifest( + f""" -""") +""" + ) - manifest = parse('a/path/', 'foo') - self.assertEqual(os.path.normpath(manifest.projects[0].gitdir), - os.path.join(self.tempdir, '.repo', 'projects', 'foo.git')) - self.assertEqual(os.path.normpath(manifest.projects[0].objdir), - os.path.join(self.tempdir, '.repo', 'project-objects', 'a', 'path.git')) + manifest = parse("a/path/", "foo") + self.assertEqual( + os.path.normpath(manifest.projects[0].gitdir), + os.path.join(self.tempdir, ".repo", "projects", "foo.git"), + ) + self.assertEqual( + os.path.normpath(manifest.projects[0].objdir), + os.path.join( + self.tempdir, ".repo", "project-objects", "a", "path.git" + ), + ) - manifest = parse('a/path', 'foo/') - self.assertEqual(os.path.normpath(manifest.projects[0].gitdir), - os.path.join(self.tempdir, '.repo', 'projects', 'foo.git')) - self.assertEqual(os.path.normpath(manifest.projects[0].objdir), - os.path.join(self.tempdir, '.repo', 'project-objects', 'a', 'path.git')) + manifest = parse("a/path", "foo/") + self.assertEqual( + os.path.normpath(manifest.projects[0].gitdir), + os.path.join(self.tempdir, ".repo", "projects", "foo.git"), + ) + self.assertEqual( + os.path.normpath(manifest.projects[0].objdir), + os.path.join( + self.tempdir, ".repo", "project-objects", "a", "path.git" + ), + ) - manifest = parse('a/path', 'foo//////') - self.assertEqual(os.path.normpath(manifest.projects[0].gitdir), - os.path.join(self.tempdir, '.repo', 'projects', 'foo.git')) - self.assertEqual(os.path.normpath(manifest.projects[0].objdir), - os.path.join(self.tempdir, '.repo', 'project-objects', 'a', 'path.git')) + manifest = parse("a/path", "foo//////") + self.assertEqual( + os.path.normpath(manifest.projects[0].gitdir), + os.path.join(self.tempdir, ".repo", "projects", "foo.git"), + ) + self.assertEqual( + os.path.normpath(manifest.projects[0].objdir), + os.path.join( + self.tempdir, ".repo", "project-objects", "a", "path.git" + ), + ) - def test_toplevel_path(self): - """Check handling of path=. specially.""" - def parse(name, path): - name = self.encodeXmlAttr(name) - path = self.encodeXmlAttr(path) - return self.getXmlManifest(f""" + def test_toplevel_path(self): + """Check handling of path=. specially.""" + + def parse(name, path): + name = self.encodeXmlAttr(name) + path = self.encodeXmlAttr(path) + return self.getXmlManifest( + f""" -""") +""" + ) - for path in ('.', './', './/', './//'): - manifest = parse('server/path', path) - self.assertEqual(os.path.normpath(manifest.projects[0].gitdir), - os.path.join(self.tempdir, '.repo', 'projects', '..git')) + for path in (".", "./", ".//", ".///"): + manifest = parse("server/path", path) + self.assertEqual( + os.path.normpath(manifest.projects[0].gitdir), + os.path.join(self.tempdir, ".repo", "projects", "..git"), + ) - def test_bad_path_name_checks(self): - """Check handling of bad path & name attributes.""" - def parse(name, path): - name = self.encodeXmlAttr(name) - path = self.encodeXmlAttr(path) - manifest = self.getXmlManifest(f""" + def test_bad_path_name_checks(self): + """Check handling of bad path & name attributes.""" + + def parse(name, path): + name = self.encodeXmlAttr(name) + path = self.encodeXmlAttr(path) + manifest = self.getXmlManifest( + f""" -""") - # Force the manifest to be parsed. - manifest.ToXml() +""" + ) + # Force the manifest to be parsed. + manifest.ToXml() - # Verify the parser is valid by default to avoid buggy tests below. - parse('ok', 'ok') + # Verify the parser is valid by default to avoid buggy tests below. + parse("ok", "ok") - # Handle empty name explicitly because a different codepath rejects it. - # Empty path is OK because it defaults to the name field. - with self.assertRaises(error.ManifestParseError): - parse('', 'ok') + # Handle empty name explicitly because a different codepath rejects it. + # Empty path is OK because it defaults to the name field. + with self.assertRaises(error.ManifestParseError): + parse("", "ok") - for path in INVALID_FS_PATHS: - if not path or path.endswith('/') or path.endswith(os.path.sep): - continue + for path in INVALID_FS_PATHS: + if not path or path.endswith("/") or path.endswith(os.path.sep): + continue - with self.assertRaises(error.ManifestInvalidPathError): - parse(path, 'ok') + with self.assertRaises(error.ManifestInvalidPathError): + parse(path, "ok") - # We have a dedicated test for path=".". - if path not in {'.'}: - with self.assertRaises(error.ManifestInvalidPathError): - parse('ok', path) + # We have a dedicated test for path=".". + if path not in {"."}: + with self.assertRaises(error.ManifestInvalidPathError): + parse("ok", path) class SuperProjectElementTests(ManifestParseTestCase): - """Tests for .""" + """Tests for .""" - def test_superproject(self): - """Check superproject settings.""" - manifest = self.getXmlManifest(""" + def test_superproject(self): + """Check superproject settings.""" + manifest = self.getXmlManifest( + """ -""") - self.assertEqual(manifest.superproject.name, 'superproject') - self.assertEqual(manifest.superproject.remote.name, 'test-remote') - self.assertEqual(manifest.superproject.remote.url, 'http://localhost/superproject') - self.assertEqual(manifest.superproject.revision, 'refs/heads/main') - self.assertEqual( - sort_attributes(manifest.ToXml().toxml()), - '' - '' - '' - '' - '') +""" + ) + self.assertEqual(manifest.superproject.name, "superproject") + self.assertEqual(manifest.superproject.remote.name, "test-remote") + self.assertEqual( + manifest.superproject.remote.url, "http://localhost/superproject" + ) + self.assertEqual(manifest.superproject.revision, "refs/heads/main") + self.assertEqual( + sort_attributes(manifest.ToXml().toxml()), + '' + '' + '' + '' + "", + ) - def test_superproject_revision(self): - """Check superproject settings with a different revision attribute""" - self.maxDiff = None - manifest = self.getXmlManifest(""" + def test_superproject_revision(self): + """Check superproject settings with a different revision attribute""" + self.maxDiff = None + manifest = self.getXmlManifest( + """ -""") - self.assertEqual(manifest.superproject.name, 'superproject') - self.assertEqual(manifest.superproject.remote.name, 'test-remote') - self.assertEqual(manifest.superproject.remote.url, 'http://localhost/superproject') - self.assertEqual(manifest.superproject.revision, 'refs/heads/stable') - self.assertEqual( - sort_attributes(manifest.ToXml().toxml()), - '' - '' - '' - '' - '') +""" + ) + self.assertEqual(manifest.superproject.name, "superproject") + self.assertEqual(manifest.superproject.remote.name, "test-remote") + self.assertEqual( + manifest.superproject.remote.url, "http://localhost/superproject" + ) + self.assertEqual(manifest.superproject.revision, "refs/heads/stable") + self.assertEqual( + sort_attributes(manifest.ToXml().toxml()), + '' + '' + '' + '' + "", + ) - def test_superproject_revision_default_negative(self): - """Check superproject settings with a same revision attribute""" - self.maxDiff = None - manifest = self.getXmlManifest(""" + def test_superproject_revision_default_negative(self): + """Check superproject settings with a same revision attribute""" + self.maxDiff = None + manifest = self.getXmlManifest( + """ -""") - self.assertEqual(manifest.superproject.name, 'superproject') - self.assertEqual(manifest.superproject.remote.name, 'test-remote') - self.assertEqual(manifest.superproject.remote.url, 'http://localhost/superproject') - self.assertEqual(manifest.superproject.revision, 'refs/heads/stable') - self.assertEqual( - sort_attributes(manifest.ToXml().toxml()), - '' - '' - '' - '' - '') +""" + ) + self.assertEqual(manifest.superproject.name, "superproject") + self.assertEqual(manifest.superproject.remote.name, "test-remote") + self.assertEqual( + manifest.superproject.remote.url, "http://localhost/superproject" + ) + self.assertEqual(manifest.superproject.revision, "refs/heads/stable") + self.assertEqual( + sort_attributes(manifest.ToXml().toxml()), + '' + '' + '' + '' + "", + ) - def test_superproject_revision_remote(self): - """Check superproject settings with a same revision attribute""" - self.maxDiff = None - manifest = self.getXmlManifest(""" + def test_superproject_revision_remote(self): + """Check superproject settings with a same revision attribute""" + self.maxDiff = None + manifest = self.getXmlManifest( + """ -""") - self.assertEqual(manifest.superproject.name, 'superproject') - self.assertEqual(manifest.superproject.remote.name, 'test-remote') - self.assertEqual(manifest.superproject.remote.url, 'http://localhost/superproject') - self.assertEqual(manifest.superproject.revision, 'refs/heads/stable') - self.assertEqual( - sort_attributes(manifest.ToXml().toxml()), - '' - '' - '' - '' - '') +""" # noqa: E501 + ) + self.assertEqual(manifest.superproject.name, "superproject") + self.assertEqual(manifest.superproject.remote.name, "test-remote") + self.assertEqual( + manifest.superproject.remote.url, "http://localhost/superproject" + ) + self.assertEqual(manifest.superproject.revision, "refs/heads/stable") + self.assertEqual( + sort_attributes(manifest.ToXml().toxml()), + '' + '' # noqa: E501 + '' + '' + "", + ) - def test_remote(self): - """Check superproject settings with a remote.""" - manifest = self.getXmlManifest(""" + def test_remote(self): + """Check superproject settings with a remote.""" + manifest = self.getXmlManifest( + """ -""") - self.assertEqual(manifest.superproject.name, 'platform/superproject') - self.assertEqual(manifest.superproject.remote.name, 'superproject-remote') - self.assertEqual(manifest.superproject.remote.url, 'http://localhost/platform/superproject') - self.assertEqual(manifest.superproject.revision, 'refs/heads/main') - self.assertEqual( - sort_attributes(manifest.ToXml().toxml()), - '' - '' - '' - '' - '' - '') +""" + ) + self.assertEqual(manifest.superproject.name, "platform/superproject") + self.assertEqual( + manifest.superproject.remote.name, "superproject-remote" + ) + self.assertEqual( + manifest.superproject.remote.url, + "http://localhost/platform/superproject", + ) + self.assertEqual(manifest.superproject.revision, "refs/heads/main") + self.assertEqual( + sort_attributes(manifest.ToXml().toxml()), + '' + '' + '' + '' + '' # noqa: E501 + "", + ) - def test_defalut_remote(self): - """Check superproject settings with a default remote.""" - manifest = self.getXmlManifest(""" + def test_defalut_remote(self): + """Check superproject settings with a default remote.""" + manifest = self.getXmlManifest( + """ -""") - self.assertEqual(manifest.superproject.name, 'superproject') - self.assertEqual(manifest.superproject.remote.name, 'default-remote') - self.assertEqual(manifest.superproject.revision, 'refs/heads/main') - self.assertEqual( - sort_attributes(manifest.ToXml().toxml()), - '' - '' - '' - '' - '') +""" + ) + self.assertEqual(manifest.superproject.name, "superproject") + self.assertEqual(manifest.superproject.remote.name, "default-remote") + self.assertEqual(manifest.superproject.revision, "refs/heads/main") + self.assertEqual( + sort_attributes(manifest.ToXml().toxml()), + '' + '' + '' + '' + "", + ) class ContactinfoElementTests(ManifestParseTestCase): - """Tests for .""" + """Tests for .""" - def test_contactinfo(self): - """Check contactinfo settings.""" - bugurl = 'http://localhost/contactinfo' - manifest = self.getXmlManifest(f""" + def test_contactinfo(self): + """Check contactinfo settings.""" + bugurl = "http://localhost/contactinfo" + manifest = self.getXmlManifest( + f""" -""") - self.assertEqual(manifest.contactinfo.bugurl, bugurl) - self.assertEqual( - manifest.ToXml().toxml(), - '' - f'' - '') +""" + ) + self.assertEqual(manifest.contactinfo.bugurl, bugurl) + self.assertEqual( + manifest.ToXml().toxml(), + '' + f'' + "", + ) class DefaultElementTests(ManifestParseTestCase): - """Tests for .""" + """Tests for .""" - def test_default(self): - """Check default settings.""" - a = manifest_xml._Default() - a.revisionExpr = 'foo' - a.remote = manifest_xml._XmlRemote(name='remote') - b = manifest_xml._Default() - b.revisionExpr = 'bar' - self.assertEqual(a, a) - self.assertNotEqual(a, b) - self.assertNotEqual(b, a.remote) - self.assertNotEqual(a, 123) - self.assertNotEqual(a, None) + def test_default(self): + """Check default settings.""" + a = manifest_xml._Default() + a.revisionExpr = "foo" + a.remote = manifest_xml._XmlRemote(name="remote") + b = manifest_xml._Default() + b.revisionExpr = "bar" + self.assertEqual(a, a) + self.assertNotEqual(a, b) + self.assertNotEqual(b, a.remote) + self.assertNotEqual(a, 123) + self.assertNotEqual(a, None) class RemoteElementTests(ManifestParseTestCase): - """Tests for .""" + """Tests for .""" - def test_remote(self): - """Check remote settings.""" - a = manifest_xml._XmlRemote(name='foo') - a.AddAnnotation('key1', 'value1', 'true') - b = manifest_xml._XmlRemote(name='foo') - b.AddAnnotation('key2', 'value1', 'true') - c = manifest_xml._XmlRemote(name='foo') - c.AddAnnotation('key1', 'value2', 'true') - d = manifest_xml._XmlRemote(name='foo') - d.AddAnnotation('key1', 'value1', 'false') - self.assertEqual(a, a) - self.assertNotEqual(a, b) - self.assertNotEqual(a, c) - self.assertNotEqual(a, d) - self.assertNotEqual(a, manifest_xml._Default()) - self.assertNotEqual(a, 123) - self.assertNotEqual(a, None) + def test_remote(self): + """Check remote settings.""" + a = manifest_xml._XmlRemote(name="foo") + a.AddAnnotation("key1", "value1", "true") + b = manifest_xml._XmlRemote(name="foo") + b.AddAnnotation("key2", "value1", "true") + c = manifest_xml._XmlRemote(name="foo") + c.AddAnnotation("key1", "value2", "true") + d = manifest_xml._XmlRemote(name="foo") + d.AddAnnotation("key1", "value1", "false") + self.assertEqual(a, a) + self.assertNotEqual(a, b) + self.assertNotEqual(a, c) + self.assertNotEqual(a, d) + self.assertNotEqual(a, manifest_xml._Default()) + self.assertNotEqual(a, 123) + self.assertNotEqual(a, None) class RemoveProjectElementTests(ManifestParseTestCase): - """Tests for .""" + """Tests for .""" - def test_remove_one_project(self): - manifest = self.getXmlManifest(""" + def test_remove_one_project(self): + manifest = self.getXmlManifest( + """ -""") - self.assertEqual(manifest.projects, []) +""" + ) + self.assertEqual(manifest.projects, []) - def test_remove_one_project_one_remains(self): - manifest = self.getXmlManifest(""" + def test_remove_one_project_one_remains(self): + manifest = self.getXmlManifest( + """ @@ -803,51 +926,59 @@ class RemoveProjectElementTests(ManifestParseTestCase): -""") +""" + ) - self.assertEqual(len(manifest.projects), 1) - self.assertEqual(manifest.projects[0].name, 'yourproject') + self.assertEqual(len(manifest.projects), 1) + self.assertEqual(manifest.projects[0].name, "yourproject") - def test_remove_one_project_doesnt_exist(self): - with self.assertRaises(manifest_xml.ManifestParseError): - manifest = self.getXmlManifest(""" + def test_remove_one_project_doesnt_exist(self): + with self.assertRaises(manifest_xml.ManifestParseError): + manifest = self.getXmlManifest( + """ -""") - manifest.projects +""" + ) + manifest.projects - def test_remove_one_optional_project_doesnt_exist(self): - manifest = self.getXmlManifest(""" + def test_remove_one_optional_project_doesnt_exist(self): + manifest = self.getXmlManifest( + """ -""") - self.assertEqual(manifest.projects, []) +""" + ) + self.assertEqual(manifest.projects, []) class ExtendProjectElementTests(ManifestParseTestCase): - """Tests for .""" + """Tests for .""" - def test_extend_project_dest_path_single_match(self): - manifest = self.getXmlManifest(""" + def test_extend_project_dest_path_single_match(self): + manifest = self.getXmlManifest( + """ -""") - self.assertEqual(len(manifest.projects), 1) - self.assertEqual(manifest.projects[0].relpath, 'bar') +""" + ) + self.assertEqual(len(manifest.projects), 1) + self.assertEqual(manifest.projects[0].relpath, "bar") - def test_extend_project_dest_path_multi_match(self): - with self.assertRaises(manifest_xml.ManifestParseError): - manifest = self.getXmlManifest(""" + def test_extend_project_dest_path_multi_match(self): + with self.assertRaises(manifest_xml.ManifestParseError): + manifest = self.getXmlManifest( + """ @@ -855,11 +986,13 @@ class ExtendProjectElementTests(ManifestParseTestCase): -""") - manifest.projects +""" + ) + manifest.projects - def test_extend_project_dest_path_multi_match_path_specified(self): - manifest = self.getXmlManifest(""" + def test_extend_project_dest_path_multi_match_path_specified(self): + manifest = self.getXmlManifest( + """ @@ -867,34 +1000,39 @@ class ExtendProjectElementTests(ManifestParseTestCase): -""") - self.assertEqual(len(manifest.projects), 2) - if manifest.projects[0].relpath == 'y': - self.assertEqual(manifest.projects[1].relpath, 'bar') - else: - self.assertEqual(manifest.projects[0].relpath, 'bar') - self.assertEqual(manifest.projects[1].relpath, 'y') +""" + ) + self.assertEqual(len(manifest.projects), 2) + if manifest.projects[0].relpath == "y": + self.assertEqual(manifest.projects[1].relpath, "bar") + else: + self.assertEqual(manifest.projects[0].relpath, "bar") + self.assertEqual(manifest.projects[1].relpath, "y") - def test_extend_project_dest_branch(self): - manifest = self.getXmlManifest(""" + def test_extend_project_dest_branch(self): + manifest = self.getXmlManifest( + """ -""") - self.assertEqual(len(manifest.projects), 1) - self.assertEqual(manifest.projects[0].dest_branch, 'bar') +""" # noqa: E501 + ) + self.assertEqual(len(manifest.projects), 1) + self.assertEqual(manifest.projects[0].dest_branch, "bar") - def test_extend_project_upstream(self): - manifest = self.getXmlManifest(""" + def test_extend_project_upstream(self): + manifest = self.getXmlManifest( + """ -""") - self.assertEqual(len(manifest.projects), 1) - self.assertEqual(manifest.projects[0].upstream, 'bar') +""" + ) + self.assertEqual(len(manifest.projects), 1) + self.assertEqual(manifest.projects[0].upstream, "bar") diff --git a/tests/test_platform_utils.py b/tests/test_platform_utils.py index 55b7805c..7a42de01 100644 --- a/tests/test_platform_utils.py +++ b/tests/test_platform_utils.py @@ -22,29 +22,31 @@ import platform_utils class RemoveTests(unittest.TestCase): - """Check remove() helper.""" + """Check remove() helper.""" - def testMissingOk(self): - """Check missing_ok handling.""" - with tempfile.TemporaryDirectory() as tmpdir: - path = os.path.join(tmpdir, 'test') + def testMissingOk(self): + """Check missing_ok handling.""" + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, "test") - # Should not fail. - platform_utils.remove(path, missing_ok=True) + # Should not fail. + platform_utils.remove(path, missing_ok=True) - # Should fail. - self.assertRaises(OSError, platform_utils.remove, path) - self.assertRaises(OSError, platform_utils.remove, path, missing_ok=False) + # Should fail. + self.assertRaises(OSError, platform_utils.remove, path) + self.assertRaises( + OSError, platform_utils.remove, path, missing_ok=False + ) - # Should not fail if it exists. - open(path, 'w').close() - platform_utils.remove(path, missing_ok=True) - self.assertFalse(os.path.exists(path)) + # Should not fail if it exists. + open(path, "w").close() + platform_utils.remove(path, missing_ok=True) + self.assertFalse(os.path.exists(path)) - open(path, 'w').close() - platform_utils.remove(path) - self.assertFalse(os.path.exists(path)) + open(path, "w").close() + platform_utils.remove(path) + self.assertFalse(os.path.exists(path)) - open(path, 'w').close() - platform_utils.remove(path, missing_ok=False) - self.assertFalse(os.path.exists(path)) + open(path, "w").close() + platform_utils.remove(path, missing_ok=False) + self.assertFalse(os.path.exists(path)) diff --git a/tests/test_project.py b/tests/test_project.py index c50d9940..bc8330b2 100644 --- a/tests/test_project.py +++ b/tests/test_project.py @@ -31,452 +31,493 @@ import project @contextlib.contextmanager def TempGitTree(): - """Create a new empty git checkout for testing.""" - with tempfile.TemporaryDirectory(prefix='repo-tests') as tempdir: - # Tests need to assume, that main is default branch at init, - # which is not supported in config until 2.28. - cmd = ['git', 'init'] - if git_command.git_require((2, 28, 0)): - cmd += ['--initial-branch=main'] - else: - # Use template dir for init. - templatedir = tempfile.mkdtemp(prefix='.test-template') - with open(os.path.join(templatedir, 'HEAD'), 'w') as fp: - fp.write('ref: refs/heads/main\n') - cmd += ['--template', templatedir] - subprocess.check_call(cmd, cwd=tempdir) - yield tempdir + """Create a new empty git checkout for testing.""" + with tempfile.TemporaryDirectory(prefix="repo-tests") as tempdir: + # Tests need to assume, that main is default branch at init, + # which is not supported in config until 2.28. + cmd = ["git", "init"] + if git_command.git_require((2, 28, 0)): + cmd += ["--initial-branch=main"] + else: + # Use template dir for init. + templatedir = tempfile.mkdtemp(prefix=".test-template") + with open(os.path.join(templatedir, "HEAD"), "w") as fp: + fp.write("ref: refs/heads/main\n") + cmd += ["--template", templatedir] + subprocess.check_call(cmd, cwd=tempdir) + yield tempdir class FakeProject(object): - """A fake for Project for basic functionality.""" + """A fake for Project for basic functionality.""" - def __init__(self, worktree): - self.worktree = worktree - self.gitdir = os.path.join(worktree, '.git') - self.name = 'fakeproject' - self.work_git = project.Project._GitGetByExec( - self, bare=False, gitdir=self.gitdir) - self.bare_git = project.Project._GitGetByExec( - self, bare=True, gitdir=self.gitdir) - self.config = git_config.GitConfig.ForRepository(gitdir=self.gitdir) + def __init__(self, worktree): + self.worktree = worktree + self.gitdir = os.path.join(worktree, ".git") + self.name = "fakeproject" + self.work_git = project.Project._GitGetByExec( + self, bare=False, gitdir=self.gitdir + ) + self.bare_git = project.Project._GitGetByExec( + self, bare=True, gitdir=self.gitdir + ) + self.config = git_config.GitConfig.ForRepository(gitdir=self.gitdir) class ReviewableBranchTests(unittest.TestCase): - """Check ReviewableBranch behavior.""" + """Check ReviewableBranch behavior.""" - def test_smoke(self): - """A quick run through everything.""" - with TempGitTree() as tempdir: - fakeproj = FakeProject(tempdir) + def test_smoke(self): + """A quick run through everything.""" + with TempGitTree() as tempdir: + fakeproj = FakeProject(tempdir) - # Generate some commits. - with open(os.path.join(tempdir, 'readme'), 'w') as fp: - fp.write('txt') - fakeproj.work_git.add('readme') - fakeproj.work_git.commit('-mAdd file') - fakeproj.work_git.checkout('-b', 'work') - fakeproj.work_git.rm('-f', 'readme') - fakeproj.work_git.commit('-mDel file') + # Generate some commits. + with open(os.path.join(tempdir, "readme"), "w") as fp: + fp.write("txt") + fakeproj.work_git.add("readme") + fakeproj.work_git.commit("-mAdd file") + fakeproj.work_git.checkout("-b", "work") + fakeproj.work_git.rm("-f", "readme") + fakeproj.work_git.commit("-mDel file") - # Start off with the normal details. - rb = project.ReviewableBranch( - fakeproj, fakeproj.config.GetBranch('work'), 'main') - self.assertEqual('work', rb.name) - self.assertEqual(1, len(rb.commits)) - self.assertIn('Del file', rb.commits[0]) - d = rb.unabbrev_commits - self.assertEqual(1, len(d)) - short, long = next(iter(d.items())) - self.assertTrue(long.startswith(short)) - self.assertTrue(rb.base_exists) - # Hard to assert anything useful about this. - self.assertTrue(rb.date) + # Start off with the normal details. + rb = project.ReviewableBranch( + fakeproj, fakeproj.config.GetBranch("work"), "main" + ) + self.assertEqual("work", rb.name) + self.assertEqual(1, len(rb.commits)) + self.assertIn("Del file", rb.commits[0]) + d = rb.unabbrev_commits + self.assertEqual(1, len(d)) + short, long = next(iter(d.items())) + self.assertTrue(long.startswith(short)) + self.assertTrue(rb.base_exists) + # Hard to assert anything useful about this. + self.assertTrue(rb.date) - # Now delete the tracking branch! - fakeproj.work_git.branch('-D', 'main') - rb = project.ReviewableBranch( - fakeproj, fakeproj.config.GetBranch('work'), 'main') - self.assertEqual(0, len(rb.commits)) - self.assertFalse(rb.base_exists) - # Hard to assert anything useful about this. - self.assertTrue(rb.date) + # Now delete the tracking branch! + fakeproj.work_git.branch("-D", "main") + rb = project.ReviewableBranch( + fakeproj, fakeproj.config.GetBranch("work"), "main" + ) + self.assertEqual(0, len(rb.commits)) + self.assertFalse(rb.base_exists) + # Hard to assert anything useful about this. + self.assertTrue(rb.date) class CopyLinkTestCase(unittest.TestCase): - """TestCase for stub repo client checkouts. + """TestCase for stub repo client checkouts. - It'll have a layout like this: - tempdir/ # self.tempdir - checkout/ # self.topdir - git-project/ # self.worktree + It'll have a layout like this: + tempdir/ # self.tempdir + checkout/ # self.topdir + git-project/ # self.worktree - Attributes: - tempdir: A dedicated temporary directory. - worktree: The top of the repo client checkout. - topdir: The top of a project checkout. - """ + Attributes: + tempdir: A dedicated temporary directory. + worktree: The top of the repo client checkout. + topdir: The top of a project checkout. + """ - def setUp(self): - self.tempdirobj = tempfile.TemporaryDirectory(prefix='repo_tests') - self.tempdir = self.tempdirobj.name - self.topdir = os.path.join(self.tempdir, 'checkout') - self.worktree = os.path.join(self.topdir, 'git-project') - os.makedirs(self.topdir) - os.makedirs(self.worktree) + def setUp(self): + self.tempdirobj = tempfile.TemporaryDirectory(prefix="repo_tests") + self.tempdir = self.tempdirobj.name + self.topdir = os.path.join(self.tempdir, "checkout") + self.worktree = os.path.join(self.topdir, "git-project") + os.makedirs(self.topdir) + os.makedirs(self.worktree) - def tearDown(self): - self.tempdirobj.cleanup() + def tearDown(self): + self.tempdirobj.cleanup() - @staticmethod - def touch(path): - with open(path, 'w'): - pass + @staticmethod + def touch(path): + with open(path, "w"): + pass - def assertExists(self, path, msg=None): - """Make sure |path| exists.""" - if os.path.exists(path): - return + def assertExists(self, path, msg=None): + """Make sure |path| exists.""" + if os.path.exists(path): + return - if msg is None: - msg = ['path is missing: %s' % path] - while path != '/': - path = os.path.dirname(path) - if not path: - # If we're given something like "foo", abort once we get to "". - break - result = os.path.exists(path) - msg.append('\tos.path.exists(%s): %s' % (path, result)) - if result: - msg.append('\tcontents: %r' % os.listdir(path)) - break - msg = '\n'.join(msg) + if msg is None: + msg = ["path is missing: %s" % path] + while path != "/": + path = os.path.dirname(path) + if not path: + # If we're given something like "foo", abort once we get to + # "". + break + result = os.path.exists(path) + msg.append("\tos.path.exists(%s): %s" % (path, result)) + if result: + msg.append("\tcontents: %r" % os.listdir(path)) + break + msg = "\n".join(msg) - raise self.failureException(msg) + raise self.failureException(msg) class CopyFile(CopyLinkTestCase): - """Check _CopyFile handling.""" + """Check _CopyFile handling.""" - def CopyFile(self, src, dest): - return project._CopyFile(self.worktree, src, self.topdir, dest) + def CopyFile(self, src, dest): + return project._CopyFile(self.worktree, src, self.topdir, dest) - def test_basic(self): - """Basic test of copying a file from a project to the toplevel.""" - src = os.path.join(self.worktree, 'foo.txt') - self.touch(src) - cf = self.CopyFile('foo.txt', 'foo') - cf._Copy() - self.assertExists(os.path.join(self.topdir, 'foo')) + def test_basic(self): + """Basic test of copying a file from a project to the toplevel.""" + src = os.path.join(self.worktree, "foo.txt") + self.touch(src) + cf = self.CopyFile("foo.txt", "foo") + cf._Copy() + self.assertExists(os.path.join(self.topdir, "foo")) - def test_src_subdir(self): - """Copy a file from a subdir of a project.""" - src = os.path.join(self.worktree, 'bar', 'foo.txt') - os.makedirs(os.path.dirname(src)) - self.touch(src) - cf = self.CopyFile('bar/foo.txt', 'new.txt') - cf._Copy() - self.assertExists(os.path.join(self.topdir, 'new.txt')) + def test_src_subdir(self): + """Copy a file from a subdir of a project.""" + src = os.path.join(self.worktree, "bar", "foo.txt") + os.makedirs(os.path.dirname(src)) + self.touch(src) + cf = self.CopyFile("bar/foo.txt", "new.txt") + cf._Copy() + self.assertExists(os.path.join(self.topdir, "new.txt")) - def test_dest_subdir(self): - """Copy a file to a subdir of a checkout.""" - src = os.path.join(self.worktree, 'foo.txt') - self.touch(src) - cf = self.CopyFile('foo.txt', 'sub/dir/new.txt') - self.assertFalse(os.path.exists(os.path.join(self.topdir, 'sub'))) - cf._Copy() - self.assertExists(os.path.join(self.topdir, 'sub', 'dir', 'new.txt')) + def test_dest_subdir(self): + """Copy a file to a subdir of a checkout.""" + src = os.path.join(self.worktree, "foo.txt") + self.touch(src) + cf = self.CopyFile("foo.txt", "sub/dir/new.txt") + self.assertFalse(os.path.exists(os.path.join(self.topdir, "sub"))) + cf._Copy() + self.assertExists(os.path.join(self.topdir, "sub", "dir", "new.txt")) - def test_update(self): - """Make sure changed files get copied again.""" - src = os.path.join(self.worktree, 'foo.txt') - dest = os.path.join(self.topdir, 'bar') - with open(src, 'w') as f: - f.write('1st') - cf = self.CopyFile('foo.txt', 'bar') - cf._Copy() - self.assertExists(dest) - with open(dest) as f: - self.assertEqual(f.read(), '1st') + def test_update(self): + """Make sure changed files get copied again.""" + src = os.path.join(self.worktree, "foo.txt") + dest = os.path.join(self.topdir, "bar") + with open(src, "w") as f: + f.write("1st") + cf = self.CopyFile("foo.txt", "bar") + cf._Copy() + self.assertExists(dest) + with open(dest) as f: + self.assertEqual(f.read(), "1st") - with open(src, 'w') as f: - f.write('2nd!') - cf._Copy() - with open(dest) as f: - self.assertEqual(f.read(), '2nd!') + with open(src, "w") as f: + f.write("2nd!") + cf._Copy() + with open(dest) as f: + self.assertEqual(f.read(), "2nd!") - def test_src_block_symlink(self): - """Do not allow reading from a symlinked path.""" - src = os.path.join(self.worktree, 'foo.txt') - sym = os.path.join(self.worktree, 'sym') - self.touch(src) - platform_utils.symlink('foo.txt', sym) - self.assertExists(sym) - cf = self.CopyFile('sym', 'foo') - self.assertRaises(error.ManifestInvalidPathError, cf._Copy) + def test_src_block_symlink(self): + """Do not allow reading from a symlinked path.""" + src = os.path.join(self.worktree, "foo.txt") + sym = os.path.join(self.worktree, "sym") + self.touch(src) + platform_utils.symlink("foo.txt", sym) + self.assertExists(sym) + cf = self.CopyFile("sym", "foo") + self.assertRaises(error.ManifestInvalidPathError, cf._Copy) - def test_src_block_symlink_traversal(self): - """Do not allow reading through a symlink dir.""" - realfile = os.path.join(self.tempdir, 'file.txt') - self.touch(realfile) - src = os.path.join(self.worktree, 'bar', 'file.txt') - platform_utils.symlink(self.tempdir, os.path.join(self.worktree, 'bar')) - self.assertExists(src) - cf = self.CopyFile('bar/file.txt', 'foo') - self.assertRaises(error.ManifestInvalidPathError, cf._Copy) + def test_src_block_symlink_traversal(self): + """Do not allow reading through a symlink dir.""" + realfile = os.path.join(self.tempdir, "file.txt") + self.touch(realfile) + src = os.path.join(self.worktree, "bar", "file.txt") + platform_utils.symlink(self.tempdir, os.path.join(self.worktree, "bar")) + self.assertExists(src) + cf = self.CopyFile("bar/file.txt", "foo") + self.assertRaises(error.ManifestInvalidPathError, cf._Copy) - def test_src_block_copy_from_dir(self): - """Do not allow copying from a directory.""" - src = os.path.join(self.worktree, 'dir') - os.makedirs(src) - cf = self.CopyFile('dir', 'foo') - self.assertRaises(error.ManifestInvalidPathError, cf._Copy) + def test_src_block_copy_from_dir(self): + """Do not allow copying from a directory.""" + src = os.path.join(self.worktree, "dir") + os.makedirs(src) + cf = self.CopyFile("dir", "foo") + self.assertRaises(error.ManifestInvalidPathError, cf._Copy) - def test_dest_block_symlink(self): - """Do not allow writing to a symlink.""" - src = os.path.join(self.worktree, 'foo.txt') - self.touch(src) - platform_utils.symlink('dest', os.path.join(self.topdir, 'sym')) - cf = self.CopyFile('foo.txt', 'sym') - self.assertRaises(error.ManifestInvalidPathError, cf._Copy) + def test_dest_block_symlink(self): + """Do not allow writing to a symlink.""" + src = os.path.join(self.worktree, "foo.txt") + self.touch(src) + platform_utils.symlink("dest", os.path.join(self.topdir, "sym")) + cf = self.CopyFile("foo.txt", "sym") + self.assertRaises(error.ManifestInvalidPathError, cf._Copy) - def test_dest_block_symlink_traversal(self): - """Do not allow writing through a symlink dir.""" - src = os.path.join(self.worktree, 'foo.txt') - self.touch(src) - platform_utils.symlink(tempfile.gettempdir(), - os.path.join(self.topdir, 'sym')) - cf = self.CopyFile('foo.txt', 'sym/foo.txt') - self.assertRaises(error.ManifestInvalidPathError, cf._Copy) + def test_dest_block_symlink_traversal(self): + """Do not allow writing through a symlink dir.""" + src = os.path.join(self.worktree, "foo.txt") + self.touch(src) + platform_utils.symlink( + tempfile.gettempdir(), os.path.join(self.topdir, "sym") + ) + cf = self.CopyFile("foo.txt", "sym/foo.txt") + self.assertRaises(error.ManifestInvalidPathError, cf._Copy) - def test_src_block_copy_to_dir(self): - """Do not allow copying to a directory.""" - src = os.path.join(self.worktree, 'foo.txt') - self.touch(src) - os.makedirs(os.path.join(self.topdir, 'dir')) - cf = self.CopyFile('foo.txt', 'dir') - self.assertRaises(error.ManifestInvalidPathError, cf._Copy) + def test_src_block_copy_to_dir(self): + """Do not allow copying to a directory.""" + src = os.path.join(self.worktree, "foo.txt") + self.touch(src) + os.makedirs(os.path.join(self.topdir, "dir")) + cf = self.CopyFile("foo.txt", "dir") + self.assertRaises(error.ManifestInvalidPathError, cf._Copy) class LinkFile(CopyLinkTestCase): - """Check _LinkFile handling.""" + """Check _LinkFile handling.""" - def LinkFile(self, src, dest): - return project._LinkFile(self.worktree, src, self.topdir, dest) + def LinkFile(self, src, dest): + return project._LinkFile(self.worktree, src, self.topdir, dest) - def test_basic(self): - """Basic test of linking a file from a project into the toplevel.""" - src = os.path.join(self.worktree, 'foo.txt') - self.touch(src) - lf = self.LinkFile('foo.txt', 'foo') - lf._Link() - dest = os.path.join(self.topdir, 'foo') - self.assertExists(dest) - self.assertTrue(os.path.islink(dest)) - self.assertEqual(os.path.join('git-project', 'foo.txt'), os.readlink(dest)) + def test_basic(self): + """Basic test of linking a file from a project into the toplevel.""" + src = os.path.join(self.worktree, "foo.txt") + self.touch(src) + lf = self.LinkFile("foo.txt", "foo") + lf._Link() + dest = os.path.join(self.topdir, "foo") + self.assertExists(dest) + self.assertTrue(os.path.islink(dest)) + self.assertEqual( + os.path.join("git-project", "foo.txt"), os.readlink(dest) + ) - def test_src_subdir(self): - """Link to a file in a subdir of a project.""" - src = os.path.join(self.worktree, 'bar', 'foo.txt') - os.makedirs(os.path.dirname(src)) - self.touch(src) - lf = self.LinkFile('bar/foo.txt', 'foo') - lf._Link() - self.assertExists(os.path.join(self.topdir, 'foo')) + def test_src_subdir(self): + """Link to a file in a subdir of a project.""" + src = os.path.join(self.worktree, "bar", "foo.txt") + os.makedirs(os.path.dirname(src)) + self.touch(src) + lf = self.LinkFile("bar/foo.txt", "foo") + lf._Link() + self.assertExists(os.path.join(self.topdir, "foo")) - def test_src_self(self): - """Link to the project itself.""" - dest = os.path.join(self.topdir, 'foo', 'bar') - lf = self.LinkFile('.', 'foo/bar') - lf._Link() - self.assertExists(dest) - self.assertEqual(os.path.join('..', 'git-project'), os.readlink(dest)) + def test_src_self(self): + """Link to the project itself.""" + dest = os.path.join(self.topdir, "foo", "bar") + lf = self.LinkFile(".", "foo/bar") + lf._Link() + self.assertExists(dest) + self.assertEqual(os.path.join("..", "git-project"), os.readlink(dest)) - def test_dest_subdir(self): - """Link a file to a subdir of a checkout.""" - src = os.path.join(self.worktree, 'foo.txt') - self.touch(src) - lf = self.LinkFile('foo.txt', 'sub/dir/foo/bar') - self.assertFalse(os.path.exists(os.path.join(self.topdir, 'sub'))) - lf._Link() - self.assertExists(os.path.join(self.topdir, 'sub', 'dir', 'foo', 'bar')) + def test_dest_subdir(self): + """Link a file to a subdir of a checkout.""" + src = os.path.join(self.worktree, "foo.txt") + self.touch(src) + lf = self.LinkFile("foo.txt", "sub/dir/foo/bar") + self.assertFalse(os.path.exists(os.path.join(self.topdir, "sub"))) + lf._Link() + self.assertExists(os.path.join(self.topdir, "sub", "dir", "foo", "bar")) - def test_src_block_relative(self): - """Do not allow relative symlinks.""" - BAD_SOURCES = ( - './', - '..', - '../', - 'foo/.', - 'foo/./bar', - 'foo/..', - 'foo/../foo', - ) - for src in BAD_SOURCES: - lf = self.LinkFile(src, 'foo') - self.assertRaises(error.ManifestInvalidPathError, lf._Link) + def test_src_block_relative(self): + """Do not allow relative symlinks.""" + BAD_SOURCES = ( + "./", + "..", + "../", + "foo/.", + "foo/./bar", + "foo/..", + "foo/../foo", + ) + for src in BAD_SOURCES: + lf = self.LinkFile(src, "foo") + self.assertRaises(error.ManifestInvalidPathError, lf._Link) - def test_update(self): - """Make sure changed targets get updated.""" - dest = os.path.join(self.topdir, 'sym') + def test_update(self): + """Make sure changed targets get updated.""" + dest = os.path.join(self.topdir, "sym") - src = os.path.join(self.worktree, 'foo.txt') - self.touch(src) - lf = self.LinkFile('foo.txt', 'sym') - lf._Link() - self.assertEqual(os.path.join('git-project', 'foo.txt'), os.readlink(dest)) + src = os.path.join(self.worktree, "foo.txt") + self.touch(src) + lf = self.LinkFile("foo.txt", "sym") + lf._Link() + self.assertEqual( + os.path.join("git-project", "foo.txt"), os.readlink(dest) + ) - # Point the symlink somewhere else. - os.unlink(dest) - platform_utils.symlink(self.tempdir, dest) - lf._Link() - self.assertEqual(os.path.join('git-project', 'foo.txt'), os.readlink(dest)) + # Point the symlink somewhere else. + os.unlink(dest) + platform_utils.symlink(self.tempdir, dest) + lf._Link() + self.assertEqual( + os.path.join("git-project", "foo.txt"), os.readlink(dest) + ) class MigrateWorkTreeTests(unittest.TestCase): - """Check _MigrateOldWorkTreeGitDir handling.""" + """Check _MigrateOldWorkTreeGitDir handling.""" - _SYMLINKS = { - 'config', 'description', 'hooks', 'info', 'logs', 'objects', - 'packed-refs', 'refs', 'rr-cache', 'shallow', 'svn', - } - _FILES = { - 'COMMIT_EDITMSG', 'FETCH_HEAD', 'HEAD', 'index', 'ORIG_HEAD', - 'unknown-file-should-be-migrated', - } - _CLEAN_FILES = { - 'a-vim-temp-file~', '#an-emacs-temp-file#', - } + _SYMLINKS = { + "config", + "description", + "hooks", + "info", + "logs", + "objects", + "packed-refs", + "refs", + "rr-cache", + "shallow", + "svn", + } + _FILES = { + "COMMIT_EDITMSG", + "FETCH_HEAD", + "HEAD", + "index", + "ORIG_HEAD", + "unknown-file-should-be-migrated", + } + _CLEAN_FILES = { + "a-vim-temp-file~", + "#an-emacs-temp-file#", + } - @classmethod - @contextlib.contextmanager - def _simple_layout(cls): - """Create a simple repo client checkout to test against.""" - with tempfile.TemporaryDirectory() as tempdir: - tempdir = Path(tempdir) + @classmethod + @contextlib.contextmanager + def _simple_layout(cls): + """Create a simple repo client checkout to test against.""" + with tempfile.TemporaryDirectory() as tempdir: + tempdir = Path(tempdir) - gitdir = tempdir / '.repo/projects/src/test.git' - gitdir.mkdir(parents=True) - cmd = ['git', 'init', '--bare', str(gitdir)] - subprocess.check_call(cmd) + gitdir = tempdir / ".repo/projects/src/test.git" + gitdir.mkdir(parents=True) + cmd = ["git", "init", "--bare", str(gitdir)] + subprocess.check_call(cmd) - dotgit = tempdir / 'src/test/.git' - dotgit.mkdir(parents=True) - for name in cls._SYMLINKS: - (dotgit / name).symlink_to(f'../../../.repo/projects/src/test.git/{name}') - for name in cls._FILES | cls._CLEAN_FILES: - (dotgit / name).write_text(name) + dotgit = tempdir / "src/test/.git" + dotgit.mkdir(parents=True) + for name in cls._SYMLINKS: + (dotgit / name).symlink_to( + f"../../../.repo/projects/src/test.git/{name}" + ) + for name in cls._FILES | cls._CLEAN_FILES: + (dotgit / name).write_text(name) - yield tempdir + yield tempdir - def test_standard(self): - """Migrate a standard checkout that we expect.""" - with self._simple_layout() as tempdir: - dotgit = tempdir / 'src/test/.git' - project.Project._MigrateOldWorkTreeGitDir(str(dotgit)) + def test_standard(self): + """Migrate a standard checkout that we expect.""" + with self._simple_layout() as tempdir: + dotgit = tempdir / "src/test/.git" + project.Project._MigrateOldWorkTreeGitDir(str(dotgit)) - # Make sure the dir was transformed into a symlink. - self.assertTrue(dotgit.is_symlink()) - self.assertEqual(os.readlink(dotgit), os.path.normpath('../../.repo/projects/src/test.git')) + # Make sure the dir was transformed into a symlink. + self.assertTrue(dotgit.is_symlink()) + self.assertEqual( + os.readlink(dotgit), + os.path.normpath("../../.repo/projects/src/test.git"), + ) - # Make sure files were moved over. - gitdir = tempdir / '.repo/projects/src/test.git' - for name in self._FILES: - self.assertEqual(name, (gitdir / name).read_text()) - # Make sure files were removed. - for name in self._CLEAN_FILES: - self.assertFalse((gitdir / name).exists()) + # Make sure files were moved over. + gitdir = tempdir / ".repo/projects/src/test.git" + for name in self._FILES: + self.assertEqual(name, (gitdir / name).read_text()) + # Make sure files were removed. + for name in self._CLEAN_FILES: + self.assertFalse((gitdir / name).exists()) - def test_unknown(self): - """A checkout with unknown files should abort.""" - with self._simple_layout() as tempdir: - dotgit = tempdir / 'src/test/.git' - (tempdir / '.repo/projects/src/test.git/random-file').write_text('one') - (dotgit / 'random-file').write_text('two') - with self.assertRaises(error.GitError): - project.Project._MigrateOldWorkTreeGitDir(str(dotgit)) + def test_unknown(self): + """A checkout with unknown files should abort.""" + with self._simple_layout() as tempdir: + dotgit = tempdir / "src/test/.git" + (tempdir / ".repo/projects/src/test.git/random-file").write_text( + "one" + ) + (dotgit / "random-file").write_text("two") + with self.assertRaises(error.GitError): + project.Project._MigrateOldWorkTreeGitDir(str(dotgit)) - # Make sure no content was actually changed. - self.assertTrue(dotgit.is_dir()) - for name in self._FILES: - self.assertTrue((dotgit / name).is_file()) - for name in self._CLEAN_FILES: - self.assertTrue((dotgit / name).is_file()) - for name in self._SYMLINKS: - self.assertTrue((dotgit / name).is_symlink()) + # Make sure no content was actually changed. + self.assertTrue(dotgit.is_dir()) + for name in self._FILES: + self.assertTrue((dotgit / name).is_file()) + for name in self._CLEAN_FILES: + self.assertTrue((dotgit / name).is_file()) + for name in self._SYMLINKS: + self.assertTrue((dotgit / name).is_symlink()) class ManifestPropertiesFetchedCorrectly(unittest.TestCase): - """Ensure properties are fetched properly.""" + """Ensure properties are fetched properly.""" - def setUpManifest(self, tempdir): - repodir = os.path.join(tempdir, '.repo') - manifest_dir = os.path.join(repodir, 'manifests') - manifest_file = os.path.join( - repodir, manifest_xml.MANIFEST_FILE_NAME) - local_manifest_dir = os.path.join( - repodir, manifest_xml.LOCAL_MANIFESTS_DIR_NAME) - os.mkdir(repodir) - os.mkdir(manifest_dir) - manifest = manifest_xml.XmlManifest(repodir, manifest_file) + def setUpManifest(self, tempdir): + repodir = os.path.join(tempdir, ".repo") + manifest_dir = os.path.join(repodir, "manifests") + manifest_file = os.path.join(repodir, manifest_xml.MANIFEST_FILE_NAME) + os.mkdir(repodir) + os.mkdir(manifest_dir) + manifest = manifest_xml.XmlManifest(repodir, manifest_file) - return project.ManifestProject( - manifest, 'test/manifest', os.path.join(tempdir, '.git'), tempdir) + return project.ManifestProject( + manifest, "test/manifest", os.path.join(tempdir, ".git"), tempdir + ) - def test_manifest_config_properties(self): - """Test we are fetching the manifest config properties correctly.""" + def test_manifest_config_properties(self): + """Test we are fetching the manifest config properties correctly.""" - with TempGitTree() as tempdir: - fakeproj = self.setUpManifest(tempdir) + with TempGitTree() as tempdir: + fakeproj = self.setUpManifest(tempdir) - # Set property using the expected Set method, then ensure - # the porperty functions are using the correct Get methods. - fakeproj.config.SetString( - 'manifest.standalone', 'https://chicken/manifest.git') - self.assertEqual( - fakeproj.standalone_manifest_url, 'https://chicken/manifest.git') + # Set property using the expected Set method, then ensure + # the porperty functions are using the correct Get methods. + fakeproj.config.SetString( + "manifest.standalone", "https://chicken/manifest.git" + ) + self.assertEqual( + fakeproj.standalone_manifest_url, "https://chicken/manifest.git" + ) - fakeproj.config.SetString('manifest.groups', 'test-group, admin-group') - self.assertEqual(fakeproj.manifest_groups, 'test-group, admin-group') + fakeproj.config.SetString( + "manifest.groups", "test-group, admin-group" + ) + self.assertEqual( + fakeproj.manifest_groups, "test-group, admin-group" + ) - fakeproj.config.SetString('repo.reference', 'mirror/ref') - self.assertEqual(fakeproj.reference, 'mirror/ref') + fakeproj.config.SetString("repo.reference", "mirror/ref") + self.assertEqual(fakeproj.reference, "mirror/ref") - fakeproj.config.SetBoolean('repo.dissociate', False) - self.assertFalse(fakeproj.dissociate) + fakeproj.config.SetBoolean("repo.dissociate", False) + self.assertFalse(fakeproj.dissociate) - fakeproj.config.SetBoolean('repo.archive', False) - self.assertFalse(fakeproj.archive) + fakeproj.config.SetBoolean("repo.archive", False) + self.assertFalse(fakeproj.archive) - fakeproj.config.SetBoolean('repo.mirror', False) - self.assertFalse(fakeproj.mirror) + fakeproj.config.SetBoolean("repo.mirror", False) + self.assertFalse(fakeproj.mirror) - fakeproj.config.SetBoolean('repo.worktree', False) - self.assertFalse(fakeproj.use_worktree) + fakeproj.config.SetBoolean("repo.worktree", False) + self.assertFalse(fakeproj.use_worktree) - fakeproj.config.SetBoolean('repo.clonebundle', False) - self.assertFalse(fakeproj.clone_bundle) + fakeproj.config.SetBoolean("repo.clonebundle", False) + self.assertFalse(fakeproj.clone_bundle) - fakeproj.config.SetBoolean('repo.submodules', False) - self.assertFalse(fakeproj.submodules) + fakeproj.config.SetBoolean("repo.submodules", False) + self.assertFalse(fakeproj.submodules) - fakeproj.config.SetBoolean('repo.git-lfs', False) - self.assertFalse(fakeproj.git_lfs) + fakeproj.config.SetBoolean("repo.git-lfs", False) + self.assertFalse(fakeproj.git_lfs) - fakeproj.config.SetBoolean('repo.superproject', False) - self.assertFalse(fakeproj.use_superproject) + fakeproj.config.SetBoolean("repo.superproject", False) + self.assertFalse(fakeproj.use_superproject) - fakeproj.config.SetBoolean('repo.partialclone', False) - self.assertFalse(fakeproj.partial_clone) + fakeproj.config.SetBoolean("repo.partialclone", False) + self.assertFalse(fakeproj.partial_clone) - fakeproj.config.SetString('repo.depth', '48') - self.assertEqual(fakeproj.depth, '48') + fakeproj.config.SetString("repo.depth", "48") + self.assertEqual(fakeproj.depth, "48") - fakeproj.config.SetString('repo.clonefilter', 'blob:limit=10M') - self.assertEqual(fakeproj.clone_filter, 'blob:limit=10M') + fakeproj.config.SetString("repo.clonefilter", "blob:limit=10M") + self.assertEqual(fakeproj.clone_filter, "blob:limit=10M") - fakeproj.config.SetString('repo.partialcloneexclude', 'third_party/big_repo') - self.assertEqual(fakeproj.partial_clone_exclude, 'third_party/big_repo') + fakeproj.config.SetString( + "repo.partialcloneexclude", "third_party/big_repo" + ) + self.assertEqual( + fakeproj.partial_clone_exclude, "third_party/big_repo" + ) - fakeproj.config.SetString('manifest.platform', 'auto') - self.assertEqual(fakeproj.manifest_platform, 'auto') + fakeproj.config.SetString("manifest.platform", "auto") + self.assertEqual(fakeproj.manifest_platform, "auto") diff --git a/tests/test_repo_trace.py b/tests/test_repo_trace.py index 5faf2938..e4aeb5de 100644 --- a/tests/test_repo_trace.py +++ b/tests/test_repo_trace.py @@ -22,35 +22,39 @@ import repo_trace class TraceTests(unittest.TestCase): - """Check Trace behavior.""" + """Check Trace behavior.""" - def testTrace_MaxSizeEnforced(self): - content = 'git chicken' + def testTrace_MaxSizeEnforced(self): + content = "git chicken" - with repo_trace.Trace(content, first_trace=True): - pass - first_trace_size = os.path.getsize(repo_trace._TRACE_FILE) + with repo_trace.Trace(content, first_trace=True): + pass + first_trace_size = os.path.getsize(repo_trace._TRACE_FILE) - with repo_trace.Trace(content): - pass - self.assertGreater( - os.path.getsize(repo_trace._TRACE_FILE), first_trace_size) + with repo_trace.Trace(content): + pass + self.assertGreater( + os.path.getsize(repo_trace._TRACE_FILE), first_trace_size + ) - # Check we clear everything is the last chunk is larger than _MAX_SIZE. - with mock.patch('repo_trace._MAX_SIZE', 0): - with repo_trace.Trace(content, first_trace=True): - pass - self.assertEqual(first_trace_size, - os.path.getsize(repo_trace._TRACE_FILE)) + # Check we clear everything is the last chunk is larger than _MAX_SIZE. + with mock.patch("repo_trace._MAX_SIZE", 0): + with repo_trace.Trace(content, first_trace=True): + pass + self.assertEqual( + first_trace_size, os.path.getsize(repo_trace._TRACE_FILE) + ) - # Check we only clear the chunks we need to. - repo_trace._MAX_SIZE = (first_trace_size + 1) / (1024 * 1024) - with repo_trace.Trace(content, first_trace=True): - pass - self.assertEqual(first_trace_size * 2, - os.path.getsize(repo_trace._TRACE_FILE)) + # Check we only clear the chunks we need to. + repo_trace._MAX_SIZE = (first_trace_size + 1) / (1024 * 1024) + with repo_trace.Trace(content, first_trace=True): + pass + self.assertEqual( + first_trace_size * 2, os.path.getsize(repo_trace._TRACE_FILE) + ) - with repo_trace.Trace(content, first_trace=True): - pass - self.assertEqual(first_trace_size * 2, - os.path.getsize(repo_trace._TRACE_FILE)) + with repo_trace.Trace(content, first_trace=True): + pass + self.assertEqual( + first_trace_size * 2, os.path.getsize(repo_trace._TRACE_FILE) + ) diff --git a/tests/test_ssh.py b/tests/test_ssh.py index ffb5cb94..a9c1be7f 100644 --- a/tests/test_ssh.py +++ b/tests/test_ssh.py @@ -23,52 +23,56 @@ import ssh class SshTests(unittest.TestCase): - """Tests the ssh functions.""" + """Tests the ssh functions.""" - def test_parse_ssh_version(self): - """Check _parse_ssh_version() handling.""" - ver = ssh._parse_ssh_version('Unknown\n') - self.assertEqual(ver, ()) - ver = ssh._parse_ssh_version('OpenSSH_1.0\n') - self.assertEqual(ver, (1, 0)) - ver = ssh._parse_ssh_version('OpenSSH_6.6.1p1 Ubuntu-2ubuntu2.13, OpenSSL 1.0.1f 6 Jan 2014\n') - self.assertEqual(ver, (6, 6, 1)) - ver = ssh._parse_ssh_version('OpenSSH_7.6p1 Ubuntu-4ubuntu0.3, OpenSSL 1.0.2n 7 Dec 2017\n') - self.assertEqual(ver, (7, 6)) + def test_parse_ssh_version(self): + """Check _parse_ssh_version() handling.""" + ver = ssh._parse_ssh_version("Unknown\n") + self.assertEqual(ver, ()) + ver = ssh._parse_ssh_version("OpenSSH_1.0\n") + self.assertEqual(ver, (1, 0)) + ver = ssh._parse_ssh_version( + "OpenSSH_6.6.1p1 Ubuntu-2ubuntu2.13, OpenSSL 1.0.1f 6 Jan 2014\n" + ) + self.assertEqual(ver, (6, 6, 1)) + ver = ssh._parse_ssh_version( + "OpenSSH_7.6p1 Ubuntu-4ubuntu0.3, OpenSSL 1.0.2n 7 Dec 2017\n" + ) + self.assertEqual(ver, (7, 6)) - def test_version(self): - """Check version() handling.""" - with mock.patch('ssh._run_ssh_version', return_value='OpenSSH_1.2\n'): - self.assertEqual(ssh.version(), (1, 2)) + def test_version(self): + """Check version() handling.""" + with mock.patch("ssh._run_ssh_version", return_value="OpenSSH_1.2\n"): + self.assertEqual(ssh.version(), (1, 2)) - def test_context_manager_empty(self): - """Verify context manager with no clients works correctly.""" - with multiprocessing.Manager() as manager: - with ssh.ProxyManager(manager): - pass + def test_context_manager_empty(self): + """Verify context manager with no clients works correctly.""" + with multiprocessing.Manager() as manager: + with ssh.ProxyManager(manager): + pass - def test_context_manager_child_cleanup(self): - """Verify orphaned clients & masters get cleaned up.""" - with multiprocessing.Manager() as manager: - with ssh.ProxyManager(manager) as ssh_proxy: - client = subprocess.Popen(['sleep', '964853320']) - ssh_proxy.add_client(client) - master = subprocess.Popen(['sleep', '964853321']) - ssh_proxy.add_master(master) - # If the process still exists, these will throw timeout errors. - client.wait(0) - master.wait(0) + def test_context_manager_child_cleanup(self): + """Verify orphaned clients & masters get cleaned up.""" + with multiprocessing.Manager() as manager: + with ssh.ProxyManager(manager) as ssh_proxy: + client = subprocess.Popen(["sleep", "964853320"]) + ssh_proxy.add_client(client) + master = subprocess.Popen(["sleep", "964853321"]) + ssh_proxy.add_master(master) + # If the process still exists, these will throw timeout errors. + client.wait(0) + master.wait(0) - def test_ssh_sock(self): - """Check sock() function.""" - manager = multiprocessing.Manager() - proxy = ssh.ProxyManager(manager) - with mock.patch('tempfile.mkdtemp', return_value='/tmp/foo'): - # old ssh version uses port - with mock.patch('ssh.version', return_value=(6, 6)): - self.assertTrue(proxy.sock().endswith('%p')) + def test_ssh_sock(self): + """Check sock() function.""" + manager = multiprocessing.Manager() + proxy = ssh.ProxyManager(manager) + with mock.patch("tempfile.mkdtemp", return_value="/tmp/foo"): + # Old ssh version uses port. + with mock.patch("ssh.version", return_value=(6, 6)): + self.assertTrue(proxy.sock().endswith("%p")) - proxy._sock_path = None - # new ssh version uses hash - with mock.patch('ssh.version', return_value=(6, 7)): - self.assertTrue(proxy.sock().endswith('%C')) + proxy._sock_path = None + # New ssh version uses hash. + with mock.patch("ssh.version", return_value=(6, 7)): + self.assertTrue(proxy.sock().endswith("%C")) diff --git a/tests/test_subcmds.py b/tests/test_subcmds.py index bc53051a..73b66e3f 100644 --- a/tests/test_subcmds.py +++ b/tests/test_subcmds.py @@ -21,53 +21,57 @@ import subcmds class AllCommands(unittest.TestCase): - """Check registered all_commands.""" + """Check registered all_commands.""" - def test_required_basic(self): - """Basic checking of registered commands.""" - # NB: We don't test all subcommands as we want to avoid "change detection" - # tests, so we just look for the most common/important ones here that are - # unlikely to ever change. - for cmd in {'cherry-pick', 'help', 'init', 'start', 'sync', 'upload'}: - self.assertIn(cmd, subcmds.all_commands) + def test_required_basic(self): + """Basic checking of registered commands.""" + # NB: We don't test all subcommands as we want to avoid "change + # detection" tests, so we just look for the most common/important ones + # here that are unlikely to ever change. + for cmd in {"cherry-pick", "help", "init", "start", "sync", "upload"}: + self.assertIn(cmd, subcmds.all_commands) - def test_naming(self): - """Verify we don't add things that we shouldn't.""" - for cmd in subcmds.all_commands: - # Reject filename suffixes like "help.py". - self.assertNotIn('.', cmd) + def test_naming(self): + """Verify we don't add things that we shouldn't.""" + for cmd in subcmds.all_commands: + # Reject filename suffixes like "help.py". + self.assertNotIn(".", cmd) - # Make sure all '_' were converted to '-'. - self.assertNotIn('_', cmd) + # Make sure all '_' were converted to '-'. + self.assertNotIn("_", cmd) - # Reject internal python paths like "__init__". - self.assertFalse(cmd.startswith('__')) + # Reject internal python paths like "__init__". + self.assertFalse(cmd.startswith("__")) - def test_help_desc_style(self): - """Force some consistency in option descriptions. + def test_help_desc_style(self): + """Force some consistency in option descriptions. - Python's optparse & argparse has a few default options like --help. Their - option description text uses lowercase sentence fragments, so enforce our - options follow the same style so UI is consistent. + Python's optparse & argparse has a few default options like --help. + Their option description text uses lowercase sentence fragments, so + enforce our options follow the same style so UI is consistent. - We enforce: - * Text starts with lowercase. - * Text doesn't end with period. - """ - for name, cls in subcmds.all_commands.items(): - cmd = cls() - parser = cmd.OptionParser - for option in parser.option_list: - if option.help == optparse.SUPPRESS_HELP: - continue + We enforce: + * Text starts with lowercase. + * Text doesn't end with period. + """ + for name, cls in subcmds.all_commands.items(): + cmd = cls() + parser = cmd.OptionParser + for option in parser.option_list: + if option.help == optparse.SUPPRESS_HELP: + continue - c = option.help[0] - self.assertEqual( - c.lower(), c, - msg=f'subcmds/{name}.py: {option.get_opt_string()}: help text ' - f'should start with lowercase: "{option.help}"') + c = option.help[0] + self.assertEqual( + c.lower(), + c, + msg=f"subcmds/{name}.py: {option.get_opt_string()}: " + f'help text should start with lowercase: "{option.help}"', + ) - self.assertNotEqual( - option.help[-1], '.', - msg=f'subcmds/{name}.py: {option.get_opt_string()}: help text ' - f'should not end in a period: "{option.help}"') + self.assertNotEqual( + option.help[-1], + ".", + msg=f"subcmds/{name}.py: {option.get_opt_string()}: " + f'help text should not end in a period: "{option.help}"', + ) diff --git a/tests/test_subcmds_init.py b/tests/test_subcmds_init.py index af4346de..25e5be56 100644 --- a/tests/test_subcmds_init.py +++ b/tests/test_subcmds_init.py @@ -20,30 +20,27 @@ from subcmds import init class InitCommand(unittest.TestCase): - """Check registered all_commands.""" + """Check registered all_commands.""" - def setUp(self): - self.cmd = init.Init() + def setUp(self): + self.cmd = init.Init() - def test_cli_parser_good(self): - """Check valid command line options.""" - ARGV = ( - [], - ) - for argv in ARGV: - opts, args = self.cmd.OptionParser.parse_args(argv) - self.cmd.ValidateOptions(opts, args) + def test_cli_parser_good(self): + """Check valid command line options.""" + ARGV = ([],) + for argv in ARGV: + opts, args = self.cmd.OptionParser.parse_args(argv) + self.cmd.ValidateOptions(opts, args) - def test_cli_parser_bad(self): - """Check invalid command line options.""" - ARGV = ( - # Too many arguments. - ['url', 'asdf'], - - # Conflicting options. - ['--mirror', '--archive'], - ) - for argv in ARGV: - opts, args = self.cmd.OptionParser.parse_args(argv) - with self.assertRaises(SystemExit): - self.cmd.ValidateOptions(opts, args) + def test_cli_parser_bad(self): + """Check invalid command line options.""" + ARGV = ( + # Too many arguments. + ["url", "asdf"], + # Conflicting options. + ["--mirror", "--archive"], + ) + for argv in ARGV: + opts, args = self.cmd.OptionParser.parse_args(argv) + with self.assertRaises(SystemExit): + self.cmd.ValidateOptions(opts, args) diff --git a/tests/test_subcmds_sync.py b/tests/test_subcmds_sync.py index 236d54e5..5c8e606e 100644 --- a/tests/test_subcmds_sync.py +++ b/tests/test_subcmds_sync.py @@ -23,111 +23,138 @@ import command from subcmds import sync -@pytest.mark.parametrize('use_superproject, cli_args, result', [ - (True, ['--current-branch'], True), - (True, ['--no-current-branch'], True), - (True, [], True), - (False, ['--current-branch'], True), - (False, ['--no-current-branch'], False), - (False, [], None), -]) +@pytest.mark.parametrize( + "use_superproject, cli_args, result", + [ + (True, ["--current-branch"], True), + (True, ["--no-current-branch"], True), + (True, [], True), + (False, ["--current-branch"], True), + (False, ["--no-current-branch"], False), + (False, [], None), + ], +) def test_get_current_branch_only(use_superproject, cli_args, result): - """Test Sync._GetCurrentBranchOnly logic. + """Test Sync._GetCurrentBranchOnly logic. - Sync._GetCurrentBranchOnly should return True if a superproject is requested, - and otherwise the value of the current_branch_only option. - """ - cmd = sync.Sync() - opts, _ = cmd.OptionParser.parse_args(cli_args) + Sync._GetCurrentBranchOnly should return True if a superproject is + requested, and otherwise the value of the current_branch_only option. + """ + cmd = sync.Sync() + opts, _ = cmd.OptionParser.parse_args(cli_args) - with mock.patch('git_superproject.UseSuperproject', - return_value=use_superproject): - assert cmd._GetCurrentBranchOnly(opts, cmd.manifest) == result + with mock.patch( + "git_superproject.UseSuperproject", return_value=use_superproject + ): + assert cmd._GetCurrentBranchOnly(opts, cmd.manifest) == result # Used to patch os.cpu_count() for reliable results. OS_CPU_COUNT = 24 -@pytest.mark.parametrize('argv, jobs_manifest, jobs, jobs_net, jobs_check', [ - # No user or manifest settings. - ([], None, OS_CPU_COUNT, 1, command.DEFAULT_LOCAL_JOBS), - # No user settings, so manifest settings control. - ([], 3, 3, 3, 3), - # User settings, but no manifest. - (['--jobs=4'], None, 4, 4, 4), - (['--jobs=4', '--jobs-network=5'], None, 4, 5, 4), - (['--jobs=4', '--jobs-checkout=6'], None, 4, 4, 6), - (['--jobs=4', '--jobs-network=5', '--jobs-checkout=6'], None, 4, 5, 6), - (['--jobs-network=5'], None, OS_CPU_COUNT, 5, command.DEFAULT_LOCAL_JOBS), - (['--jobs-checkout=6'], None, OS_CPU_COUNT, 1, 6), - (['--jobs-network=5', '--jobs-checkout=6'], None, OS_CPU_COUNT, 5, 6), - # User settings with manifest settings. - (['--jobs=4'], 3, 4, 4, 4), - (['--jobs=4', '--jobs-network=5'], 3, 4, 5, 4), - (['--jobs=4', '--jobs-checkout=6'], 3, 4, 4, 6), - (['--jobs=4', '--jobs-network=5', '--jobs-checkout=6'], 3, 4, 5, 6), - (['--jobs-network=5'], 3, 3, 5, 3), - (['--jobs-checkout=6'], 3, 3, 3, 6), - (['--jobs-network=5', '--jobs-checkout=6'], 3, 3, 5, 6), - # Settings that exceed rlimits get capped. - (['--jobs=1000000'], None, 83, 83, 83), - ([], 1000000, 83, 83, 83), -]) + +@pytest.mark.parametrize( + "argv, jobs_manifest, jobs, jobs_net, jobs_check", + [ + # No user or manifest settings. + ([], None, OS_CPU_COUNT, 1, command.DEFAULT_LOCAL_JOBS), + # No user settings, so manifest settings control. + ([], 3, 3, 3, 3), + # User settings, but no manifest. + (["--jobs=4"], None, 4, 4, 4), + (["--jobs=4", "--jobs-network=5"], None, 4, 5, 4), + (["--jobs=4", "--jobs-checkout=6"], None, 4, 4, 6), + (["--jobs=4", "--jobs-network=5", "--jobs-checkout=6"], None, 4, 5, 6), + ( + ["--jobs-network=5"], + None, + OS_CPU_COUNT, + 5, + command.DEFAULT_LOCAL_JOBS, + ), + (["--jobs-checkout=6"], None, OS_CPU_COUNT, 1, 6), + (["--jobs-network=5", "--jobs-checkout=6"], None, OS_CPU_COUNT, 5, 6), + # User settings with manifest settings. + (["--jobs=4"], 3, 4, 4, 4), + (["--jobs=4", "--jobs-network=5"], 3, 4, 5, 4), + (["--jobs=4", "--jobs-checkout=6"], 3, 4, 4, 6), + (["--jobs=4", "--jobs-network=5", "--jobs-checkout=6"], 3, 4, 5, 6), + (["--jobs-network=5"], 3, 3, 5, 3), + (["--jobs-checkout=6"], 3, 3, 3, 6), + (["--jobs-network=5", "--jobs-checkout=6"], 3, 3, 5, 6), + # Settings that exceed rlimits get capped. + (["--jobs=1000000"], None, 83, 83, 83), + ([], 1000000, 83, 83, 83), + ], +) def test_cli_jobs(argv, jobs_manifest, jobs, jobs_net, jobs_check): - """Tests --jobs option behavior.""" - mp = mock.MagicMock() - mp.manifest.default.sync_j = jobs_manifest + """Tests --jobs option behavior.""" + mp = mock.MagicMock() + mp.manifest.default.sync_j = jobs_manifest - cmd = sync.Sync() - opts, args = cmd.OptionParser.parse_args(argv) - cmd.ValidateOptions(opts, args) + cmd = sync.Sync() + opts, args = cmd.OptionParser.parse_args(argv) + cmd.ValidateOptions(opts, args) - with mock.patch.object(sync, '_rlimit_nofile', return_value=(256, 256)): - with mock.patch.object(os, 'cpu_count', return_value=OS_CPU_COUNT): - cmd._ValidateOptionsWithManifest(opts, mp) - assert opts.jobs == jobs - assert opts.jobs_network == jobs_net - assert opts.jobs_checkout == jobs_check + with mock.patch.object(sync, "_rlimit_nofile", return_value=(256, 256)): + with mock.patch.object(os, "cpu_count", return_value=OS_CPU_COUNT): + cmd._ValidateOptionsWithManifest(opts, mp) + assert opts.jobs == jobs + assert opts.jobs_network == jobs_net + assert opts.jobs_checkout == jobs_check class GetPreciousObjectsState(unittest.TestCase): - """Tests for _GetPreciousObjectsState.""" + """Tests for _GetPreciousObjectsState.""" - def setUp(self): - """Common setup.""" - self.cmd = sync.Sync() - self.project = p = mock.MagicMock(use_git_worktrees=False, - UseAlternates=False) - p.manifest.GetProjectsWithName.return_value = [p] + def setUp(self): + """Common setup.""" + self.cmd = sync.Sync() + self.project = p = mock.MagicMock( + use_git_worktrees=False, UseAlternates=False + ) + p.manifest.GetProjectsWithName.return_value = [p] - self.opt = mock.Mock(spec_set=['this_manifest_only']) - self.opt.this_manifest_only = False + self.opt = mock.Mock(spec_set=["this_manifest_only"]) + self.opt.this_manifest_only = False - def test_worktrees(self): - """False for worktrees.""" - self.project.use_git_worktrees = True - self.assertFalse(self.cmd._GetPreciousObjectsState(self.project, self.opt)) + def test_worktrees(self): + """False for worktrees.""" + self.project.use_git_worktrees = True + self.assertFalse( + self.cmd._GetPreciousObjectsState(self.project, self.opt) + ) - def test_not_shared(self): - """Singleton project.""" - self.assertFalse(self.cmd._GetPreciousObjectsState(self.project, self.opt)) + def test_not_shared(self): + """Singleton project.""" + self.assertFalse( + self.cmd._GetPreciousObjectsState(self.project, self.opt) + ) - def test_shared(self): - """Shared project.""" - self.project.manifest.GetProjectsWithName.return_value = [ - self.project, self.project - ] - self.assertTrue(self.cmd._GetPreciousObjectsState(self.project, self.opt)) + def test_shared(self): + """Shared project.""" + self.project.manifest.GetProjectsWithName.return_value = [ + self.project, + self.project, + ] + self.assertTrue( + self.cmd._GetPreciousObjectsState(self.project, self.opt) + ) - def test_shared_with_alternates(self): - """Shared project, with alternates.""" - self.project.manifest.GetProjectsWithName.return_value = [ - self.project, self.project - ] - self.project.UseAlternates = True - self.assertFalse(self.cmd._GetPreciousObjectsState(self.project, self.opt)) + def test_shared_with_alternates(self): + """Shared project, with alternates.""" + self.project.manifest.GetProjectsWithName.return_value = [ + self.project, + self.project, + ] + self.project.UseAlternates = True + self.assertFalse( + self.cmd._GetPreciousObjectsState(self.project, self.opt) + ) - def test_not_found(self): - """Project not found in manifest.""" - self.project.manifest.GetProjectsWithName.return_value = [] - self.assertFalse(self.cmd._GetPreciousObjectsState(self.project, self.opt)) + def test_not_found(self): + """Project not found in manifest.""" + self.project.manifest.GetProjectsWithName.return_value = [] + self.assertFalse( + self.cmd._GetPreciousObjectsState(self.project, self.opt) + ) diff --git a/tests/test_update_manpages.py b/tests/test_update_manpages.py index 0de85be9..12b19ec4 100644 --- a/tests/test_update_manpages.py +++ b/tests/test_update_manpages.py @@ -20,9 +20,9 @@ from release import update_manpages class UpdateManpagesTest(unittest.TestCase): - """Tests the update-manpages code.""" + """Tests the update-manpages code.""" - def test_replace_regex(self): - """Check that replace_regex works.""" - data = '\n\033[1mSummary\033[m\n' - self.assertEqual(update_manpages.replace_regex(data),'\nSummary\n') + def test_replace_regex(self): + """Check that replace_regex works.""" + data = "\n\033[1mSummary\033[m\n" + self.assertEqual(update_manpages.replace_regex(data), "\nSummary\n") diff --git a/tests/test_wrapper.py b/tests/test_wrapper.py index ef879a5d..21fa094d 100644 --- a/tests/test_wrapper.py +++ b/tests/test_wrapper.py @@ -28,528 +28,615 @@ import wrapper def fixture(*paths): - """Return a path relative to tests/fixtures. - """ - return os.path.join(os.path.dirname(__file__), 'fixtures', *paths) + """Return a path relative to tests/fixtures.""" + return os.path.join(os.path.dirname(__file__), "fixtures", *paths) class RepoWrapperTestCase(unittest.TestCase): - """TestCase for the wrapper module.""" + """TestCase for the wrapper module.""" - def setUp(self): - """Load the wrapper module every time.""" - wrapper.Wrapper.cache_clear() - self.wrapper = wrapper.Wrapper() + def setUp(self): + """Load the wrapper module every time.""" + wrapper.Wrapper.cache_clear() + self.wrapper = wrapper.Wrapper() class RepoWrapperUnitTest(RepoWrapperTestCase): - """Tests helper functions in the repo wrapper - """ + """Tests helper functions in the repo wrapper""" - def test_version(self): - """Make sure _Version works.""" - with self.assertRaises(SystemExit) as e: - with mock.patch('sys.stdout', new_callable=StringIO) as stdout: - with mock.patch('sys.stderr', new_callable=StringIO) as stderr: - self.wrapper._Version() - self.assertEqual(0, e.exception.code) - self.assertEqual('', stderr.getvalue()) - self.assertIn('repo launcher version', stdout.getvalue()) + def test_version(self): + """Make sure _Version works.""" + with self.assertRaises(SystemExit) as e: + with mock.patch("sys.stdout", new_callable=StringIO) as stdout: + with mock.patch("sys.stderr", new_callable=StringIO) as stderr: + self.wrapper._Version() + self.assertEqual(0, e.exception.code) + self.assertEqual("", stderr.getvalue()) + self.assertIn("repo launcher version", stdout.getvalue()) - def test_python_constraints(self): - """The launcher should never require newer than main.py.""" - self.assertGreaterEqual(main.MIN_PYTHON_VERSION_HARD, - self.wrapper.MIN_PYTHON_VERSION_HARD) - self.assertGreaterEqual(main.MIN_PYTHON_VERSION_SOFT, - self.wrapper.MIN_PYTHON_VERSION_SOFT) - # Make sure the versions are themselves in sync. - self.assertGreaterEqual(self.wrapper.MIN_PYTHON_VERSION_SOFT, - self.wrapper.MIN_PYTHON_VERSION_HARD) + def test_python_constraints(self): + """The launcher should never require newer than main.py.""" + self.assertGreaterEqual( + main.MIN_PYTHON_VERSION_HARD, self.wrapper.MIN_PYTHON_VERSION_HARD + ) + self.assertGreaterEqual( + main.MIN_PYTHON_VERSION_SOFT, self.wrapper.MIN_PYTHON_VERSION_SOFT + ) + # Make sure the versions are themselves in sync. + self.assertGreaterEqual( + self.wrapper.MIN_PYTHON_VERSION_SOFT, + self.wrapper.MIN_PYTHON_VERSION_HARD, + ) - def test_init_parser(self): - """Make sure 'init' GetParser works.""" - parser = self.wrapper.GetParser(gitc_init=False) - opts, args = parser.parse_args([]) - self.assertEqual([], args) - self.assertIsNone(opts.manifest_url) + def test_init_parser(self): + """Make sure 'init' GetParser works.""" + parser = self.wrapper.GetParser(gitc_init=False) + opts, args = parser.parse_args([]) + self.assertEqual([], args) + self.assertIsNone(opts.manifest_url) - def test_gitc_init_parser(self): - """Make sure 'gitc-init' GetParser works.""" - parser = self.wrapper.GetParser(gitc_init=True) - opts, args = parser.parse_args([]) - self.assertEqual([], args) - self.assertIsNone(opts.manifest_file) + def test_gitc_init_parser(self): + """Make sure 'gitc-init' GetParser works.""" + parser = self.wrapper.GetParser(gitc_init=True) + opts, args = parser.parse_args([]) + self.assertEqual([], args) + self.assertIsNone(opts.manifest_file) - def test_get_gitc_manifest_dir_no_gitc(self): - """ - Test reading a missing gitc config file - """ - self.wrapper.GITC_CONFIG_FILE = fixture('missing_gitc_config') - val = self.wrapper.get_gitc_manifest_dir() - self.assertEqual(val, '') + def test_get_gitc_manifest_dir_no_gitc(self): + """ + Test reading a missing gitc config file + """ + self.wrapper.GITC_CONFIG_FILE = fixture("missing_gitc_config") + val = self.wrapper.get_gitc_manifest_dir() + self.assertEqual(val, "") - def test_get_gitc_manifest_dir(self): - """ - Test reading the gitc config file and parsing the directory - """ - self.wrapper.GITC_CONFIG_FILE = fixture('gitc_config') - val = self.wrapper.get_gitc_manifest_dir() - self.assertEqual(val, '/test/usr/local/google/gitc') + def test_get_gitc_manifest_dir(self): + """ + Test reading the gitc config file and parsing the directory + """ + self.wrapper.GITC_CONFIG_FILE = fixture("gitc_config") + val = self.wrapper.get_gitc_manifest_dir() + self.assertEqual(val, "/test/usr/local/google/gitc") - def test_gitc_parse_clientdir_no_gitc(self): - """ - Test parsing the gitc clientdir without gitc running - """ - self.wrapper.GITC_CONFIG_FILE = fixture('missing_gitc_config') - self.assertEqual(self.wrapper.gitc_parse_clientdir('/something'), None) - self.assertEqual(self.wrapper.gitc_parse_clientdir('/gitc/manifest-rw/test'), 'test') + def test_gitc_parse_clientdir_no_gitc(self): + """ + Test parsing the gitc clientdir without gitc running + """ + self.wrapper.GITC_CONFIG_FILE = fixture("missing_gitc_config") + self.assertEqual(self.wrapper.gitc_parse_clientdir("/something"), None) + self.assertEqual( + self.wrapper.gitc_parse_clientdir("/gitc/manifest-rw/test"), "test" + ) - def test_gitc_parse_clientdir(self): - """ - Test parsing the gitc clientdir - """ - self.wrapper.GITC_CONFIG_FILE = fixture('gitc_config') - self.assertEqual(self.wrapper.gitc_parse_clientdir('/something'), None) - self.assertEqual(self.wrapper.gitc_parse_clientdir('/gitc/manifest-rw/test'), 'test') - self.assertEqual(self.wrapper.gitc_parse_clientdir('/gitc/manifest-rw/test/'), 'test') - self.assertEqual(self.wrapper.gitc_parse_clientdir('/gitc/manifest-rw/test/extra'), 'test') - self.assertEqual(self.wrapper.gitc_parse_clientdir('/test/usr/local/google/gitc/test'), 'test') - self.assertEqual(self.wrapper.gitc_parse_clientdir('/test/usr/local/google/gitc/test/'), 'test') - self.assertEqual(self.wrapper.gitc_parse_clientdir('/test/usr/local/google/gitc/test/extra'), - 'test') - self.assertEqual(self.wrapper.gitc_parse_clientdir('/gitc/manifest-rw/'), None) - self.assertEqual(self.wrapper.gitc_parse_clientdir('/test/usr/local/google/gitc/'), None) + def test_gitc_parse_clientdir(self): + """ + Test parsing the gitc clientdir + """ + self.wrapper.GITC_CONFIG_FILE = fixture("gitc_config") + self.assertEqual(self.wrapper.gitc_parse_clientdir("/something"), None) + self.assertEqual( + self.wrapper.gitc_parse_clientdir("/gitc/manifest-rw/test"), "test" + ) + self.assertEqual( + self.wrapper.gitc_parse_clientdir("/gitc/manifest-rw/test/"), "test" + ) + self.assertEqual( + self.wrapper.gitc_parse_clientdir("/gitc/manifest-rw/test/extra"), + "test", + ) + self.assertEqual( + self.wrapper.gitc_parse_clientdir( + "/test/usr/local/google/gitc/test" + ), + "test", + ) + self.assertEqual( + self.wrapper.gitc_parse_clientdir( + "/test/usr/local/google/gitc/test/" + ), + "test", + ) + self.assertEqual( + self.wrapper.gitc_parse_clientdir( + "/test/usr/local/google/gitc/test/extra" + ), + "test", + ) + self.assertEqual( + self.wrapper.gitc_parse_clientdir("/gitc/manifest-rw/"), None + ) + self.assertEqual( + self.wrapper.gitc_parse_clientdir("/test/usr/local/google/gitc/"), + None, + ) class SetGitTrace2ParentSid(RepoWrapperTestCase): - """Check SetGitTrace2ParentSid behavior.""" + """Check SetGitTrace2ParentSid behavior.""" - KEY = 'GIT_TRACE2_PARENT_SID' - VALID_FORMAT = re.compile(r'^repo-[0-9]{8}T[0-9]{6}Z-P[0-9a-f]{8}$') + KEY = "GIT_TRACE2_PARENT_SID" + VALID_FORMAT = re.compile(r"^repo-[0-9]{8}T[0-9]{6}Z-P[0-9a-f]{8}$") - def test_first_set(self): - """Test env var not yet set.""" - env = {} - self.wrapper.SetGitTrace2ParentSid(env) - self.assertIn(self.KEY, env) - value = env[self.KEY] - self.assertRegex(value, self.VALID_FORMAT) + def test_first_set(self): + """Test env var not yet set.""" + env = {} + self.wrapper.SetGitTrace2ParentSid(env) + self.assertIn(self.KEY, env) + value = env[self.KEY] + self.assertRegex(value, self.VALID_FORMAT) - def test_append(self): - """Test env var is appended.""" - env = {self.KEY: 'pfx'} - self.wrapper.SetGitTrace2ParentSid(env) - self.assertIn(self.KEY, env) - value = env[self.KEY] - self.assertTrue(value.startswith('pfx/')) - self.assertRegex(value[4:], self.VALID_FORMAT) + def test_append(self): + """Test env var is appended.""" + env = {self.KEY: "pfx"} + self.wrapper.SetGitTrace2ParentSid(env) + self.assertIn(self.KEY, env) + value = env[self.KEY] + self.assertTrue(value.startswith("pfx/")) + self.assertRegex(value[4:], self.VALID_FORMAT) - def test_global_context(self): - """Check os.environ gets updated by default.""" - os.environ.pop(self.KEY, None) - self.wrapper.SetGitTrace2ParentSid() - self.assertIn(self.KEY, os.environ) - value = os.environ[self.KEY] - self.assertRegex(value, self.VALID_FORMAT) + def test_global_context(self): + """Check os.environ gets updated by default.""" + os.environ.pop(self.KEY, None) + self.wrapper.SetGitTrace2ParentSid() + self.assertIn(self.KEY, os.environ) + value = os.environ[self.KEY] + self.assertRegex(value, self.VALID_FORMAT) class RunCommand(RepoWrapperTestCase): - """Check run_command behavior.""" + """Check run_command behavior.""" - def test_capture(self): - """Check capture_output handling.""" - ret = self.wrapper.run_command(['echo', 'hi'], capture_output=True) - # echo command appends OS specific linesep, but on Windows + Git Bash - # we get UNIX ending, so we allow both. - self.assertIn(ret.stdout, ['hi' + os.linesep, 'hi\n']) + def test_capture(self): + """Check capture_output handling.""" + ret = self.wrapper.run_command(["echo", "hi"], capture_output=True) + # echo command appends OS specific linesep, but on Windows + Git Bash + # we get UNIX ending, so we allow both. + self.assertIn(ret.stdout, ["hi" + os.linesep, "hi\n"]) - def test_check(self): - """Check check handling.""" - self.wrapper.run_command(['true'], check=False) - self.wrapper.run_command(['true'], check=True) - self.wrapper.run_command(['false'], check=False) - with self.assertRaises(self.wrapper.RunError): - self.wrapper.run_command(['false'], check=True) + def test_check(self): + """Check check handling.""" + self.wrapper.run_command(["true"], check=False) + self.wrapper.run_command(["true"], check=True) + self.wrapper.run_command(["false"], check=False) + with self.assertRaises(self.wrapper.RunError): + self.wrapper.run_command(["false"], check=True) class RunGit(RepoWrapperTestCase): - """Check run_git behavior.""" + """Check run_git behavior.""" - def test_capture(self): - """Check capture_output handling.""" - ret = self.wrapper.run_git('--version') - self.assertIn('git', ret.stdout) + def test_capture(self): + """Check capture_output handling.""" + ret = self.wrapper.run_git("--version") + self.assertIn("git", ret.stdout) - def test_check(self): - """Check check handling.""" - with self.assertRaises(self.wrapper.CloneFailure): - self.wrapper.run_git('--version-asdfasdf') - self.wrapper.run_git('--version-asdfasdf', check=False) + def test_check(self): + """Check check handling.""" + with self.assertRaises(self.wrapper.CloneFailure): + self.wrapper.run_git("--version-asdfasdf") + self.wrapper.run_git("--version-asdfasdf", check=False) class ParseGitVersion(RepoWrapperTestCase): - """Check ParseGitVersion behavior.""" + """Check ParseGitVersion behavior.""" - def test_autoload(self): - """Check we can load the version from the live git.""" - ret = self.wrapper.ParseGitVersion() - self.assertIsNotNone(ret) + def test_autoload(self): + """Check we can load the version from the live git.""" + ret = self.wrapper.ParseGitVersion() + self.assertIsNotNone(ret) - def test_bad_ver(self): - """Check handling of bad git versions.""" - ret = self.wrapper.ParseGitVersion(ver_str='asdf') - self.assertIsNone(ret) + def test_bad_ver(self): + """Check handling of bad git versions.""" + ret = self.wrapper.ParseGitVersion(ver_str="asdf") + self.assertIsNone(ret) - def test_normal_ver(self): - """Check handling of normal git versions.""" - ret = self.wrapper.ParseGitVersion(ver_str='git version 2.25.1') - self.assertEqual(2, ret.major) - self.assertEqual(25, ret.minor) - self.assertEqual(1, ret.micro) - self.assertEqual('2.25.1', ret.full) + def test_normal_ver(self): + """Check handling of normal git versions.""" + ret = self.wrapper.ParseGitVersion(ver_str="git version 2.25.1") + self.assertEqual(2, ret.major) + self.assertEqual(25, ret.minor) + self.assertEqual(1, ret.micro) + self.assertEqual("2.25.1", ret.full) - def test_extended_ver(self): - """Check handling of extended distro git versions.""" - ret = self.wrapper.ParseGitVersion( - ver_str='git version 1.30.50.696.g5e7596f4ac-goog') - self.assertEqual(1, ret.major) - self.assertEqual(30, ret.minor) - self.assertEqual(50, ret.micro) - self.assertEqual('1.30.50.696.g5e7596f4ac-goog', ret.full) + def test_extended_ver(self): + """Check handling of extended distro git versions.""" + ret = self.wrapper.ParseGitVersion( + ver_str="git version 1.30.50.696.g5e7596f4ac-goog" + ) + self.assertEqual(1, ret.major) + self.assertEqual(30, ret.minor) + self.assertEqual(50, ret.micro) + self.assertEqual("1.30.50.696.g5e7596f4ac-goog", ret.full) class CheckGitVersion(RepoWrapperTestCase): - """Check _CheckGitVersion behavior.""" + """Check _CheckGitVersion behavior.""" - def test_unknown(self): - """Unknown versions should abort.""" - with mock.patch.object(self.wrapper, 'ParseGitVersion', return_value=None): - with self.assertRaises(self.wrapper.CloneFailure): - self.wrapper._CheckGitVersion() + def test_unknown(self): + """Unknown versions should abort.""" + with mock.patch.object( + self.wrapper, "ParseGitVersion", return_value=None + ): + with self.assertRaises(self.wrapper.CloneFailure): + self.wrapper._CheckGitVersion() - def test_old(self): - """Old versions should abort.""" - with mock.patch.object( - self.wrapper, 'ParseGitVersion', - return_value=self.wrapper.GitVersion(1, 0, 0, '1.0.0')): - with self.assertRaises(self.wrapper.CloneFailure): - self.wrapper._CheckGitVersion() + def test_old(self): + """Old versions should abort.""" + with mock.patch.object( + self.wrapper, + "ParseGitVersion", + return_value=self.wrapper.GitVersion(1, 0, 0, "1.0.0"), + ): + with self.assertRaises(self.wrapper.CloneFailure): + self.wrapper._CheckGitVersion() - def test_new(self): - """Newer versions should run fine.""" - with mock.patch.object( - self.wrapper, 'ParseGitVersion', - return_value=self.wrapper.GitVersion(100, 0, 0, '100.0.0')): - self.wrapper._CheckGitVersion() + def test_new(self): + """Newer versions should run fine.""" + with mock.patch.object( + self.wrapper, + "ParseGitVersion", + return_value=self.wrapper.GitVersion(100, 0, 0, "100.0.0"), + ): + self.wrapper._CheckGitVersion() class Requirements(RepoWrapperTestCase): - """Check Requirements handling.""" + """Check Requirements handling.""" - def test_missing_file(self): - """Don't crash if the file is missing (old version).""" - testdir = os.path.dirname(os.path.realpath(__file__)) - self.assertIsNone(self.wrapper.Requirements.from_dir(testdir)) - self.assertIsNone(self.wrapper.Requirements.from_file( - os.path.join(testdir, 'xxxxxxxxxxxxxxxxxxxxxxxx'))) + def test_missing_file(self): + """Don't crash if the file is missing (old version).""" + testdir = os.path.dirname(os.path.realpath(__file__)) + self.assertIsNone(self.wrapper.Requirements.from_dir(testdir)) + self.assertIsNone( + self.wrapper.Requirements.from_file( + os.path.join(testdir, "xxxxxxxxxxxxxxxxxxxxxxxx") + ) + ) - def test_corrupt_data(self): - """If the file can't be parsed, don't blow up.""" - self.assertIsNone(self.wrapper.Requirements.from_file(__file__)) - self.assertIsNone(self.wrapper.Requirements.from_data(b'x')) + def test_corrupt_data(self): + """If the file can't be parsed, don't blow up.""" + self.assertIsNone(self.wrapper.Requirements.from_file(__file__)) + self.assertIsNone(self.wrapper.Requirements.from_data(b"x")) - def test_valid_data(self): - """Make sure we can parse the file we ship.""" - self.assertIsNotNone(self.wrapper.Requirements.from_data(b'{}')) - rootdir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) - self.assertIsNotNone(self.wrapper.Requirements.from_dir(rootdir)) - self.assertIsNotNone(self.wrapper.Requirements.from_file(os.path.join( - rootdir, 'requirements.json'))) + def test_valid_data(self): + """Make sure we can parse the file we ship.""" + self.assertIsNotNone(self.wrapper.Requirements.from_data(b"{}")) + rootdir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) + self.assertIsNotNone(self.wrapper.Requirements.from_dir(rootdir)) + self.assertIsNotNone( + self.wrapper.Requirements.from_file( + os.path.join(rootdir, "requirements.json") + ) + ) - def test_format_ver(self): - """Check format_ver can format.""" - self.assertEqual('1.2.3', self.wrapper.Requirements._format_ver((1, 2, 3))) - self.assertEqual('1', self.wrapper.Requirements._format_ver([1])) + def test_format_ver(self): + """Check format_ver can format.""" + self.assertEqual( + "1.2.3", self.wrapper.Requirements._format_ver((1, 2, 3)) + ) + self.assertEqual("1", self.wrapper.Requirements._format_ver([1])) - def test_assert_all_unknown(self): - """Check assert_all works with incompatible file.""" - reqs = self.wrapper.Requirements({}) - reqs.assert_all() + def test_assert_all_unknown(self): + """Check assert_all works with incompatible file.""" + reqs = self.wrapper.Requirements({}) + reqs.assert_all() - def test_assert_all_new_repo(self): - """Check assert_all accepts new enough repo.""" - reqs = self.wrapper.Requirements({'repo': {'hard': [1, 0]}}) - reqs.assert_all() + def test_assert_all_new_repo(self): + """Check assert_all accepts new enough repo.""" + reqs = self.wrapper.Requirements({"repo": {"hard": [1, 0]}}) + reqs.assert_all() - def test_assert_all_old_repo(self): - """Check assert_all rejects old repo.""" - reqs = self.wrapper.Requirements({'repo': {'hard': [99999, 0]}}) - with self.assertRaises(SystemExit): - reqs.assert_all() + def test_assert_all_old_repo(self): + """Check assert_all rejects old repo.""" + reqs = self.wrapper.Requirements({"repo": {"hard": [99999, 0]}}) + with self.assertRaises(SystemExit): + reqs.assert_all() - def test_assert_all_new_python(self): - """Check assert_all accepts new enough python.""" - reqs = self.wrapper.Requirements({'python': {'hard': sys.version_info}}) - reqs.assert_all() + def test_assert_all_new_python(self): + """Check assert_all accepts new enough python.""" + reqs = self.wrapper.Requirements({"python": {"hard": sys.version_info}}) + reqs.assert_all() - def test_assert_all_old_python(self): - """Check assert_all rejects old python.""" - reqs = self.wrapper.Requirements({'python': {'hard': [99999, 0]}}) - with self.assertRaises(SystemExit): - reqs.assert_all() + def test_assert_all_old_python(self): + """Check assert_all rejects old python.""" + reqs = self.wrapper.Requirements({"python": {"hard": [99999, 0]}}) + with self.assertRaises(SystemExit): + reqs.assert_all() - def test_assert_ver_unknown(self): - """Check assert_ver works with incompatible file.""" - reqs = self.wrapper.Requirements({}) - reqs.assert_ver('xxx', (1, 0)) + def test_assert_ver_unknown(self): + """Check assert_ver works with incompatible file.""" + reqs = self.wrapper.Requirements({}) + reqs.assert_ver("xxx", (1, 0)) - def test_assert_ver_new(self): - """Check assert_ver allows new enough versions.""" - reqs = self.wrapper.Requirements({'git': {'hard': [1, 0], 'soft': [2, 0]}}) - reqs.assert_ver('git', (1, 0)) - reqs.assert_ver('git', (1, 5)) - reqs.assert_ver('git', (2, 0)) - reqs.assert_ver('git', (2, 5)) + def test_assert_ver_new(self): + """Check assert_ver allows new enough versions.""" + reqs = self.wrapper.Requirements( + {"git": {"hard": [1, 0], "soft": [2, 0]}} + ) + reqs.assert_ver("git", (1, 0)) + reqs.assert_ver("git", (1, 5)) + reqs.assert_ver("git", (2, 0)) + reqs.assert_ver("git", (2, 5)) - def test_assert_ver_old(self): - """Check assert_ver rejects old versions.""" - reqs = self.wrapper.Requirements({'git': {'hard': [1, 0], 'soft': [2, 0]}}) - with self.assertRaises(SystemExit): - reqs.assert_ver('git', (0, 5)) + def test_assert_ver_old(self): + """Check assert_ver rejects old versions.""" + reqs = self.wrapper.Requirements( + {"git": {"hard": [1, 0], "soft": [2, 0]}} + ) + with self.assertRaises(SystemExit): + reqs.assert_ver("git", (0, 5)) class NeedSetupGnuPG(RepoWrapperTestCase): - """Check NeedSetupGnuPG behavior.""" + """Check NeedSetupGnuPG behavior.""" - def test_missing_dir(self): - """The ~/.repoconfig tree doesn't exist yet.""" - with tempfile.TemporaryDirectory(prefix='repo-tests') as tempdir: - self.wrapper.home_dot_repo = os.path.join(tempdir, 'foo') - self.assertTrue(self.wrapper.NeedSetupGnuPG()) + def test_missing_dir(self): + """The ~/.repoconfig tree doesn't exist yet.""" + with tempfile.TemporaryDirectory(prefix="repo-tests") as tempdir: + self.wrapper.home_dot_repo = os.path.join(tempdir, "foo") + self.assertTrue(self.wrapper.NeedSetupGnuPG()) - def test_missing_keyring(self): - """The keyring-version file doesn't exist yet.""" - with tempfile.TemporaryDirectory(prefix='repo-tests') as tempdir: - self.wrapper.home_dot_repo = tempdir - self.assertTrue(self.wrapper.NeedSetupGnuPG()) + def test_missing_keyring(self): + """The keyring-version file doesn't exist yet.""" + with tempfile.TemporaryDirectory(prefix="repo-tests") as tempdir: + self.wrapper.home_dot_repo = tempdir + self.assertTrue(self.wrapper.NeedSetupGnuPG()) - def test_empty_keyring(self): - """The keyring-version file exists, but is empty.""" - with tempfile.TemporaryDirectory(prefix='repo-tests') as tempdir: - self.wrapper.home_dot_repo = tempdir - with open(os.path.join(tempdir, 'keyring-version'), 'w'): - pass - self.assertTrue(self.wrapper.NeedSetupGnuPG()) + def test_empty_keyring(self): + """The keyring-version file exists, but is empty.""" + with tempfile.TemporaryDirectory(prefix="repo-tests") as tempdir: + self.wrapper.home_dot_repo = tempdir + with open(os.path.join(tempdir, "keyring-version"), "w"): + pass + self.assertTrue(self.wrapper.NeedSetupGnuPG()) - def test_old_keyring(self): - """The keyring-version file exists, but it's old.""" - with tempfile.TemporaryDirectory(prefix='repo-tests') as tempdir: - self.wrapper.home_dot_repo = tempdir - with open(os.path.join(tempdir, 'keyring-version'), 'w') as fp: - fp.write('1.0\n') - self.assertTrue(self.wrapper.NeedSetupGnuPG()) + def test_old_keyring(self): + """The keyring-version file exists, but it's old.""" + with tempfile.TemporaryDirectory(prefix="repo-tests") as tempdir: + self.wrapper.home_dot_repo = tempdir + with open(os.path.join(tempdir, "keyring-version"), "w") as fp: + fp.write("1.0\n") + self.assertTrue(self.wrapper.NeedSetupGnuPG()) - def test_new_keyring(self): - """The keyring-version file exists, and is up-to-date.""" - with tempfile.TemporaryDirectory(prefix='repo-tests') as tempdir: - self.wrapper.home_dot_repo = tempdir - with open(os.path.join(tempdir, 'keyring-version'), 'w') as fp: - fp.write('1000.0\n') - self.assertFalse(self.wrapper.NeedSetupGnuPG()) + def test_new_keyring(self): + """The keyring-version file exists, and is up-to-date.""" + with tempfile.TemporaryDirectory(prefix="repo-tests") as tempdir: + self.wrapper.home_dot_repo = tempdir + with open(os.path.join(tempdir, "keyring-version"), "w") as fp: + fp.write("1000.0\n") + self.assertFalse(self.wrapper.NeedSetupGnuPG()) class SetupGnuPG(RepoWrapperTestCase): - """Check SetupGnuPG behavior.""" + """Check SetupGnuPG behavior.""" - def test_full(self): - """Make sure it works completely.""" - with tempfile.TemporaryDirectory(prefix='repo-tests') as tempdir: - self.wrapper.home_dot_repo = tempdir - self.wrapper.gpg_dir = os.path.join(self.wrapper.home_dot_repo, 'gnupg') - self.assertTrue(self.wrapper.SetupGnuPG(True)) - with open(os.path.join(tempdir, 'keyring-version'), 'r') as fp: - data = fp.read() - self.assertEqual('.'.join(str(x) for x in self.wrapper.KEYRING_VERSION), - data.strip()) + def test_full(self): + """Make sure it works completely.""" + with tempfile.TemporaryDirectory(prefix="repo-tests") as tempdir: + self.wrapper.home_dot_repo = tempdir + self.wrapper.gpg_dir = os.path.join( + self.wrapper.home_dot_repo, "gnupg" + ) + self.assertTrue(self.wrapper.SetupGnuPG(True)) + with open(os.path.join(tempdir, "keyring-version"), "r") as fp: + data = fp.read() + self.assertEqual( + ".".join(str(x) for x in self.wrapper.KEYRING_VERSION), + data.strip(), + ) class VerifyRev(RepoWrapperTestCase): - """Check verify_rev behavior.""" + """Check verify_rev behavior.""" - def test_verify_passes(self): - """Check when we have a valid signed tag.""" - desc_result = self.wrapper.RunResult(0, 'v1.0\n', '') - gpg_result = self.wrapper.RunResult(0, '', '') - with mock.patch.object(self.wrapper, 'run_git', - side_effect=(desc_result, gpg_result)): - ret = self.wrapper.verify_rev('/', 'refs/heads/stable', '1234', True) - self.assertEqual('v1.0^0', ret) + def test_verify_passes(self): + """Check when we have a valid signed tag.""" + desc_result = self.wrapper.RunResult(0, "v1.0\n", "") + gpg_result = self.wrapper.RunResult(0, "", "") + with mock.patch.object( + self.wrapper, "run_git", side_effect=(desc_result, gpg_result) + ): + ret = self.wrapper.verify_rev( + "/", "refs/heads/stable", "1234", True + ) + self.assertEqual("v1.0^0", ret) - def test_unsigned_commit(self): - """Check we fall back to signed tag when we have an unsigned commit.""" - desc_result = self.wrapper.RunResult(0, 'v1.0-10-g1234\n', '') - gpg_result = self.wrapper.RunResult(0, '', '') - with mock.patch.object(self.wrapper, 'run_git', - side_effect=(desc_result, gpg_result)): - ret = self.wrapper.verify_rev('/', 'refs/heads/stable', '1234', True) - self.assertEqual('v1.0^0', ret) + def test_unsigned_commit(self): + """Check we fall back to signed tag when we have an unsigned commit.""" + desc_result = self.wrapper.RunResult(0, "v1.0-10-g1234\n", "") + gpg_result = self.wrapper.RunResult(0, "", "") + with mock.patch.object( + self.wrapper, "run_git", side_effect=(desc_result, gpg_result) + ): + ret = self.wrapper.verify_rev( + "/", "refs/heads/stable", "1234", True + ) + self.assertEqual("v1.0^0", ret) - def test_verify_fails(self): - """Check we fall back to signed tag when we have an unsigned commit.""" - desc_result = self.wrapper.RunResult(0, 'v1.0-10-g1234\n', '') - gpg_result = Exception - with mock.patch.object(self.wrapper, 'run_git', - side_effect=(desc_result, gpg_result)): - with self.assertRaises(Exception): - self.wrapper.verify_rev('/', 'refs/heads/stable', '1234', True) + def test_verify_fails(self): + """Check we fall back to signed tag when we have an unsigned commit.""" + desc_result = self.wrapper.RunResult(0, "v1.0-10-g1234\n", "") + gpg_result = Exception + with mock.patch.object( + self.wrapper, "run_git", side_effect=(desc_result, gpg_result) + ): + with self.assertRaises(Exception): + self.wrapper.verify_rev("/", "refs/heads/stable", "1234", True) class GitCheckoutTestCase(RepoWrapperTestCase): - """Tests that use a real/small git checkout.""" + """Tests that use a real/small git checkout.""" - GIT_DIR = None - REV_LIST = None + GIT_DIR = None + REV_LIST = None - @classmethod - def setUpClass(cls): - # Create a repo to operate on, but do it once per-class. - cls.tempdirobj = tempfile.TemporaryDirectory(prefix='repo-rev-tests') - cls.GIT_DIR = cls.tempdirobj.name - run_git = wrapper.Wrapper().run_git + @classmethod + def setUpClass(cls): + # Create a repo to operate on, but do it once per-class. + cls.tempdirobj = tempfile.TemporaryDirectory(prefix="repo-rev-tests") + cls.GIT_DIR = cls.tempdirobj.name + run_git = wrapper.Wrapper().run_git - remote = os.path.join(cls.GIT_DIR, 'remote') - os.mkdir(remote) + remote = os.path.join(cls.GIT_DIR, "remote") + os.mkdir(remote) - # Tests need to assume, that main is default branch at init, - # which is not supported in config until 2.28. - if git_command.git_require((2, 28, 0)): - initstr = '--initial-branch=main' - else: - # Use template dir for init. - templatedir = tempfile.mkdtemp(prefix='.test-template') - with open(os.path.join(templatedir, 'HEAD'), 'w') as fp: - fp.write('ref: refs/heads/main\n') - initstr = '--template=' + templatedir + # Tests need to assume, that main is default branch at init, + # which is not supported in config until 2.28. + if git_command.git_require((2, 28, 0)): + initstr = "--initial-branch=main" + else: + # Use template dir for init. + templatedir = tempfile.mkdtemp(prefix=".test-template") + with open(os.path.join(templatedir, "HEAD"), "w") as fp: + fp.write("ref: refs/heads/main\n") + initstr = "--template=" + templatedir - run_git('init', initstr, cwd=remote) - run_git('commit', '--allow-empty', '-minit', cwd=remote) - run_git('branch', 'stable', cwd=remote) - run_git('tag', 'v1.0', cwd=remote) - run_git('commit', '--allow-empty', '-m2nd commit', cwd=remote) - cls.REV_LIST = run_git('rev-list', 'HEAD', cwd=remote).stdout.splitlines() + run_git("init", initstr, cwd=remote) + run_git("commit", "--allow-empty", "-minit", cwd=remote) + run_git("branch", "stable", cwd=remote) + run_git("tag", "v1.0", cwd=remote) + run_git("commit", "--allow-empty", "-m2nd commit", cwd=remote) + cls.REV_LIST = run_git( + "rev-list", "HEAD", cwd=remote + ).stdout.splitlines() - run_git('init', cwd=cls.GIT_DIR) - run_git('fetch', remote, '+refs/heads/*:refs/remotes/origin/*', cwd=cls.GIT_DIR) + run_git("init", cwd=cls.GIT_DIR) + run_git( + "fetch", + remote, + "+refs/heads/*:refs/remotes/origin/*", + cwd=cls.GIT_DIR, + ) - @classmethod - def tearDownClass(cls): - if not cls.tempdirobj: - return + @classmethod + def tearDownClass(cls): + if not cls.tempdirobj: + return - cls.tempdirobj.cleanup() + cls.tempdirobj.cleanup() class ResolveRepoRev(GitCheckoutTestCase): - """Check resolve_repo_rev behavior.""" + """Check resolve_repo_rev behavior.""" - def test_explicit_branch(self): - """Check refs/heads/branch argument.""" - rrev, lrev = self.wrapper.resolve_repo_rev(self.GIT_DIR, 'refs/heads/stable') - self.assertEqual('refs/heads/stable', rrev) - self.assertEqual(self.REV_LIST[1], lrev) + def test_explicit_branch(self): + """Check refs/heads/branch argument.""" + rrev, lrev = self.wrapper.resolve_repo_rev( + self.GIT_DIR, "refs/heads/stable" + ) + self.assertEqual("refs/heads/stable", rrev) + self.assertEqual(self.REV_LIST[1], lrev) - with self.assertRaises(self.wrapper.CloneFailure): - self.wrapper.resolve_repo_rev(self.GIT_DIR, 'refs/heads/unknown') + with self.assertRaises(self.wrapper.CloneFailure): + self.wrapper.resolve_repo_rev(self.GIT_DIR, "refs/heads/unknown") - def test_explicit_tag(self): - """Check refs/tags/tag argument.""" - rrev, lrev = self.wrapper.resolve_repo_rev(self.GIT_DIR, 'refs/tags/v1.0') - self.assertEqual('refs/tags/v1.0', rrev) - self.assertEqual(self.REV_LIST[1], lrev) + def test_explicit_tag(self): + """Check refs/tags/tag argument.""" + rrev, lrev = self.wrapper.resolve_repo_rev( + self.GIT_DIR, "refs/tags/v1.0" + ) + self.assertEqual("refs/tags/v1.0", rrev) + self.assertEqual(self.REV_LIST[1], lrev) - with self.assertRaises(self.wrapper.CloneFailure): - self.wrapper.resolve_repo_rev(self.GIT_DIR, 'refs/tags/unknown') + with self.assertRaises(self.wrapper.CloneFailure): + self.wrapper.resolve_repo_rev(self.GIT_DIR, "refs/tags/unknown") - def test_branch_name(self): - """Check branch argument.""" - rrev, lrev = self.wrapper.resolve_repo_rev(self.GIT_DIR, 'stable') - self.assertEqual('refs/heads/stable', rrev) - self.assertEqual(self.REV_LIST[1], lrev) + def test_branch_name(self): + """Check branch argument.""" + rrev, lrev = self.wrapper.resolve_repo_rev(self.GIT_DIR, "stable") + self.assertEqual("refs/heads/stable", rrev) + self.assertEqual(self.REV_LIST[1], lrev) - rrev, lrev = self.wrapper.resolve_repo_rev(self.GIT_DIR, 'main') - self.assertEqual('refs/heads/main', rrev) - self.assertEqual(self.REV_LIST[0], lrev) + rrev, lrev = self.wrapper.resolve_repo_rev(self.GIT_DIR, "main") + self.assertEqual("refs/heads/main", rrev) + self.assertEqual(self.REV_LIST[0], lrev) - def test_tag_name(self): - """Check tag argument.""" - rrev, lrev = self.wrapper.resolve_repo_rev(self.GIT_DIR, 'v1.0') - self.assertEqual('refs/tags/v1.0', rrev) - self.assertEqual(self.REV_LIST[1], lrev) + def test_tag_name(self): + """Check tag argument.""" + rrev, lrev = self.wrapper.resolve_repo_rev(self.GIT_DIR, "v1.0") + self.assertEqual("refs/tags/v1.0", rrev) + self.assertEqual(self.REV_LIST[1], lrev) - def test_full_commit(self): - """Check specific commit argument.""" - commit = self.REV_LIST[0] - rrev, lrev = self.wrapper.resolve_repo_rev(self.GIT_DIR, commit) - self.assertEqual(commit, rrev) - self.assertEqual(commit, lrev) + def test_full_commit(self): + """Check specific commit argument.""" + commit = self.REV_LIST[0] + rrev, lrev = self.wrapper.resolve_repo_rev(self.GIT_DIR, commit) + self.assertEqual(commit, rrev) + self.assertEqual(commit, lrev) - def test_partial_commit(self): - """Check specific (partial) commit argument.""" - commit = self.REV_LIST[0][0:20] - rrev, lrev = self.wrapper.resolve_repo_rev(self.GIT_DIR, commit) - self.assertEqual(self.REV_LIST[0], rrev) - self.assertEqual(self.REV_LIST[0], lrev) + def test_partial_commit(self): + """Check specific (partial) commit argument.""" + commit = self.REV_LIST[0][0:20] + rrev, lrev = self.wrapper.resolve_repo_rev(self.GIT_DIR, commit) + self.assertEqual(self.REV_LIST[0], rrev) + self.assertEqual(self.REV_LIST[0], lrev) - def test_unknown(self): - """Check unknown ref/commit argument.""" - with self.assertRaises(self.wrapper.CloneFailure): - self.wrapper.resolve_repo_rev(self.GIT_DIR, 'boooooooya') + def test_unknown(self): + """Check unknown ref/commit argument.""" + with self.assertRaises(self.wrapper.CloneFailure): + self.wrapper.resolve_repo_rev(self.GIT_DIR, "boooooooya") class CheckRepoVerify(RepoWrapperTestCase): - """Check check_repo_verify behavior.""" + """Check check_repo_verify behavior.""" - def test_no_verify(self): - """Always fail with --no-repo-verify.""" - self.assertFalse(self.wrapper.check_repo_verify(False)) + def test_no_verify(self): + """Always fail with --no-repo-verify.""" + self.assertFalse(self.wrapper.check_repo_verify(False)) - def test_gpg_initialized(self): - """Should pass if gpg is setup already.""" - with mock.patch.object(self.wrapper, 'NeedSetupGnuPG', return_value=False): - self.assertTrue(self.wrapper.check_repo_verify(True)) + def test_gpg_initialized(self): + """Should pass if gpg is setup already.""" + with mock.patch.object( + self.wrapper, "NeedSetupGnuPG", return_value=False + ): + self.assertTrue(self.wrapper.check_repo_verify(True)) - def test_need_gpg_setup(self): - """Should pass/fail based on gpg setup.""" - with mock.patch.object(self.wrapper, 'NeedSetupGnuPG', return_value=True): - with mock.patch.object(self.wrapper, 'SetupGnuPG') as m: - m.return_value = True - self.assertTrue(self.wrapper.check_repo_verify(True)) + def test_need_gpg_setup(self): + """Should pass/fail based on gpg setup.""" + with mock.patch.object( + self.wrapper, "NeedSetupGnuPG", return_value=True + ): + with mock.patch.object(self.wrapper, "SetupGnuPG") as m: + m.return_value = True + self.assertTrue(self.wrapper.check_repo_verify(True)) - m.return_value = False - self.assertFalse(self.wrapper.check_repo_verify(True)) + m.return_value = False + self.assertFalse(self.wrapper.check_repo_verify(True)) class CheckRepoRev(GitCheckoutTestCase): - """Check check_repo_rev behavior.""" + """Check check_repo_rev behavior.""" - def test_verify_works(self): - """Should pass when verification passes.""" - with mock.patch.object(self.wrapper, 'check_repo_verify', return_value=True): - with mock.patch.object(self.wrapper, 'verify_rev', return_value='12345'): - rrev, lrev = self.wrapper.check_repo_rev(self.GIT_DIR, 'stable') - self.assertEqual('refs/heads/stable', rrev) - self.assertEqual('12345', lrev) + def test_verify_works(self): + """Should pass when verification passes.""" + with mock.patch.object( + self.wrapper, "check_repo_verify", return_value=True + ): + with mock.patch.object( + self.wrapper, "verify_rev", return_value="12345" + ): + rrev, lrev = self.wrapper.check_repo_rev(self.GIT_DIR, "stable") + self.assertEqual("refs/heads/stable", rrev) + self.assertEqual("12345", lrev) - def test_verify_fails(self): - """Should fail when verification fails.""" - with mock.patch.object(self.wrapper, 'check_repo_verify', return_value=True): - with mock.patch.object(self.wrapper, 'verify_rev', side_effect=Exception): - with self.assertRaises(Exception): - self.wrapper.check_repo_rev(self.GIT_DIR, 'stable') + def test_verify_fails(self): + """Should fail when verification fails.""" + with mock.patch.object( + self.wrapper, "check_repo_verify", return_value=True + ): + with mock.patch.object( + self.wrapper, "verify_rev", side_effect=Exception + ): + with self.assertRaises(Exception): + self.wrapper.check_repo_rev(self.GIT_DIR, "stable") - def test_verify_ignore(self): - """Should pass when verification is disabled.""" - with mock.patch.object(self.wrapper, 'verify_rev', side_effect=Exception): - rrev, lrev = self.wrapper.check_repo_rev(self.GIT_DIR, 'stable', repo_verify=False) - self.assertEqual('refs/heads/stable', rrev) - self.assertEqual(self.REV_LIST[1], lrev) + def test_verify_ignore(self): + """Should pass when verification is disabled.""" + with mock.patch.object( + self.wrapper, "verify_rev", side_effect=Exception + ): + rrev, lrev = self.wrapper.check_repo_rev( + self.GIT_DIR, "stable", repo_verify=False + ) + self.assertEqual("refs/heads/stable", rrev) + self.assertEqual(self.REV_LIST[1], lrev) diff --git a/tox.ini b/tox.ini index 8d3cc43c..2575a713 100644 --- a/tox.ini +++ b/tox.ini @@ -27,6 +27,7 @@ python = [testenv] deps = + black pytest pytest-timeout commands = {envpython} run_tests {posargs} diff --git a/wrapper.py b/wrapper.py index 3099ad5d..d8823368 100644 --- a/wrapper.py +++ b/wrapper.py @@ -19,14 +19,14 @@ import os def WrapperPath(): - return os.path.join(os.path.dirname(__file__), 'repo') + return os.path.join(os.path.dirname(__file__), "repo") @functools.lru_cache(maxsize=None) def Wrapper(): - modname = 'wrapper' - loader = importlib.machinery.SourceFileLoader(modname, WrapperPath()) - spec = importlib.util.spec_from_loader(modname, loader) - module = importlib.util.module_from_spec(spec) - loader.exec_module(module) - return module + modname = "wrapper" + loader = importlib.machinery.SourceFileLoader(modname, WrapperPath()) + spec = importlib.util.spec_from_loader(modname, loader) + module = importlib.util.module_from_spec(spec) + loader.exec_module(module) + return module