From 8c23618563ccf7e89adbaadcdb6076a4b4e764bc Mon Sep 17 00:00:00 2001 From: Dongie Agnir <261310+dagnir@users.noreply.github.com> Date: Tue, 28 Jan 2025 21:02:08 -0800 Subject: [PATCH 1/2] Buffer if content provider provided w/o len (#5837) This updates the `RequestBody.fromContentProvider(ContentProvider, String)` method such that the underlying implementation will buffer the contents of the stream in memory during the first pass through the stream. --- .../feature-AWSSDKforJavav2-9db145b.json | 6 + .../sync/BufferingContentStreamProvider.java | 94 ++++++++++++ .../amazon/awssdk/core/sync/RequestBody.java | 16 +- .../BufferingContentStreamProviderTest.java | 141 ++++++++++++++++++ 4 files changed, 255 insertions(+), 2 deletions(-) create mode 100644 .changes/next-release/feature-AWSSDKforJavav2-9db145b.json create mode 100644 core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/sync/BufferingContentStreamProvider.java create mode 100644 core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/sync/BufferingContentStreamProviderTest.java diff --git a/.changes/next-release/feature-AWSSDKforJavav2-9db145b.json b/.changes/next-release/feature-AWSSDKforJavav2-9db145b.json new file mode 100644 index 000000000000..8eb1b249139c --- /dev/null +++ b/.changes/next-release/feature-AWSSDKforJavav2-9db145b.json @@ -0,0 +1,6 @@ +{ + "type": "feature", + "category": "AWS SDK for Java v2", + "contributor": "", + "description": "Buffer input data from ContentStreamProvider to avoid the need to reread the stream after calculating its length." +} diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/sync/BufferingContentStreamProvider.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/sync/BufferingContentStreamProvider.java new file mode 100644 index 000000000000..4a071e61c4dc --- /dev/null +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/sync/BufferingContentStreamProvider.java @@ -0,0 +1,94 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.core.internal.sync; + +import static software.amazon.awssdk.utils.FunctionalUtils.invokeSafely; + +import java.io.BufferedInputStream; +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import software.amazon.awssdk.annotations.NotThreadSafe; +import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.http.ContentStreamProvider; +import software.amazon.awssdk.utils.IoUtils; + +/** + * {@code ContentStreamProvider} implementation that buffers the data stream data to memory as it's read. Once the underlying + * stream is read fully, all subsequent calls to {@link #newStream()} will use the buffered data. + */ +@SdkInternalApi +@NotThreadSafe +public final class BufferingContentStreamProvider implements ContentStreamProvider { + private final ContentStreamProvider delegate; + private InputStream bufferedStream; + + private byte[] bufferedStreamData; + private int count; + + public BufferingContentStreamProvider(ContentStreamProvider delegate) { + this.delegate = delegate; + } + + @Override + public InputStream newStream() { + if (bufferedStreamData != null) { + return new ByteArrayInputStream(bufferedStreamData, 0, this.count); + } + + if (bufferedStream == null) { + InputStream delegateStream = delegate.newStream(); + bufferedStream = new BufferStream(delegateStream); + IoUtils.markStreamWithMaxReadLimit(bufferedStream, Integer.MAX_VALUE); + } + + invokeSafely(bufferedStream::reset); + return bufferedStream; + } + + private class BufferStream extends BufferedInputStream { + BufferStream(InputStream in) { + super(in); + } + + @Override + public synchronized int read() throws IOException { + int read = super.read(); + if (read < 0) { + saveBuffer(); + } + return read; + } + + @Override + public synchronized int read(byte[] b, int off, int len) throws IOException { + int read = super.read(b, off, len); + if (read < 0) { + saveBuffer(); + } + return read; + } + + private void saveBuffer() { + if (bufferedStreamData == null) { + IoUtils.closeQuietlyV2(in, null); + BufferingContentStreamProvider.this.bufferedStreamData = this.buf; + BufferingContentStreamProvider.this.count = this.count; + } + } + } + +} diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/sync/RequestBody.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/sync/RequestBody.java index c5927c2db375..e8525122110e 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/sync/RequestBody.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/sync/RequestBody.java @@ -31,6 +31,7 @@ import java.util.Arrays; import java.util.Optional; import software.amazon.awssdk.annotations.SdkPublicApi; +import software.amazon.awssdk.core.internal.sync.BufferingContentStreamProvider; import software.amazon.awssdk.core.internal.sync.FileContentStreamProvider; import software.amazon.awssdk.core.internal.util.Mimetype; import software.amazon.awssdk.core.io.ReleasableInputStream; @@ -220,7 +221,18 @@ public static RequestBody fromContentProvider(ContentStreamProvider provider, lo } /** - * Creates a {@link RequestBody} from the given {@link ContentStreamProvider}. + * Creates a {@link RequestBody} from the given {@link ContentStreamProvider} when the content length is unknown. If you + * are able to provide the content length at creation time, consider using {@link #fromInputStream(InputStream, long)} or + * {@link #fromContentProvider(ContentStreamProvider, long, String)} to negate the need to read through the stream to find + * the content length. + *
+ * Important: Be aware that this override requires the SDK to buffer the entirety of your content stream to compute the + * content length. This will cause increased memory usage. + *
+ * If you are using this in conjunction with S3 and want to upload a stream with an unknown content length, you can refer
+ * S3's documentation for
+ * alternative
+ * methods.
*
* @param provider The content provider.
* @param mimeType The MIME type of the content.
@@ -228,7 +240,7 @@ public static RequestBody fromContentProvider(ContentStreamProvider provider, lo
* @return The created {@code RequestBody}.
*/
public static RequestBody fromContentProvider(ContentStreamProvider provider, String mimeType) {
- return new RequestBody(provider, null, mimeType);
+ return new RequestBody(new BufferingContentStreamProvider(provider), null, mimeType);
}
/**
diff --git a/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/sync/BufferingContentStreamProviderTest.java b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/sync/BufferingContentStreamProviderTest.java
new file mode 100644
index 000000000000..f756701c4d5c
--- /dev/null
+++ b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/sync/BufferingContentStreamProviderTest.java
@@ -0,0 +1,141 @@
+/*
+ * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License").
+ * You may not use this file except in compliance with the License.
+ * A copy of the License is located at
+ *
+ * http://aws.amazon.com/apache2.0
+ *
+ * or in the "license" file accompanying this file. This file is distributed
+ * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
+ * express or implied. See the License for the specific language governing
+ * permissions and limitations under the License.
+ */
+
+package software.amazon.awssdk.core.internal.sync;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.anyInt;
+
+import java.io.ByteArrayInputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.nio.charset.StandardCharsets;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.mockito.Mockito;
+import software.amazon.awssdk.checksums.DefaultChecksumAlgorithm;
+import software.amazon.awssdk.checksums.SdkChecksum;
+import software.amazon.awssdk.core.sync.RequestBody;
+import software.amazon.awssdk.utils.BinaryUtils;
+import software.amazon.awssdk.utils.IoUtils;
+
+class BufferingContentStreamProviderTest {
+ private static final SdkChecksum CRC32 = SdkChecksum.forAlgorithm(DefaultChecksumAlgorithm.CRC32);
+ private static final byte[] TEST_DATA = "BufferingContentStreamProviderTest".getBytes(StandardCharsets.UTF_8);
+ private static final String TEST_DATA_CHECKSUM = "f9ed1825";
+
+ private RequestBody requestBody;
+
+ @BeforeEach
+ void setup() {
+ ByteArrayInputStream stream = new ByteArrayInputStream(TEST_DATA);
+ requestBody = RequestBody.fromContentProvider(() -> stream, "text/plain");
+ }
+
+ @Test
+ void newStream_alwaysStartsAtBeginning() {
+ String stream1Crc32 = getCrc32(requestBody.contentStreamProvider().newStream());
+ String stream2Crc32 = getCrc32(requestBody.contentStreamProvider().newStream());
+
+ assertThat(stream1Crc32).isEqualTo(TEST_DATA_CHECKSUM);
+ assertThat(stream2Crc32).isEqualTo(TEST_DATA_CHECKSUM);
+ }
+
+ @Test
+ void newStream_buffersSkippedBytes() throws IOException {
+ InputStream stream1 = requestBody.contentStreamProvider().newStream();
+
+ assertThat(stream1.skip(Long.MAX_VALUE)).isEqualTo(TEST_DATA.length);
+
+ String stream2Crc32 = getCrc32(requestBody.contentStreamProvider().newStream());
+
+ assertThat(stream2Crc32).isEqualTo(TEST_DATA_CHECKSUM);
+ }
+
+ @Test
+ void newStream_oneByteReads_dataBufferedCorrectly() throws IOException {
+ InputStream stream = requestBody.contentStreamProvider().newStream();
+ int read;
+ do {
+ read = stream.read();
+ } while (read != -1);
+
+ assertThat(getCrc32(requestBody.contentStreamProvider().newStream())).isEqualTo(TEST_DATA_CHECKSUM);
+ }
+
+ @Test
+ void newStream_wholeArrayReads_dataBufferedCorrectly() throws IOException {
+ InputStream stream = requestBody.contentStreamProvider().newStream();
+ int read;
+ byte[] buff = new byte[32];
+ do {
+ read = stream.read(buff);
+ } while (read != -1);
+
+ assertThat(getCrc32(requestBody.contentStreamProvider().newStream())).isEqualTo(TEST_DATA_CHECKSUM);
+ }
+
+ @Test
+ void newStream_offsetArrayReads_dataBufferedCorrectly() throws IOException {
+ InputStream stream = requestBody.contentStreamProvider().newStream();
+ int read;
+ byte[] buff = new byte[32];
+ do {
+ read = stream.read(buff, 0, 32);
+ } while (read != -1);
+
+ assertThat(getCrc32(requestBody.contentStreamProvider().newStream())).isEqualTo(TEST_DATA_CHECKSUM);
+ }
+
+ @Test
+ void newStream_closeClosesDelegateStream() throws IOException {
+ InputStream stream = Mockito.spy(new ByteArrayInputStream(TEST_DATA));
+ requestBody = RequestBody.fromContentProvider(() -> stream, "text/plain");
+ requestBody.contentStreamProvider().newStream().close();
+
+ Mockito.verify(stream).close();
+ }
+
+ @Test
+ void newStream_allDataBuffered_closesDelegateStream() throws IOException {
+ InputStream delegateStream = Mockito.spy(new ByteArrayInputStream(TEST_DATA));
+
+ requestBody = RequestBody.fromContentProvider(() -> delegateStream, "text/plain");
+
+ IoUtils.drainInputStream(requestBody.contentStreamProvider().newStream());
+ Mockito.verify(delegateStream, Mockito.atLeast(1)).read(any(), anyInt(), anyInt());
+ Mockito.verify(delegateStream).close();
+
+ IoUtils.drainInputStream(requestBody.contentStreamProvider().newStream());
+ Mockito.verifyNoMoreInteractions(delegateStream);
+ }
+
+ private static String getCrc32(InputStream inputStream) {
+ byte[] buff = new byte[1024];
+ int read;
+
+ CRC32.reset();
+ try {
+ while ((read = inputStream.read(buff)) != -1) {
+ CRC32.update(buff, 0, read);
+ }
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+
+ return BinaryUtils.toHex(CRC32.getChecksumBytes());
+ }
+}
From b80823fbc1225f45d3fa494c25705fb625056e3b Mon Sep 17 00:00:00 2001
From: Dongie Agnir