mirror of
https://gerrit.googlesource.com/git-repo
synced 2025-01-08 16:14:26 +00:00
ssh: move all ssh logic to a common place
We had ssh logic sprinkled between two git modules, and neither was quite the right home for it. This largely moves the logic as-is to its new home. We'll leave major refactoring to followup commits. Bug: https://crbug.com/gerrit/12389 Change-Id: I300a8f7dba74f2bd132232a5eb1e856a8490e0e9 Reviewed-on: https://gerrit-review.googlesource.com/c/git-repo/+/305483 Reviewed-by: Chris Mcdonald <cjmcdonald@google.com> Tested-by: Mike Frysinger <vapier@google.com>
This commit is contained in:
parent
8e768eaaa7
commit
5291eafa41
@ -14,16 +14,14 @@
|
|||||||
|
|
||||||
import functools
|
import functools
|
||||||
import os
|
import os
|
||||||
import re
|
|
||||||
import sys
|
import sys
|
||||||
import subprocess
|
import subprocess
|
||||||
import tempfile
|
|
||||||
from signal import SIGTERM
|
|
||||||
|
|
||||||
from error import GitError
|
from error import GitError
|
||||||
from git_refs import HEAD
|
from git_refs import HEAD
|
||||||
import platform_utils
|
import platform_utils
|
||||||
from repo_trace import REPO_TRACE, IsTrace, Trace
|
from repo_trace import REPO_TRACE, IsTrace, Trace
|
||||||
|
import ssh
|
||||||
from wrapper import Wrapper
|
from wrapper import Wrapper
|
||||||
|
|
||||||
GIT = 'git'
|
GIT = 'git'
|
||||||
@ -43,85 +41,6 @@ GIT_DIR = 'GIT_DIR'
|
|||||||
LAST_GITDIR = None
|
LAST_GITDIR = None
|
||||||
LAST_CWD = None
|
LAST_CWD = None
|
||||||
|
|
||||||
_ssh_proxy_path = None
|
|
||||||
_ssh_sock_path = None
|
|
||||||
_ssh_clients = []
|
|
||||||
|
|
||||||
|
|
||||||
def _run_ssh_version():
|
|
||||||
"""run ssh -V to display the version number"""
|
|
||||||
return subprocess.check_output(['ssh', '-V'], stderr=subprocess.STDOUT).decode()
|
|
||||||
|
|
||||||
|
|
||||||
def _parse_ssh_version(ver_str=None):
|
|
||||||
"""parse a ssh version string into a tuple"""
|
|
||||||
if ver_str is None:
|
|
||||||
ver_str = _run_ssh_version()
|
|
||||||
m = re.match(r'^OpenSSH_([0-9.]+)(p[0-9]+)?\s', ver_str)
|
|
||||||
if m:
|
|
||||||
return tuple(int(x) for x in m.group(1).split('.'))
|
|
||||||
else:
|
|
||||||
return ()
|
|
||||||
|
|
||||||
|
|
||||||
@functools.lru_cache(maxsize=None)
|
|
||||||
def ssh_version():
|
|
||||||
"""return ssh version as a tuple"""
|
|
||||||
try:
|
|
||||||
return _parse_ssh_version()
|
|
||||||
except subprocess.CalledProcessError:
|
|
||||||
print('fatal: unable to detect ssh version', file=sys.stderr)
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
|
|
||||||
def ssh_sock(create=True):
|
|
||||||
global _ssh_sock_path
|
|
||||||
if _ssh_sock_path is None:
|
|
||||||
if not create:
|
|
||||||
return None
|
|
||||||
tmp_dir = '/tmp'
|
|
||||||
if not os.path.exists(tmp_dir):
|
|
||||||
tmp_dir = tempfile.gettempdir()
|
|
||||||
if ssh_version() < (6, 7):
|
|
||||||
tokens = '%r@%h:%p'
|
|
||||||
else:
|
|
||||||
tokens = '%C' # hash of %l%h%p%r
|
|
||||||
_ssh_sock_path = os.path.join(
|
|
||||||
tempfile.mkdtemp('', 'ssh-', tmp_dir),
|
|
||||||
'master-' + tokens)
|
|
||||||
return _ssh_sock_path
|
|
||||||
|
|
||||||
|
|
||||||
def _ssh_proxy():
|
|
||||||
global _ssh_proxy_path
|
|
||||||
if _ssh_proxy_path is None:
|
|
||||||
_ssh_proxy_path = os.path.join(
|
|
||||||
os.path.dirname(__file__),
|
|
||||||
'git_ssh')
|
|
||||||
return _ssh_proxy_path
|
|
||||||
|
|
||||||
|
|
||||||
def _add_ssh_client(p):
|
|
||||||
_ssh_clients.append(p)
|
|
||||||
|
|
||||||
|
|
||||||
def _remove_ssh_client(p):
|
|
||||||
try:
|
|
||||||
_ssh_clients.remove(p)
|
|
||||||
except ValueError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def terminate_ssh_clients():
|
|
||||||
global _ssh_clients
|
|
||||||
for p in _ssh_clients:
|
|
||||||
try:
|
|
||||||
os.kill(p.pid, SIGTERM)
|
|
||||||
p.wait()
|
|
||||||
except OSError:
|
|
||||||
pass
|
|
||||||
_ssh_clients = []
|
|
||||||
|
|
||||||
|
|
||||||
class _GitCall(object):
|
class _GitCall(object):
|
||||||
@functools.lru_cache(maxsize=None)
|
@functools.lru_cache(maxsize=None)
|
||||||
@ -256,8 +175,8 @@ class GitCommand(object):
|
|||||||
if disable_editor:
|
if disable_editor:
|
||||||
env['GIT_EDITOR'] = ':'
|
env['GIT_EDITOR'] = ':'
|
||||||
if ssh_proxy:
|
if ssh_proxy:
|
||||||
env['REPO_SSH_SOCK'] = ssh_sock()
|
env['REPO_SSH_SOCK'] = ssh.sock()
|
||||||
env['GIT_SSH'] = _ssh_proxy()
|
env['GIT_SSH'] = ssh.proxy()
|
||||||
env['GIT_SSH_VARIANT'] = 'ssh'
|
env['GIT_SSH_VARIANT'] = 'ssh'
|
||||||
if 'http_proxy' in env and 'darwin' == sys.platform:
|
if 'http_proxy' in env and 'darwin' == sys.platform:
|
||||||
s = "'http.proxy=%s'" % (env['http_proxy'],)
|
s = "'http.proxy=%s'" % (env['http_proxy'],)
|
||||||
@ -340,7 +259,7 @@ class GitCommand(object):
|
|||||||
raise GitError('%s: %s' % (command[1], e))
|
raise GitError('%s: %s' % (command[1], e))
|
||||||
|
|
||||||
if ssh_proxy:
|
if ssh_proxy:
|
||||||
_add_ssh_client(p)
|
ssh.add_client(p)
|
||||||
|
|
||||||
self.process = p
|
self.process = p
|
||||||
if input:
|
if input:
|
||||||
@ -352,7 +271,7 @@ class GitCommand(object):
|
|||||||
try:
|
try:
|
||||||
self.stdout, self.stderr = p.communicate()
|
self.stdout, self.stderr = p.communicate()
|
||||||
finally:
|
finally:
|
||||||
_remove_ssh_client(p)
|
ssh.remove_client(p)
|
||||||
self.rc = p.wait()
|
self.rc = p.wait()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
156
git_config.py
156
git_config.py
@ -18,25 +18,17 @@ from http.client import HTTPException
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import signal
|
|
||||||
import ssl
|
import ssl
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
try:
|
|
||||||
import threading as _threading
|
|
||||||
except ImportError:
|
|
||||||
import dummy_threading as _threading
|
|
||||||
import time
|
|
||||||
import urllib.error
|
import urllib.error
|
||||||
import urllib.request
|
import urllib.request
|
||||||
|
|
||||||
from error import GitError, UploadError
|
from error import GitError, UploadError
|
||||||
import platform_utils
|
import platform_utils
|
||||||
from repo_trace import Trace
|
from repo_trace import Trace
|
||||||
|
import ssh
|
||||||
from git_command import GitCommand
|
from git_command import GitCommand
|
||||||
from git_command import ssh_sock
|
|
||||||
from git_command import terminate_ssh_clients
|
|
||||||
from git_refs import R_CHANGES, R_HEADS, R_TAGS
|
from git_refs import R_CHANGES, R_HEADS, R_TAGS
|
||||||
|
|
||||||
ID_RE = re.compile(r'^[0-9a-f]{40}$')
|
ID_RE = re.compile(r'^[0-9a-f]{40}$')
|
||||||
@ -440,129 +432,6 @@ class RefSpec(object):
|
|||||||
return s
|
return s
|
||||||
|
|
||||||
|
|
||||||
_master_processes = []
|
|
||||||
_master_keys = set()
|
|
||||||
_ssh_master = True
|
|
||||||
_master_keys_lock = None
|
|
||||||
|
|
||||||
|
|
||||||
def init_ssh():
|
|
||||||
"""Should be called once at the start of repo to init ssh master handling.
|
|
||||||
|
|
||||||
At the moment, all we do is to create our lock.
|
|
||||||
"""
|
|
||||||
global _master_keys_lock
|
|
||||||
assert _master_keys_lock is None, "Should only call init_ssh once"
|
|
||||||
_master_keys_lock = _threading.Lock()
|
|
||||||
|
|
||||||
|
|
||||||
def _open_ssh(host, port=None):
|
|
||||||
global _ssh_master
|
|
||||||
|
|
||||||
# Bail before grabbing the lock if we already know that we aren't going to
|
|
||||||
# try creating new masters below.
|
|
||||||
if sys.platform in ('win32', 'cygwin'):
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Acquire the lock. This is needed to prevent opening multiple masters for
|
|
||||||
# the same host when we're running "repo sync -jN" (for N > 1) _and_ the
|
|
||||||
# manifest <remote fetch="ssh://xyz"> specifies a different host from the
|
|
||||||
# one that was passed to repo init.
|
|
||||||
_master_keys_lock.acquire()
|
|
||||||
try:
|
|
||||||
|
|
||||||
# Check to see whether we already think that the master is running; if we
|
|
||||||
# think it's already running, return right away.
|
|
||||||
if port is not None:
|
|
||||||
key = '%s:%s' % (host, port)
|
|
||||||
else:
|
|
||||||
key = host
|
|
||||||
|
|
||||||
if key in _master_keys:
|
|
||||||
return True
|
|
||||||
|
|
||||||
if not _ssh_master or 'GIT_SSH' in os.environ:
|
|
||||||
# Failed earlier, so don't retry.
|
|
||||||
return False
|
|
||||||
|
|
||||||
# We will make two calls to ssh; this is the common part of both calls.
|
|
||||||
command_base = ['ssh',
|
|
||||||
'-o', 'ControlPath %s' % ssh_sock(),
|
|
||||||
host]
|
|
||||||
if port is not None:
|
|
||||||
command_base[1:1] = ['-p', str(port)]
|
|
||||||
|
|
||||||
# Since the key wasn't in _master_keys, we think that master isn't running.
|
|
||||||
# ...but before actually starting a master, we'll double-check. This can
|
|
||||||
# be important because we can't tell that that 'git@myhost.com' is the same
|
|
||||||
# as 'myhost.com' where "User git" is setup in the user's ~/.ssh/config file.
|
|
||||||
check_command = command_base + ['-O', 'check']
|
|
||||||
try:
|
|
||||||
Trace(': %s', ' '.join(check_command))
|
|
||||||
check_process = subprocess.Popen(check_command,
|
|
||||||
stdout=subprocess.PIPE,
|
|
||||||
stderr=subprocess.PIPE)
|
|
||||||
check_process.communicate() # read output, but ignore it...
|
|
||||||
isnt_running = check_process.wait()
|
|
||||||
|
|
||||||
if not isnt_running:
|
|
||||||
# Our double-check found that the master _was_ infact running. Add to
|
|
||||||
# the list of keys.
|
|
||||||
_master_keys.add(key)
|
|
||||||
return True
|
|
||||||
except Exception:
|
|
||||||
# Ignore excpetions. We we will fall back to the normal command and print
|
|
||||||
# to the log there.
|
|
||||||
pass
|
|
||||||
|
|
||||||
command = command_base[:1] + ['-M', '-N'] + command_base[1:]
|
|
||||||
try:
|
|
||||||
Trace(': %s', ' '.join(command))
|
|
||||||
p = subprocess.Popen(command)
|
|
||||||
except Exception as e:
|
|
||||||
_ssh_master = False
|
|
||||||
print('\nwarn: cannot enable ssh control master for %s:%s\n%s'
|
|
||||||
% (host, port, str(e)), file=sys.stderr)
|
|
||||||
return False
|
|
||||||
|
|
||||||
time.sleep(1)
|
|
||||||
ssh_died = (p.poll() is not None)
|
|
||||||
if ssh_died:
|
|
||||||
return False
|
|
||||||
|
|
||||||
_master_processes.append(p)
|
|
||||||
_master_keys.add(key)
|
|
||||||
return True
|
|
||||||
finally:
|
|
||||||
_master_keys_lock.release()
|
|
||||||
|
|
||||||
|
|
||||||
def close_ssh():
|
|
||||||
global _master_keys_lock
|
|
||||||
|
|
||||||
terminate_ssh_clients()
|
|
||||||
|
|
||||||
for p in _master_processes:
|
|
||||||
try:
|
|
||||||
os.kill(p.pid, signal.SIGTERM)
|
|
||||||
p.wait()
|
|
||||||
except OSError:
|
|
||||||
pass
|
|
||||||
del _master_processes[:]
|
|
||||||
_master_keys.clear()
|
|
||||||
|
|
||||||
d = ssh_sock(create=False)
|
|
||||||
if d:
|
|
||||||
try:
|
|
||||||
platform_utils.rmdir(os.path.dirname(d))
|
|
||||||
except OSError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# We're done with the lock, so we can delete it.
|
|
||||||
_master_keys_lock = None
|
|
||||||
|
|
||||||
|
|
||||||
URI_SCP = re.compile(r'^([^@:]*@?[^:/]{1,}):')
|
|
||||||
URI_ALL = re.compile(r'^([a-z][a-z+-]*)://([^@/]*@?[^/]*)/')
|
URI_ALL = re.compile(r'^([a-z][a-z+-]*)://([^@/]*@?[^/]*)/')
|
||||||
|
|
||||||
|
|
||||||
@ -614,27 +483,6 @@ def GetUrlCookieFile(url, quiet):
|
|||||||
yield cookiefile, None
|
yield cookiefile, None
|
||||||
|
|
||||||
|
|
||||||
def _preconnect(url):
|
|
||||||
m = URI_ALL.match(url)
|
|
||||||
if m:
|
|
||||||
scheme = m.group(1)
|
|
||||||
host = m.group(2)
|
|
||||||
if ':' in host:
|
|
||||||
host, port = host.split(':')
|
|
||||||
else:
|
|
||||||
port = None
|
|
||||||
if scheme in ('ssh', 'git+ssh', 'ssh+git'):
|
|
||||||
return _open_ssh(host, port)
|
|
||||||
return False
|
|
||||||
|
|
||||||
m = URI_SCP.match(url)
|
|
||||||
if m:
|
|
||||||
host = m.group(1)
|
|
||||||
return _open_ssh(host)
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
class Remote(object):
|
class Remote(object):
|
||||||
"""Configuration options related to a remote.
|
"""Configuration options related to a remote.
|
||||||
"""
|
"""
|
||||||
@ -673,7 +521,7 @@ class Remote(object):
|
|||||||
|
|
||||||
def PreConnectFetch(self):
|
def PreConnectFetch(self):
|
||||||
connectionUrl = self._InsteadOf()
|
connectionUrl = self._InsteadOf()
|
||||||
return _preconnect(connectionUrl)
|
return ssh.preconnect(connectionUrl)
|
||||||
|
|
||||||
def ReviewUrl(self, userEmail, validate_certs):
|
def ReviewUrl(self, userEmail, validate_certs):
|
||||||
if self._review_url is None:
|
if self._review_url is None:
|
||||||
|
7
main.py
7
main.py
@ -39,7 +39,7 @@ from color import SetDefaultColoring
|
|||||||
import event_log
|
import event_log
|
||||||
from repo_trace import SetTrace
|
from repo_trace import SetTrace
|
||||||
from git_command import user_agent
|
from git_command import user_agent
|
||||||
from git_config import init_ssh, close_ssh, RepoConfig
|
from git_config import RepoConfig
|
||||||
from git_trace2_event_log import EventLog
|
from git_trace2_event_log import EventLog
|
||||||
from command import InteractiveCommand
|
from command import InteractiveCommand
|
||||||
from command import MirrorSafeCommand
|
from command import MirrorSafeCommand
|
||||||
@ -56,6 +56,7 @@ from error import RepoChangedException
|
|||||||
import gitc_utils
|
import gitc_utils
|
||||||
from manifest_xml import GitcClient, RepoClient
|
from manifest_xml import GitcClient, RepoClient
|
||||||
from pager import RunPager, TerminatePager
|
from pager import RunPager, TerminatePager
|
||||||
|
import ssh
|
||||||
from wrapper import WrapperPath, Wrapper
|
from wrapper import WrapperPath, Wrapper
|
||||||
|
|
||||||
from subcmds import all_commands
|
from subcmds import all_commands
|
||||||
@ -592,7 +593,7 @@ def _Main(argv):
|
|||||||
repo = _Repo(opt.repodir)
|
repo = _Repo(opt.repodir)
|
||||||
try:
|
try:
|
||||||
try:
|
try:
|
||||||
init_ssh()
|
ssh.init()
|
||||||
init_http()
|
init_http()
|
||||||
name, gopts, argv = repo._ParseArgs(argv)
|
name, gopts, argv = repo._ParseArgs(argv)
|
||||||
run = lambda: repo._Run(name, gopts, argv) or 0
|
run = lambda: repo._Run(name, gopts, argv) or 0
|
||||||
@ -604,7 +605,7 @@ def _Main(argv):
|
|||||||
else:
|
else:
|
||||||
result = run()
|
result = run()
|
||||||
finally:
|
finally:
|
||||||
close_ssh()
|
ssh.close()
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
print('aborted by user', file=sys.stderr)
|
print('aborted by user', file=sys.stderr)
|
||||||
result = 1
|
result = 1
|
||||||
|
257
ssh.py
Normal file
257
ssh.py
Normal file
@ -0,0 +1,257 @@
|
|||||||
|
# Copyright (C) 2008 The Android Open Source Project
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""Common SSH management logic."""
|
||||||
|
|
||||||
|
import functools
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import signal
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
import tempfile
|
||||||
|
try:
|
||||||
|
import threading as _threading
|
||||||
|
except ImportError:
|
||||||
|
import dummy_threading as _threading
|
||||||
|
import time
|
||||||
|
|
||||||
|
import platform_utils
|
||||||
|
from repo_trace import Trace
|
||||||
|
|
||||||
|
|
||||||
|
_ssh_proxy_path = None
|
||||||
|
_ssh_sock_path = None
|
||||||
|
_ssh_clients = []
|
||||||
|
|
||||||
|
|
||||||
|
def _run_ssh_version():
|
||||||
|
"""run ssh -V to display the version number"""
|
||||||
|
return subprocess.check_output(['ssh', '-V'], stderr=subprocess.STDOUT).decode()
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_ssh_version(ver_str=None):
|
||||||
|
"""parse a ssh version string into a tuple"""
|
||||||
|
if ver_str is None:
|
||||||
|
ver_str = _run_ssh_version()
|
||||||
|
m = re.match(r'^OpenSSH_([0-9.]+)(p[0-9]+)?\s', ver_str)
|
||||||
|
if m:
|
||||||
|
return tuple(int(x) for x in m.group(1).split('.'))
|
||||||
|
else:
|
||||||
|
return ()
|
||||||
|
|
||||||
|
|
||||||
|
@functools.lru_cache(maxsize=None)
|
||||||
|
def version():
|
||||||
|
"""return ssh version as a tuple"""
|
||||||
|
try:
|
||||||
|
return _parse_ssh_version()
|
||||||
|
except subprocess.CalledProcessError:
|
||||||
|
print('fatal: unable to detect ssh version', file=sys.stderr)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
def proxy():
|
||||||
|
global _ssh_proxy_path
|
||||||
|
if _ssh_proxy_path is None:
|
||||||
|
_ssh_proxy_path = os.path.join(
|
||||||
|
os.path.dirname(__file__),
|
||||||
|
'git_ssh')
|
||||||
|
return _ssh_proxy_path
|
||||||
|
|
||||||
|
|
||||||
|
def add_client(p):
|
||||||
|
_ssh_clients.append(p)
|
||||||
|
|
||||||
|
|
||||||
|
def remove_client(p):
|
||||||
|
try:
|
||||||
|
_ssh_clients.remove(p)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def _terminate_clients():
|
||||||
|
global _ssh_clients
|
||||||
|
for p in _ssh_clients:
|
||||||
|
try:
|
||||||
|
os.kill(p.pid, signal.SIGTERM)
|
||||||
|
p.wait()
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
_ssh_clients = []
|
||||||
|
|
||||||
|
|
||||||
|
_master_processes = []
|
||||||
|
_master_keys = set()
|
||||||
|
_ssh_master = True
|
||||||
|
_master_keys_lock = None
|
||||||
|
|
||||||
|
|
||||||
|
def init():
|
||||||
|
"""Should be called once at the start of repo to init ssh master handling.
|
||||||
|
|
||||||
|
At the moment, all we do is to create our lock.
|
||||||
|
"""
|
||||||
|
global _master_keys_lock
|
||||||
|
assert _master_keys_lock is None, "Should only call init once"
|
||||||
|
_master_keys_lock = _threading.Lock()
|
||||||
|
|
||||||
|
|
||||||
|
def _open_ssh(host, port=None):
|
||||||
|
global _ssh_master
|
||||||
|
|
||||||
|
# Bail before grabbing the lock if we already know that we aren't going to
|
||||||
|
# try creating new masters below.
|
||||||
|
if sys.platform in ('win32', 'cygwin'):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Acquire the lock. This is needed to prevent opening multiple masters for
|
||||||
|
# the same host when we're running "repo sync -jN" (for N > 1) _and_ the
|
||||||
|
# manifest <remote fetch="ssh://xyz"> specifies a different host from the
|
||||||
|
# one that was passed to repo init.
|
||||||
|
_master_keys_lock.acquire()
|
||||||
|
try:
|
||||||
|
|
||||||
|
# Check to see whether we already think that the master is running; if we
|
||||||
|
# think it's already running, return right away.
|
||||||
|
if port is not None:
|
||||||
|
key = '%s:%s' % (host, port)
|
||||||
|
else:
|
||||||
|
key = host
|
||||||
|
|
||||||
|
if key in _master_keys:
|
||||||
|
return True
|
||||||
|
|
||||||
|
if not _ssh_master or 'GIT_SSH' in os.environ:
|
||||||
|
# Failed earlier, so don't retry.
|
||||||
|
return False
|
||||||
|
|
||||||
|
# We will make two calls to ssh; this is the common part of both calls.
|
||||||
|
command_base = ['ssh',
|
||||||
|
'-o', 'ControlPath %s' % sock(),
|
||||||
|
host]
|
||||||
|
if port is not None:
|
||||||
|
command_base[1:1] = ['-p', str(port)]
|
||||||
|
|
||||||
|
# Since the key wasn't in _master_keys, we think that master isn't running.
|
||||||
|
# ...but before actually starting a master, we'll double-check. This can
|
||||||
|
# be important because we can't tell that that 'git@myhost.com' is the same
|
||||||
|
# as 'myhost.com' where "User git" is setup in the user's ~/.ssh/config file.
|
||||||
|
check_command = command_base + ['-O', 'check']
|
||||||
|
try:
|
||||||
|
Trace(': %s', ' '.join(check_command))
|
||||||
|
check_process = subprocess.Popen(check_command,
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
stderr=subprocess.PIPE)
|
||||||
|
check_process.communicate() # read output, but ignore it...
|
||||||
|
isnt_running = check_process.wait()
|
||||||
|
|
||||||
|
if not isnt_running:
|
||||||
|
# Our double-check found that the master _was_ infact running. Add to
|
||||||
|
# the list of keys.
|
||||||
|
_master_keys.add(key)
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
# Ignore excpetions. We we will fall back to the normal command and print
|
||||||
|
# to the log there.
|
||||||
|
pass
|
||||||
|
|
||||||
|
command = command_base[:1] + ['-M', '-N'] + command_base[1:]
|
||||||
|
try:
|
||||||
|
Trace(': %s', ' '.join(command))
|
||||||
|
p = subprocess.Popen(command)
|
||||||
|
except Exception as e:
|
||||||
|
_ssh_master = False
|
||||||
|
print('\nwarn: cannot enable ssh control master for %s:%s\n%s'
|
||||||
|
% (host, port, str(e)), file=sys.stderr)
|
||||||
|
return False
|
||||||
|
|
||||||
|
time.sleep(1)
|
||||||
|
ssh_died = (p.poll() is not None)
|
||||||
|
if ssh_died:
|
||||||
|
return False
|
||||||
|
|
||||||
|
_master_processes.append(p)
|
||||||
|
_master_keys.add(key)
|
||||||
|
return True
|
||||||
|
finally:
|
||||||
|
_master_keys_lock.release()
|
||||||
|
|
||||||
|
|
||||||
|
def close():
|
||||||
|
global _master_keys_lock
|
||||||
|
|
||||||
|
_terminate_clients()
|
||||||
|
|
||||||
|
for p in _master_processes:
|
||||||
|
try:
|
||||||
|
os.kill(p.pid, signal.SIGTERM)
|
||||||
|
p.wait()
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
del _master_processes[:]
|
||||||
|
_master_keys.clear()
|
||||||
|
|
||||||
|
d = sock(create=False)
|
||||||
|
if d:
|
||||||
|
try:
|
||||||
|
platform_utils.rmdir(os.path.dirname(d))
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# We're done with the lock, so we can delete it.
|
||||||
|
_master_keys_lock = None
|
||||||
|
|
||||||
|
|
||||||
|
URI_SCP = re.compile(r'^([^@:]*@?[^:/]{1,}):')
|
||||||
|
URI_ALL = re.compile(r'^([a-z][a-z+-]*)://([^@/]*@?[^/]*)/')
|
||||||
|
|
||||||
|
|
||||||
|
def preconnect(url):
|
||||||
|
m = URI_ALL.match(url)
|
||||||
|
if m:
|
||||||
|
scheme = m.group(1)
|
||||||
|
host = m.group(2)
|
||||||
|
if ':' in host:
|
||||||
|
host, port = host.split(':')
|
||||||
|
else:
|
||||||
|
port = None
|
||||||
|
if scheme in ('ssh', 'git+ssh', 'ssh+git'):
|
||||||
|
return _open_ssh(host, port)
|
||||||
|
return False
|
||||||
|
|
||||||
|
m = URI_SCP.match(url)
|
||||||
|
if m:
|
||||||
|
host = m.group(1)
|
||||||
|
return _open_ssh(host)
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def sock(create=True):
|
||||||
|
global _ssh_sock_path
|
||||||
|
if _ssh_sock_path is None:
|
||||||
|
if not create:
|
||||||
|
return None
|
||||||
|
tmp_dir = '/tmp'
|
||||||
|
if not os.path.exists(tmp_dir):
|
||||||
|
tmp_dir = tempfile.gettempdir()
|
||||||
|
if version() < (6, 7):
|
||||||
|
tokens = '%r@%h:%p'
|
||||||
|
else:
|
||||||
|
tokens = '%C' # hash of %l%h%p%r
|
||||||
|
_ssh_sock_path = os.path.join(
|
||||||
|
tempfile.mkdtemp('', 'ssh-', tmp_dir),
|
||||||
|
'master-' + tokens)
|
||||||
|
return _ssh_sock_path
|
@ -26,38 +26,6 @@ import git_command
|
|||||||
import wrapper
|
import wrapper
|
||||||
|
|
||||||
|
|
||||||
class SSHUnitTest(unittest.TestCase):
|
|
||||||
"""Tests the ssh functions."""
|
|
||||||
|
|
||||||
def test_parse_ssh_version(self):
|
|
||||||
"""Check parse_ssh_version() handling."""
|
|
||||||
ver = git_command._parse_ssh_version('Unknown\n')
|
|
||||||
self.assertEqual(ver, ())
|
|
||||||
ver = git_command._parse_ssh_version('OpenSSH_1.0\n')
|
|
||||||
self.assertEqual(ver, (1, 0))
|
|
||||||
ver = git_command._parse_ssh_version('OpenSSH_6.6.1p1 Ubuntu-2ubuntu2.13, OpenSSL 1.0.1f 6 Jan 2014\n')
|
|
||||||
self.assertEqual(ver, (6, 6, 1))
|
|
||||||
ver = git_command._parse_ssh_version('OpenSSH_7.6p1 Ubuntu-4ubuntu0.3, OpenSSL 1.0.2n 7 Dec 2017\n')
|
|
||||||
self.assertEqual(ver, (7, 6))
|
|
||||||
|
|
||||||
def test_ssh_version(self):
|
|
||||||
"""Check ssh_version() handling."""
|
|
||||||
with mock.patch('git_command._run_ssh_version', return_value='OpenSSH_1.2\n'):
|
|
||||||
self.assertEqual(git_command.ssh_version(), (1, 2))
|
|
||||||
|
|
||||||
def test_ssh_sock(self):
|
|
||||||
"""Check ssh_sock() function."""
|
|
||||||
with mock.patch('tempfile.mkdtemp', return_value='/tmp/foo'):
|
|
||||||
# old ssh version uses port
|
|
||||||
with mock.patch('git_command.ssh_version', return_value=(6, 6)):
|
|
||||||
self.assertTrue(git_command.ssh_sock().endswith('%p'))
|
|
||||||
git_command._ssh_sock_path = None
|
|
||||||
# new ssh version uses hash
|
|
||||||
with mock.patch('git_command.ssh_version', return_value=(6, 7)):
|
|
||||||
self.assertTrue(git_command.ssh_sock().endswith('%C'))
|
|
||||||
git_command._ssh_sock_path = None
|
|
||||||
|
|
||||||
|
|
||||||
class GitCallUnitTest(unittest.TestCase):
|
class GitCallUnitTest(unittest.TestCase):
|
||||||
"""Tests the _GitCall class (via git_command.git)."""
|
"""Tests the _GitCall class (via git_command.git)."""
|
||||||
|
|
||||||
|
52
tests/test_ssh.py
Normal file
52
tests/test_ssh.py
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
# Copyright 2019 The Android Open Source Project
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""Unittests for the ssh.py module."""
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
|
import ssh
|
||||||
|
|
||||||
|
|
||||||
|
class SshTests(unittest.TestCase):
|
||||||
|
"""Tests the ssh functions."""
|
||||||
|
|
||||||
|
def test_parse_ssh_version(self):
|
||||||
|
"""Check _parse_ssh_version() handling."""
|
||||||
|
ver = ssh._parse_ssh_version('Unknown\n')
|
||||||
|
self.assertEqual(ver, ())
|
||||||
|
ver = ssh._parse_ssh_version('OpenSSH_1.0\n')
|
||||||
|
self.assertEqual(ver, (1, 0))
|
||||||
|
ver = ssh._parse_ssh_version('OpenSSH_6.6.1p1 Ubuntu-2ubuntu2.13, OpenSSL 1.0.1f 6 Jan 2014\n')
|
||||||
|
self.assertEqual(ver, (6, 6, 1))
|
||||||
|
ver = ssh._parse_ssh_version('OpenSSH_7.6p1 Ubuntu-4ubuntu0.3, OpenSSL 1.0.2n 7 Dec 2017\n')
|
||||||
|
self.assertEqual(ver, (7, 6))
|
||||||
|
|
||||||
|
def test_version(self):
|
||||||
|
"""Check version() handling."""
|
||||||
|
with mock.patch('ssh._run_ssh_version', return_value='OpenSSH_1.2\n'):
|
||||||
|
self.assertEqual(ssh.version(), (1, 2))
|
||||||
|
|
||||||
|
def test_ssh_sock(self):
|
||||||
|
"""Check sock() function."""
|
||||||
|
with mock.patch('tempfile.mkdtemp', return_value='/tmp/foo'):
|
||||||
|
# old ssh version uses port
|
||||||
|
with mock.patch('ssh.version', return_value=(6, 6)):
|
||||||
|
self.assertTrue(ssh.sock().endswith('%p'))
|
||||||
|
ssh._ssh_sock_path = None
|
||||||
|
# new ssh version uses hash
|
||||||
|
with mock.patch('ssh.version', return_value=(6, 7)):
|
||||||
|
self.assertTrue(ssh.sock().endswith('%C'))
|
||||||
|
ssh._ssh_sock_path = None
|
Loading…
Reference in New Issue
Block a user