diff --git a/.changes/2.30.8.json b/.changes/2.30.8.json index af5b473e317a..62e605cdd50e 100644 --- a/.changes/2.30.8.json +++ b/.changes/2.30.8.json @@ -37,6 +37,12 @@ "category": "Timestream InfluxDB", "contributor": "", "description": "Adds 'allocatedStorage' parameter to UpdateDbInstance API that allows increasing the database instance storage size and 'dbStorageType' parameter to UpdateDbInstance API that allows changing the storage type of the database instance" + }, + { + "type": "feature", + "category": "AWS SDK for Java v2", + "contributor": "", + "description": "Buffer input data from ContentStreamProvider to avoid the need to reread the stream after calculating its length." } ] } \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md index 3f5a537a0eed..127dea2b661e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,6 +24,10 @@ - ### Features - Adds 'allocatedStorage' parameter to UpdateDbInstance API that allows increasing the database instance storage size and 'dbStorageType' parameter to UpdateDbInstance API that allows changing the storage type of the database instance +## __AWS SDK for Java v2__ + - ### Features + - Buffer input data from ContentStreamProvider to avoid the need to reread the stream after calculating its length. + # __2.30.7__ __2025-01-27__ ## __AWS Elemental MediaConvert__ - ### Features diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/sync/BufferingContentStreamProvider.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/sync/BufferingContentStreamProvider.java new file mode 100644 index 000000000000..4a071e61c4dc --- /dev/null +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/sync/BufferingContentStreamProvider.java @@ -0,0 +1,94 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.core.internal.sync; + +import static software.amazon.awssdk.utils.FunctionalUtils.invokeSafely; + +import java.io.BufferedInputStream; +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import software.amazon.awssdk.annotations.NotThreadSafe; +import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.http.ContentStreamProvider; +import software.amazon.awssdk.utils.IoUtils; + +/** + * {@code ContentStreamProvider} implementation that buffers the data stream data to memory as it's read. Once the underlying + * stream is read fully, all subsequent calls to {@link #newStream()} will use the buffered data. + */ +@SdkInternalApi +@NotThreadSafe +public final class BufferingContentStreamProvider implements ContentStreamProvider { + private final ContentStreamProvider delegate; + private InputStream bufferedStream; + + private byte[] bufferedStreamData; + private int count; + + public BufferingContentStreamProvider(ContentStreamProvider delegate) { + this.delegate = delegate; + } + + @Override + public InputStream newStream() { + if (bufferedStreamData != null) { + return new ByteArrayInputStream(bufferedStreamData, 0, this.count); + } + + if (bufferedStream == null) { + InputStream delegateStream = delegate.newStream(); + bufferedStream = new BufferStream(delegateStream); + IoUtils.markStreamWithMaxReadLimit(bufferedStream, Integer.MAX_VALUE); + } + + invokeSafely(bufferedStream::reset); + return bufferedStream; + } + + private class BufferStream extends BufferedInputStream { + BufferStream(InputStream in) { + super(in); + } + + @Override + public synchronized int read() throws IOException { + int read = super.read(); + if (read < 0) { + saveBuffer(); + } + return read; + } + + @Override + public synchronized int read(byte[] b, int off, int len) throws IOException { + int read = super.read(b, off, len); + if (read < 0) { + saveBuffer(); + } + return read; + } + + private void saveBuffer() { + if (bufferedStreamData == null) { + IoUtils.closeQuietlyV2(in, null); + BufferingContentStreamProvider.this.bufferedStreamData = this.buf; + BufferingContentStreamProvider.this.count = this.count; + } + } + } + +} diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/sync/RequestBody.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/sync/RequestBody.java index c5927c2db375..e8525122110e 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/sync/RequestBody.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/sync/RequestBody.java @@ -31,6 +31,7 @@ import java.util.Arrays; import java.util.Optional; import software.amazon.awssdk.annotations.SdkPublicApi; +import software.amazon.awssdk.core.internal.sync.BufferingContentStreamProvider; import software.amazon.awssdk.core.internal.sync.FileContentStreamProvider; import software.amazon.awssdk.core.internal.util.Mimetype; import software.amazon.awssdk.core.io.ReleasableInputStream; @@ -220,7 +221,18 @@ public static RequestBody fromContentProvider(ContentStreamProvider provider, lo } /** - * Creates a {@link RequestBody} from the given {@link ContentStreamProvider}. + * Creates a {@link RequestBody} from the given {@link ContentStreamProvider} when the content length is unknown. If you + * are able to provide the content length at creation time, consider using {@link #fromInputStream(InputStream, long)} or + * {@link #fromContentProvider(ContentStreamProvider, long, String)} to negate the need to read through the stream to find + * the content length. + *
+ * Important: Be aware that this override requires the SDK to buffer the entirety of your content stream to compute the + * content length. This will cause increased memory usage. + *
+ * If you are using this in conjunction with S3 and want to upload a stream with an unknown content length, you can refer + * S3's documentation for + * alternative + * methods. * * @param provider The content provider. * @param mimeType The MIME type of the content. @@ -228,7 +240,7 @@ public static RequestBody fromContentProvider(ContentStreamProvider provider, lo * @return The created {@code RequestBody}. */ public static RequestBody fromContentProvider(ContentStreamProvider provider, String mimeType) { - return new RequestBody(provider, null, mimeType); + return new RequestBody(new BufferingContentStreamProvider(provider), null, mimeType); } /** diff --git a/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/sync/BufferingContentStreamProviderTest.java b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/sync/BufferingContentStreamProviderTest.java new file mode 100644 index 000000000000..f756701c4d5c --- /dev/null +++ b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/sync/BufferingContentStreamProviderTest.java @@ -0,0 +1,141 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.core.internal.sync; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.nio.charset.StandardCharsets; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; +import software.amazon.awssdk.checksums.DefaultChecksumAlgorithm; +import software.amazon.awssdk.checksums.SdkChecksum; +import software.amazon.awssdk.core.sync.RequestBody; +import software.amazon.awssdk.utils.BinaryUtils; +import software.amazon.awssdk.utils.IoUtils; + +class BufferingContentStreamProviderTest { + private static final SdkChecksum CRC32 = SdkChecksum.forAlgorithm(DefaultChecksumAlgorithm.CRC32); + private static final byte[] TEST_DATA = "BufferingContentStreamProviderTest".getBytes(StandardCharsets.UTF_8); + private static final String TEST_DATA_CHECKSUM = "f9ed1825"; + + private RequestBody requestBody; + + @BeforeEach + void setup() { + ByteArrayInputStream stream = new ByteArrayInputStream(TEST_DATA); + requestBody = RequestBody.fromContentProvider(() -> stream, "text/plain"); + } + + @Test + void newStream_alwaysStartsAtBeginning() { + String stream1Crc32 = getCrc32(requestBody.contentStreamProvider().newStream()); + String stream2Crc32 = getCrc32(requestBody.contentStreamProvider().newStream()); + + assertThat(stream1Crc32).isEqualTo(TEST_DATA_CHECKSUM); + assertThat(stream2Crc32).isEqualTo(TEST_DATA_CHECKSUM); + } + + @Test + void newStream_buffersSkippedBytes() throws IOException { + InputStream stream1 = requestBody.contentStreamProvider().newStream(); + + assertThat(stream1.skip(Long.MAX_VALUE)).isEqualTo(TEST_DATA.length); + + String stream2Crc32 = getCrc32(requestBody.contentStreamProvider().newStream()); + + assertThat(stream2Crc32).isEqualTo(TEST_DATA_CHECKSUM); + } + + @Test + void newStream_oneByteReads_dataBufferedCorrectly() throws IOException { + InputStream stream = requestBody.contentStreamProvider().newStream(); + int read; + do { + read = stream.read(); + } while (read != -1); + + assertThat(getCrc32(requestBody.contentStreamProvider().newStream())).isEqualTo(TEST_DATA_CHECKSUM); + } + + @Test + void newStream_wholeArrayReads_dataBufferedCorrectly() throws IOException { + InputStream stream = requestBody.contentStreamProvider().newStream(); + int read; + byte[] buff = new byte[32]; + do { + read = stream.read(buff); + } while (read != -1); + + assertThat(getCrc32(requestBody.contentStreamProvider().newStream())).isEqualTo(TEST_DATA_CHECKSUM); + } + + @Test + void newStream_offsetArrayReads_dataBufferedCorrectly() throws IOException { + InputStream stream = requestBody.contentStreamProvider().newStream(); + int read; + byte[] buff = new byte[32]; + do { + read = stream.read(buff, 0, 32); + } while (read != -1); + + assertThat(getCrc32(requestBody.contentStreamProvider().newStream())).isEqualTo(TEST_DATA_CHECKSUM); + } + + @Test + void newStream_closeClosesDelegateStream() throws IOException { + InputStream stream = Mockito.spy(new ByteArrayInputStream(TEST_DATA)); + requestBody = RequestBody.fromContentProvider(() -> stream, "text/plain"); + requestBody.contentStreamProvider().newStream().close(); + + Mockito.verify(stream).close(); + } + + @Test + void newStream_allDataBuffered_closesDelegateStream() throws IOException { + InputStream delegateStream = Mockito.spy(new ByteArrayInputStream(TEST_DATA)); + + requestBody = RequestBody.fromContentProvider(() -> delegateStream, "text/plain"); + + IoUtils.drainInputStream(requestBody.contentStreamProvider().newStream()); + Mockito.verify(delegateStream, Mockito.atLeast(1)).read(any(), anyInt(), anyInt()); + Mockito.verify(delegateStream).close(); + + IoUtils.drainInputStream(requestBody.contentStreamProvider().newStream()); + Mockito.verifyNoMoreInteractions(delegateStream); + } + + private static String getCrc32(InputStream inputStream) { + byte[] buff = new byte[1024]; + int read; + + CRC32.reset(); + try { + while ((read = inputStream.read(buff)) != -1) { + CRC32.update(buff, 0, read); + } + } catch (IOException e) { + throw new RuntimeException(e); + } + + return BinaryUtils.toHex(CRC32.getChecksumBytes()); + } +}