Skip to content

Commit

Permalink
Implement "strict kex" support to harden AsyncSSH against Terrapin At…
Browse files Browse the repository at this point in the history
…tack

This commit implements "strict kex" support and other countermeasures to
protect against the Terrapin Attack described in CVE-2023-48795. Thanks
once again go to Fabian Bäumer, Marcus Brinkmann, and Jörg Schwenk for
identifying and reporting this vulnerability and providing detailed
analysis and suggestions about proposed fixes.
  • Loading branch information
ronf committed Dec 18, 2023
1 parent a788cfb commit 0bc7325
Show file tree
Hide file tree
Showing 4 changed files with 200 additions and 34 deletions.
67 changes: 55 additions & 12 deletions asyncssh/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -861,6 +861,7 @@ def __init__(self, loop: asyncio.AbstractEventLoop,
self._kexinit_sent = False
self._kex_complete = False
self._ignore_first_kex = False
self._strict_kex = False

self._gss: Optional[GSSBase] = None
self._gss_kex = False
Expand Down Expand Up @@ -1398,10 +1399,13 @@ def _choose_alg(self, alg_type: str, local_algs: Sequence[bytes],
(alg_type, b','.join(local_algs).decode('ascii'),
b','.join(remote_algs).decode('ascii')))

def _get_ext_info_kex_alg(self) -> List[bytes]:
"""Return the kex alg to add if any to request extension info"""
def _get_extra_kex_algs(self) -> List[bytes]:
"""Return the extra kex algs to add"""

return [b'ext-info-c' if self.is_client() else b'ext-info-s']
if self.is_client():
return [b'ext-info-c', b'[email protected]']
else:
return [b'ext-info-s', b'[email protected]']

def _send(self, data: bytes) -> None:
"""Send data to the SSH connection"""
Expand Down Expand Up @@ -1546,6 +1550,11 @@ def _recv_packet(self) -> bool:
else:
skip_reason = 'kex not in progress'
exc_reason = 'Key exchange not in progress'
elif self._strict_kex and not self._recv_encryption and \
MSG_IGNORE <= pkttype <= MSG_DEBUG:
skip_reason = 'strict kex violation'
exc_reason = 'Strict key exchange violation: ' \
'unexpected packet type %d received' % pkttype
elif MSG_USERAUTH_FIRST <= pkttype <= MSG_USERAUTH_LAST:
if self._auth:
handler = self._auth
Expand Down Expand Up @@ -1581,8 +1590,13 @@ def _recv_packet(self) -> bool:
raise ProtocolError(str(exc)) from None

if not processed:
self.logger.debug1('Unknown packet type %d received', pkttype)
self.send_packet(MSG_UNIMPLEMENTED, UInt32(seq))
if self._strict_kex and not self._recv_encryption:
exc_reason = 'Strict key exchange violation: ' \
'unexpected packet type %d received' % pkttype
else:
self.logger.debug1('Unknown packet type %d received',
pkttype)
self.send_packet(MSG_UNIMPLEMENTED, UInt32(seq))

if exc_reason:
raise ProtocolError(exc_reason)
Expand All @@ -1591,9 +1605,16 @@ def _recv_packet(self) -> bool:
self._auth_final = True

if self._transport:
self._recv_seq = (seq + 1) & 0xffffffff
self._recv_handler = self._recv_pkthdr

if self._recv_seq == 0xffffffff and not self._recv_encryption:
raise ProtocolError('Sequence rollover before kex complete')

if pkttype == MSG_NEWKEYS and self._strict_kex:
self._recv_seq = 0
else:
self._recv_seq = (seq + 1) & 0xffffffff

return True

def send_packet(self, pkttype: int, *args: bytes,
Expand Down Expand Up @@ -1645,7 +1666,15 @@ def send_packet(self, pkttype: int, *args: bytes,
mac = b''

self._send(packet + mac)
self._send_seq = (seq + 1) & 0xffffffff

if self._send_seq == 0xffffffff and not self._send_encryption:
self._send_seq = 0
raise ProtocolError('Sequence rollover before kex complete')

if pkttype == MSG_NEWKEYS and self._strict_kex:
self._send_seq = 0
else:
self._send_seq = (seq + 1) & 0xffffffff

if self._kex_complete:
self._rekey_bytes_sent += pktlen
Expand Down Expand Up @@ -1689,7 +1718,7 @@ def _send_kexinit(self) -> None:

kex_algs = expand_kex_algs(self._kex_algs, gss_mechs,
bool(self._server_host_key_algs)) + \
self._get_ext_info_kex_alg()
self._get_extra_kex_algs()

host_key_algs = self._server_host_key_algs or [b'null']

Expand Down Expand Up @@ -2191,13 +2220,27 @@ def _process_kexinit(self, _pkttype: int, _pktid: int,
if self.is_server():
self._client_kexinit = packet.get_consumed_payload()

if b'ext-info-c' in peer_kex_algs and not self._session_id:
self._can_send_ext_info = True
if not self._session_id:
if b'ext-info-c' in peer_kex_algs:
self._can_send_ext_info = True

if b'[email protected]' in peer_kex_algs:
self._strict_kex = True
else:
self._server_kexinit = packet.get_consumed_payload()

if b'ext-info-s' in peer_kex_algs and not self._session_id:
self._can_send_ext_info = True
if not self._session_id:
if b'ext-info-s' in peer_kex_algs:
self._can_send_ext_info = True

if b'[email protected]' in peer_kex_algs:
self._strict_kex = True

if self._strict_kex and not self._recv_encryption and \
self._recv_seq != 0:
raise ProtocolError('Strict key exchange violation: '
'KEXINIT was not the first packet')


if self._kexinit_sent:
self._kexinit_sent = False
Expand Down
149 changes: 129 additions & 20 deletions tests/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,10 @@
from unittest.mock import patch

import asyncssh
from asyncssh.constants import MSG_DEBUG
from asyncssh.constants import MSG_IGNORE, MSG_DEBUG
from asyncssh.constants import MSG_SERVICE_REQUEST, MSG_SERVICE_ACCEPT
from asyncssh.constants import MSG_KEXINIT, MSG_NEWKEYS, MSG_KEX_FIRST
from asyncssh.constants import MSG_KEXINIT, MSG_NEWKEYS
from asyncssh.constants import MSG_KEX_FIRST, MSG_KEX_LAST
from asyncssh.constants import MSG_USERAUTH_REQUEST, MSG_USERAUTH_SUCCESS
from asyncssh.constants import MSG_USERAUTH_FAILURE, MSG_USERAUTH_BANNER
from asyncssh.constants import MSG_USERAUTH_FIRST
Expand All @@ -43,6 +44,7 @@
from asyncssh.crypto.cipher import GCMCipher
from asyncssh.encryption import get_encryption_algs
from asyncssh.kex import get_kex_algs
from asyncssh.kex_dh import MSG_KEX_ECDH_REPLY
from asyncssh.mac import _HMAC, _mac_handler, get_mac_algs
from asyncssh.packet import Boolean, NameList, String, UInt32
from asyncssh.public_key import get_default_public_key_algs
Expand All @@ -51,8 +53,8 @@

from .server import Server, ServerTestCase

from .util import asynctest, gss_available, nc_available, patch_gss
from .util import patch_getnameinfo, x509_available
from .util import asynctest, patch_extra_kex, patch_getnameinfo, patch_gss
from .util import gss_available, nc_available, x509_available


class _CheckAlgsClientConnection(asyncssh.SSHClientConnection):
Expand Down Expand Up @@ -930,22 +932,6 @@ def unsupported_kex_alg():
with self.assertRaises(asyncssh.KeyExchangeFailed):
await self.connect(kex_algs=['fail'])

@asynctest
async def test_skip_ext_info(self):
"""Test not requesting extension info from the server"""

def skip_ext_info(self):
"""Don't request extension information"""

# pylint: disable=unused-argument

return []

with patch('asyncssh.connection.SSHConnection._get_ext_info_kex_alg',
skip_ext_info):
async with self.connect():
pass

@asynctest
async def test_unknown_ext_info(self):
"""Test receiving unknown extension information"""
Expand All @@ -970,6 +956,54 @@ def send_newkeys(self, k, h):
with self.assertRaises(asyncssh.ProtocolError):
await self.connect()

@asynctest
async def test_message_before_kexinit_strict_kex(self):
"""Test receiving a message before KEXINIT with strict_kex enabled"""

def send_packet(self, pkttype, *args, **kwargs):
if pkttype == MSG_KEXINIT:
self.send_packet(MSG_IGNORE, String(b''))

asyncssh.connection.SSHConnection.send_packet(
self, pkttype, *args, **kwargs)

with patch('asyncssh.connection.SSHClientConnection.send_packet',
send_packet):
with self.assertRaises(asyncssh.ProtocolError):
await self.connect()

@asynctest
async def test_message_during_kex_strict_kex(self):
"""Test receiving an unexpected message with strict_kex enabled"""

def send_packet(self, pkttype, *args, **kwargs):
if pkttype == MSG_KEX_ECDH_REPLY:
self.send_packet(MSG_IGNORE, String(b''))

asyncssh.connection.SSHConnection.send_packet(
self, pkttype, *args, **kwargs)

with patch('asyncssh.connection.SSHServerConnection.send_packet',
send_packet):
with self.assertRaises(asyncssh.ProtocolError):
await self.connect()

@asynctest
async def test_unknown_message_during_kex_strict_kex(self):
"""Test receiving an unknown message with strict_kex enabled"""

def send_packet(self, pkttype, *args, **kwargs):
if pkttype == MSG_KEX_ECDH_REPLY:
self.send_packet(MSG_KEX_LAST)

asyncssh.connection.SSHConnection.send_packet(
self, pkttype, *args, **kwargs)

with patch('asyncssh.connection.SSHServerConnection.send_packet',
send_packet):
with self.assertRaises(asyncssh.ProtocolError):
await self.connect()

@asynctest
async def test_encryption_algs(self):
"""Test connecting with different encryption algorithms"""
Expand Down Expand Up @@ -1602,6 +1636,81 @@ async def test_internal_error(self):
await self.create_connection(_InternalErrorClient)


@patch_extra_kex
class _TestConnectionNoStrictKex(ServerTestCase):
"""Unit tests for connection API with ext info and strict kex disabled"""

@classmethod
async def start_server(cls):
"""Start an SSH server to connect to"""

return (await cls.create_server(_TunnelServer, gss_host=(),
compression_algs='*',
encryption_algs='*',
kex_algs='*', mac_algs='*'))

@asynctest
async def test_skip_ext_info(self):
"""Test not requesting extension info from the server"""

async with self.connect():
pass

@asynctest
async def test_message_before_kexinit(self):
"""Test receiving a message before KEXINIT"""

def send_packet(self, pkttype, *args, **kwargs):
if pkttype == MSG_KEXINIT:
self.send_packet(MSG_IGNORE, String(b''))

asyncssh.connection.SSHConnection.send_packet(
self, pkttype, *args, **kwargs)

with patch('asyncssh.connection.SSHClientConnection.send_packet',
send_packet):
async with self.connect():
pass

@asynctest
async def test_message_during_kex(self):
"""Test receiving an unexpected message in key exchange"""

def send_packet(self, pkttype, *args, **kwargs):
if pkttype == MSG_KEX_ECDH_REPLY:
self.send_packet(MSG_IGNORE, String(b''))

asyncssh.connection.SSHConnection.send_packet(
self, pkttype, *args, **kwargs)

with patch('asyncssh.connection.SSHServerConnection.send_packet',
send_packet):
async with self.connect():
pass

@asynctest
async def test_sequence_wrap_during_kex(self):
"""Test sequence wrap during initial key exchange"""

def send_packet(self, pkttype, *args, **kwargs):
if pkttype == MSG_KEXINIT:
if self._options.command == 'send':
self._send_seq = 0xfffffffe
else:
self._recv_seq = 0xfffffffe

asyncssh.connection.SSHConnection.send_packet(
self, pkttype, *args, **kwargs)

with patch('asyncssh.connection.SSHClientConnection.send_packet',
send_packet):
with self.assertRaises(asyncssh.ProtocolError):
await self.connect(command='send')

with self.assertRaises(asyncssh.ProtocolError):
await self.connect(command='recv')


class _TestConnectionListenSock(ServerTestCase):
"""Unit test for specifying a listen socket"""

Expand Down
4 changes: 2 additions & 2 deletions tests/test_connection_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,7 +739,7 @@ def skip_ext_info(self):

return []

with patch('asyncssh.connection.SSHConnection._get_ext_info_kex_alg',
with patch('asyncssh.connection.SSHConnection._get_extra_kex_algs',
skip_ext_info):
try:
async with self.connect(username='user',
Expand Down Expand Up @@ -1245,7 +1245,7 @@ def skip_ext_info(self):

return []

with patch('asyncssh.connection.SSHConnection._get_ext_info_kex_alg',
with patch('asyncssh.connection.SSHConnection._get_extra_kex_algs',
skip_ext_info):
try:
async with self.connect(username='ckey', client_keys='ckey',
Expand Down
14 changes: 14 additions & 0 deletions tests/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,20 @@ def getnameinfo(sockaddr, flags):
return patch('socket.getnameinfo', getnameinfo)(cls)


def patch_extra_kex(cls):
"""Decorator for skipping extra kex algs"""

def skip_extra_kex_algs(self):
"""Don't send extra key exchange algorithms"""

# pylint: disable=unused-argument

return []

return patch('asyncssh.connection.SSHConnection._get_extra_kex_algs',
skip_extra_kex_algs)(cls)


def patch_gss(cls):
"""Decorator for patching GSSAPI classes"""

Expand Down

0 comments on commit 0bc7325

Please sign in to comment.