Skip to content

Commit

Permalink
Merge pull request confidential-containers#96 from arronwy/stream
Browse files Browse the repository at this point in the history
Add image layers stream pulling support
  • Loading branch information
jiangliu authored Dec 31, 2022
2 parents 133960c + 74c9daf commit cf1f7f9
Show file tree
Hide file tree
Showing 17 changed files with 734 additions and 101 deletions.
9 changes: 6 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,23 @@ edition = "2018"

[dependencies]
anyhow = ">=1.0"
async-compression = { version = "0.3.15", features = ["futures-io", "tokio", "gzip", "zstd"] }
flate2 = "1.0"
flume = "0.10.14"
futures-util = "0.3"
libc = "0.2"
nix = "0.23.0"
oci-distribution = "0.9.3"
oci-distribution = "0.9.4"
oci-spec = { git = "https://github.com/containers/oci-spec-rs" }
ocicrypt-rs = { git = "https://github.com/confidential-containers/ocicrypt-rs", rev = "8bd6dfe", optional = true }
ocicrypt-rs = { git = "https://github.com/confidential-containers/ocicrypt-rs", rev = "6c84dde", features = ["default", "async-io"], optional = true }
attestation_agent = { git = "https://github.com/confidential-containers/attestation-agent", rev = "cbdd744", optional = true }
serde = { version = ">=1.0.27", features = ["serde_derive", "rc"] }
serde_json = ">=1.0.9"
serde_yaml = "0.8"
sha2 = ">=0.10"
tar = "0.4.37"
tokio = {version = "1.0", features = ["full"]}
zstd = "0.9"
zstd = "0.12.1"
fs_extra = "1.2.0"
walkdir = "2"
dircpy = "0.3.12"
Expand All @@ -43,6 +45,7 @@ url = "2.2.2"
[dev-dependencies]
filetime = "0.2"
tempfile = "3.2"
openssl = "0.10.44"
strum = "0.24"
strum_macros = "0.24"
serial_test = "0.9.0"
Expand Down
7 changes: 2 additions & 5 deletions src/bundle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
//
// SPDX-License-Identifier: Apache-2.0

use anyhow::{anyhow, Result};
use anyhow::{bail, Result};
use std::collections::HashMap;
use std::path::{Path, PathBuf};

Expand Down Expand Up @@ -207,10 +207,7 @@ pub fn create_runtime_config(
let bundle_config = bundle_path.join(BUNDLE_CONFIG);

if bundle_config.exists() {
return Err(anyhow!(
"OCI config file already exists: {:?}",
bundle_config
));
bail!("OCI config file already exists: {:?}", bundle_config);
}

spec.save(&bundle_config)?;
Expand Down
53 changes: 51 additions & 2 deletions src/decoder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
//
// SPDX-License-Identifier: Apache-2.0

use anyhow::{anyhow, Result};
use anyhow::{bail, Result};
use flate2;
use oci_distribution::manifest;
use oci_spec::image::MediaType;
use serde::Deserialize;
use std::convert::TryFrom;
use std::fmt;
use std::io;
use tokio::io::{AsyncRead, BufReader};
use zstd;

pub const ERR_BAD_MEDIA_TYPE: &str = "unhandled media type";
Expand Down Expand Up @@ -58,6 +59,14 @@ impl Compression {
)),
}
}

pub fn async_gzip_decompress(&self, input: (impl AsyncRead + Unpin)) -> impl AsyncRead + Unpin {
async_compression::tokio::bufread::GzipDecoder::new(BufReader::new(input))
}

pub fn async_zstd_decompress(&self, input: (impl AsyncRead + Unpin)) -> impl AsyncRead + Unpin {
async_compression::tokio::bufread::ZstdDecoder::new(BufReader::new(input))
}
}

// Decompress a gzip encoded data with flate2 crate.
Expand Down Expand Up @@ -105,7 +114,7 @@ impl TryFrom<&str> for Compression {
MediaType::ImageLayerZstd | MediaType::ImageLayerNonDistributableZstd => {
Compression::Zstd
}
_ => return Err(anyhow!("{}: {}", ERR_BAD_MEDIA_TYPE, media_type)),
_ => bail!("{}: {}", ERR_BAD_MEDIA_TYPE, media_type),
};

Ok(decoder)
Expand All @@ -115,6 +124,7 @@ impl TryFrom<&str> for Compression {
#[cfg(test)]
mod tests {
use super::*;
use anyhow::anyhow;
use flate2::write::GzEncoder;
use std::io::Write;

Expand Down Expand Up @@ -167,6 +177,45 @@ mod tests {
assert_eq!(data, output);
}

#[tokio::test]
async fn test_async_gzip_decode() {
let data: Vec<u8> = b"This is some text!".to_vec();

let mut encoder = GzEncoder::new(Vec::new(), flate2::Compression::default());
encoder.write_all(&data).unwrap();
let bytes = encoder.finish().unwrap();

let mut output = Vec::new();

let compression = Compression::default();
let mut reader = compression.async_gzip_decompress(bytes.as_slice());
assert!(
tokio::io::AsyncReadExt::read_to_end(&mut reader, &mut output)
.await
.is_ok()
);
assert_eq!(data, output);
}

#[tokio::test]
async fn test_async_zstd_decode() {
let data: Vec<u8> = b"This is some text!".to_vec();
let level = 1;

let bytes = zstd::encode_all(&data[..], level).unwrap();

let mut output = Vec::new();
let compression = Compression::Zstd;
let mut reader = compression.async_zstd_decompress(bytes.as_slice());
assert!(
tokio::io::AsyncReadExt::read_to_end(&mut reader, &mut output)
.await
.is_ok()
);

assert_eq!(data, output);
}

#[tokio::test]
async fn test_try_from_compression() {
#[derive(Debug)]
Expand Down
63 changes: 53 additions & 10 deletions src/decrypt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
//
// SPDX-License-Identifier: Apache-2.0

use anyhow::{anyhow, Result};
use anyhow::{anyhow, bail, Result};
use ocicrypt_rs::config::CryptoConfig;
use ocicrypt_rs::encryption::decrypt_layer;
use ocicrypt_rs::encryption::{async_decrypt_layer, decrypt_layer, decrypt_layer_key_opts_data};
use ocicrypt_rs::helpers::create_decrypt_config;
use ocicrypt_rs::spec::{
MEDIA_TYPE_LAYER_ENC, MEDIA_TYPE_LAYER_GZIP_ENC, MEDIA_TYPE_LAYER_NON_DISTRIBUTABLE_ENC,
Expand All @@ -15,6 +15,7 @@ use oci_distribution::manifest;
use oci_distribution::manifest::OciDescriptor;

use std::io::Read;
use tokio::io::AsyncRead;

#[derive(Default, Clone, Debug)]
pub struct Decryptor {
Expand Down Expand Up @@ -70,15 +71,11 @@ impl Decryptor {
decrypt_config: &str,
) -> Result<Vec<u8>> {
if !self.is_encrypted() {
return Err(anyhow!(
"{}: {}",
Self::ERR_UNENCRYPTED_MEDIA_TYPE,
self.media_type
));
bail!("{}: {}", Self::ERR_UNENCRYPTED_MEDIA_TYPE, self.media_type);
}

if decrypt_config.is_empty() {
return Err(anyhow!(Self::ERR_EMPTY_CFG));
bail!(Self::ERR_EMPTY_CFG);
}

let cc = create_decrypt_config(vec![decrypt_config.to_string()], vec![])?;
Expand All @@ -97,6 +94,52 @@ impl Decryptor {
Err(anyhow!("decrypt failed!"))
}
}

pub async fn get_decrypt_key(
&self,
descriptor: &OciDescriptor,
decrypt_config: &str,
) -> Result<Vec<u8>> {
if !self.is_encrypted() {
bail!("unencrypted media type: {}", self.media_type);
}

if decrypt_config.is_empty() {
bail!("decrypt_config is empty");
}

let cc = create_decrypt_config(vec![decrypt_config.to_string()], vec![])?;
let descript = descriptor.clone();

// ocicrypt-rs keyprovider module will create a new runtime to talk with
// attestation agent, to avoid startup a runtime within a runtime, we
// spawn a new thread here.
let handler = tokio::task::spawn_blocking(move || {
if let Some(decrypt_config) = cc.decrypt_config {
decrypt_layer_key_opts_data(&decrypt_config, &descript)
} else {
Err(anyhow!("no decrypt config available"))
}
});

if let Ok(priv_opts_data) = handler.await? {
Ok(priv_opts_data)
} else {
Err(anyhow!("failed to retrive decrypt key!"))
}
}

pub fn async_get_plaintext_layer(
&self,
encrypted_layer: impl AsyncRead,
descriptor: &OciDescriptor,
priv_opts_data: &[u8],
) -> Result<impl tokio::io::AsyncRead> {
let (layer_decryptor, _dec_digest) =
async_decrypt_layer(encrypted_layer, descriptor, priv_opts_data)
.map_err(|e| anyhow!("failed to async decrypt layer {}", e.to_string()))?;
Ok(layer_decryptor)
}
}

fn decrypt_layer_data(
Expand Down Expand Up @@ -234,15 +277,15 @@ mod tests {
media_type: "",
descriptor: OciDescriptor::default(),
encrypted_layer: Vec::<u8>::new(),
decrypt_config: "foo",
decrypt_config: "provider:grpc",
result: Err(anyhow!(ERR_OCICRYPT_RS_DECRYPT_FAIL)),
},
TestData {
encrypted: true,
media_type: MEDIA_TYPE_LAYER_ENC,
descriptor: OciDescriptor::default(),
encrypted_layer: Vec::<u8>::new(),
decrypt_config: "foo",
decrypt_config: "provider:grpc",
result: Err(anyhow!(ERR_OCICRYPT_RS_DECRYPT_FAIL)),
},
];
Expand Down
22 changes: 10 additions & 12 deletions src/image.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
//
// SPDX-License-Identifier: Apache-2.0

use anyhow::{anyhow, Result};
use anyhow::{anyhow, bail, Result};
use log::warn;
use oci_distribution::secrets::RegistryAuth;
use oci_distribution::Reference;
Expand Down Expand Up @@ -174,7 +174,7 @@ impl ImageClient {
let auth = RegistryAuth::Basic(username.to_string(), password.to_string());
Some(auth)
} else {
return Err(anyhow!("Invalid authentication info ({:?})", auth_info));
bail!("Invalid authentication info ({:?})", auth_info);
}
} else {
None
Expand Down Expand Up @@ -203,7 +203,7 @@ impl ImageClient {
Arc::new(Mutex::new(SecureChannel::new(aa_kbc_params).await?));
Some(secure_channel)
} else {
return Err(anyhow!("Secure channel creation needs aa_kbc_params."));
bail!("Secure channel creation needs aa_kbc_params.");
}
}
false => None,
Expand Down Expand Up @@ -242,10 +242,10 @@ impl ImageClient {
let snapshot = match self.snapshots.get_mut(&self.config.default_snapshot) {
Some(s) => s,
_ => {
return Err(anyhow!(
bail!(
"default snapshot {} not found",
&self.config.default_snapshot
));
);
}
};

Expand Down Expand Up @@ -277,9 +277,7 @@ impl ImageClient {

let diff_ids = image_data.image_config.rootfs().diff_ids();
if diff_ids.len() != image_manifest.layers.len() {
return Err(anyhow!(
"Pulled number of layers mismatch with image config diff_ids"
));
bail!("Pulled number of layers mismatch with image config diff_ids");
}

let mut unique_layers = Vec::new();
Expand All @@ -295,7 +293,7 @@ impl ImageClient {

let unique_layers_len = unique_layers.len();
let layer_metas = client
.pull_layers(
.async_pull_layers(
unique_layers,
diff_ids,
decrypt_config,
Expand All @@ -312,10 +310,10 @@ impl ImageClient {

self.meta_store.lock().await.layer_db.extend(layer_db);
if unique_layers_len != image_data.layer_metas.len() {
return Err(anyhow!(
bail!(
" {} layers failed to pull",
unique_layers_len - image_data.layer_metas.len()
));
);
}

let image_id = create_bundle(&image_data, bundle_dir, snapshot)?;
Expand Down Expand Up @@ -346,7 +344,7 @@ fn create_bundle(

let image_config = image_data.image_config.clone();
if image_config.os() != &Os::Linux {
return Err(anyhow!("unsupport OS image {:?}", image_config.os()));
bail!("unsupport OS image {:?}", image_config.os());
}

create_runtime_config(&image_config, bundle_dir)?;
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,5 @@ pub mod pull;
pub mod secure_channel;
pub mod signature;
pub mod snapshots;
pub mod stream;
pub mod unpack;
Loading

0 comments on commit cf1f7f9

Please sign in to comment.