diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 33f8fbf5e9..d89577336a 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -20,7 +20,6 @@ Added Changed ~~~~~~~ - * Install pack with the latest tag version if it exists when branch is not specialized. (improvement) #4743 * Implement "continue" engine command to orquesta workflow. (improvement) #4740 @@ -65,6 +64,7 @@ Changed Fixed ~~~~~ +* Fix ssh zombies when using ProxyCommand from ssh config #4881 [Eric Edgar] * Fix rbac with execution view where the rbac is unable to verify the pack or uid of the execution because it was not returned from the action execution db. This would result in an internal server error when trying to view the results of a single execution. diff --git a/st2actions/tests/unit/test_paramiko_ssh.py b/st2actions/tests/unit/test_paramiko_ssh.py index e0697664a3..c6da6e3058 100644 --- a/st2actions/tests/unit/test_paramiko_ssh.py +++ b/st2actions/tests/unit/test_paramiko_ssh.py @@ -802,3 +802,32 @@ def test_use_ssh_config_port_value_provided_in_the_config(self, mock_sshclient): call_kwargs = mock_client.connect.call_args[1] self.assertEqual(call_kwargs['port'], 9999) + + @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase', + MagicMock(return_value=False)) + def test_socket_closed(self): + conn_params = {'hostname': 'dummy.host.org', + 'username': 'ubuntu', + 'password': 'pass', + 'timeout': '600'} + ssh_client = ParamikoSSHClient(**conn_params) + + # Make sure .close() doesn't actually call anything real + ssh_client.client = Mock() + ssh_client.sftp_client = None + ssh_client.bastion_client = None + + ssh_client.socket = Mock() + + # Make sure we havent called any close methods at this point + # TODO: Replace these with .assert_not_called() once it's Python 3.6+ only + self.assertEqual(ssh_client.socket.process.kill.call_count, 0) + self.assertEqual(ssh_client.socket.process.poll.call_count, 0) + + # Call the function that has changed + ssh_client.close() + + # Make sure we have called kill and poll + # TODO: Replace these with .assert_called_once() once it's Python 3.6+ only + self.assertEqual(ssh_client.socket.process.kill.call_count, 1) + self.assertEqual(ssh_client.socket.process.poll.call_count, 1) diff --git a/st2common/st2common/runners/paramiko_ssh.py b/st2common/st2common/runners/paramiko_ssh.py index c489629fb3..2e5154db9f 100644 --- a/st2common/st2common/runners/paramiko_ssh.py +++ b/st2common/st2common/runners/paramiko_ssh.py @@ -123,6 +123,7 @@ def __init__(self, hostname, port=DEFAULT_SSH_PORT, username=None, password=None self.bastion_client = None self.bastion_socket = None + self.socket = None def connect(self): """ @@ -455,6 +456,12 @@ def close(self): self.client.close() + if self.socket: + self.logger.debug('Closing proxycommand socket connection') + # https://github.com/paramiko/paramiko/issues/789 Avoid zombie ssh processes + self.socket.process.kill() + self.socket.process.poll() + if self.sftp_client: self.sftp_client.close() @@ -698,8 +705,8 @@ def _connect(self, host, socket=None): '_username': self.username, '_timeout': self.timeout} self.logger.debug('Connecting to server', extra=extra) - socket = socket or ssh_config_file_info.get('sock', None) - if socket: + self.socket = socket or ssh_config_file_info.get('sock', None) + if self.socket: conninfo['sock'] = socket client = paramiko.SSHClient()