Skip to content

Commit

Permalink
Buffer if custom content provider stream
Browse files Browse the repository at this point in the history
This updates the
`RequestBody.fromContentProvider(ContentStreamProvider,long,String)`
override such that the underlying implementation will buffer the
contents of the stream in memory during the first pass through the
stream.

This is a followup to #5837.
  • Loading branch information
dagnir committed Jan 29, 2025
1 parent 70ae0dd commit 13d507d
Show file tree
Hide file tree
Showing 4 changed files with 138 additions and 36 deletions.
6 changes: 6 additions & 0 deletions .changes/next-release/feature-AWSSDKforJavav2-9db145b.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"type": "feature",
"category": "AWS SDK for Java v2",
"contributor": "",
"description": "Buffer input data from ContentStreamProvider to avoid the need to reread the stream after calculating its length."
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,19 +34,21 @@
@NotThreadSafe
public final class BufferingContentStreamProvider implements ContentStreamProvider {
private final ContentStreamProvider delegate;
private InputStream bufferedStream;
private final Long expectedLength;
private BufferStream bufferedStream;

private byte[] bufferedStreamData;
private int count;

public BufferingContentStreamProvider(ContentStreamProvider delegate) {
public BufferingContentStreamProvider(ContentStreamProvider delegate, Long expectedLength) {
this.delegate = delegate;
this.expectedLength = expectedLength;
}

@Override
public InputStream newStream() {
if (bufferedStreamData != null) {
return new ByteArrayInputStream(bufferedStreamData, 0, this.count);
return new ByteArrayStream(bufferedStreamData, 0, this.count);
}

if (bufferedStream == null) {
Expand All @@ -59,36 +61,54 @@ public InputStream newStream() {
return bufferedStream;
}

private class BufferStream extends BufferedInputStream {
class ByteArrayStream extends ByteArrayInputStream {

ByteArrayStream(byte[] buf, int offset, int length) {
super(buf, offset, length);
}

@Override
public void close() throws IOException {
super.close();
bufferedStream.close();
}
}

class BufferStream extends BufferedInputStream {
BufferStream(InputStream in) {
super(in);
}

@Override
public synchronized int read() throws IOException {
int read = super.read();
if (read < 0) {
saveBuffer();
}
return read;
public byte[] getBuf() {
return this.buf;
}

public int getCount() {
return this.count;
}

@Override
public synchronized int read(byte[] b, int off, int len) throws IOException {
int read = super.read(b, off, len);
if (read < 0) {
public void close() throws IOException {
if (!hasExpectedLength() || expectedLengthReached()) {
saveBuffer();
super.close();
}
return read;
}
}

private void saveBuffer() {
if (bufferedStreamData == null) {
IoUtils.closeQuietlyV2(in, null);
BufferingContentStreamProvider.this.bufferedStreamData = this.buf;
BufferingContentStreamProvider.this.count = this.count;
}
private void saveBuffer() {
if (bufferedStreamData == null) {
this.bufferedStreamData = bufferedStream.getBuf();
this.count = bufferedStream.getCount();
}
}

private boolean expectedLengthReached() {
return bufferedStream.getCount() >= expectedLength;
}

private boolean hasExpectedLength() {
return this.expectedLength != null;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,14 @@ public static RequestBody empty() {

/**
* Creates a {@link RequestBody} from the given {@link ContentStreamProvider}.
* <p>
* Important: Be aware that is implementation requires buffering the contents for {@code ContentStreamProvider}, which can
* cause increased memory usage.
* <p>
* If you are using this in conjunction with S3 and want to upload a stream with an unknown content length, you can refer
* S3's documentation for
* <a href="https://docs.aws.amazon.com/AmazonS3/latest/API/s3_example_s3_Scenario_UploadStream_section.html">alternative
* methods</a>.
*
* @param provider The content provider.
* @param contentLength The content length.
Expand All @@ -217,17 +225,14 @@ public static RequestBody empty() {
* @return The created {@code RequestBody}.
*/
public static RequestBody fromContentProvider(ContentStreamProvider provider, long contentLength, String mimeType) {
return new RequestBody(provider, contentLength, mimeType);
return new RequestBody(new BufferingContentStreamProvider(provider, contentLength), contentLength, mimeType);
}

/**
* Creates a {@link RequestBody} from the given {@link ContentStreamProvider} when the content length is unknown. If you
* are able to provide the content length at creation time, consider using {@link #fromInputStream(InputStream, long)} or
* {@link #fromContentProvider(ContentStreamProvider, long, String)} to negate the need to read through the stream to find
* the content length.
* Creates a {@link RequestBody} from the given {@link ContentStreamProvider} when the content length is unknown.
* <p>
* Important: Be aware that this override requires the SDK to buffer the entirety of your content stream to compute the
* content length. This will cause increased memory usage.
* Important: Be aware that is implementation requires buffering the contents for {@code ContentStreamProvider}, which can
* cause increased memory usage.
* <p>
* If you are using this in conjunction with S3 and want to upload a stream with an unknown content length, you can refer
* S3's documentation for
Expand All @@ -240,7 +245,7 @@ public static RequestBody fromContentProvider(ContentStreamProvider provider, lo
* @return The created {@code RequestBody}.
*/
public static RequestBody fromContentProvider(ContentStreamProvider provider, String mimeType) {
return new RequestBody(new BufferingContentStreamProvider(provider), null, mimeType);
return new RequestBody(new BufferingContentStreamProvider(provider, null), null, mimeType);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,12 @@
package software.amazon.awssdk.core.internal.sync;

import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyInt;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.nio.charset.StandardCharsets;
import java.util.Random;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.Mockito;
Expand Down Expand Up @@ -110,17 +109,89 @@ void newStream_closeClosesDelegateStream() throws IOException {
}

@Test
void newStream_allDataBuffered_closesDelegateStream() throws IOException {
public void newStream_delegateStreamClosedOnBufferingStreamClose() throws IOException {
InputStream delegateStream = Mockito.spy(new ByteArrayInputStream(TEST_DATA));

requestBody = RequestBody.fromContentProvider(() -> delegateStream, "text/plain");

IoUtils.drainInputStream(requestBody.contentStreamProvider().newStream());
Mockito.verify(delegateStream, Mockito.atLeast(1)).read(any(), anyInt(), anyInt());
InputStream stream = requestBody.contentStreamProvider().newStream();
IoUtils.drainInputStream(stream);
stream.close();

Mockito.verify(delegateStream).close();
}

@Test
public void newStream_lengthKnown_readUpToLengthThenClosed_newStreamUsesBufferedData() throws IOException {
ByteArrayInputStream stream = new ByteArrayInputStream(TEST_DATA);
requestBody = RequestBody.fromContentProvider(() -> stream, TEST_DATA.length, "text/plain");

int totalRead = 0;
int read;

InputStream stream1 = requestBody.contentStreamProvider().newStream();
do {
read = stream1.read();
if (read != -1) {
++totalRead;
}
} while (read != -1);

assertThat(totalRead).isEqualTo(TEST_DATA.length);

stream1.close();

assertThat(requestBody.contentStreamProvider().newStream())
.isInstanceOf(BufferingContentStreamProvider.ByteArrayStream.class);
}

@Test
public void newStream_lengthKnown_partialRead_close_doesNotBufferData() throws IOException {
// We need a large buffer because BufferedInputStream buffers data in chunks. If the buffer is small enough, a single
// read() on the BufferedInputStream might actually buffer all the delegate's data.

byte[] newData = new byte[16536];
new Random().nextBytes(newData);
ByteArrayInputStream stream = new ByteArrayInputStream(newData);
requestBody = RequestBody.fromContentProvider(() -> stream, newData.length, "text/plain");

InputStream stream1 = requestBody.contentStreamProvider().newStream();
int read = stream1.read();
assertThat(read).isNotEqualTo(-1);

stream1.close();

InputStream stream2 = requestBody.contentStreamProvider().newStream();
assertThat(stream2).isInstanceOf(BufferingContentStreamProvider.BufferStream.class);

assertThat(getCrc32(stream2)).isEqualTo(getCrc32(new ByteArrayInputStream(newData)));
}

@Test
public void newStream_bufferedDataStreamPartialRead_closed_bufferedDataIsNotReplaced() throws IOException {
byte[] newData = new byte[16536];
new Random().nextBytes(newData);
String newDataChecksum = getCrc32(new ByteArrayInputStream(newData));

ByteArrayInputStream stream = new ByteArrayInputStream(newData);

requestBody = RequestBody.fromContentProvider(() -> stream, "text/plain");
InputStream stream1 = requestBody.contentStreamProvider().newStream();
IoUtils.drainInputStream(stream1);
stream1.close();

InputStream stream2 = requestBody.contentStreamProvider().newStream();
assertThat(stream2).isInstanceOf(BufferingContentStreamProvider.ByteArrayStream.class);

int read = stream2.read();
assertThat(read).isNotEqualTo(-1);

stream2.close();

InputStream stream3 = requestBody.contentStreamProvider().newStream();
assertThat(stream3).isInstanceOf(BufferingContentStreamProvider.ByteArrayStream.class);

IoUtils.drainInputStream(requestBody.contentStreamProvider().newStream());
Mockito.verifyNoMoreInteractions(delegateStream);
assertThat(getCrc32(stream3)).isEqualTo(newDataChecksum);
}

private static String getCrc32(InputStream inputStream) {
Expand Down

0 comments on commit 13d507d

Please sign in to comment.