diff --git a/git_command.py b/git_command.py index b6a4a343..954bebad 100644 --- a/git_command.py +++ b/git_command.py @@ -16,6 +16,7 @@ import os import sys import subprocess +import tempfile from error import GitError from trace import REPO_TRACE, IsTrace, Trace @@ -26,6 +27,27 @@ GIT_DIR = 'GIT_DIR' LAST_GITDIR = None LAST_CWD = None +_ssh_proxy_path = None +_ssh_sock_path = None + +def _ssh_sock(create=True): + global _ssh_sock_path + if _ssh_sock_path is None: + if not create: + return None + _ssh_sock_path = os.path.join( + tempfile.mkdtemp('', 'ssh-'), + 'master-%r@%h:%p') + return _ssh_sock_path + +def _ssh_proxy(): + global _ssh_proxy_path + if _ssh_proxy_path is None: + _ssh_proxy_path = os.path.join( + os.path.dirname(__file__), + 'git_ssh') + return _ssh_proxy_path + class _GitCall(object): def version(self): @@ -52,6 +74,7 @@ class GitCommand(object): capture_stdout = False, capture_stderr = False, disable_editor = False, + ssh_proxy = False, cwd = None, gitdir = None): env = dict(os.environ) @@ -68,6 +91,9 @@ class GitCommand(object): if disable_editor: env['GIT_EDITOR'] = ':' + if ssh_proxy: + env['REPO_SSH_SOCK'] = _ssh_sock() + env['GIT_SSH'] = _ssh_proxy() if project: if not cwd: diff --git a/git_config.py b/git_config.py index 7e642a4c..163b0809 100644 --- a/git_config.py +++ b/git_config.py @@ -16,11 +16,14 @@ import cPickle import os import re +import subprocess import sys +import time +from signal import SIGTERM from urllib2 import urlopen, HTTPError from error import GitError, UploadError from trace import Trace -from git_command import GitCommand +from git_command import GitCommand, _ssh_sock R_HEADS = 'refs/heads/' R_TAGS = 'refs/tags/' @@ -331,6 +334,79 @@ class RefSpec(object): return s +_ssh_cache = {} +_ssh_master = True + +def _open_ssh(host, port=None): + global _ssh_master + + if port is None: + port = 22 + + key = '%s:%s' % (host, port) + if key in _ssh_cache: + return True + + if not _ssh_master \ + or 'GIT_SSH' in os.environ \ + or sys.platform == 'win32': + # failed earlier, or cygwin ssh can't do this + # + return False + + command = ['ssh', + '-o','ControlPath %s' % _ssh_sock(), + '-p',str(port), + '-M', + '-N', + host] + try: + Trace(': %s', ' '.join(command)) + p = subprocess.Popen(command) + except Exception, e: + _ssh_master = False + print >>sys.stderr, \ + '\nwarn: cannot enable ssh control master for %s:%s\n%s' \ + % (host,port, str(e)) + return False + + _ssh_cache[key] = p + time.sleep(1) + return True + +def close_ssh(): + for key,p in _ssh_cache.iteritems(): + os.kill(p.pid, SIGTERM) + p.wait() + _ssh_cache.clear() + + d = _ssh_sock(create=False) + if d: + try: + os.rmdir(os.path.dirname(d)) + except OSError: + pass + +URI_SCP = re.compile(r'^([^@:]*@?[^:/]{1,}):') +URI_ALL = re.compile(r'^([a-z][a-z+]*)://([^@/]*@?[^/])/') + +def _preconnect(url): + m = URI_ALL.match(url) + if m: + scheme = m.group(1) + host = m.group(2) + if ':' in host: + host, port = host.split(':') + if scheme in ('ssh', 'git+ssh', 'ssh+git'): + return _open_ssh(host, port) + return False + + m = URI_SCP.match(url) + if m: + host = m.group(1) + return _open_ssh(host) + + class Remote(object): """Configuration options related to a remote. """ @@ -344,6 +420,9 @@ class Remote(object): self._Get('fetch', all=True)) self._review_protocol = None + def PreConnectFetch(self): + return _preconnect(self.url) + @property def ReviewProtocol(self): if self._review_protocol is None: diff --git a/git_ssh b/git_ssh new file mode 100755 index 00000000..63aa63c2 --- /dev/null +++ b/git_ssh @@ -0,0 +1,2 @@ +#!/bin/sh +exec ssh -o "ControlPath $REPO_SSH_SOCK" "$@" diff --git a/main.py b/main.py index 6fa1e51b..774b9038 100755 --- a/main.py +++ b/main.py @@ -28,6 +28,7 @@ import re import sys from trace import SetTrace +from git_config import close_ssh from command import InteractiveCommand from command import MirrorSafeCommand from command import PagedCommand @@ -212,7 +213,10 @@ def _Main(argv): repo = _Repo(opt.repodir) try: - repo._Run(argv) + try: + repo._Run(argv) + finally: + close_ssh() except KeyboardInterrupt: sys.exit(1) except RepoChangedException, rce: diff --git a/project.py b/project.py index fd3f0b8d..304480a8 100644 --- a/project.py +++ b/project.py @@ -969,11 +969,19 @@ class Project(object): def _RemoteFetch(self, name=None): if not name: name = self.remote.name + + ssh_proxy = False + if self.GetRemote(name).PreConnectFetch(): + ssh_proxy = True + cmd = ['fetch'] if not self.worktree: cmd.append('--update-head-ok') cmd.append(name) - return GitCommand(self, cmd, bare = True).Wait() == 0 + return GitCommand(self, + cmd, + bare = True, + ssh_proxy = ssh_proxy).Wait() == 0 def _Checkout(self, rev, quiet=False): cmd = ['checkout']