From 0bc73254f41acb140187e0c89606311f88de5b7b Mon Sep 17 00:00:00 2001 From: Ron Frederick Date: Mon, 18 Dec 2023 07:41:57 -0800 Subject: [PATCH] Implement "strict kex" support to harden AsyncSSH against Terrapin Attack MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This commit implements "strict kex" support and other countermeasures to protect against the Terrapin Attack described in CVE-2023-48795. Thanks once again go to Fabian Bäumer, Marcus Brinkmann, and Jörg Schwenk for identifying and reporting this vulnerability and providing detailed analysis and suggestions about proposed fixes. --- asyncssh/connection.py | 67 ++++++++++++--- tests/test_connection.py | 149 +++++++++++++++++++++++++++++----- tests/test_connection_auth.py | 4 +- tests/util.py | 14 ++++ 4 files changed, 200 insertions(+), 34 deletions(-) diff --git a/asyncssh/connection.py b/asyncssh/connection.py index 2f814e9..ffe5703 100644 --- a/asyncssh/connection.py +++ b/asyncssh/connection.py @@ -861,6 +861,7 @@ def __init__(self, loop: asyncio.AbstractEventLoop, self._kexinit_sent = False self._kex_complete = False self._ignore_first_kex = False + self._strict_kex = False self._gss: Optional[GSSBase] = None self._gss_kex = False @@ -1398,10 +1399,13 @@ def _choose_alg(self, alg_type: str, local_algs: Sequence[bytes], (alg_type, b','.join(local_algs).decode('ascii'), b','.join(remote_algs).decode('ascii'))) - def _get_ext_info_kex_alg(self) -> List[bytes]: - """Return the kex alg to add if any to request extension info""" + def _get_extra_kex_algs(self) -> List[bytes]: + """Return the extra kex algs to add""" - return [b'ext-info-c' if self.is_client() else b'ext-info-s'] + if self.is_client(): + return [b'ext-info-c', b'kex-strict-c-v00@openssh.com'] + else: + return [b'ext-info-s', b'kex-strict-s-v00@openssh.com'] def _send(self, data: bytes) -> None: """Send data to the SSH connection""" @@ -1546,6 +1550,11 @@ def _recv_packet(self) -> bool: else: skip_reason = 'kex not in progress' exc_reason = 'Key exchange not in progress' + elif self._strict_kex and not self._recv_encryption and \ + MSG_IGNORE <= pkttype <= MSG_DEBUG: + skip_reason = 'strict kex violation' + exc_reason = 'Strict key exchange violation: ' \ + 'unexpected packet type %d received' % pkttype elif MSG_USERAUTH_FIRST <= pkttype <= MSG_USERAUTH_LAST: if self._auth: handler = self._auth @@ -1581,8 +1590,13 @@ def _recv_packet(self) -> bool: raise ProtocolError(str(exc)) from None if not processed: - self.logger.debug1('Unknown packet type %d received', pkttype) - self.send_packet(MSG_UNIMPLEMENTED, UInt32(seq)) + if self._strict_kex and not self._recv_encryption: + exc_reason = 'Strict key exchange violation: ' \ + 'unexpected packet type %d received' % pkttype + else: + self.logger.debug1('Unknown packet type %d received', + pkttype) + self.send_packet(MSG_UNIMPLEMENTED, UInt32(seq)) if exc_reason: raise ProtocolError(exc_reason) @@ -1591,9 +1605,16 @@ def _recv_packet(self) -> bool: self._auth_final = True if self._transport: - self._recv_seq = (seq + 1) & 0xffffffff self._recv_handler = self._recv_pkthdr + if self._recv_seq == 0xffffffff and not self._recv_encryption: + raise ProtocolError('Sequence rollover before kex complete') + + if pkttype == MSG_NEWKEYS and self._strict_kex: + self._recv_seq = 0 + else: + self._recv_seq = (seq + 1) & 0xffffffff + return True def send_packet(self, pkttype: int, *args: bytes, @@ -1645,7 +1666,15 @@ def send_packet(self, pkttype: int, *args: bytes, mac = b'' self._send(packet + mac) - self._send_seq = (seq + 1) & 0xffffffff + + if self._send_seq == 0xffffffff and not self._send_encryption: + self._send_seq = 0 + raise ProtocolError('Sequence rollover before kex complete') + + if pkttype == MSG_NEWKEYS and self._strict_kex: + self._send_seq = 0 + else: + self._send_seq = (seq + 1) & 0xffffffff if self._kex_complete: self._rekey_bytes_sent += pktlen @@ -1689,7 +1718,7 @@ def _send_kexinit(self) -> None: kex_algs = expand_kex_algs(self._kex_algs, gss_mechs, bool(self._server_host_key_algs)) + \ - self._get_ext_info_kex_alg() + self._get_extra_kex_algs() host_key_algs = self._server_host_key_algs or [b'null'] @@ -2191,13 +2220,27 @@ def _process_kexinit(self, _pkttype: int, _pktid: int, if self.is_server(): self._client_kexinit = packet.get_consumed_payload() - if b'ext-info-c' in peer_kex_algs and not self._session_id: - self._can_send_ext_info = True + if not self._session_id: + if b'ext-info-c' in peer_kex_algs: + self._can_send_ext_info = True + + if b'kex-strict-c-v00@openssh.com' in peer_kex_algs: + self._strict_kex = True else: self._server_kexinit = packet.get_consumed_payload() - if b'ext-info-s' in peer_kex_algs and not self._session_id: - self._can_send_ext_info = True + if not self._session_id: + if b'ext-info-s' in peer_kex_algs: + self._can_send_ext_info = True + + if b'kex-strict-s-v00@openssh.com' in peer_kex_algs: + self._strict_kex = True + + if self._strict_kex and not self._recv_encryption and \ + self._recv_seq != 0: + raise ProtocolError('Strict key exchange violation: ' + 'KEXINIT was not the first packet') + if self._kexinit_sent: self._kexinit_sent = False diff --git a/tests/test_connection.py b/tests/test_connection.py index 65b1542..387cb31 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -30,9 +30,10 @@ from unittest.mock import patch import asyncssh -from asyncssh.constants import MSG_DEBUG +from asyncssh.constants import MSG_IGNORE, MSG_DEBUG from asyncssh.constants import MSG_SERVICE_REQUEST, MSG_SERVICE_ACCEPT -from asyncssh.constants import MSG_KEXINIT, MSG_NEWKEYS, MSG_KEX_FIRST +from asyncssh.constants import MSG_KEXINIT, MSG_NEWKEYS +from asyncssh.constants import MSG_KEX_FIRST, MSG_KEX_LAST from asyncssh.constants import MSG_USERAUTH_REQUEST, MSG_USERAUTH_SUCCESS from asyncssh.constants import MSG_USERAUTH_FAILURE, MSG_USERAUTH_BANNER from asyncssh.constants import MSG_USERAUTH_FIRST @@ -43,6 +44,7 @@ from asyncssh.crypto.cipher import GCMCipher from asyncssh.encryption import get_encryption_algs from asyncssh.kex import get_kex_algs +from asyncssh.kex_dh import MSG_KEX_ECDH_REPLY from asyncssh.mac import _HMAC, _mac_handler, get_mac_algs from asyncssh.packet import Boolean, NameList, String, UInt32 from asyncssh.public_key import get_default_public_key_algs @@ -51,8 +53,8 @@ from .server import Server, ServerTestCase -from .util import asynctest, gss_available, nc_available, patch_gss -from .util import patch_getnameinfo, x509_available +from .util import asynctest, patch_extra_kex, patch_getnameinfo, patch_gss +from .util import gss_available, nc_available, x509_available class _CheckAlgsClientConnection(asyncssh.SSHClientConnection): @@ -930,22 +932,6 @@ def unsupported_kex_alg(): with self.assertRaises(asyncssh.KeyExchangeFailed): await self.connect(kex_algs=['fail']) - @asynctest - async def test_skip_ext_info(self): - """Test not requesting extension info from the server""" - - def skip_ext_info(self): - """Don't request extension information""" - - # pylint: disable=unused-argument - - return [] - - with patch('asyncssh.connection.SSHConnection._get_ext_info_kex_alg', - skip_ext_info): - async with self.connect(): - pass - @asynctest async def test_unknown_ext_info(self): """Test receiving unknown extension information""" @@ -970,6 +956,54 @@ def send_newkeys(self, k, h): with self.assertRaises(asyncssh.ProtocolError): await self.connect() + @asynctest + async def test_message_before_kexinit_strict_kex(self): + """Test receiving a message before KEXINIT with strict_kex enabled""" + + def send_packet(self, pkttype, *args, **kwargs): + if pkttype == MSG_KEXINIT: + self.send_packet(MSG_IGNORE, String(b'')) + + asyncssh.connection.SSHConnection.send_packet( + self, pkttype, *args, **kwargs) + + with patch('asyncssh.connection.SSHClientConnection.send_packet', + send_packet): + with self.assertRaises(asyncssh.ProtocolError): + await self.connect() + + @asynctest + async def test_message_during_kex_strict_kex(self): + """Test receiving an unexpected message with strict_kex enabled""" + + def send_packet(self, pkttype, *args, **kwargs): + if pkttype == MSG_KEX_ECDH_REPLY: + self.send_packet(MSG_IGNORE, String(b'')) + + asyncssh.connection.SSHConnection.send_packet( + self, pkttype, *args, **kwargs) + + with patch('asyncssh.connection.SSHServerConnection.send_packet', + send_packet): + with self.assertRaises(asyncssh.ProtocolError): + await self.connect() + + @asynctest + async def test_unknown_message_during_kex_strict_kex(self): + """Test receiving an unknown message with strict_kex enabled""" + + def send_packet(self, pkttype, *args, **kwargs): + if pkttype == MSG_KEX_ECDH_REPLY: + self.send_packet(MSG_KEX_LAST) + + asyncssh.connection.SSHConnection.send_packet( + self, pkttype, *args, **kwargs) + + with patch('asyncssh.connection.SSHServerConnection.send_packet', + send_packet): + with self.assertRaises(asyncssh.ProtocolError): + await self.connect() + @asynctest async def test_encryption_algs(self): """Test connecting with different encryption algorithms""" @@ -1602,6 +1636,81 @@ async def test_internal_error(self): await self.create_connection(_InternalErrorClient) +@patch_extra_kex +class _TestConnectionNoStrictKex(ServerTestCase): + """Unit tests for connection API with ext info and strict kex disabled""" + + @classmethod + async def start_server(cls): + """Start an SSH server to connect to""" + + return (await cls.create_server(_TunnelServer, gss_host=(), + compression_algs='*', + encryption_algs='*', + kex_algs='*', mac_algs='*')) + + @asynctest + async def test_skip_ext_info(self): + """Test not requesting extension info from the server""" + + async with self.connect(): + pass + + @asynctest + async def test_message_before_kexinit(self): + """Test receiving a message before KEXINIT""" + + def send_packet(self, pkttype, *args, **kwargs): + if pkttype == MSG_KEXINIT: + self.send_packet(MSG_IGNORE, String(b'')) + + asyncssh.connection.SSHConnection.send_packet( + self, pkttype, *args, **kwargs) + + with patch('asyncssh.connection.SSHClientConnection.send_packet', + send_packet): + async with self.connect(): + pass + + @asynctest + async def test_message_during_kex(self): + """Test receiving an unexpected message in key exchange""" + + def send_packet(self, pkttype, *args, **kwargs): + if pkttype == MSG_KEX_ECDH_REPLY: + self.send_packet(MSG_IGNORE, String(b'')) + + asyncssh.connection.SSHConnection.send_packet( + self, pkttype, *args, **kwargs) + + with patch('asyncssh.connection.SSHServerConnection.send_packet', + send_packet): + async with self.connect(): + pass + + @asynctest + async def test_sequence_wrap_during_kex(self): + """Test sequence wrap during initial key exchange""" + + def send_packet(self, pkttype, *args, **kwargs): + if pkttype == MSG_KEXINIT: + if self._options.command == 'send': + self._send_seq = 0xfffffffe + else: + self._recv_seq = 0xfffffffe + + asyncssh.connection.SSHConnection.send_packet( + self, pkttype, *args, **kwargs) + + with patch('asyncssh.connection.SSHClientConnection.send_packet', + send_packet): + with self.assertRaises(asyncssh.ProtocolError): + await self.connect(command='send') + + with self.assertRaises(asyncssh.ProtocolError): + await self.connect(command='recv') + + class _TestConnectionListenSock(ServerTestCase): """Unit test for specifying a listen socket""" diff --git a/tests/test_connection_auth.py b/tests/test_connection_auth.py index 822c418..75c8c47 100644 --- a/tests/test_connection_auth.py +++ b/tests/test_connection_auth.py @@ -739,7 +739,7 @@ def skip_ext_info(self): return [] - with patch('asyncssh.connection.SSHConnection._get_ext_info_kex_alg', + with patch('asyncssh.connection.SSHConnection._get_extra_kex_algs', skip_ext_info): try: async with self.connect(username='user', @@ -1245,7 +1245,7 @@ def skip_ext_info(self): return [] - with patch('asyncssh.connection.SSHConnection._get_ext_info_kex_alg', + with patch('asyncssh.connection.SSHConnection._get_extra_kex_algs', skip_ext_info): try: async with self.connect(username='ckey', client_keys='ckey', diff --git a/tests/util.py b/tests/util.py index 6600a5b..d80256f 100644 --- a/tests/util.py +++ b/tests/util.py @@ -106,6 +106,20 @@ def getnameinfo(sockaddr, flags): return patch('socket.getnameinfo', getnameinfo)(cls) +def patch_extra_kex(cls): + """Decorator for skipping extra kex algs""" + + def skip_extra_kex_algs(self): + """Don't send extra key exchange algorithms""" + + # pylint: disable=unused-argument + + return [] + + return patch('asyncssh.connection.SSHConnection._get_extra_kex_algs', + skip_extra_kex_algs)(cls) + + def patch_gss(cls): """Decorator for patching GSSAPI classes"""