diff --git a/.changes/next-release/feature-AWSSDKforJavav2-9db145b.json b/.changes/next-release/feature-AWSSDKforJavav2-9db145b.json new file mode 100644 index 00000000000..8eb1b249139 --- /dev/null +++ b/.changes/next-release/feature-AWSSDKforJavav2-9db145b.json @@ -0,0 +1,6 @@ +{ + "type": "feature", + "category": "AWS SDK for Java v2", + "contributor": "", + "description": "Buffer input data from ContentStreamProvider to avoid the need to reread the stream after calculating its length." +} diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/sync/BufferingContentStreamProvider.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/sync/BufferingContentStreamProvider.java index 4a071e61c4d..6ba80306e7f 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/sync/BufferingContentStreamProvider.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/sync/BufferingContentStreamProvider.java @@ -34,19 +34,21 @@ @NotThreadSafe public final class BufferingContentStreamProvider implements ContentStreamProvider { private final ContentStreamProvider delegate; - private InputStream bufferedStream; + private final Long expectedLength; + private BufferStream bufferedStream; private byte[] bufferedStreamData; private int count; - public BufferingContentStreamProvider(ContentStreamProvider delegate) { + public BufferingContentStreamProvider(ContentStreamProvider delegate, Long expectedLength) { this.delegate = delegate; + this.expectedLength = expectedLength; } @Override public InputStream newStream() { if (bufferedStreamData != null) { - return new ByteArrayInputStream(bufferedStreamData, 0, this.count); + return new ByteArrayStream(bufferedStreamData, 0, this.count); } if (bufferedStream == null) { @@ -59,36 +61,54 @@ public InputStream newStream() { return bufferedStream; } - private class BufferStream extends BufferedInputStream { + class ByteArrayStream extends ByteArrayInputStream { + + ByteArrayStream(byte[] buf, int offset, int length) { + super(buf, offset, length); + } + + @Override + public void close() throws IOException { + super.close(); + bufferedStream.close(); + } + } + + class BufferStream extends BufferedInputStream { BufferStream(InputStream in) { super(in); } - @Override - public synchronized int read() throws IOException { - int read = super.read(); - if (read < 0) { - saveBuffer(); - } - return read; + public byte[] getBuf() { + return this.buf; + } + + public int getCount() { + return this.count; } @Override - public synchronized int read(byte[] b, int off, int len) throws IOException { - int read = super.read(b, off, len); - if (read < 0) { + public void close() throws IOException { + if (!hasExpectedLength() || expectedLengthReached()) { saveBuffer(); + super.close(); } - return read; } + } - private void saveBuffer() { - if (bufferedStreamData == null) { - IoUtils.closeQuietlyV2(in, null); - BufferingContentStreamProvider.this.bufferedStreamData = this.buf; - BufferingContentStreamProvider.this.count = this.count; - } + private void saveBuffer() { + if (bufferedStreamData == null) { + this.bufferedStreamData = bufferedStream.getBuf(); + this.count = bufferedStream.getCount(); } } + private boolean expectedLengthReached() { + return bufferedStream.getCount() >= expectedLength; + } + + private boolean hasExpectedLength() { + return this.expectedLength != null; + } + } diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/sync/RequestBody.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/sync/RequestBody.java index e8525122110..c2664e2c032 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/sync/RequestBody.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/sync/RequestBody.java @@ -209,6 +209,14 @@ public static RequestBody empty() { /** * Creates a {@link RequestBody} from the given {@link ContentStreamProvider}. + *
+ * Important: Be aware that is implementation requires buffering the contents for {@code ContentStreamProvider}, which can + * cause increased memory usage. + *
+ * If you are using this in conjunction with S3 and want to upload a stream with an unknown content length, you can refer + * S3's documentation for + * alternative + * methods. * * @param provider The content provider. * @param contentLength The content length. @@ -217,17 +225,14 @@ public static RequestBody empty() { * @return The created {@code RequestBody}. */ public static RequestBody fromContentProvider(ContentStreamProvider provider, long contentLength, String mimeType) { - return new RequestBody(provider, contentLength, mimeType); + return new RequestBody(new BufferingContentStreamProvider(provider, contentLength), contentLength, mimeType); } /** - * Creates a {@link RequestBody} from the given {@link ContentStreamProvider} when the content length is unknown. If you - * are able to provide the content length at creation time, consider using {@link #fromInputStream(InputStream, long)} or - * {@link #fromContentProvider(ContentStreamProvider, long, String)} to negate the need to read through the stream to find - * the content length. + * Creates a {@link RequestBody} from the given {@link ContentStreamProvider} when the content length is unknown. *
- * Important: Be aware that this override requires the SDK to buffer the entirety of your content stream to compute the - * content length. This will cause increased memory usage. + * Important: Be aware that is implementation requires buffering the contents for {@code ContentStreamProvider}, which can + * cause increased memory usage. *
* If you are using this in conjunction with S3 and want to upload a stream with an unknown content length, you can refer * S3's documentation for @@ -240,7 +245,7 @@ public static RequestBody fromContentProvider(ContentStreamProvider provider, lo * @return The created {@code RequestBody}. */ public static RequestBody fromContentProvider(ContentStreamProvider provider, String mimeType) { - return new RequestBody(new BufferingContentStreamProvider(provider), null, mimeType); + return new RequestBody(new BufferingContentStreamProvider(provider, null), null, mimeType); } /** diff --git a/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/sync/BufferingContentStreamProviderTest.java b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/sync/BufferingContentStreamProviderTest.java index f756701c4d5..7741a4a819e 100644 --- a/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/sync/BufferingContentStreamProviderTest.java +++ b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/sync/BufferingContentStreamProviderTest.java @@ -16,13 +16,12 @@ package software.amazon.awssdk.core.internal.sync; import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyInt; import java.io.ByteArrayInputStream; import java.io.IOException; import java.io.InputStream; import java.nio.charset.StandardCharsets; +import java.util.Random; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.mockito.Mockito; @@ -110,17 +109,89 @@ void newStream_closeClosesDelegateStream() throws IOException { } @Test - void newStream_allDataBuffered_closesDelegateStream() throws IOException { + public void newStream_delegateStreamClosedOnBufferingStreamClose() throws IOException { InputStream delegateStream = Mockito.spy(new ByteArrayInputStream(TEST_DATA)); requestBody = RequestBody.fromContentProvider(() -> delegateStream, "text/plain"); - IoUtils.drainInputStream(requestBody.contentStreamProvider().newStream()); - Mockito.verify(delegateStream, Mockito.atLeast(1)).read(any(), anyInt(), anyInt()); + InputStream stream = requestBody.contentStreamProvider().newStream(); + IoUtils.drainInputStream(stream); + stream.close(); + Mockito.verify(delegateStream).close(); + } + + @Test + public void newStream_lengthKnown_readUpToLengthThenClosed_newStreamUsesBufferedData() throws IOException { + ByteArrayInputStream stream = new ByteArrayInputStream(TEST_DATA); + requestBody = RequestBody.fromContentProvider(() -> stream, TEST_DATA.length, "text/plain"); + + int totalRead = 0; + int read; + + InputStream stream1 = requestBody.contentStreamProvider().newStream(); + do { + read = stream1.read(); + if (read != -1) { + ++totalRead; + } + } while (read != -1); + + assertThat(totalRead).isEqualTo(TEST_DATA.length); + + stream1.close(); + + assertThat(requestBody.contentStreamProvider().newStream()) + .isInstanceOf(BufferingContentStreamProvider.ByteArrayStream.class); + } + + @Test + public void newStream_lengthKnown_partialRead_close_doesNotBufferData() throws IOException { + // We need a large buffer because BufferedInputStream buffers data in chunks. If the buffer is small enough, a single + // read() on the BufferedInputStream might actually buffer all the delegate's data. + + byte[] newData = new byte[16536]; + new Random().nextBytes(newData); + ByteArrayInputStream stream = new ByteArrayInputStream(newData); + requestBody = RequestBody.fromContentProvider(() -> stream, newData.length, "text/plain"); + + InputStream stream1 = requestBody.contentStreamProvider().newStream(); + int read = stream1.read(); + assertThat(read).isNotEqualTo(-1); + + stream1.close(); + + InputStream stream2 = requestBody.contentStreamProvider().newStream(); + assertThat(stream2).isInstanceOf(BufferingContentStreamProvider.BufferStream.class); + + assertThat(getCrc32(stream2)).isEqualTo(getCrc32(new ByteArrayInputStream(newData))); + } + + @Test + public void newStream_bufferedDataStreamPartialRead_closed_bufferedDataIsNotReplaced() throws IOException { + byte[] newData = new byte[16536]; + new Random().nextBytes(newData); + String newDataChecksum = getCrc32(new ByteArrayInputStream(newData)); + + ByteArrayInputStream stream = new ByteArrayInputStream(newData); + + requestBody = RequestBody.fromContentProvider(() -> stream, "text/plain"); + InputStream stream1 = requestBody.contentStreamProvider().newStream(); + IoUtils.drainInputStream(stream1); + stream1.close(); + + InputStream stream2 = requestBody.contentStreamProvider().newStream(); + assertThat(stream2).isInstanceOf(BufferingContentStreamProvider.ByteArrayStream.class); + + int read = stream2.read(); + assertThat(read).isNotEqualTo(-1); + + stream2.close(); + + InputStream stream3 = requestBody.contentStreamProvider().newStream(); + assertThat(stream3).isInstanceOf(BufferingContentStreamProvider.ByteArrayStream.class); - IoUtils.drainInputStream(requestBody.contentStreamProvider().newStream()); - Mockito.verifyNoMoreInteractions(delegateStream); + assertThat(getCrc32(stream3)).isEqualTo(newDataChecksum); } private static String getCrc32(InputStream inputStream) {