# Copyright (C) 2009 The Android Open Source Project # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import platform_utils from repo_trace import Trace HEAD = "HEAD" R_CHANGES = "refs/changes/" R_HEADS = "refs/heads/" R_TAGS = "refs/tags/" R_PUB = "refs/published/" R_WORKTREE = "refs/worktree/" R_WORKTREE_M = R_WORKTREE + "m/" R_M = "refs/remotes/m/" class GitRefs: def __init__(self, gitdir): self._gitdir = gitdir self._phyref = None self._symref = None self._mtime = {} @property def all(self): self._EnsureLoaded() return self._phyref def get(self, name): try: return self.all[name] except KeyError: return "" def deleted(self, name): if self._phyref is not None: if name in self._phyref: del self._phyref[name] if name in self._symref: del self._symref[name] if name in self._mtime: del self._mtime[name] def symref(self, name): try: self._EnsureLoaded() return self._symref[name] except KeyError: return "" def _EnsureLoaded(self): if self._phyref is None or self._NeedUpdate(): self._LoadAll() def _NeedUpdate(self): with Trace(": scan refs %s", self._gitdir): for name, mtime in self._mtime.items(): try: if mtime != os.path.getmtime( os.path.join(self._gitdir, name) ): return True except OSError: return True return False def _LoadAll(self): with Trace(": load refs %s", self._gitdir): self._phyref = {} self._symref = {} self._mtime = {} self._ReadPackedRefs() self._ReadLoose("refs/") self._ReadLoose1(os.path.join(self._gitdir, HEAD), HEAD) scan = self._symref attempts = 0 while scan and attempts < 5: scan_next = {} for name, dest in scan.items(): if dest in self._phyref: self._phyref[name] = self._phyref[dest] else: scan_next[name] = dest scan = scan_next attempts += 1 def _ReadPackedRefs(self): path = os.path.join(self._gitdir, "packed-refs") try: fd = open(path) mtime = os.path.getmtime(path) except IOError: return except OSError: return try: for line in fd: line = str(line) if line[0] == "#": continue if line[0] == "^": continue line = line[:-1] p = line.split(" ") ref_id = p[0] name = p[1] self._phyref[name] = ref_id finally: fd.close() self._mtime["packed-refs"] = mtime def _ReadLoose(self, prefix): base = os.path.join(self._gitdir, prefix) for name in platform_utils.listdir(base): p = os.path.join(base, name) # We don't implement the full ref validation algorithm, just the # simple rules that would show up in local filesystems. # https://git-scm.com/docs/git-check-ref-format if name.startswith(".") or name.endswith(".lock"): pass elif platform_utils.isdir(p): self._mtime[prefix] = os.path.getmtime(base) self._ReadLoose(prefix + name + "/") else: self._ReadLoose1(p, prefix + name) def _ReadLoose1(self, path, name): try: with open(path) as fd: mtime = os.path.getmtime(path) ref_id = fd.readline() except (OSError, UnicodeError): return try: ref_id = ref_id.decode() except AttributeError: pass if not ref_id: return ref_id = ref_id[:-1] if ref_id.startswith("ref: "): self._symref[name] = ref_id[5:] else: self._phyref[name] = ref_id self._mtime[name] = mtime