diff --git a/asyncssh/config.py b/asyncssh/config.py index dfb4b97..5007ef7 100644 --- a/asyncssh/config.py +++ b/asyncssh/config.py @@ -60,12 +60,15 @@ class SSHConfig: _percent_expand = {'AuthorizedKeysFile'} _handlers: Dict[str, Tuple[str, Callable]] = {} - def __init__(self, last_config: Optional['SSHConfig'], reload: bool): + def __init__(self, last_config: Optional['SSHConfig'], reload: bool, + canonical: bool, final: bool): if last_config: self._last_options = last_config.get_options(reload) else: self._last_options = {} + self._canonical = canonical + self._final = True if final else None self._default_path = Path('~', '.ssh').expanduser() self._path = Path() self._line_no = 0 @@ -153,35 +156,53 @@ def _match(self, option: str, args: List[str]) -> None: # pylint: disable=unused-argument + matching = True + while args: match = args.pop(0).lower() + if match[0] == '!': + match = match[1:] + negated = True + else: + negated = False + + if match == 'final' and self._final is None: + self._final = False + if match == 'all': - self._matching = True - continue + result = True + elif match == 'canonical': + result = self._canonical + elif match == 'final': + result = cast(bool, self._final) + else: + match_val = self._match_val(match) - match_val = self._match_val(match) + if match != 'exec' and match_val is None: + self._error(f'Invalid match condition {match}') - if match != 'exec' and match_val is None: - self._error('Invalid match condition') + try: + arg = args.pop(0) + except IndexError: + self._error(f'Missing {match} match pattern') + + if matching: + if match == 'exec': + result = _exec(arg) + elif match in ('address', 'localaddress'): + host_pat = HostPatternList(arg) + ip = ip_address(cast(str, match_val)) \ + if match_val else None + result = host_pat.matches(None, match_val, ip) + else: + wild_pat = WildcardPatternList(arg) + result = wild_pat.matches(match_val) - try: - if match == 'exec': - self._matching = _exec(args.pop(0)) - elif match in ('address', 'localaddress'): - host_pat = HostPatternList(args.pop(0)) - ip = ip_address(cast(str, match_val)) \ - if match_val else None - self._matching = host_pat.matches(None, match_val, ip) - else: - wild_pat = WildcardPatternList(args.pop(0)) - self._matching = wild_pat.matches(match_val) - except IndexError: - self._error(f'Missing {match} match pattern') + if matching and result == negated: + matching = False - if not self._matching: - args.clear() - break + self._matching = matching def _set_bool(self, option: str, args: List[str]) -> None: """Set a boolean config option""" @@ -276,6 +297,23 @@ def _set_address_family(self, option: str, args: List[str]) -> None: if option not in self._options: self._options[option] = value + def _set_canonicalize_host(self, option: str, args: List[str]) -> None: + """Set a canonicalize host config option""" + + value_str = args.pop(0).lower() + + if value_str in ('yes', 'true'): + value: Union[bool, str] = True + elif value_str in ('no', 'false'): + value = False + elif value_str == 'always': + value = value_str + else: + self._error(f'Invalid {option} value: {value_str}') + + if option not in self._options: + self._options[option] = value + def _set_rekey_limits(self, option: str, args: List[str]) -> None: """Set rekey limits config option""" @@ -295,6 +333,11 @@ def _set_rekey_limits(self, option: str, args: List[str]) -> None: if option not in self._options: self._options[option] = byte_limit, time_limit + def has_match_final(self) -> bool: + """Return whether this config includes a 'Match final' block""" + + return self._final is not None + def parse(self, path: Path) -> None: """Parse an OpenSSH config file and return matching declarations""" @@ -384,10 +427,10 @@ def get_options(self, reload: bool) -> Dict[str, object]: @classmethod def load(cls, last_config: Optional['SSHConfig'], config_paths: ConfigPaths, reload: bool, - *args: object) -> 'SSHConfig': + canonical: bool, final: bool, *args: object) -> 'SSHConfig': """Load a list of OpenSSH config files into a config object""" - config = cls(last_config, reload, *args) + config = cls(last_config, reload, canonical, final, *args) if config_paths: if isinstance(config_paths, (str, PurePath)): @@ -429,8 +472,9 @@ class SSHClientConfig(SSHConfig): 'IdentityFile', 'ProxyCommand', 'RemoteCommand'} def __init__(self, last_config: 'SSHConfig', reload: bool, - local_user: str, user: str, host: str, port: int) -> None: - super().__init__(last_config, reload) + canonical: bool, final: bool, local_user: str, + user: str, host: str, port: int) -> None: + super().__init__(last_config, reload, canonical, final) self._local_user = local_user self._orig_host = host @@ -485,10 +529,10 @@ def _set_request_tty(self, option: str, args: List[str]) -> None: value: Union[bool, str] = True elif value_str in ('no', 'false'): value = False - elif value_str not in ('force', 'auto'): - self._error(f'Invalid {option} value: {value_str}') - else: + elif value_str in ('force', 'auto'): value = value_str + else: + self._error(f'Invalid {option} value: {value_str}') if option not in self._options: self._options[option] = value @@ -531,6 +575,11 @@ def _set_tokens(self) -> None: ('AddressFamily', SSHConfig._set_address_family), ('BindAddress', SSHConfig._set_string), + ('CanonicalDomains', SSHConfig._set_string_list), + ('CanonicalizeFallbackLocal', SSHConfig._set_bool), + ('CanonicalizeHostname', SSHConfig._set_canonicalize_host), + ('CanonicalizeMaxDots', SSHConfig._set_int), + ('CanonicalizePermittedCNAMEs', SSHConfig._set_string_list), ('CASignatureAlgorithms', SSHConfig._set_string), ('CertificateFile', SSHConfig._append_string), ('ChallengeResponseAuthentication', SSHConfig._set_bool), @@ -579,9 +628,9 @@ class SSHServerConfig(SSHConfig): """Settings from an OpenSSH server config file""" def __init__(self, last_config: 'SSHConfig', reload: bool, - local_addr: str, local_port: int, user: str, - host: str, addr: str) -> None: - super().__init__(last_config, reload) + canonical: bool, final: bool, local_addr: str, + local_port: int, user: str, host: str, addr: str) -> None: + super().__init__(last_config, reload, canonical, final) self._local_addr = local_addr self._local_port = local_port @@ -618,6 +667,11 @@ def _set_tokens(self) -> None: ('AuthorizedKeysFile', SSHConfig._set_string_list), ('AllowAgentForwarding', SSHConfig._set_bool), ('BindAddress', SSHConfig._set_string), + ('CanonicalDomains', SSHConfig._set_string_list), + ('CanonicalizeFallbackLocal', SSHConfig._set_bool), + ('CanonicalizeHostname', SSHConfig._set_canonicalize_host), + ('CanonicalizeMaxDots', SSHConfig._set_int), + ('CanonicalizePermittedCNAMEs', SSHConfig._set_string_list), ('CASignatureAlgorithms', SSHConfig._set_string), ('ChallengeResponseAuthentication', SSHConfig._set_bool), ('Ciphers', SSHConfig._set_string), diff --git a/asyncssh/connection.py b/asyncssh/connection.py index f7ee231..dec13a3 100644 --- a/asyncssh/connection.py +++ b/asyncssh/connection.py @@ -120,7 +120,7 @@ from .packet import Boolean, Byte, NameList, String, UInt32, PacketDecodeError from .packet import SSHPacket, SSHPacketHandler, SSHPacketLogger -from .pattern import WildcardPattern +from .pattern import WildcardPattern, WildcardPatternList from .pkcs11 import load_pkcs11_keys @@ -223,6 +223,7 @@ async def create_server(self, session_factory: TCPListenerFactory, _AuthKeysArg = DefTuple[Union[None, str, List[str], SSHAuthorizedKeys]] _ClientHostKey = Union[SSHKeyPair, SSHKeySignKeyPair] _ClientKeysArg = Union[KeyListArg, KeyPairListArg] +_CNAMEArg = DefTuple[Union[Sequence[str], Sequence[Tuple[str, str]]]] _GlobalRequest = Tuple[Optional[_PacketHandler], SSHPacket, bool] _GlobalRequestResult = Tuple[int, SSHPacket] @@ -273,6 +274,52 @@ async def create_server(self, session_factory: TCPListenerFactory, _DEFAULT_MAX_LINE_LENGTH = 1024 # 1024 characters +async def _resolve_host(host, loop: asyncio.AbstractEventLoop) -> Optional[str]: + """Attempt to resolve a hostname, returning a canonical name""" + + try: + addrinfo = await loop.getaddrinfo(host + '.', 0, + flags=socket.AI_CANONNAME) + except socket.gaierror: + return None + else: + return addrinfo[0][3] + + +async def _canonicalize_host(loop: asyncio.AbstractEventLoop, + options: 'SSHConnectionOptions') -> Optional[str]: + """Canonicalize a host name""" + + host = options.host + + if not options.canonicalize_hostname or not options.canonical_domains or \ + host.count('.') > options.canonicalize_max_dots or \ + (await _resolve_host(host, loop)): + return None + + for domain in options.canonical_domains: + canon_host = f'{host}.{domain}' + cname = await _resolve_host(canon_host, loop) + + if cname is not None: + if cname: + for patterns in options.canonicalize_permitted_cnames: + host_pat, cname_pat = map(WildcardPatternList, patterns) + + if host_pat.matches(canon_host) and \ + cname_pat.matches(cname): + canon_host = cname + break + + print(f'{host} => {canon_host}') + return canon_host + + if not options.canonicalize_fallback_local: + raise OSError(f'Unable to canonicalize hostname "{host}"') + + return None + + async def _open_proxy( loop: asyncio.AbstractEventLoop, command: Sequence[str], conn_factory: Callable[[], _Conn]) -> _Conn: @@ -348,7 +395,7 @@ def close(self) -> None: return cast(_Conn, cast(_ProxyCommandTunnel, tunnel).get_conn()) -async def _open_tunnel(tunnels: object, passphrase: Optional[BytesOrStr], +async def _open_tunnel(tunnels: object, options: '_Options', config: DefTuple[ConfigPaths]) -> \ Optional['SSHClientConnection']: """Parse and open connection to tunnel over""" @@ -373,10 +420,13 @@ async def _open_tunnel(tunnels: object, passphrase: Optional[BytesOrStr], last_conn = conn conn = await connect(host, port, username=username, - passphrase=passphrase, tunnel=conn, + passphrase=options.passphrase, tunnel=conn, config=config) conn.set_tunnel(last_conn) + if options.canonicalize_hostname != 'always': + options.canonicalize_hostname = False + return conn else: return None @@ -388,6 +438,17 @@ async def _connect(options: '_Options', config: DefTuple[ConfigPaths], conn_factory: Callable[[], _Conn], msg: str) -> _Conn: """Make outbound TCP or SSH tunneled connection""" + options.waiter = loop.create_future() + + canon_host = await _canonicalize_host(loop, options) + + host = canon_host if canon_host else options.host + canonical = bool(canon_host) + final = options.config.has_match_final() + + if canonical or final: + options.update(host=host, reload=True, canonical=canonical, final=final) + host = options.host port = options.port tunnel = options.tunnel @@ -396,9 +457,7 @@ async def _connect(options: '_Options', config: DefTuple[ConfigPaths], proxy_command = options.proxy_command free_conn = True - options.waiter = loop.create_future() - - new_tunnel = await _open_tunnel(tunnel, options.passphrase, config) + new_tunnel = await _open_tunnel(tunnel, options, config) tunnel: _TunnelConnectorProtocol try: @@ -446,6 +505,8 @@ async def _connect(options: '_Options', config: DefTuple[ConfigPaths], options.waiter.cancel() raise + conn.set_extra_info(host=host, port=port) + try: await options.waiter free_conn = False @@ -474,7 +535,7 @@ def tunnel_factory(_orig_host: str, _orig_port: int) -> SSHTCPSession: tunnel = options.tunnel family = options.family - new_tunnel = await _open_tunnel(tunnel, options.passphrase, config) + new_tunnel = await _open_tunnel(tunnel, options, config) tunnel: _TunnelListenerProtocol if sock: @@ -761,7 +822,7 @@ def update(self, **kwargs: object) -> None: """ - self._options.update(kwargs) + self._options.update(**kwargs) class SSHConnection(SSHPacketHandler, asyncio.Protocol): @@ -1267,7 +1328,7 @@ def connection_made(self, transport: asyncio.BaseTransport) -> None: if sock: sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, - self._tcp_keepalive) + 1 if self._tcp_keepalive else 0) if sock.family in (socket.AF_INET, socket.AF_INET6): sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) @@ -2792,6 +2853,8 @@ def get_extra_info(self, name: str, default: Any = None) -> Any: it is established. Supported values include everything supported by a socket transport plus: + | host + | port | username | client_version | server_version @@ -7102,6 +7165,11 @@ class SSHConnectionOptions(Options): family: int local_addr: HostPort tcp_keepalive: bool + canonicalize_hostname: Union[bool, str] + canonical_domains: Sequence[str] + canonicalize_fallback_local: bool + canonicalize_max_dots: int + canonicalize_permitted_cnames: Sequence[Tuple[str, str]] kex_algs: Sequence[bytes] encryption_algs: Sequence[bytes] mac_algs: Sequence[bytes] @@ -7142,6 +7210,11 @@ def prepare(self, config: SSHConfig, # type: ignore passphrase: Optional[BytesOrStr], proxy_command: DefTuple[_ProxyCommand], family: DefTuple[int], local_addr: DefTuple[HostPort], tcp_keepalive: DefTuple[bool], + canonicalize_hostname: DefTuple[Union[bool, str]], + canonical_domains: DefTuple[Sequence[str]], + canonicalize_fallback_local: DefTuple[bool], + canonicalize_max_dots: DefTuple[int], + canonicalize_permitted_cnames: _CNAMEArg, kex_algs: _AlgsArg, encryption_algs: _AlgsArg, mac_algs: _AlgsArg, compression_algs: _AlgsArg, signature_algs: _AlgsArg, host_based_auth: _AuthArg, @@ -7157,6 +7230,20 @@ def prepare(self, config: SSHConfig, # type: ignore keepalive_count_max: int) -> None: """Prepare common connection configuration options""" + def _split_cname_patterns( + patterns: Union[str, Tuple[str, str]]) -> Tuple[str, str]: + """Split CNAME patterns""" + + if isinstance(patterns, str): + domains = patterns.split(':') + + if len(domains) == 2: + patterns = cast(Tuple[str, str], tuple(domains)) + else: + raise ValueError('CNAME rules must contain two patterns') + + return patterns + self.config = config self.protocol_factory = protocol_factory self.version = _validate_version(version) @@ -7187,6 +7274,32 @@ def prepare(self, config: SSHConfig, # type: ignore self.tcp_keepalive = cast(bool, tcp_keepalive if tcp_keepalive != () else config.get('TCPKeepAlive', True)) + self.canonicalize_hostname = \ + cast(Union[bool, str], canonicalize_hostname + if canonicalize_hostname != () + else config.get('CanonicalizeHostname', False)) + + self.canonical_domains = \ + cast(Sequence[str], canonical_domains if canonical_domains != () + else config.get('CanonicalDomains', ())) + + self.canonicalize_fallback_local = \ + cast(bool, canonicalize_fallback_local \ + if canonicalize_fallback_local != () + else config.get('CanonicalizeFallbackLocal', True)) + + self.canonicalize_max_dots = \ + cast(int, canonicalize_max_dots if canonicalize_max_dots != () + else config.get('CanonicalizeMaxDots', 1)) + + permitted_cnames = \ + cast(Sequence[str], canonicalize_permitted_cnames + if canonicalize_permitted_cnames != () + else config.get('CanonicalizePermittedCNAMEs', ())) + + self.canonicalize_permitted_cnames = \ + [_split_cname_patterns(patterns) for patterns in permitted_cnames] + self.kex_algs, self.encryption_algs, self.mac_algs, \ self.compression_algs, self.signature_algs = \ _validate_algs(config, kex_algs, encryption_algs, mac_algs, @@ -7584,6 +7697,40 @@ class SSHClientConnectionOptions(SSHConnectionOptions): without getting a response before disconnecting from the server. This defaults to 3, but only applies when keepalive_interval is non-zero. + :param tcp_keepalive: (optional) + Whether or not to enable keepalive probes at the TCP level to + detect broken connections, defaulting to `True`. + :param canonicalize_hostname: (optional) + Whether or not to enable hostname canonicalization, defaulting + to `False`, in which case hostnames are passed as-is to the + system resolver. If set to `True`, requests that don't involve + a proxy tunnel or command will attempt to canonicalize the hostname + using canonical_domains and rules in canonicalize_permitted_cnames. + If set to `'always'`, hostname canonicalization is also applied + to proxied requests. + :param canonical_domains: (optional) + When canonicalize_hostname is set, this specifies list of domain + suffixes in which to search for the hostname. + :param canonicalize_fallback_local: (optional) + Whether or not to fall back to looking up the hostname against + the system resolver's search domains when no matches are found + in canonical_domains, defaulting to `True`. + :param canonicalize_max_dots: (optional) + Tha maximum number of dots which can appear in a hostname + before hostname canonicalization is disabled, defaulting + to 1. Hostnames with more than this number of dots are + treated as already being fully qualified and passed as-is + to the system resolver. + :param canonicalize_permitted_cnames: (optional) + Patterns to match against to decide whether hostname + canonicalization should return a CNAME. This argument + contains a list of pairs of wildcard pattern lists. The + first pattern is matched against the hostname found after + adding one of the search domains from canonical_domains and + the second pattern is matched against the associated CNAME. + If a match can be found in the list for both patterns, the + CNAME is returned as the canonical hostname. The default + is an empty list, preventing CNAMEs from being returned. :param command: (optional) The default remote command to execute on client sessions. An interactive shell is started if no command or subsystem is @@ -7724,6 +7871,12 @@ class SSHClientConnectionOptions(SSHConnectionOptions): :type login_timeout: *see* :ref:`SpecifyingTimeIntervals` :type keepalive_interval: *see* :ref:`SpecifyingTimeIntervals` :type keepalive_count_max: `int` + :type tcp_keepalive: `bool` + :type canonicalize_hostname: `bool` or `'always'` + :type canonical_domains: `list` of `str` + :type canonicalize_fallback_local: `bool` + :type canonicalize_max_dots: `int` + :type canonicalize_permitted_cnames: `list` of `tuple` of 2 `str` values :type command: `str` :type subsystem: `str` :type env: `dict` with `str` keys and values @@ -7796,6 +7949,7 @@ def prepare(self, # type: ignore loop: Optional[asyncio.AbstractEventLoop] = None, last_config: Optional[SSHConfig] = None, config: DefTuple[ConfigPaths] = None, reload: bool = False, + canonical: bool = False, final: bool = False, client_factory: Optional[_ClientFactory] = None, client_version: _VersionArg = (), host: str = '', port: DefTuple[int] = (), tunnel: object = (), @@ -7803,6 +7957,11 @@ def prepare(self, # type: ignore family: DefTuple[int] = (), local_addr: DefTuple[HostPort] = (), tcp_keepalive: DefTuple[bool] = (), + canonicalize_hostname: DefTuple[Union[bool, str]] = (), + canonical_domains: DefTuple[Sequence[str]] = (), + canonicalize_fallback_local: DefTuple[bool] = (), + canonicalize_max_dots: DefTuple[int] = (), + canonicalize_permitted_cnames: DefTuple[Sequence[str]] = (), kex_algs: _AlgsArg = (), encryption_algs: _AlgsArg = (), mac_algs: _AlgsArg = (), compression_algs: _AlgsArg = (), signature_algs: _AlgsArg = (), host_based_auth: _AuthArg = (), @@ -7870,8 +8029,9 @@ def prepare(self, # type: ignore config = [default_config] if os.access(default_config, os.R_OK) else [] - config = SSHClientConfig.load(last_config, config, reload, - local_username, username, host, port) + config = SSHClientConfig.load(last_config, config, reload, canonical, + final, local_username, username, host, + port) if x509_trusted_certs == (): default_x509_certs = Path('~', '.ssh', 'ca-bundle.crt').expanduser() @@ -7907,10 +8067,12 @@ def prepare(self, # type: ignore super().prepare(config, client_factory or SSHClient, client_version, host, port, tunnel, passphrase, proxy_command, family, - local_addr, tcp_keepalive, kex_algs, encryption_algs, - mac_algs, compression_algs, signature_algs, - host_based_auth, public_key_auth, kbdint_auth, - password_auth, x509_trusted_certs, + local_addr, tcp_keepalive, canonicalize_hostname, + canonical_domains, canonicalize_fallback_local, + canonicalize_max_dots, canonicalize_permitted_cnames, + kex_algs, encryption_algs, mac_algs, compression_algs, + signature_algs, host_based_auth, public_key_auth, + kbdint_auth, password_auth, x509_trusted_certs, x509_trusted_cert_paths, x509_purposes, rekey_bytes, rekey_seconds, connect_timeout, login_timeout, keepalive_interval, keepalive_count_max) @@ -8351,7 +8513,28 @@ class SSHServerConnectionOptions(SSHConnectionOptions): non-zero. :param tcp_keepalive: (optional) Whether or not to enable keepalive probes at the TCP level to - detect broken connections, defaulting to `True` + detect broken connections, defaulting to `True`. + :param canonicalize_hostname: (optional) + Whether or not to enable hostname canonicalization, defaulting + to `False`, in which case hostnames are passed as-is to the + system resolver. If set to `True`, requests that don't involve + a proxy tunnel or command will attempt to canonicalize the hostname + using canonical_domains and rules in canonicalize_permitted_cnames. + If set to `'always'`, hostname canonicalization is also applied + to proxied requests. + :param canonical_domains: (optional) + When canonicalize_hostname is set, this specifies list of domain + suffixes in which to search for the hostname. + :param canonicalize_fallback_local: (optional) + Whether or not to fall back to looking up the hostname against + the system resolver's search domains when no matches are found + in canonical_domains, defaulting to `True`. + :param canonicalize_max_dots: (optional) + Tha maximum number of dots which can appear in a hostname + before hostname canonicalization is disabled, defaulting + to 1. Hostnames with more than this number of dots are + treated as already being fully qualified and passed as-is + to the system resolver. :param config: (optional) Paths to OpenSSH server configuration files to load. This configuration will be used as a fallback to override the @@ -8426,6 +8609,12 @@ class SSHServerConnectionOptions(SSHConnectionOptions): :type login_timeout: *see* :ref:`SpecifyingTimeIntervals` :type keepalive_interval: *see* :ref:`SpecifyingTimeIntervals` :type keepalive_count_max: `int` + :type tcp_keepalive: `bool` + :type canonicalize_hostname: `bool` or `'always'` + :type canonical_domains: `list` of `str` + :type canonicalize_fallback_local: `bool` + :type canonicalize_max_dots: `int` + :type canonicalize_permitted_cnames: `list` of `tuple` of 2 `str` values :type config: `list` of `str` :type options: :class:`SSHServerConnectionOptions` @@ -8468,6 +8657,7 @@ def prepare(self, # type: ignore loop: Optional[asyncio.AbstractEventLoop] = None, last_config: Optional[SSHConfig] = None, config: DefTuple[ConfigPaths] = None, reload: bool = False, + canonical: bool = False, final: bool = False, accept_addr: str = '', accept_port: int = 0, username: str = '', client_host: str = '', client_addr: str = '', @@ -8478,6 +8668,11 @@ def prepare(self, # type: ignore family: DefTuple[int] = (), local_addr: DefTuple[HostPort] = (), tcp_keepalive: DefTuple[bool] = (), + canonicalize_hostname: DefTuple[Union[bool, str]] = (), + canonical_domains: DefTuple[Sequence[str]] = (), + canonicalize_fallback_local: DefTuple[bool] = (), + canonicalize_max_dots: DefTuple[int] = (), + canonicalize_permitted_cnames: DefTuple[Sequence[str]] = (), kex_algs: _AlgsArg = (), encryption_algs: _AlgsArg = (), mac_algs: _AlgsArg = (), compression_algs: _AlgsArg = (), signature_algs: _AlgsArg = (), host_based_auth: _AuthArg = (), @@ -8521,8 +8716,8 @@ def prepare(self, # type: ignore max_pktsize: int = _DEFAULT_MAX_PKTSIZE) -> None: """Prepare server connection configuration options""" - config = SSHServerConfig.load(last_config, config, reload, - accept_addr, accept_port, username, + config = SSHServerConfig.load(last_config, config, reload, canonical, + final, accept_addr, accept_port, username, client_host, client_addr) if login_timeout == (): @@ -8548,10 +8743,12 @@ def prepare(self, # type: ignore super().prepare(config, server_factory or SSHServer, server_version, host, port, tunnel, passphrase, proxy_command, family, - local_addr, tcp_keepalive, kex_algs, encryption_algs, - mac_algs, compression_algs, signature_algs, - host_based_auth, public_key_auth, kbdint_auth, - password_auth, x509_trusted_certs, + local_addr, tcp_keepalive, canonicalize_hostname, + canonical_domains, canonicalize_fallback_local, + canonicalize_max_dots, canonicalize_permitted_cnames, + kex_algs, encryption_algs, mac_algs, compression_algs, + signature_algs, host_based_auth, public_key_auth, + kbdint_auth, password_auth, x509_trusted_certs, x509_trusted_cert_paths, x509_purposes, rekey_bytes, rekey_seconds, connect_timeout, login_timeout, keepalive_interval, keepalive_count_max) diff --git a/asyncssh/misc.py b/asyncssh/misc.py index 2fb9c8e..a3d0fb7 100644 --- a/asyncssh/misc.py +++ b/asyncssh/misc.py @@ -456,7 +456,7 @@ def __init__(self, options: Optional['Options'] = None, **kwargs: object): def prepare(self, **kwargs: object) -> None: """Pre-process configuration options""" - def update(self, kwargs: Dict[str, object]) -> None: + def update(self, **kwargs: object) -> None: """Update options based on keyword parameters passed in""" self.kwargs.update(kwargs) diff --git a/docs/api.rst b/docs/api.rst index 335ed36..f544b0e 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -1600,6 +1600,11 @@ The following OpenSSH client config options are currently supported: | AddressFamily | BindAddress + | CanonicalDomains + | CanonicalizeFallbackLocal + | CanonicalizeHostname + | CanonicalizeMaxDots + | CanonicalizePermittedCNAMEs | CASignatureAlgorithms | CertificateFile | ChallengeResponseAuthentication @@ -1697,6 +1702,11 @@ The following OpenSSH server config options are currently supported: | AuthorizedKeysFile | AllowAgentForwarding | BindAddress + | CanonicalDomains + | CanonicalizeFallbackLocal + | CanonicalizeHostname + | CanonicalizeMaxDots + | CanonicalizePermittedCNAMEs | CASignatureAlgorithms | ChallengeResponseAuthentication | Ciphers diff --git a/tests/test_config.py b/tests/test_config.py index f0280aa..41380a5 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -108,6 +108,19 @@ def test_set_address_family(self): 'AddressFamily inet6') self.assertEqual(config.get('AddressFamily'), socket.AF_INET) + def test_set_canonicaize_host(self): + """Test canonicalize host config option""" + + for value, result in (('yes', True), ('true', True), + ('no', False), ('false', False), + ('always', 'always')): + config = self._parse_config(f'CanonicalizeHostname {value}') + self.assertEqual(config.get('CanonicalizeHostname'), result) + + config = self._parse_config('CanonicalizeHostname yes\n' + 'CanonicalizeHostname no') + self.assertEqual(config.get('CanonicalizeHostname'), True) + def test_set_rekey_limit(self): """Test rekey limit config option""" @@ -182,6 +195,24 @@ def test_match_all(self): config = self._parse_config('Match user xxx\nMatch all\nPort 2222') self.assertEqual(config.get('Port'), 2222) + def test_match_negated(self): + """Test a match block which never matches due to negation""" + + config = self._parse_config('Match !all user xxx\nPort 2222') + self.assertEqual(config.get('Port'), None) + + def test_match_canonical(self): + """Test a match block which matches when the host is canonicalized""" + + config = self._parse_config('Match canonical\nPort 2222') + self.assertEqual(config.get('Port'), None) + + def test_match_final(self): + """Test a match block which matches on the final parsing pass""" + + config = self._parse_config('Match final\nPort 2222') + self.assertEqual(config.get('Port'), None) + def test_match_exec(self): """Test a match block which runs a subprocess""" @@ -231,6 +262,8 @@ def test_errors(self): ('Unbalanced quotes', 'BindAddress "foo'), ('Extra data at end', 'BindAddress foo bar'), ('Invalid address family', 'AddressFamily xxx'), + ('Invalid canonicalization option', + 'CanonicalizeHostname xxx'), ('Invalid boolean', 'Compression xxx'), ('Invalid integer', 'Port xxx'), ('Invalid match condition', 'Match xxx')): @@ -243,13 +276,14 @@ class _TestClientConfig(_TestConfig): """Unit tests for client config objects""" def _load_config(self, config, last_config=None, reload=False, - local_user='user', user=(), host='host', port=()): + canonical=False, final=False, local_user='user', + user=(), host='host', port=()): """Load a client configuration""" # pylint: disable=arguments-differ - return SSHClientConfig.load(last_config, config, reload, - local_user, user, host, port) + return SSHClientConfig.load(last_config, config, reload, canonical, + final, local_user, user, host, port) def test_set_string_none(self): """Test string config option""" @@ -478,14 +512,16 @@ class _TestServerConfig(_TestConfig): """Unit tests for server config objects""" def _load_config(self, config, last_config=None, reload=False, + canonical=False, final=False, local_addr='127.0.0.1', local_port=22, user='user', host=None, addr='127.0.0.1'): """Load a server configuration""" # pylint: disable=arguments-differ - return SSHServerConfig.load(last_config, config, reload, - local_addr, local_port, user, host, addr) + return SSHServerConfig.load(last_config, config, reload, canonical, + final, local_addr, local_port, user, + host, addr) def test_match_local_address(self): """Test matching on local address""" diff --git a/tests/test_connection.py b/tests/test_connection.py index 9eadf45..e4e5852 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -53,7 +53,8 @@ from .server import Server, ServerTestCase -from .util import asynctest, patch_extra_kex, patch_getnameinfo, patch_gss +from .util import asynctest, patch_extra_kex, patch_getaddrinfo +from .util import patch_getnameinfo, patch_gss from .util import gss_available, nc_available, x509_available @@ -2671,3 +2672,109 @@ async def test_ssh_listen_context_manager(self): async with asyncssh.connect('127.0.0.1', listen_port, known_hosts=(['skey.pub'], [], [])): pass + + +@patch_getaddrinfo +class _TestCanonicalizeHost(ServerTestCase): + """Test hostname canonicalization""" + + @classmethod + async def start_server(cls): + """Start an SSH server to connect to""" + + #import logging + #logging.basicConfig(level='DEBUG') + #asyncssh.set_debug_level(2) + return await cls.create_server(_TunnelServer) + + @asynctest + async def test_canonicalize(self): + """Test hostname canonicalization""" + + async with self.connect('testhost', known_hosts=(['skey.pub'], [], []), + canonicalize_hostname=True, + canonical_domains=['test']) as conn: + self.assertEqual(conn.get_extra_info('host'), 'testhost.test') + + @asynctest + async def test_canonicalize_proxy(self): + """Test hostname canonicalization with proxy""" + + with open('config', 'w') as f: + f.write('UserKnownHostsFile none\n' + 'Match host localhost\nPubkeyAuthentication no') + + async with self.connect('testhost', config='config', + tunnel=f'localhost:{self._server_port}', + canonicalize_hostname=True, + canonical_domains=['test']) as conn: + self.assertEqual(conn.get_extra_info('host'), 'testhost.test') + + @asynctest + async def test_canonicalize_always(self): + """Test hostname canonicalization for all connections""" + + with open('config', 'w') as f: + f.write('UserKnownHostsFile none\n' + 'Match host localhost\nPubkeyAuthentication no') + + async with self.connect('testhost', config='config', + tunnel=f'localhost:{self._server_port}', + canonicalize_hostname='always', + canonical_domains=['test']) as conn: + self.assertEqual(conn.get_extra_info('host'), 'testhost.test') + + @asynctest + async def test_canonicalize_failure(self): + """Test hostname canonicalization failure""" + + with self.assertRaises(socket.gaierror): + await self.connect('unknown', known_hosts=(['skey.pub'], [], []), + canonicalize_hostname=True, + canonical_domains=['test']) + + @asynctest + async def test_canonicalize_failed_no_fallback(self): + """Test hostname canonicalization""" + + with self.assertRaises(OSError): + await self.connect('unknown', known_hosts=(['skey.pub'], [], []), + canonicalize_hostname=True, + canonical_domains=['test'], + canonicalize_fallback_local=False) + + @asynctest + async def test_cname_returned(self): + """Test hostname canonicalization with cname returned""" + + async with self.connect('testcname', + known_hosts=(['skey.pub'], [], []), + canonicalize_hostname=True, + canonical_domains=['test'], + canonicalize_permitted_cnames= \ + [('*.test', '*.test')]) as conn: + self.assertEqual(conn.get_extra_info('host'), 'cname.test') + + @asynctest + async def test_cname_not_returned(self): + """Test hostname canonicalization with cname not returned""" + + async with self.connect('testcname', + known_hosts=(['skey.pub'], [], []), + canonicalize_hostname=True, + canonical_domains=['test'], + canonicalize_permitted_cnames= \ + ['*.xxx:*.test']) as conn: + self.assertEqual(conn.get_extra_info('host'), 'testcname.test') + + @asynctest + async def test_bad_cname_rules(self): + """Test hostname canonicalization with bad cname rules""" + + with self.assertRaises(ValueError): + await self.connect('testcname', + known_hosts=(['skey.pub'], [], []), + canonicalize_hostname=True, + canonical_domains=['test'], + canonicalize_permitted_cnames= \ + ['*.xxx:*.test:*.xxx']) diff --git a/tests/util.py b/tests/util.py index 0c364d8..bb7caf3 100644 --- a/tests/util.py +++ b/tests/util.py @@ -94,6 +94,34 @@ def async_wrapper(self, *args, **kwargs): return async_wrapper +def patch_getaddrinfo(cls): + """Decorator for patching socket.getaddrinfo""" + + # pylint: disable=redefined-builtin + + cls.orig_getaddrinfo = socket.getaddrinfo + + hosts = {'testhost.test': '', + 'testcname.test': 'cname.test', + 'cname.test': ''} + + def getaddrinfo(host, port, family=0, type=0, proto=0, flags=0): + """Mock DNS lookup of server hostname""" + + # pylint: disable=unused-argument + + if host.endswith('.'): + host = host[:-1] + + try: + return [(socket.AF_INET, socket.SOCK_STREAM, socket.IPPROTO_TCP, + hosts[host], ('127.0.0.1', port))] + except KeyError: + return cls.orig_getaddrinfo(host, port, family, type, proto, flags) + + return patch('socket.getaddrinfo', getaddrinfo)(cls) + + def patch_getnameinfo(cls): """Decorator for patching socket.getnameinfo"""