ssh: move all ssh logic to a common place

We had ssh logic sprinkled between two git modules, and neither was
quite the right home for it.  This largely moves the logic as-is to
its new home.  We'll leave major refactoring to followup commits.

Bug: https://crbug.com/gerrit/12389
Change-Id: I300a8f7dba74f2bd132232a5eb1e856a8490e0e9
Reviewed-on: https://gerrit-review.googlesource.com/c/git-repo/+/305483
Reviewed-by: Chris Mcdonald <cjmcdonald@google.com>
Tested-by: Mike Frysinger <vapier@google.com>
This commit is contained in:
Mike Frysinger 2021-05-05 15:53:03 -04:00
parent 8e768eaaa7
commit 5291eafa41
6 changed files with 320 additions and 275 deletions

View File

@ -14,16 +14,14 @@
import functools import functools
import os import os
import re
import sys import sys
import subprocess import subprocess
import tempfile
from signal import SIGTERM
from error import GitError from error import GitError
from git_refs import HEAD from git_refs import HEAD
import platform_utils import platform_utils
from repo_trace import REPO_TRACE, IsTrace, Trace from repo_trace import REPO_TRACE, IsTrace, Trace
import ssh
from wrapper import Wrapper from wrapper import Wrapper
GIT = 'git' GIT = 'git'
@ -43,85 +41,6 @@ GIT_DIR = 'GIT_DIR'
LAST_GITDIR = None LAST_GITDIR = None
LAST_CWD = None LAST_CWD = None
_ssh_proxy_path = None
_ssh_sock_path = None
_ssh_clients = []
def _run_ssh_version():
"""run ssh -V to display the version number"""
return subprocess.check_output(['ssh', '-V'], stderr=subprocess.STDOUT).decode()
def _parse_ssh_version(ver_str=None):
"""parse a ssh version string into a tuple"""
if ver_str is None:
ver_str = _run_ssh_version()
m = re.match(r'^OpenSSH_([0-9.]+)(p[0-9]+)?\s', ver_str)
if m:
return tuple(int(x) for x in m.group(1).split('.'))
else:
return ()
@functools.lru_cache(maxsize=None)
def ssh_version():
"""return ssh version as a tuple"""
try:
return _parse_ssh_version()
except subprocess.CalledProcessError:
print('fatal: unable to detect ssh version', file=sys.stderr)
sys.exit(1)
def ssh_sock(create=True):
global _ssh_sock_path
if _ssh_sock_path is None:
if not create:
return None
tmp_dir = '/tmp'
if not os.path.exists(tmp_dir):
tmp_dir = tempfile.gettempdir()
if ssh_version() < (6, 7):
tokens = '%r@%h:%p'
else:
tokens = '%C' # hash of %l%h%p%r
_ssh_sock_path = os.path.join(
tempfile.mkdtemp('', 'ssh-', tmp_dir),
'master-' + tokens)
return _ssh_sock_path
def _ssh_proxy():
global _ssh_proxy_path
if _ssh_proxy_path is None:
_ssh_proxy_path = os.path.join(
os.path.dirname(__file__),
'git_ssh')
return _ssh_proxy_path
def _add_ssh_client(p):
_ssh_clients.append(p)
def _remove_ssh_client(p):
try:
_ssh_clients.remove(p)
except ValueError:
pass
def terminate_ssh_clients():
global _ssh_clients
for p in _ssh_clients:
try:
os.kill(p.pid, SIGTERM)
p.wait()
except OSError:
pass
_ssh_clients = []
class _GitCall(object): class _GitCall(object):
@functools.lru_cache(maxsize=None) @functools.lru_cache(maxsize=None)
@ -256,8 +175,8 @@ class GitCommand(object):
if disable_editor: if disable_editor:
env['GIT_EDITOR'] = ':' env['GIT_EDITOR'] = ':'
if ssh_proxy: if ssh_proxy:
env['REPO_SSH_SOCK'] = ssh_sock() env['REPO_SSH_SOCK'] = ssh.sock()
env['GIT_SSH'] = _ssh_proxy() env['GIT_SSH'] = ssh.proxy()
env['GIT_SSH_VARIANT'] = 'ssh' env['GIT_SSH_VARIANT'] = 'ssh'
if 'http_proxy' in env and 'darwin' == sys.platform: if 'http_proxy' in env and 'darwin' == sys.platform:
s = "'http.proxy=%s'" % (env['http_proxy'],) s = "'http.proxy=%s'" % (env['http_proxy'],)
@ -340,7 +259,7 @@ class GitCommand(object):
raise GitError('%s: %s' % (command[1], e)) raise GitError('%s: %s' % (command[1], e))
if ssh_proxy: if ssh_proxy:
_add_ssh_client(p) ssh.add_client(p)
self.process = p self.process = p
if input: if input:
@ -352,7 +271,7 @@ class GitCommand(object):
try: try:
self.stdout, self.stderr = p.communicate() self.stdout, self.stderr = p.communicate()
finally: finally:
_remove_ssh_client(p) ssh.remove_client(p)
self.rc = p.wait() self.rc = p.wait()
@staticmethod @staticmethod

View File

@ -18,25 +18,17 @@ from http.client import HTTPException
import json import json
import os import os
import re import re
import signal
import ssl import ssl
import subprocess import subprocess
import sys import sys
try:
import threading as _threading
except ImportError:
import dummy_threading as _threading
import time
import urllib.error import urllib.error
import urllib.request import urllib.request
from error import GitError, UploadError from error import GitError, UploadError
import platform_utils import platform_utils
from repo_trace import Trace from repo_trace import Trace
import ssh
from git_command import GitCommand from git_command import GitCommand
from git_command import ssh_sock
from git_command import terminate_ssh_clients
from git_refs import R_CHANGES, R_HEADS, R_TAGS from git_refs import R_CHANGES, R_HEADS, R_TAGS
ID_RE = re.compile(r'^[0-9a-f]{40}$') ID_RE = re.compile(r'^[0-9a-f]{40}$')
@ -440,129 +432,6 @@ class RefSpec(object):
return s return s
_master_processes = []
_master_keys = set()
_ssh_master = True
_master_keys_lock = None
def init_ssh():
"""Should be called once at the start of repo to init ssh master handling.
At the moment, all we do is to create our lock.
"""
global _master_keys_lock
assert _master_keys_lock is None, "Should only call init_ssh once"
_master_keys_lock = _threading.Lock()
def _open_ssh(host, port=None):
global _ssh_master
# Bail before grabbing the lock if we already know that we aren't going to
# try creating new masters below.
if sys.platform in ('win32', 'cygwin'):
return False
# Acquire the lock. This is needed to prevent opening multiple masters for
# the same host when we're running "repo sync -jN" (for N > 1) _and_ the
# manifest <remote fetch="ssh://xyz"> specifies a different host from the
# one that was passed to repo init.
_master_keys_lock.acquire()
try:
# Check to see whether we already think that the master is running; if we
# think it's already running, return right away.
if port is not None:
key = '%s:%s' % (host, port)
else:
key = host
if key in _master_keys:
return True
if not _ssh_master or 'GIT_SSH' in os.environ:
# Failed earlier, so don't retry.
return False
# We will make two calls to ssh; this is the common part of both calls.
command_base = ['ssh',
'-o', 'ControlPath %s' % ssh_sock(),
host]
if port is not None:
command_base[1:1] = ['-p', str(port)]
# Since the key wasn't in _master_keys, we think that master isn't running.
# ...but before actually starting a master, we'll double-check. This can
# be important because we can't tell that that 'git@myhost.com' is the same
# as 'myhost.com' where "User git" is setup in the user's ~/.ssh/config file.
check_command = command_base + ['-O', 'check']
try:
Trace(': %s', ' '.join(check_command))
check_process = subprocess.Popen(check_command,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
check_process.communicate() # read output, but ignore it...
isnt_running = check_process.wait()
if not isnt_running:
# Our double-check found that the master _was_ infact running. Add to
# the list of keys.
_master_keys.add(key)
return True
except Exception:
# Ignore excpetions. We we will fall back to the normal command and print
# to the log there.
pass
command = command_base[:1] + ['-M', '-N'] + command_base[1:]
try:
Trace(': %s', ' '.join(command))
p = subprocess.Popen(command)
except Exception as e:
_ssh_master = False
print('\nwarn: cannot enable ssh control master for %s:%s\n%s'
% (host, port, str(e)), file=sys.stderr)
return False
time.sleep(1)
ssh_died = (p.poll() is not None)
if ssh_died:
return False
_master_processes.append(p)
_master_keys.add(key)
return True
finally:
_master_keys_lock.release()
def close_ssh():
global _master_keys_lock
terminate_ssh_clients()
for p in _master_processes:
try:
os.kill(p.pid, signal.SIGTERM)
p.wait()
except OSError:
pass
del _master_processes[:]
_master_keys.clear()
d = ssh_sock(create=False)
if d:
try:
platform_utils.rmdir(os.path.dirname(d))
except OSError:
pass
# We're done with the lock, so we can delete it.
_master_keys_lock = None
URI_SCP = re.compile(r'^([^@:]*@?[^:/]{1,}):')
URI_ALL = re.compile(r'^([a-z][a-z+-]*)://([^@/]*@?[^/]*)/') URI_ALL = re.compile(r'^([a-z][a-z+-]*)://([^@/]*@?[^/]*)/')
@ -614,27 +483,6 @@ def GetUrlCookieFile(url, quiet):
yield cookiefile, None yield cookiefile, None
def _preconnect(url):
m = URI_ALL.match(url)
if m:
scheme = m.group(1)
host = m.group(2)
if ':' in host:
host, port = host.split(':')
else:
port = None
if scheme in ('ssh', 'git+ssh', 'ssh+git'):
return _open_ssh(host, port)
return False
m = URI_SCP.match(url)
if m:
host = m.group(1)
return _open_ssh(host)
return False
class Remote(object): class Remote(object):
"""Configuration options related to a remote. """Configuration options related to a remote.
""" """
@ -673,7 +521,7 @@ class Remote(object):
def PreConnectFetch(self): def PreConnectFetch(self):
connectionUrl = self._InsteadOf() connectionUrl = self._InsteadOf()
return _preconnect(connectionUrl) return ssh.preconnect(connectionUrl)
def ReviewUrl(self, userEmail, validate_certs): def ReviewUrl(self, userEmail, validate_certs):
if self._review_url is None: if self._review_url is None:

View File

@ -39,7 +39,7 @@ from color import SetDefaultColoring
import event_log import event_log
from repo_trace import SetTrace from repo_trace import SetTrace
from git_command import user_agent from git_command import user_agent
from git_config import init_ssh, close_ssh, RepoConfig from git_config import RepoConfig
from git_trace2_event_log import EventLog from git_trace2_event_log import EventLog
from command import InteractiveCommand from command import InteractiveCommand
from command import MirrorSafeCommand from command import MirrorSafeCommand
@ -56,6 +56,7 @@ from error import RepoChangedException
import gitc_utils import gitc_utils
from manifest_xml import GitcClient, RepoClient from manifest_xml import GitcClient, RepoClient
from pager import RunPager, TerminatePager from pager import RunPager, TerminatePager
import ssh
from wrapper import WrapperPath, Wrapper from wrapper import WrapperPath, Wrapper
from subcmds import all_commands from subcmds import all_commands
@ -592,7 +593,7 @@ def _Main(argv):
repo = _Repo(opt.repodir) repo = _Repo(opt.repodir)
try: try:
try: try:
init_ssh() ssh.init()
init_http() init_http()
name, gopts, argv = repo._ParseArgs(argv) name, gopts, argv = repo._ParseArgs(argv)
run = lambda: repo._Run(name, gopts, argv) or 0 run = lambda: repo._Run(name, gopts, argv) or 0
@ -604,7 +605,7 @@ def _Main(argv):
else: else:
result = run() result = run()
finally: finally:
close_ssh() ssh.close()
except KeyboardInterrupt: except KeyboardInterrupt:
print('aborted by user', file=sys.stderr) print('aborted by user', file=sys.stderr)
result = 1 result = 1

257
ssh.py Normal file
View File

@ -0,0 +1,257 @@
# Copyright (C) 2008 The Android Open Source Project
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Common SSH management logic."""
import functools
import os
import re
import signal
import subprocess
import sys
import tempfile
try:
import threading as _threading
except ImportError:
import dummy_threading as _threading
import time
import platform_utils
from repo_trace import Trace
_ssh_proxy_path = None
_ssh_sock_path = None
_ssh_clients = []
def _run_ssh_version():
"""run ssh -V to display the version number"""
return subprocess.check_output(['ssh', '-V'], stderr=subprocess.STDOUT).decode()
def _parse_ssh_version(ver_str=None):
"""parse a ssh version string into a tuple"""
if ver_str is None:
ver_str = _run_ssh_version()
m = re.match(r'^OpenSSH_([0-9.]+)(p[0-9]+)?\s', ver_str)
if m:
return tuple(int(x) for x in m.group(1).split('.'))
else:
return ()
@functools.lru_cache(maxsize=None)
def version():
"""return ssh version as a tuple"""
try:
return _parse_ssh_version()
except subprocess.CalledProcessError:
print('fatal: unable to detect ssh version', file=sys.stderr)
sys.exit(1)
def proxy():
global _ssh_proxy_path
if _ssh_proxy_path is None:
_ssh_proxy_path = os.path.join(
os.path.dirname(__file__),
'git_ssh')
return _ssh_proxy_path
def add_client(p):
_ssh_clients.append(p)
def remove_client(p):
try:
_ssh_clients.remove(p)
except ValueError:
pass
def _terminate_clients():
global _ssh_clients
for p in _ssh_clients:
try:
os.kill(p.pid, signal.SIGTERM)
p.wait()
except OSError:
pass
_ssh_clients = []
_master_processes = []
_master_keys = set()
_ssh_master = True
_master_keys_lock = None
def init():
"""Should be called once at the start of repo to init ssh master handling.
At the moment, all we do is to create our lock.
"""
global _master_keys_lock
assert _master_keys_lock is None, "Should only call init once"
_master_keys_lock = _threading.Lock()
def _open_ssh(host, port=None):
global _ssh_master
# Bail before grabbing the lock if we already know that we aren't going to
# try creating new masters below.
if sys.platform in ('win32', 'cygwin'):
return False
# Acquire the lock. This is needed to prevent opening multiple masters for
# the same host when we're running "repo sync -jN" (for N > 1) _and_ the
# manifest <remote fetch="ssh://xyz"> specifies a different host from the
# one that was passed to repo init.
_master_keys_lock.acquire()
try:
# Check to see whether we already think that the master is running; if we
# think it's already running, return right away.
if port is not None:
key = '%s:%s' % (host, port)
else:
key = host
if key in _master_keys:
return True
if not _ssh_master or 'GIT_SSH' in os.environ:
# Failed earlier, so don't retry.
return False
# We will make two calls to ssh; this is the common part of both calls.
command_base = ['ssh',
'-o', 'ControlPath %s' % sock(),
host]
if port is not None:
command_base[1:1] = ['-p', str(port)]
# Since the key wasn't in _master_keys, we think that master isn't running.
# ...but before actually starting a master, we'll double-check. This can
# be important because we can't tell that that 'git@myhost.com' is the same
# as 'myhost.com' where "User git" is setup in the user's ~/.ssh/config file.
check_command = command_base + ['-O', 'check']
try:
Trace(': %s', ' '.join(check_command))
check_process = subprocess.Popen(check_command,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
check_process.communicate() # read output, but ignore it...
isnt_running = check_process.wait()
if not isnt_running:
# Our double-check found that the master _was_ infact running. Add to
# the list of keys.
_master_keys.add(key)
return True
except Exception:
# Ignore excpetions. We we will fall back to the normal command and print
# to the log there.
pass
command = command_base[:1] + ['-M', '-N'] + command_base[1:]
try:
Trace(': %s', ' '.join(command))
p = subprocess.Popen(command)
except Exception as e:
_ssh_master = False
print('\nwarn: cannot enable ssh control master for %s:%s\n%s'
% (host, port, str(e)), file=sys.stderr)
return False
time.sleep(1)
ssh_died = (p.poll() is not None)
if ssh_died:
return False
_master_processes.append(p)
_master_keys.add(key)
return True
finally:
_master_keys_lock.release()
def close():
global _master_keys_lock
_terminate_clients()
for p in _master_processes:
try:
os.kill(p.pid, signal.SIGTERM)
p.wait()
except OSError:
pass
del _master_processes[:]
_master_keys.clear()
d = sock(create=False)
if d:
try:
platform_utils.rmdir(os.path.dirname(d))
except OSError:
pass
# We're done with the lock, so we can delete it.
_master_keys_lock = None
URI_SCP = re.compile(r'^([^@:]*@?[^:/]{1,}):')
URI_ALL = re.compile(r'^([a-z][a-z+-]*)://([^@/]*@?[^/]*)/')
def preconnect(url):
m = URI_ALL.match(url)
if m:
scheme = m.group(1)
host = m.group(2)
if ':' in host:
host, port = host.split(':')
else:
port = None
if scheme in ('ssh', 'git+ssh', 'ssh+git'):
return _open_ssh(host, port)
return False
m = URI_SCP.match(url)
if m:
host = m.group(1)
return _open_ssh(host)
return False
def sock(create=True):
global _ssh_sock_path
if _ssh_sock_path is None:
if not create:
return None
tmp_dir = '/tmp'
if not os.path.exists(tmp_dir):
tmp_dir = tempfile.gettempdir()
if version() < (6, 7):
tokens = '%r@%h:%p'
else:
tokens = '%C' # hash of %l%h%p%r
_ssh_sock_path = os.path.join(
tempfile.mkdtemp('', 'ssh-', tmp_dir),
'master-' + tokens)
return _ssh_sock_path

View File

@ -26,38 +26,6 @@ import git_command
import wrapper import wrapper
class SSHUnitTest(unittest.TestCase):
"""Tests the ssh functions."""
def test_parse_ssh_version(self):
"""Check parse_ssh_version() handling."""
ver = git_command._parse_ssh_version('Unknown\n')
self.assertEqual(ver, ())
ver = git_command._parse_ssh_version('OpenSSH_1.0\n')
self.assertEqual(ver, (1, 0))
ver = git_command._parse_ssh_version('OpenSSH_6.6.1p1 Ubuntu-2ubuntu2.13, OpenSSL 1.0.1f 6 Jan 2014\n')
self.assertEqual(ver, (6, 6, 1))
ver = git_command._parse_ssh_version('OpenSSH_7.6p1 Ubuntu-4ubuntu0.3, OpenSSL 1.0.2n 7 Dec 2017\n')
self.assertEqual(ver, (7, 6))
def test_ssh_version(self):
"""Check ssh_version() handling."""
with mock.patch('git_command._run_ssh_version', return_value='OpenSSH_1.2\n'):
self.assertEqual(git_command.ssh_version(), (1, 2))
def test_ssh_sock(self):
"""Check ssh_sock() function."""
with mock.patch('tempfile.mkdtemp', return_value='/tmp/foo'):
# old ssh version uses port
with mock.patch('git_command.ssh_version', return_value=(6, 6)):
self.assertTrue(git_command.ssh_sock().endswith('%p'))
git_command._ssh_sock_path = None
# new ssh version uses hash
with mock.patch('git_command.ssh_version', return_value=(6, 7)):
self.assertTrue(git_command.ssh_sock().endswith('%C'))
git_command._ssh_sock_path = None
class GitCallUnitTest(unittest.TestCase): class GitCallUnitTest(unittest.TestCase):
"""Tests the _GitCall class (via git_command.git).""" """Tests the _GitCall class (via git_command.git)."""

52
tests/test_ssh.py Normal file
View File

@ -0,0 +1,52 @@
# Copyright 2019 The Android Open Source Project
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unittests for the ssh.py module."""
import unittest
from unittest import mock
import ssh
class SshTests(unittest.TestCase):
"""Tests the ssh functions."""
def test_parse_ssh_version(self):
"""Check _parse_ssh_version() handling."""
ver = ssh._parse_ssh_version('Unknown\n')
self.assertEqual(ver, ())
ver = ssh._parse_ssh_version('OpenSSH_1.0\n')
self.assertEqual(ver, (1, 0))
ver = ssh._parse_ssh_version('OpenSSH_6.6.1p1 Ubuntu-2ubuntu2.13, OpenSSL 1.0.1f 6 Jan 2014\n')
self.assertEqual(ver, (6, 6, 1))
ver = ssh._parse_ssh_version('OpenSSH_7.6p1 Ubuntu-4ubuntu0.3, OpenSSL 1.0.2n 7 Dec 2017\n')
self.assertEqual(ver, (7, 6))
def test_version(self):
"""Check version() handling."""
with mock.patch('ssh._run_ssh_version', return_value='OpenSSH_1.2\n'):
self.assertEqual(ssh.version(), (1, 2))
def test_ssh_sock(self):
"""Check sock() function."""
with mock.patch('tempfile.mkdtemp', return_value='/tmp/foo'):
# old ssh version uses port
with mock.patch('ssh.version', return_value=(6, 6)):
self.assertTrue(ssh.sock().endswith('%p'))
ssh._ssh_sock_path = None
# new ssh version uses hash
with mock.patch('ssh.version', return_value=(6, 7)):
self.assertTrue(ssh.sock().endswith('%C'))
ssh._ssh_sock_path = None