From 743966da9f74a0c7ec9e11cc697ec5699aed9f86 Mon Sep 17 00:00:00 2001 From: Ron Frederick Date: Fri, 29 Nov 2024 16:40:46 -0800 Subject: [PATCH] Add client and server support for SFTP copy-data extension This commit adds client and server support for the SFTP "copy-data" extension, and a new remote_copy() method on SFTPClient wihch allows you to make a request to copy bytes between two files on the remote server without needing to download and re-upload the data, if the server supports it. Thanks go to Ali Khosravi for suggesting this addition. --- asyncssh/sftp.py | 138 +++++++++++++++++++++++++++++++++++++++++++-- docs/api.rst | 7 ++- tests/test_sftp.py | 85 ++++++++++++++++++++++++---- 3 files changed, 211 insertions(+), 19 deletions(-) diff --git a/asyncssh/sftp.py b/asyncssh/sftp.py index ad992f8b..3809d3e5 100644 --- a/asyncssh/sftp.py +++ b/asyncssh/sftp.py @@ -161,6 +161,8 @@ MAX_SFTP_WRITE_LEN = 4*1024*1024 # 4 MiB MAX_SFTP_PACKET_LEN = MAX_SFTP_WRITE_LEN + 1024 +_COPY_DATA_BLOCK_SIZE = 256*1024 # 256 KiB + _MAX_SFTP_REQUESTS = 128 _MAX_READDIR_NAMES = 128 @@ -806,6 +808,24 @@ async def run(self) -> None: if self._progress_handler and self._total_bytes == 0: self._progress_handler(self._srcpath, self._dstpath, 0, 0) + if self._srcfs == self._dstfs and \ + isinstance(self._srcfs, SFTPClient): + try: + await self._srcfs.remote_copy( + cast(SFTPClientFile, self._src), + cast(SFTPClientFile, self._dst)) + except SFTPOpUnsupported: + pass + else: + self._bytes_copied = self._total_bytes + + if self._progress_handler: + self._progress_handler(self._srcpath, self._dstpath, + self._bytes_copied, + self._total_bytes) + + return + async for _, datalen in self.iter(): if datalen: self._bytes_copied += datalen @@ -822,8 +842,6 @@ async def run(self) -> None: setattr(exc, 'offset', self._bytes_copied) raise exc - - finally: if self._src: # pragma: no branch await self._src.close() @@ -2472,6 +2490,7 @@ def __init__(self, loop: asyncio.AbstractEventLoop, self._supports_fsync = False self._supports_lsetstat = False self._supports_limits = False + self._supports_copy_data = False @property def version(self) -> int: @@ -2692,6 +2711,8 @@ async def start(self) -> None: self._supports_lsetstat = True elif name == b'limits@openssh.com' and data == b'1': self._supports_limits = True + elif name == b'copy-data' and data == b'1': + self._supports_copy_data = True if version == 3: # Check if the server has a buggy SYMLINK implementation @@ -3090,6 +3111,26 @@ async def fsync(self, handle: bytes) -> None: else: raise SFTPOpUnsupported('fsync not supported') + async def copy_data(self, read_from_handle: bytes, read_from_offset: int, + read_from_length: int, write_to_handle: bytes, + write_to_offset: int) -> None: + """Make an SFTP copy data request""" + + if self._supports_copy_data: + self.logger.debug1('Sending copy-data from handle %s, ' + 'offset %d, length %d to handle %s, ' + 'offset %d', read_from_handle.hex(), + read_from_offset, read_from_length, + write_to_handle.hex(), write_to_offset) + + await self._make_request(b'copy-data', String(read_from_handle), + UInt64(read_from_offset), + UInt64(read_from_length), + String(write_to_handle), + UInt64(write_to_offset)) + else: + raise SFTPOpUnsupported('copy-data not supported') + def exit(self) -> None: """Handle a request to close the SFTP session""" @@ -3142,6 +3183,15 @@ async def __aexit__(self, _exc_type: Optional[Type[BaseException]], await self.close() return False + @property + def handle(self) -> bytes: + """Return handle or raise an error if clsoed""" + + if self._handle is None: + raise ValueError('I/O operation on closed file') + + return self._handle + async def _end(self) -> int: """Return the offset of the end of the file""" @@ -4233,6 +4283,35 @@ async def mcopy(self, srcpaths: _SFTPPaths, block_size, max_requests, progress_handler, error_handler) + async def remote_copy(self, src: SFTPClientFile, dst: SFTPClientFile, + src_offset: int = 0, src_length: int = 0, + dst_offset: int = 0) -> None: + """Copy data between remote files + + :param src: + The remote file object to read data from + :param dst: + The remote file object to write data to + :param src_offset: (optional) + The offset to begin reading data from + :param src_length: (optional) + The number of bytes to attempt to copy + :param dst_offset: (optional) + The offset to begin writing data to + :type src: :class:`SSHClientFile` + :type dst: :class:`SSHClientFile` + :type src_offset: `int` + :type src_length: `int` + :type dst_offset: `int` + + :raises: :exc:`SFTPError` if the server doesn't support this + extension or returns an error + + """ + + await self._handler.copy_data(src.handle, src_offset, src_length, + dst.handle, dst_offset) + async def glob(self, patterns: _SFTPPaths, error_handler: SFTPErrorHandler = None) -> \ Sequence[BytesOrStr]: @@ -5583,7 +5662,8 @@ class SFTPServerHandler(SFTPHandler): (b'hardlink@openssh.com', b'1'), (b'fsync@openssh.com', b'1'), (b'lsetstat@openssh.com', b'1'), - (b'limits@openssh.com', b'1')] + (b'limits@openssh.com', b'1'), + (b'copy-data', b'1')] _attrib_extensions: List[bytes] = [] @@ -6437,6 +6517,55 @@ async def _process_limits(self, packet: SSHPacket) -> SFTPLimits: return SFTPLimits(MAX_SFTP_PACKET_LEN, MAX_SFTP_READ_LEN, MAX_SFTP_WRITE_LEN, nfiles) + async def _process_copy_data(self, packet: SSHPacket) -> None: + """Process an incoming copy data request""" + + read_from_handle = packet.get_string() + read_from_offset = packet.get_uint64() + read_from_length = packet.get_uint64() + write_to_handle = packet.get_string() + write_to_offset = packet.get_uint64() + packet.check_end() + + self.logger.debug1('Received copy-data from handle %s, ' + 'offset %d, length %d to handle %s, ' + 'offset %d', read_from_handle.hex(), + read_from_offset, read_from_length, + write_to_handle.hex(), write_to_offset) + + src = self._file_handles.get(read_from_handle) + dst = self._file_handles.get(write_to_handle) + + if src and dst: + read_to_end = read_from_length == 0 + + while read_to_end or read_from_length: + if read_to_end: + size = _COPY_DATA_BLOCK_SIZE + else: + size = min(read_from_length, _COPY_DATA_BLOCK_SIZE) + + data = self._server.read(src, read_from_offset, size) + + if inspect.isawaitable(data): + data = await cast(Awaitable[bytes], data) + + result = self._server.write(dst, write_to_offset, data) + + if inspect.isawaitable(result): + await result + + if len(data) < size: + break + + read_from_offset += size + write_to_offset += size + + if not read_to_end: + read_from_length -= size + else: + raise SFTPInvalidHandle('Invalid file handle') + _packet_handlers: Dict[Union[int, bytes], _SFTPPacketHandler] = { FXP_OPEN: _process_open, FXP_CLOSE: _process_close, @@ -6465,7 +6594,8 @@ async def _process_limits(self, packet: SSHPacket) -> SFTPLimits: b'hardlink@openssh.com': _process_openssh_link, b'fsync@openssh.com': _process_fsync, b'lsetstat@openssh.com': _process_lsetstat, - b'limits@openssh.com': _process_limits + b'limits@openssh.com': _process_limits, + b'copy-data': _process_copy_data } async def run(self) -> None: diff --git a/docs/api.rst b/docs/api.rst index 6bb83105..434e14fb 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -1077,16 +1077,17 @@ SFTP Support .. autoattribute:: limits ======================================================================= = - ===================== = + =========================== = File transfer methods - ===================== = + =========================== = .. automethod:: get .. automethod:: put .. automethod:: copy .. automethod:: mget .. automethod:: mput .. automethod:: mcopy - ===================== = + .. automethod:: remote_copy + =========================== = ============================================================================================================================================================================================================================== = File access methods diff --git a/tests/test_sftp.py b/tests/test_sftp.py index 13068d50..3a5f3d55 100644 --- a/tests/test_sftp.py +++ b/tests/test_sftp.py @@ -748,6 +748,26 @@ async def test_copy(self, sftp): finally: remove('src dst') + def test_copy_non_remote(self): + """Test copying without using remote_copy function""" + + @sftp_test + async def _test_copy_non_remote(self, sftp): + """Test copying without using remote_copy function""" + + for src in ('src', b'src', Path('src')): + with self.subTest(src=type(src)): + try: + self._create_file('src') + await sftp.copy(src, 'dst') + self._check_file('src', 'dst') + finally: + remove('src dst') + + with patch('asyncssh.sftp.SFTPServerHandler._extensions', []): + # pylint: disable=no-value-for-parameter + _test_copy_non_remote(self) + @sftp_test async def test_copy_progress(self, sftp): """Test copying a file over SFTP with progress reporting""" @@ -769,7 +789,9 @@ def _report_progress(_srcpath, _dstpath, bytes_copied, _total_bytes): progress_handler=_report_progress) self._check_file('src', 'dst') - self.assertEqual(len(reports), (size // 8192) + 1) + if method != 'copy': + self.assertEqual(len(reports), (size // 8192) + 1) + self.assertEqual(reports[-1], size) finally: remove('src dst') @@ -1130,6 +1152,37 @@ def err_handler(exc): finally: remove('src1 src2 dst') + @sftp_test + async def test_remote_copy_arguments(self, sftp): + """Test remote copy arguments""" + + try: + self._create_file('src', os.urandom(2*1024*1024)) + + async with sftp.open('src', 'rb') as src: + async with sftp.open('dst', 'wb') as dst: + await sftp.remote_copy(src, dst, 0, 1024*1024, 0) + await sftp.remote_copy(src, dst, 1024*1024, 0, 1024*1024) + + self._check_file('src', 'dst') + finally: + remove('src dst') + + @sftp_test + async def test_remote_copy_closed_file(self, sftp): + """Test remote copy of a closed file""" + + try: + self._create_file('file') + + async with sftp.open('file', 'rb') as f: + await f.close() + + with self.assertRaises(ValueError): + await sftp.remote_copy(f, f) + finally: + remove('file') + @sftp_test async def test_glob(self, sftp): """Test a glob pattern match over SFTP""" @@ -3173,6 +3226,9 @@ async def _return_invalid_handle(self, path, pflags, attrs): with self.assertRaises(SFTPFailure): await f.fsync() + with self.assertRaises(SFTPFailure): + await sftp.remote_copy(f, f) + with self.assertRaises(SFTPFailure): await f.close() @@ -4300,19 +4356,24 @@ async def start_server(cls): return await cls.create_server(sftp_factory=_IOErrorSFTPServer) - @sftp_test - async def test_put_error(self, sftp): - """Test error when putting a file to an SFTP server""" + def test_copy_error(self): + """Test error when copying a file on an SFTP server""" - for method in ('get', 'put', 'copy'): - with self.subTest(method=method): - try: - self._create_file('src', 8*1024*1024*'\0') + @sftp_test + async def _test_copy_error(self, sftp): + """Test error when copying a file on an SFTP server""" - with self.assertRaises(SFTPFailure): - await getattr(sftp, method)('src', 'dst') - finally: - remove('src dst') + try: + self._create_file('src', 8*1024*1024*'\0') + + with self.assertRaises(SFTPFailure): + await sftp.copy('src', 'dst') + finally: + remove('src dst') + + with patch('asyncssh.sftp.SFTPServerHandler._extensions', []): + # pylint: disable=no-value-for-parameter + _test_copy_error(self) @sftp_test async def test_read_error(self, sftp):