diff --git a/git_command.py b/git_command.py index fabad0e0..04953f38 100644 --- a/git_command.py +++ b/git_command.py @@ -21,7 +21,6 @@ from error import GitError from git_refs import HEAD import platform_utils from repo_trace import REPO_TRACE, IsTrace, Trace -import ssh from wrapper import Wrapper GIT = 'git' @@ -167,7 +166,7 @@ class GitCommand(object): capture_stderr=False, merge_output=False, disable_editor=False, - ssh_proxy=False, + ssh_proxy=None, cwd=None, gitdir=None): env = self._GetBasicEnv() @@ -175,8 +174,8 @@ class GitCommand(object): if disable_editor: env['GIT_EDITOR'] = ':' if ssh_proxy: - env['REPO_SSH_SOCK'] = ssh.sock() - env['GIT_SSH'] = ssh.proxy() + env['REPO_SSH_SOCK'] = ssh_proxy.sock() + env['GIT_SSH'] = ssh_proxy.proxy env['GIT_SSH_VARIANT'] = 'ssh' if 'http_proxy' in env and 'darwin' == sys.platform: s = "'http.proxy=%s'" % (env['http_proxy'],) @@ -259,7 +258,7 @@ class GitCommand(object): raise GitError('%s: %s' % (command[1], e)) if ssh_proxy: - ssh.add_client(p) + ssh_proxy.add_client(p) self.process = p if input: @@ -271,7 +270,8 @@ class GitCommand(object): try: self.stdout, self.stderr = p.communicate() finally: - ssh.remove_client(p) + if ssh_proxy: + ssh_proxy.remove_client(p) self.rc = p.wait() @staticmethod diff --git a/git_config.py b/git_config.py index d7fef8ca..978f6a59 100644 --- a/git_config.py +++ b/git_config.py @@ -27,7 +27,6 @@ import urllib.request from error import GitError, UploadError import platform_utils from repo_trace import Trace -import ssh from git_command import GitCommand from git_refs import R_CHANGES, R_HEADS, R_TAGS @@ -519,17 +518,23 @@ class Remote(object): return self.url.replace(longest, longestUrl, 1) - def PreConnectFetch(self): + def PreConnectFetch(self, ssh_proxy): """Run any setup for this remote before we connect to it. In practice, if the remote is using SSH, we'll attempt to create a new SSH master session to it for reuse across projects. + Args: + ssh_proxy: The SSH settings for managing master sessions. + Returns: Whether the preconnect phase for this remote was successful. """ + if not ssh_proxy: + return True + connectionUrl = self._InsteadOf() - return ssh.preconnect(connectionUrl) + return ssh_proxy.preconnect(connectionUrl) def ReviewUrl(self, userEmail, validate_certs): if self._review_url is None: diff --git a/project.py b/project.py index 37558061..2f83d796 100644 --- a/project.py +++ b/project.py @@ -2045,8 +2045,8 @@ class Project(object): name = self.remote.name remote = self.GetRemote(name) - if not remote.PreConnectFetch(): - ssh_proxy = False + if not remote.PreConnectFetch(ssh_proxy): + ssh_proxy = None if initial: if alt_dir and 'objects' == os.path.basename(alt_dir): diff --git a/ssh.py b/ssh.py index d06c4eb2..0ae8d120 100644 --- a/ssh.py +++ b/ssh.py @@ -15,25 +15,20 @@ """Common SSH management logic.""" import functools +import multiprocessing import os import re import signal import subprocess import sys import tempfile -try: - import threading as _threading -except ImportError: - import dummy_threading as _threading import time import platform_utils from repo_trace import Trace -_ssh_proxy_path = None -_ssh_sock_path = None -_ssh_clients = [] +PROXY_PATH = os.path.join(os.path.dirname(__file__), 'git_ssh') def _run_ssh_version(): @@ -62,68 +57,104 @@ def version(): sys.exit(1) -def proxy(): - global _ssh_proxy_path - if _ssh_proxy_path is None: - _ssh_proxy_path = os.path.join( - os.path.dirname(__file__), - 'git_ssh') - return _ssh_proxy_path +URI_SCP = re.compile(r'^([^@:]*@?[^:/]{1,}):') +URI_ALL = re.compile(r'^([a-z][a-z+-]*)://([^@/]*@?[^/]*)/') -def add_client(p): - _ssh_clients.append(p) +class ProxyManager: + """Manage various ssh clients & masters that we spawn. + This will take care of sharing state between multiprocessing children, and + make sure that if we crash, we don't leak any of the ssh sessions. -def remove_client(p): - try: - _ssh_clients.remove(p) - except ValueError: - pass - - -def _terminate_clients(): - global _ssh_clients - for p in _ssh_clients: - try: - os.kill(p.pid, signal.SIGTERM) - p.wait() - except OSError: - pass - _ssh_clients = [] - - -_master_processes = [] -_master_keys = set() -_ssh_master = True -_master_keys_lock = None - - -def init(): - """Should be called once at the start of repo to init ssh master handling. - - At the moment, all we do is to create our lock. + The code should work with a single-process scenario too, and not add too much + overhead due to the manager. """ - global _master_keys_lock - assert _master_keys_lock is None, "Should only call init once" - _master_keys_lock = _threading.Lock() + # Path to the ssh program to run which will pass our master settings along. + # Set here more as a convenience API. + proxy = PROXY_PATH -def _open_ssh(host, port=None): - global _ssh_master + def __init__(self, manager): + # Protect access to the list of active masters. + self._lock = multiprocessing.Lock() + # List of active masters (pid). These will be spawned on demand, and we are + # responsible for shutting them all down at the end. + self._masters = manager.list() + # Set of active masters indexed by "host:port" information. + # The value isn't used, but multiprocessing doesn't provide a set class. + self._master_keys = manager.dict() + # Whether ssh masters are known to be broken, so we give up entirely. + self._master_broken = manager.Value('b', False) + # List of active ssh sesssions. Clients will be added & removed as + # connections finish, so this list is just for safety & cleanup if we crash. + self._clients = manager.list() + # Path to directory for holding master sockets. + self._sock_path = None - # Bail before grabbing the lock if we already know that we aren't going to - # try creating new masters below. - if sys.platform in ('win32', 'cygwin'): - return False + def __enter__(self): + """Enter a new context.""" + return self - # Acquire the lock. This is needed to prevent opening multiple masters for - # the same host when we're running "repo sync -jN" (for N > 1) _and_ the - # manifest specifies a different host from the - # one that was passed to repo init. - _master_keys_lock.acquire() - try: + def __exit__(self, exc_type, exc_value, traceback): + """Exit a context & clean up all resources.""" + self.close() + def add_client(self, proc): + """Track a new ssh session.""" + self._clients.append(proc.pid) + + def remove_client(self, proc): + """Remove a completed ssh session.""" + try: + self._clients.remove(proc.pid) + except ValueError: + pass + + def add_master(self, proc): + """Track a new master connection.""" + self._masters.append(proc.pid) + + def _terminate(self, procs): + """Kill all |procs|.""" + for pid in procs: + try: + os.kill(pid, signal.SIGTERM) + os.waitpid(pid, 0) + except OSError: + pass + + # The multiprocessing.list() API doesn't provide many standard list() + # methods, so we have to manually clear the list. + while True: + try: + procs.pop(0) + except: + break + + def close(self): + """Close this active ssh session. + + Kill all ssh clients & masters we created, and nuke the socket dir. + """ + self._terminate(self._clients) + self._terminate(self._masters) + + d = self.sock(create=False) + if d: + try: + platform_utils.rmdir(os.path.dirname(d)) + except OSError: + pass + + def _open_unlocked(self, host, port=None): + """Make sure a ssh master session exists for |host| & |port|. + + If one doesn't exist already, we'll create it. + + We won't grab any locks, so the caller has to do that. This helps keep the + business logic of actually creating the master separate from grabbing locks. + """ # Check to see whether we already think that the master is running; if we # think it's already running, return right away. if port is not None: @@ -131,17 +162,15 @@ def _open_ssh(host, port=None): else: key = host - if key in _master_keys: + if key in self._master_keys: return True - if not _ssh_master or 'GIT_SSH' in os.environ: + if self._master_broken.value or 'GIT_SSH' in os.environ: # Failed earlier, so don't retry. return False # We will make two calls to ssh; this is the common part of both calls. - command_base = ['ssh', - '-o', 'ControlPath %s' % sock(), - host] + command_base = ['ssh', '-o', 'ControlPath %s' % self.sock(), host] if port is not None: command_base[1:1] = ['-p', str(port)] @@ -161,7 +190,7 @@ def _open_ssh(host, port=None): if not isnt_running: # Our double-check found that the master _was_ infact running. Add to # the list of keys. - _master_keys.add(key) + self._master_keys[key] = True return True except Exception: # Ignore excpetions. We we will fall back to the normal command and print @@ -173,7 +202,7 @@ def _open_ssh(host, port=None): Trace(': %s', ' '.join(command)) p = subprocess.Popen(command) except Exception as e: - _ssh_master = False + self._master_broken.value = True print('\nwarn: cannot enable ssh control master for %s:%s\n%s' % (host, port, str(e)), file=sys.stderr) return False @@ -183,75 +212,66 @@ def _open_ssh(host, port=None): if ssh_died: return False - _master_processes.append(p) - _master_keys.add(key) + self.add_master(p) + self._master_keys[key] = True return True - finally: - _master_keys_lock.release() + def _open(self, host, port=None): + """Make sure a ssh master session exists for |host| & |port|. -def close(): - global _master_keys_lock + If one doesn't exist already, we'll create it. - _terminate_clients() + This will obtain any necessary locks to avoid inter-process races. + """ + # Bail before grabbing the lock if we already know that we aren't going to + # try creating new masters below. + if sys.platform in ('win32', 'cygwin'): + return False - for p in _master_processes: - try: - os.kill(p.pid, signal.SIGTERM) - p.wait() - except OSError: - pass - del _master_processes[:] - _master_keys.clear() + # Acquire the lock. This is needed to prevent opening multiple masters for + # the same host when we're running "repo sync -jN" (for N > 1) _and_ the + # manifest specifies a different host from the + # one that was passed to repo init. + with self._lock: + return self._open_unlocked(host, port) - d = sock(create=False) - if d: - try: - platform_utils.rmdir(os.path.dirname(d)) - except OSError: - pass + def preconnect(self, url): + """If |uri| will create a ssh connection, setup the ssh master for it.""" + m = URI_ALL.match(url) + if m: + scheme = m.group(1) + host = m.group(2) + if ':' in host: + host, port = host.split(':') + else: + port = None + if scheme in ('ssh', 'git+ssh', 'ssh+git'): + return self._open(host, port) + return False - # We're done with the lock, so we can delete it. - _master_keys_lock = None + m = URI_SCP.match(url) + if m: + host = m.group(1) + return self._open(host) - -URI_SCP = re.compile(r'^([^@:]*@?[^:/]{1,}):') -URI_ALL = re.compile(r'^([a-z][a-z+-]*)://([^@/]*@?[^/]*)/') - - -def preconnect(url): - m = URI_ALL.match(url) - if m: - scheme = m.group(1) - host = m.group(2) - if ':' in host: - host, port = host.split(':') - else: - port = None - if scheme in ('ssh', 'git+ssh', 'ssh+git'): - return _open_ssh(host, port) return False - m = URI_SCP.match(url) - if m: - host = m.group(1) - return _open_ssh(host) + def sock(self, create=True): + """Return the path to the ssh socket dir. - return False - -def sock(create=True): - global _ssh_sock_path - if _ssh_sock_path is None: - if not create: - return None - tmp_dir = '/tmp' - if not os.path.exists(tmp_dir): - tmp_dir = tempfile.gettempdir() - if version() < (6, 7): - tokens = '%r@%h:%p' - else: - tokens = '%C' # hash of %l%h%p%r - _ssh_sock_path = os.path.join( - tempfile.mkdtemp('', 'ssh-', tmp_dir), - 'master-' + tokens) - return _ssh_sock_path + This has all the master sockets so clients can talk to them. + """ + if self._sock_path is None: + if not create: + return None + tmp_dir = '/tmp' + if not os.path.exists(tmp_dir): + tmp_dir = tempfile.gettempdir() + if version() < (6, 7): + tokens = '%r@%h:%p' + else: + tokens = '%C' # hash of %l%h%p%r + self._sock_path = os.path.join( + tempfile.mkdtemp('', 'ssh-', tmp_dir), + 'master-' + tokens) + return self._sock_path diff --git a/subcmds/sync.py b/subcmds/sync.py index 28568062..fb25c221 100644 --- a/subcmds/sync.py +++ b/subcmds/sync.py @@ -358,7 +358,7 @@ later is required to fix a server side protocol bug. optimized_fetch=opt.optimized_fetch, retry_fetches=opt.retry_fetches, prune=opt.prune, - ssh_proxy=True, + ssh_proxy=self.ssh_proxy, clone_filter=self.manifest.CloneFilter, partial_clone_exclude=self.manifest.PartialCloneExclude) @@ -380,7 +380,11 @@ later is required to fix a server side protocol bug. finish = time.time() return (success, project, start, finish) - def _Fetch(self, projects, opt, err_event): + @classmethod + def _FetchInitChild(cls, ssh_proxy): + cls.ssh_proxy = ssh_proxy + + def _Fetch(self, projects, opt, err_event, ssh_proxy): ret = True jobs = opt.jobs_network if opt.jobs_network else self.jobs @@ -410,8 +414,14 @@ later is required to fix a server side protocol bug. break return ret + # We pass the ssh proxy settings via the class. This allows multiprocessing + # to pickle it up when spawning children. We can't pass it as an argument + # to _FetchProjectList below as multiprocessing is unable to pickle those. + Sync.ssh_proxy = None + # NB: Multiprocessing is heavy, so don't spin it up for one job. if len(projects_list) == 1 or jobs == 1: + self._FetchInitChild(ssh_proxy) if not _ProcessResults(self._FetchProjectList(opt, x) for x in projects_list): ret = False else: @@ -429,7 +439,8 @@ later is required to fix a server side protocol bug. else: pm.update(inc=0, msg='warming up') chunksize = 4 - with multiprocessing.Pool(jobs) as pool: + with multiprocessing.Pool( + jobs, initializer=self._FetchInitChild, initargs=(ssh_proxy,)) as pool: results = pool.imap_unordered( functools.partial(self._FetchProjectList, opt), projects_list, @@ -438,6 +449,11 @@ later is required to fix a server side protocol bug. ret = False pool.close() + # Cleanup the reference now that we're done with it, and we're going to + # release any resources it points to. If we don't, later multiprocessing + # usage (e.g. checkouts) will try to pickle and then crash. + del Sync.ssh_proxy + pm.end() self._fetch_times.Save() @@ -447,7 +463,7 @@ later is required to fix a server side protocol bug. return (ret, fetched) def _FetchMain(self, opt, args, all_projects, err_event, manifest_name, - load_local_manifests): + load_local_manifests, ssh_proxy): """The main network fetch loop. Args: @@ -457,6 +473,7 @@ later is required to fix a server side protocol bug. err_event: Whether an error was hit while processing. manifest_name: Manifest file to be reloaded. load_local_manifests: Whether to load local manifests. + ssh_proxy: SSH manager for clients & masters. """ rp = self.manifest.repoProject @@ -467,7 +484,7 @@ later is required to fix a server side protocol bug. to_fetch.extend(all_projects) to_fetch.sort(key=self._fetch_times.Get, reverse=True) - success, fetched = self._Fetch(to_fetch, opt, err_event) + success, fetched = self._Fetch(to_fetch, opt, err_event, ssh_proxy) if not success: err_event.set() @@ -498,7 +515,7 @@ later is required to fix a server side protocol bug. if previously_missing_set == missing_set: break previously_missing_set = missing_set - success, new_fetched = self._Fetch(missing, opt, err_event) + success, new_fetched = self._Fetch(missing, opt, err_event, ssh_proxy) if not success: err_event.set() fetched.update(new_fetched) @@ -985,12 +1002,15 @@ later is required to fix a server side protocol bug. self._fetch_times = _FetchTimes(self.manifest) if not opt.local_only: - try: - ssh.init() - self._FetchMain(opt, args, all_projects, err_event, manifest_name, - load_local_manifests) - finally: - ssh.close() + with multiprocessing.Manager() as manager: + with ssh.ProxyManager(manager) as ssh_proxy: + # Initialize the socket dir once in the parent. + ssh_proxy.sock() + self._FetchMain(opt, args, all_projects, err_event, manifest_name, + load_local_manifests, ssh_proxy) + + if opt.network_only: + return # If we saw an error, exit with code 1 so that other scripts can check. if err_event.is_set(): diff --git a/tests/test_ssh.py b/tests/test_ssh.py index 5a4f27e4..ffb5cb94 100644 --- a/tests/test_ssh.py +++ b/tests/test_ssh.py @@ -14,6 +14,8 @@ """Unittests for the ssh.py module.""" +import multiprocessing +import subprocess import unittest from unittest import mock @@ -39,14 +41,34 @@ class SshTests(unittest.TestCase): with mock.patch('ssh._run_ssh_version', return_value='OpenSSH_1.2\n'): self.assertEqual(ssh.version(), (1, 2)) + def test_context_manager_empty(self): + """Verify context manager with no clients works correctly.""" + with multiprocessing.Manager() as manager: + with ssh.ProxyManager(manager): + pass + + def test_context_manager_child_cleanup(self): + """Verify orphaned clients & masters get cleaned up.""" + with multiprocessing.Manager() as manager: + with ssh.ProxyManager(manager) as ssh_proxy: + client = subprocess.Popen(['sleep', '964853320']) + ssh_proxy.add_client(client) + master = subprocess.Popen(['sleep', '964853321']) + ssh_proxy.add_master(master) + # If the process still exists, these will throw timeout errors. + client.wait(0) + master.wait(0) + def test_ssh_sock(self): """Check sock() function.""" + manager = multiprocessing.Manager() + proxy = ssh.ProxyManager(manager) with mock.patch('tempfile.mkdtemp', return_value='/tmp/foo'): # old ssh version uses port with mock.patch('ssh.version', return_value=(6, 6)): - self.assertTrue(ssh.sock().endswith('%p')) - ssh._ssh_sock_path = None + self.assertTrue(proxy.sock().endswith('%p')) + + proxy._sock_path = None # new ssh version uses hash with mock.patch('ssh.version', return_value=(6, 7)): - self.assertTrue(ssh.sock().endswith('%C')) - ssh._ssh_sock_path = None + self.assertTrue(proxy.sock().endswith('%C'))