diff --git a/git_command.py b/git_command.py index d06fc77c..f8cb280c 100644 --- a/git_command.py +++ b/git_command.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import functools import os import re import sys @@ -45,7 +46,6 @@ LAST_CWD = None _ssh_proxy_path = None _ssh_sock_path = None _ssh_clients = [] -_ssh_version = None def _run_ssh_version(): @@ -64,16 +64,14 @@ def _parse_ssh_version(ver_str=None): return () +@functools.lru_cache(maxsize=None) def ssh_version(): """return ssh version as a tuple""" - global _ssh_version - if _ssh_version is None: - try: - _ssh_version = _parse_ssh_version() - except subprocess.CalledProcessError: - print('fatal: unable to detect ssh version', file=sys.stderr) - sys.exit(1) - return _ssh_version + try: + return _parse_ssh_version() + except subprocess.CalledProcessError: + print('fatal: unable to detect ssh version', file=sys.stderr) + sys.exit(1) def ssh_sock(create=True): @@ -125,18 +123,14 @@ def terminate_ssh_clients(): _ssh_clients = [] -_git_version = None - - class _GitCall(object): + @functools.lru_cache(maxsize=None) def version_tuple(self): - global _git_version - if _git_version is None: - _git_version = Wrapper().ParseGitVersion() - if _git_version is None: - print('fatal: unable to detect git version', file=sys.stderr) - sys.exit(1) - return _git_version + ret = Wrapper().ParseGitVersion() + if ret is None: + print('fatal: unable to detect git version', file=sys.stderr) + sys.exit(1) + return ret def __getattr__(self, name): name = name.replace('_', '-') diff --git a/tests/test_git_command.py b/tests/test_git_command.py index 912a9dbe..76c092f4 100644 --- a/tests/test_git_command.py +++ b/tests/test_git_command.py @@ -29,8 +29,8 @@ import wrapper class SSHUnitTest(unittest.TestCase): """Tests the ssh functions.""" - def test_ssh_version(self): - """Check ssh_version() handling.""" + def test_parse_ssh_version(self): + """Check parse_ssh_version() handling.""" ver = git_command._parse_ssh_version('Unknown\n') self.assertEqual(ver, ()) ver = git_command._parse_ssh_version('OpenSSH_1.0\n') @@ -40,6 +40,11 @@ class SSHUnitTest(unittest.TestCase): ver = git_command._parse_ssh_version('OpenSSH_7.6p1 Ubuntu-4ubuntu0.3, OpenSSL 1.0.2n 7 Dec 2017\n') self.assertEqual(ver, (7, 6)) + def test_ssh_version(self): + """Check ssh_version() handling.""" + with mock.patch('git_command._run_ssh_version', return_value='OpenSSH_1.2\n'): + self.assertEqual(git_command.ssh_version(), (1, 2)) + def test_ssh_sock(self): """Check ssh_sock() function.""" with mock.patch('tempfile.mkdtemp', return_value='/tmp/foo'):