diff --git a/c-ext/decompressor.c b/c-ext/decompressor.c index ccbeaa87..3a940dff 100644 --- a/c-ext/decompressor.c +++ b/c-ext/decompressor.c @@ -262,22 +262,39 @@ static PyObject *Decompressor_copy_stream(ZstdDecompressor *self, PyObject *Decompressor_decompress(ZstdDecompressor *self, PyObject *args, PyObject *kwargs) { - static char *kwlist[] = {"data", "max_output_size", NULL}; + static char *kwlist[] = { + "data", + "max_output_size", + "read_across_frames", + "allow_extra_data", + NULL + }; Py_buffer source; Py_ssize_t maxOutputSize = 0; + unsigned long long decompressedSize; + PyObject *readAcrossFrames = NULL; + PyObject *allowExtraData = NULL; size_t destCapacity; PyObject *result = NULL; size_t zresult; ZSTD_outBuffer outBuffer; ZSTD_inBuffer inBuffer; - if (!PyArg_ParseTupleAndKeywords(args, kwargs, "y*|n:decompress", kwlist, - &source, &maxOutputSize)) { + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "y*|nOO:decompress", kwlist, + &source, &maxOutputSize, &readAcrossFrames, + &allowExtraData)) { return NULL; } + if (readAcrossFrames ? PyObject_IsTrue(readAcrossFrames) : 0) { + PyErr_SetString(ZstdError, + "ZstdDecompressor.read_across_frames=True is not yet implemented" + ); + goto finally; + } + if (ensure_dctx(self, 1)) { goto finally; } @@ -361,6 +378,16 @@ PyObject *Decompressor_decompress(ZstdDecompressor *self, PyObject *args, goto finally; } } + else if ((allowExtraData ? PyObject_IsTrue(allowExtraData) : 1) == 0 + && inBuffer.pos < inBuffer.size) { + PyErr_Format( + ZstdError, + "compressed input contains %zu bytes of unused data, which is disallowed", + inBuffer.size - inBuffer.pos + ); + Py_CLEAR(result); + goto finally; + } finally: PyBuffer_Release(&source); diff --git a/docs/news.rst b/docs/news.rst index ce9407cc..7a740097 100644 --- a/docs/news.rst +++ b/docs/news.rst @@ -103,6 +103,17 @@ Changes This may have fixed unconfirmed issues where ``unused_data`` was set prematurely. The new logic will also avoid an extra call to ``ZSTD_decompressStream()`` in some scenarios, possibly improving performance. +* ``ZstdDecompressor.decompress()`` how has a ``read_across_frames`` keyword + argument. It defaults to False. True is not yet implemented and will raise an + exception if used. The new argument will default to True in a future release + and is provided now so callers can start passing ``read_across_frames=False`` + to preserve the existing functionality during a future upgrade. +* ``ZstdDecompressor.decompress()`` now has an ``allow_extra_data`` keyword + argument to control whether an exception is raised if input contains extra + data. It defaults to True, preserving existing behavior of ignoring extra + data. It will likely default to False in a future release. Callers desiring + the current behavior are encouraged to explicitly pass + ``allow_extra_data=True`` so behavior won't change during a future upgrade. 0.18.0 (released 2022-06-20) ============================ diff --git a/rust-ext/src/decompressor.rs b/rust-ext/src/decompressor.rs index 3088668f..c20fde07 100644 --- a/rust-ext/src/decompressor.rs +++ b/rust-ext/src/decompressor.rs @@ -159,13 +159,26 @@ impl ZstdDecompressor { Ok((total_read, total_write)) } - #[args(buffer, max_output_size = "0")] + #[args( + buffer, + max_output_size = "0", + read_across_frames = "false", + allow_extra_data = "true" + )] fn decompress<'p>( &mut self, py: Python<'p>, buffer: PyBuffer, max_output_size: usize, + read_across_frames: bool, + allow_extra_data: bool, ) -> PyResult<&'p PyBytes> { + if read_across_frames { + return Err(ZstdError::new_err( + "ZstdDecompressor.read_across_frames=True is not yet implemented", + )); + } + self.setup_dctx(py, true)?; let output_size = @@ -215,6 +228,11 @@ impl ZstdDecompressor { "decompression error: decompressed {} bytes; expected {}", zresult, output_size ))) + } else if !allow_extra_data && in_buffer.pos < in_buffer.size { + Err(ZstdError::new_err(format!( + "compressed input contains {} bytes of unused data, which is disallowed", + in_buffer.size - in_buffer.pos + ))) } else { // TODO avoid memory copy Ok(PyBytes::new(py, &dest_buffer)) diff --git a/tests/test_decompressor_decompress.py b/tests/test_decompressor_decompress.py index d3e8d0e9..0b48f2bf 100644 --- a/tests/test_decompressor_decompress.py +++ b/tests/test_decompressor_decompress.py @@ -183,6 +183,20 @@ def test_multiple_frames(self): dctx = zstd.ZstdDecompressor() self.assertEqual(dctx.decompress(foo + bar), b"foo") + self.assertEqual( + dctx.decompress(foo + bar, allow_extra_data=True), b"foo" + ) + + with self.assertRaisesRegex( + zstd.ZstdError, + "ZstdDecompressor.read_across_frames=True is not yet implemented", + ): + dctx.decompress(foo + bar, read_across_frames=True) + + with self.assertRaisesRegex( + zstd.ZstdError, "%d bytes of unused data, which is disallowed" % len(bar) + ): + dctx.decompress(foo + bar, allow_extra_data=False) def test_junk_after_frame(self): cctx = zstd.ZstdCompressor() @@ -190,3 +204,10 @@ def test_junk_after_frame(self): dctx = zstd.ZstdDecompressor() self.assertEqual(dctx.decompress(frame + b"junk"), b"foo") + + self.assertEqual(dctx.decompress(frame + b"junk", allow_extra_data=True), b"foo") + + with self.assertRaisesRegex( + zstd.ZstdError, "4 bytes of unused data, which is disallowed" + ): + dctx.decompress(frame + b"junk", allow_extra_data=False) diff --git a/zstandard/__init__.pyi b/zstandard/__init__.pyi index ed8c2e4f..795e0c6f 100644 --- a/zstandard/__init__.pyi +++ b/zstandard/__init__.pyi @@ -389,7 +389,11 @@ class ZstdDecompressor(object): ): ... def memory_size(self) -> int: ... def decompress( - self, data: ByteString, max_output_size: int = ... + self, + data: ByteString, + max_output_size: int = ..., + read_across_frames: bool = ..., + allow_extra_data: bool = ..., ) -> bytes: ... def stream_reader( self, diff --git a/zstandard/backend_cffi.py b/zstandard/backend_cffi.py index 7075faa6..39b49192 100644 --- a/zstandard/backend_cffi.py +++ b/zstandard/backend_cffi.py @@ -3006,7 +3006,10 @@ def decompress(self, data): # buffer. So if the output buffer is partially filled and the input # is exhausted, there's nothing more to write. So we've done all we # can. - elif in_buffer.pos == in_buffer.size and out_buffer.pos < out_buffer.size: + elif ( + in_buffer.pos == in_buffer.size + and out_buffer.pos < out_buffer.size + ): break else: out_buffer.pos = 0 @@ -3715,7 +3718,13 @@ def memory_size(self): """ return lib.ZSTD_sizeof_DCtx(self._dctx) - def decompress(self, data, max_output_size=0): + def decompress( + self, + data, + max_output_size=0, + read_across_frames=False, + allow_extra_data=True, + ): """ Decompress data in a single operation. @@ -3727,11 +3736,20 @@ def decompress(self, data, max_output_size=0): similar). If the input does not contain a full frame, an exception will be raised. - If the input contains multiple frames, only the first frame will be - decompressed. If you need to decompress multiple frames, use an API - like :py:meth:`ZstdCompressor.stream_reader` with + ``read_across_frames`` controls whether to read multiple zstandard + frames in the input. When False, decompression stops after reading the + first frame. This feature is not yet implemented but the argument is + provided for forward API compatibility when the default is changed to + True in a future release. For now, if you need to decompress multiple + frames, use an API like :py:meth:`ZstdCompressor.stream_reader` with ``read_across_frames=True``. + ``allow_extra_data`` controls how to handle extra input data after a + fully decoded frame. If False, any extra data (which could be a valid + zstd frame) will result in ``ZstdError`` being raised. If True, extra + data is silently ignored. The default will likely change to False in a + future release when ``read_across_frames`` defaults to True. + If the input contains extra data after a full frame, that extra input data is silently ignored. This behavior is undesirable in many scenarios and will likely be changed or controllable in a future release (see @@ -3783,6 +3801,11 @@ def decompress(self, data, max_output_size=0): ``bytes`` representing decompressed output. """ + if read_across_frames: + raise ZstdError( + "ZstdDecompressor.read_across_frames=True is not yet implemented" + ) + self._ensure_dctx() data_buffer = ffi.from_buffer(data) @@ -3830,6 +3853,13 @@ def decompress(self, data, max_output_size=0): "decompression error: decompressed %d bytes; expected %d" % (zresult, output_size) ) + elif not allow_extra_data and in_buffer.pos < in_buffer.size: + count = in_buffer.size - in_buffer.pos + + raise ZstdError( + "compressed input contains %d bytes of unused data, which is disallowed" + % count + ) return ffi.buffer(result_buffer, out_buffer.pos)[:]