# Copyright (C) 2008 The Android Open Source Project # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Common SSH management logic.""" import functools import os import re import signal import subprocess import sys import tempfile try: import threading as _threading except ImportError: import dummy_threading as _threading import time import platform_utils from repo_trace import Trace _ssh_proxy_path = None _ssh_sock_path = None _ssh_clients = [] def _run_ssh_version(): """run ssh -V to display the version number""" return subprocess.check_output(['ssh', '-V'], stderr=subprocess.STDOUT).decode() def _parse_ssh_version(ver_str=None): """parse a ssh version string into a tuple""" if ver_str is None: ver_str = _run_ssh_version() m = re.match(r'^OpenSSH_([0-9.]+)(p[0-9]+)?\s', ver_str) if m: return tuple(int(x) for x in m.group(1).split('.')) else: return () @functools.lru_cache(maxsize=None) def version(): """return ssh version as a tuple""" try: return _parse_ssh_version() except subprocess.CalledProcessError: print('fatal: unable to detect ssh version', file=sys.stderr) sys.exit(1) def proxy(): global _ssh_proxy_path if _ssh_proxy_path is None: _ssh_proxy_path = os.path.join( os.path.dirname(__file__), 'git_ssh') return _ssh_proxy_path def add_client(p): _ssh_clients.append(p) def remove_client(p): try: _ssh_clients.remove(p) except ValueError: pass def _terminate_clients(): global _ssh_clients for p in _ssh_clients: try: os.kill(p.pid, signal.SIGTERM) p.wait() except OSError: pass _ssh_clients = [] _master_processes = [] _master_keys = set() _ssh_master = True _master_keys_lock = None def init(): """Should be called once at the start of repo to init ssh master handling. At the moment, all we do is to create our lock. """ global _master_keys_lock assert _master_keys_lock is None, "Should only call init once" _master_keys_lock = _threading.Lock() def _open_ssh(host, port=None): global _ssh_master # Bail before grabbing the lock if we already know that we aren't going to # try creating new masters below. if sys.platform in ('win32', 'cygwin'): return False # Acquire the lock. This is needed to prevent opening multiple masters for # the same host when we're running "repo sync -jN" (for N > 1) _and_ the # manifest specifies a different host from the # one that was passed to repo init. _master_keys_lock.acquire() try: # Check to see whether we already think that the master is running; if we # think it's already running, return right away. if port is not None: key = '%s:%s' % (host, port) else: key = host if key in _master_keys: return True if not _ssh_master or 'GIT_SSH' in os.environ: # Failed earlier, so don't retry. return False # We will make two calls to ssh; this is the common part of both calls. command_base = ['ssh', '-o', 'ControlPath %s' % sock(), host] if port is not None: command_base[1:1] = ['-p', str(port)] # Since the key wasn't in _master_keys, we think that master isn't running. # ...but before actually starting a master, we'll double-check. This can # be important because we can't tell that that 'git@myhost.com' is the same # as 'myhost.com' where "User git" is setup in the user's ~/.ssh/config file. check_command = command_base + ['-O', 'check'] try: Trace(': %s', ' '.join(check_command)) check_process = subprocess.Popen(check_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE) check_process.communicate() # read output, but ignore it... isnt_running = check_process.wait() if not isnt_running: # Our double-check found that the master _was_ infact running. Add to # the list of keys. _master_keys.add(key) return True except Exception: # Ignore excpetions. We we will fall back to the normal command and print # to the log there. pass command = command_base[:1] + ['-M', '-N'] + command_base[1:] try: Trace(': %s', ' '.join(command)) p = subprocess.Popen(command) except Exception as e: _ssh_master = False print('\nwarn: cannot enable ssh control master for %s:%s\n%s' % (host, port, str(e)), file=sys.stderr) return False time.sleep(1) ssh_died = (p.poll() is not None) if ssh_died: return False _master_processes.append(p) _master_keys.add(key) return True finally: _master_keys_lock.release() def close(): global _master_keys_lock _terminate_clients() for p in _master_processes: try: os.kill(p.pid, signal.SIGTERM) p.wait() except OSError: pass del _master_processes[:] _master_keys.clear() d = sock(create=False) if d: try: platform_utils.rmdir(os.path.dirname(d)) except OSError: pass # We're done with the lock, so we can delete it. _master_keys_lock = None URI_SCP = re.compile(r'^([^@:]*@?[^:/]{1,}):') URI_ALL = re.compile(r'^([a-z][a-z+-]*)://([^@/]*@?[^/]*)/') def preconnect(url): m = URI_ALL.match(url) if m: scheme = m.group(1) host = m.group(2) if ':' in host: host, port = host.split(':') else: port = None if scheme in ('ssh', 'git+ssh', 'ssh+git'): return _open_ssh(host, port) return False m = URI_SCP.match(url) if m: host = m.group(1) return _open_ssh(host) return False def sock(create=True): global _ssh_sock_path if _ssh_sock_path is None: if not create: return None tmp_dir = '/tmp' if not os.path.exists(tmp_dir): tmp_dir = tempfile.gettempdir() if version() < (6, 7): tokens = '%r@%h:%p' else: tokens = '%C' # hash of %l%h%p%r _ssh_sock_path = os.path.join( tempfile.mkdtemp('', 'ssh-', tmp_dir), 'master-' + tokens) return _ssh_sock_path