Skip to content

Commit

Permalink
Buffer if content provider provided w/o len
Browse files Browse the repository at this point in the history
This updates the `RequestBody.fromContentProvider(ContentStreamProvider,
String)` override such that the underlying implementation will buffer
the contents of the stream in memory during the first pass through the
stream. This guards against issues where the implementation of
`ContentStreamProvider#newStream()` does not behave as expected.
  • Loading branch information
dagnir committed Jan 28, 2025
1 parent b1b5216 commit 097b1bc
Show file tree
Hide file tree
Showing 4 changed files with 257 additions and 2 deletions.
6 changes: 6 additions & 0 deletions .changes/next-release/feature-AWSSDKforJavav2-9db145b.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"type": "feature",
"category": "AWS SDK for Java v2",
"contributor": "",
"description": "Buffer input data from ContentStreamProvider to avoid the need to reread the stream after calculating its length."
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License").
* You may not use this file except in compliance with the License.
* A copy of the License is located at
*
* http://aws.amazon.com/apache2.0
*
* or in the "license" file accompanying this file. This file is distributed
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*/

package software.amazon.awssdk.core.internal.sync;

import static software.amazon.awssdk.utils.FunctionalUtils.invokeSafely;

import java.io.BufferedInputStream;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import software.amazon.awssdk.annotations.NotThreadSafe;
import software.amazon.awssdk.annotations.SdkInternalApi;
import software.amazon.awssdk.core.io.ReleasableInputStream;
import software.amazon.awssdk.http.ContentStreamProvider;
import software.amazon.awssdk.utils.IoUtils;

/**
* {@code ContentStreamProvider} implementation that buffers the data stream data to memory as it's read. Once the underlying
* stream is read fully, all subsequent calls to {@link #newStream()} will use the buffered data.
*/
@SdkInternalApi
@NotThreadSafe
public final class BufferingContentStreamProvider implements ContentStreamProvider {
private final ContentStreamProvider delegate;
private InputStream bufferedStream;

private byte[] bufferedStreamData;
private int count;

public BufferingContentStreamProvider(ContentStreamProvider delegate) {
this.delegate = delegate;
}

@Override
public InputStream newStream() {
if (bufferedStreamData != null) {
return new ByteArrayInputStream(bufferedStreamData, 0, this.count);
}

if (bufferedStream == null) {
InputStream delegateStream = delegate.newStream();
bufferedStream = new BufferStream(delegateStream);
IoUtils.markStreamWithMaxReadLimit(bufferedStream, Integer.MAX_VALUE);

// Make non-closeable so that the SDK can't close those prematurely, before we can buffer all the data.
bufferedStream = ReleasableInputStream.wrap(bufferedStream).disableClose();
}

invokeSafely(bufferedStream::reset);
return bufferedStream;
}

private class BufferStream extends BufferedInputStream {
BufferStream(InputStream in) {
super(in);
}

@Override
public synchronized int read() throws IOException {
int read = super.read();
if (read < 0) {
saveBuffer();
}
return read;
}

@Override
public synchronized int read(byte[] b, int off, int len) throws IOException {
int read = super.read(b, off, len);
if (read < 0) {
saveBuffer();
}
return read;
}

private void saveBuffer() {
if (bufferedStreamData == null) {
IoUtils.closeQuietlyV2(in, null);
BufferingContentStreamProvider.this.bufferedStreamData = this.buf;
BufferingContentStreamProvider.this.count = this.count;
}
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import java.util.Arrays;
import java.util.Optional;
import software.amazon.awssdk.annotations.SdkPublicApi;
import software.amazon.awssdk.core.internal.sync.BufferingContentStreamProvider;
import software.amazon.awssdk.core.internal.sync.FileContentStreamProvider;
import software.amazon.awssdk.core.internal.util.Mimetype;
import software.amazon.awssdk.core.io.ReleasableInputStream;
Expand Down Expand Up @@ -220,15 +221,26 @@ public static RequestBody fromContentProvider(ContentStreamProvider provider, lo
}

/**
* Creates a {@link RequestBody} from the given {@link ContentStreamProvider}.
* Creates a {@link RequestBody} from the given {@link ContentStreamProvider} when the content length is unknown. If you
* are able to provide the content length at creation time, consider using {@link #fromInputStream(InputStream, long)} or
* {@link #fromContentProvider(ContentStreamProvider, long, String)} to negate the need to read through the stream to find
* the content length.
* <p>
* Important: Be aware that this override requires the SDK to buffer the entirety of your content stream to compute the
* content length. This will cause increased memory usage.
* <p>
* If you are using this in conjunction with S3 and want to upload a stream with an unknown content length, you can refer
* S3's documentation for
* <a href="https://docs.aws.amazon.com/AmazonS3/latest/API/s3_example_s3_Scenario_UploadStream_section.html">alternative
* methods</a>.
*
* @param provider The content provider.
* @param mimeType The MIME type of the content.
*
* @return The created {@code RequestBody}.
*/
public static RequestBody fromContentProvider(ContentStreamProvider provider, String mimeType) {
return new RequestBody(provider, null, mimeType);
return new RequestBody(new BufferingContentStreamProvider(provider), null, mimeType);
}

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License").
* You may not use this file except in compliance with the License.
* A copy of the License is located at
*
* http://aws.amazon.com/apache2.0
*
* or in the "license" file accompanying this file. This file is distributed
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*/

package software.amazon.awssdk.core.internal.sync;

import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyInt;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.nio.charset.StandardCharsets;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.Mockito;
import software.amazon.awssdk.checksums.DefaultChecksumAlgorithm;
import software.amazon.awssdk.checksums.SdkChecksum;
import software.amazon.awssdk.core.sync.RequestBody;
import software.amazon.awssdk.utils.BinaryUtils;
import software.amazon.awssdk.utils.IoUtils;

class BufferingContentStreamProviderTest {
private static final SdkChecksum CRC32 = SdkChecksum.forAlgorithm(DefaultChecksumAlgorithm.CRC32);
private static final byte[] TEST_DATA = "BufferingContentStreamProviderTest".getBytes(StandardCharsets.UTF_8);
private static final String TEST_DATA_CHECKSUM = "f9ed1825";

private RequestBody requestBody;

@BeforeEach
void setup() {
ByteArrayInputStream stream = new ByteArrayInputStream(TEST_DATA);
requestBody = RequestBody.fromContentProvider(() -> stream, "text/plain");
}

@Test
void newStream_alwaysStartsAtBeginning() {
String stream1Crc32 = getCrc32(requestBody.contentStreamProvider().newStream());
String stream2Crc32 = getCrc32(requestBody.contentStreamProvider().newStream());

assertThat(stream1Crc32).isEqualTo(TEST_DATA_CHECKSUM);
assertThat(stream2Crc32).isEqualTo(TEST_DATA_CHECKSUM);
}

@Test
void newStream_buffersSkippedBytes() throws IOException {
InputStream stream1 = requestBody.contentStreamProvider().newStream();

assertThat(stream1.skip(Long.MAX_VALUE)).isEqualTo(TEST_DATA.length);

String stream2Crc32 = getCrc32(requestBody.contentStreamProvider().newStream());

assertThat(stream2Crc32).isEqualTo(TEST_DATA_CHECKSUM);
}

@Test
void newStream_oneByteReads_dataBufferedCorrectly() throws IOException {
InputStream stream = requestBody.contentStreamProvider().newStream();
int read;
do {
read = stream.read();
} while (read != -1);

assertThat(getCrc32(requestBody.contentStreamProvider().newStream())).isEqualTo(TEST_DATA_CHECKSUM);
}

@Test
void newStream_wholeArrayReads_dataBufferedCorrectly() throws IOException {
InputStream stream = requestBody.contentStreamProvider().newStream();
int read;
byte[] buff = new byte[32];
do {
read = stream.read(buff);
} while (read != -1);

assertThat(getCrc32(requestBody.contentStreamProvider().newStream())).isEqualTo(TEST_DATA_CHECKSUM);
}

@Test
void newStream_offsetArrayReads_dataBufferedCorrectly() throws IOException {
InputStream stream = requestBody.contentStreamProvider().newStream();
int read;
byte[] buff = new byte[32];
do {
read = stream.read(buff, 0, 32);
} while (read != -1);

assertThat(getCrc32(requestBody.contentStreamProvider().newStream())).isEqualTo(TEST_DATA_CHECKSUM);
}

@Test
void newStream_streamCannotBeClosedByCaller() throws IOException {
requestBody.contentStreamProvider().newStream().close();

assertThat(getCrc32(requestBody.contentStreamProvider().newStream())).isEqualTo(TEST_DATA_CHECKSUM);
}

@Test
void newStream_allDataBuffered_closesDelegateStream() throws IOException {
InputStream delegateStream = Mockito.spy(new ByteArrayInputStream(TEST_DATA));

requestBody = RequestBody.fromContentProvider(() -> delegateStream, "text/plain");

IoUtils.drainInputStream(requestBody.contentStreamProvider().newStream());
Mockito.verify(delegateStream, Mockito.atLeast(1)).read(any(), anyInt(), anyInt());
Mockito.verify(delegateStream).close();

IoUtils.drainInputStream(requestBody.contentStreamProvider().newStream());
Mockito.verifyNoMoreInteractions(delegateStream);
}

private static String getCrc32(InputStream inputStream) {
byte[] buff = new byte[1024];
int read;

CRC32.reset();
try {
while ((read = inputStream.read(buff)) != -1) {
CRC32.update(buff, 0, read);
}
} catch (IOException e) {
throw new RuntimeException(e);
}

return BinaryUtils.toHex(CRC32.getChecksumBytes());
}
}

0 comments on commit 097b1bc

Please sign in to comment.