git_command: switch version caches to functools

Simplifies the code a bit to use the stdlib cache helper.

Change-Id: I778e90100ce748a71cc3a5a5d67dda403334315e
Reviewed-on: https://gerrit-review.googlesource.com/c/git-repo/+/305482
Reviewed-by: Raman Tenneti <rtenneti@google.com>
Tested-by: Mike Frysinger <vapier@google.com>
This commit is contained in:
Mike Frysinger 2021-05-06 00:28:32 -04:00
parent 2f8fdbecde
commit 8e768eaaa7
2 changed files with 20 additions and 21 deletions

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import functools
import os import os
import re import re
import sys import sys
@ -45,7 +46,6 @@ LAST_CWD = None
_ssh_proxy_path = None _ssh_proxy_path = None
_ssh_sock_path = None _ssh_sock_path = None
_ssh_clients = [] _ssh_clients = []
_ssh_version = None
def _run_ssh_version(): def _run_ssh_version():
@ -64,16 +64,14 @@ def _parse_ssh_version(ver_str=None):
return () return ()
@functools.lru_cache(maxsize=None)
def ssh_version(): def ssh_version():
"""return ssh version as a tuple""" """return ssh version as a tuple"""
global _ssh_version try:
if _ssh_version is None: return _parse_ssh_version()
try: except subprocess.CalledProcessError:
_ssh_version = _parse_ssh_version() print('fatal: unable to detect ssh version', file=sys.stderr)
except subprocess.CalledProcessError: sys.exit(1)
print('fatal: unable to detect ssh version', file=sys.stderr)
sys.exit(1)
return _ssh_version
def ssh_sock(create=True): def ssh_sock(create=True):
@ -125,18 +123,14 @@ def terminate_ssh_clients():
_ssh_clients = [] _ssh_clients = []
_git_version = None
class _GitCall(object): class _GitCall(object):
@functools.lru_cache(maxsize=None)
def version_tuple(self): def version_tuple(self):
global _git_version ret = Wrapper().ParseGitVersion()
if _git_version is None: if ret is None:
_git_version = Wrapper().ParseGitVersion() print('fatal: unable to detect git version', file=sys.stderr)
if _git_version is None: sys.exit(1)
print('fatal: unable to detect git version', file=sys.stderr) return ret
sys.exit(1)
return _git_version
def __getattr__(self, name): def __getattr__(self, name):
name = name.replace('_', '-') name = name.replace('_', '-')

View File

@ -29,8 +29,8 @@ import wrapper
class SSHUnitTest(unittest.TestCase): class SSHUnitTest(unittest.TestCase):
"""Tests the ssh functions.""" """Tests the ssh functions."""
def test_ssh_version(self): def test_parse_ssh_version(self):
"""Check ssh_version() handling.""" """Check parse_ssh_version() handling."""
ver = git_command._parse_ssh_version('Unknown\n') ver = git_command._parse_ssh_version('Unknown\n')
self.assertEqual(ver, ()) self.assertEqual(ver, ())
ver = git_command._parse_ssh_version('OpenSSH_1.0\n') ver = git_command._parse_ssh_version('OpenSSH_1.0\n')
@ -40,6 +40,11 @@ class SSHUnitTest(unittest.TestCase):
ver = git_command._parse_ssh_version('OpenSSH_7.6p1 Ubuntu-4ubuntu0.3, OpenSSL 1.0.2n 7 Dec 2017\n') ver = git_command._parse_ssh_version('OpenSSH_7.6p1 Ubuntu-4ubuntu0.3, OpenSSL 1.0.2n 7 Dec 2017\n')
self.assertEqual(ver, (7, 6)) self.assertEqual(ver, (7, 6))
def test_ssh_version(self):
"""Check ssh_version() handling."""
with mock.patch('git_command._run_ssh_version', return_value='OpenSSH_1.2\n'):
self.assertEqual(git_command.ssh_version(), (1, 2))
def test_ssh_sock(self): def test_ssh_sock(self):
"""Check ssh_sock() function.""" """Check ssh_sock() function."""
with mock.patch('tempfile.mkdtemp', return_value='/tmp/foo'): with mock.patch('tempfile.mkdtemp', return_value='/tmp/foo'):