Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Buffer if custom content provider stream #5829

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .changes/next-release/feature-AWSSDKforJavav2-9db145b.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"type": "feature",
"category": "AWS SDK for Java v2",
"contributor": "",
"description": "Buffer input data from ContentStreamProvider to avoid the need to reread the stream after calculating its length."
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,19 +34,21 @@
@NotThreadSafe
public final class BufferingContentStreamProvider implements ContentStreamProvider {
private final ContentStreamProvider delegate;
private InputStream bufferedStream;
private final Long expectedLength;
private BufferStream bufferedStream;

private byte[] bufferedStreamData;
private int count;

public BufferingContentStreamProvider(ContentStreamProvider delegate) {
public BufferingContentStreamProvider(ContentStreamProvider delegate, Long expectedLength) {
this.delegate = delegate;
this.expectedLength = expectedLength;
}

@Override
public InputStream newStream() {
if (bufferedStreamData != null) {
return new ByteArrayInputStream(bufferedStreamData, 0, this.count);
return new ByteArrayStream(bufferedStreamData, 0, this.count);
}

if (bufferedStream == null) {
Expand All @@ -59,36 +61,54 @@ public InputStream newStream() {
return bufferedStream;
}

private class ByteArrayStream extends ByteArrayInputStream {

ByteArrayStream(byte[] buf, int offset, int length) {
super(buf, offset, length);
}

@Override
public void close() throws IOException {
super.close();
bufferedStream.close();
}
}

private class BufferStream extends BufferedInputStream {
BufferStream(InputStream in) {
super(in);
}

@Override
public synchronized int read() throws IOException {
int read = super.read();
if (read < 0) {
saveBuffer();
}
return read;
public byte[] getBuf() {
return this.buf;
}

public int getCount() {
return this.count;
}

@Override
public synchronized int read(byte[] b, int off, int len) throws IOException {
int read = super.read(b, off, len);
if (read < 0) {
public void close() throws IOException {
if (!hasExpectedLength() || expectedLengthReached()) {
saveBuffer();
super.close();
}
return read;
}
}

private void saveBuffer() {
if (bufferedStreamData == null) {
IoUtils.closeQuietlyV2(in, null);
BufferingContentStreamProvider.this.bufferedStreamData = this.buf;
BufferingContentStreamProvider.this.count = this.count;
}
private void saveBuffer() {
if (bufferedStreamData == null) {
this.bufferedStreamData = bufferedStream.getBuf();
this.count = bufferedStream.getCount();
}
}

private boolean expectedLengthReached() {
return hasExpectedLength() && bufferedStream.getCount() >= expectedLength;
}

private boolean hasExpectedLength() {
return this.expectedLength != null;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,14 @@ public static RequestBody empty() {

/**
* Creates a {@link RequestBody} from the given {@link ContentStreamProvider}.
* <p>
* Important: Be aware that is implementation requires buffering the contents for {@code ContentStreamProvider}, which can
* cause increased memory usage.
* <p>
* If you are using this in conjunction with S3 and want to upload a stream with an unknown content length, you can refer
* S3's documentation for
* <a href="https://docs.aws.amazon.com/AmazonS3/latest/API/s3_example_s3_Scenario_UploadStream_section.html">alternative
* methods</a>.
*
* @param provider The content provider.
* @param contentLength The content length.
Expand All @@ -217,17 +225,14 @@ public static RequestBody empty() {
* @return The created {@code RequestBody}.
*/
public static RequestBody fromContentProvider(ContentStreamProvider provider, long contentLength, String mimeType) {
return new RequestBody(provider, contentLength, mimeType);
return new RequestBody(new BufferingContentStreamProvider(provider, contentLength), contentLength, mimeType);
}

/**
* Creates a {@link RequestBody} from the given {@link ContentStreamProvider} when the content length is unknown. If you
* are able to provide the content length at creation time, consider using {@link #fromInputStream(InputStream, long)} or
* {@link #fromContentProvider(ContentStreamProvider, long, String)} to negate the need to read through the stream to find
* the content length.
* Creates a {@link RequestBody} from the given {@link ContentStreamProvider} when the content length is unknown.
* <p>
* Important: Be aware that this override requires the SDK to buffer the entirety of your content stream to compute the
* content length. This will cause increased memory usage.
* Important: Be aware that is implementation requires buffering the contents for {@code ContentStreamProvider}, which can
* cause increased memory usage.
* <p>
* If you are using this in conjunction with S3 and want to upload a stream with an unknown content length, you can refer
* S3's documentation for
Expand All @@ -240,7 +245,7 @@ public static RequestBody fromContentProvider(ContentStreamProvider provider, lo
* @return The created {@code RequestBody}.
*/
public static RequestBody fromContentProvider(ContentStreamProvider provider, String mimeType) {
return new RequestBody(new BufferingContentStreamProvider(provider), null, mimeType);
return new RequestBody(new BufferingContentStreamProvider(provider, null), null, mimeType);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@
package software.amazon.awssdk.core.internal.sync;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyInt;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.UncheckedIOException;
import java.nio.charset.StandardCharsets;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
Expand Down Expand Up @@ -110,19 +112,29 @@ void newStream_closeClosesDelegateStream() throws IOException {
}

@Test
void newStream_allDataBuffered_closesDelegateStream() throws IOException {
void newStream_allDataBuffered_doesNotCloseDelegate() throws IOException {
InputStream delegateStream = Mockito.spy(new ByteArrayInputStream(TEST_DATA));

requestBody = RequestBody.fromContentProvider(() -> delegateStream, "text/plain");

IoUtils.drainInputStream(requestBody.contentStreamProvider().newStream());
Mockito.verify(delegateStream, Mockito.atLeast(1)).read(any(), anyInt(), anyInt());
Mockito.verify(delegateStream).close();

IoUtils.drainInputStream(requestBody.contentStreamProvider().newStream());
Mockito.verifyNoMoreInteractions(delegateStream);
}

@Test
public void newStream_delegateStreamClosedOnBufferingStreamClose() throws IOException {
InputStream delegateStream = Mockito.spy(new ByteArrayInputStream(TEST_DATA));

requestBody = RequestBody.fromContentProvider(() -> delegateStream, "text/plain");

InputStream stream = requestBody.contentStreamProvider().newStream();
IoUtils.drainInputStream(stream);
stream.close();

Mockito.verify(delegateStream).close();
}

private static String getCrc32(InputStream inputStream) {
byte[] buff = new byte[1024];
int read;
Expand Down
Loading