Skip to content

Commit

Permalink
feat(auth): fallback to MDS credentials (#585)
Browse files Browse the repository at this point in the history
  • Loading branch information
dbolduc authored Jan 4, 2025
1 parent faef96e commit e758f27
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 42 deletions.
124 changes: 101 additions & 23 deletions src/auth/src/credentials.rs
Original file line number Diff line number Diff line change
Expand Up @@ -215,22 +215,10 @@ pub mod traits {
/// [gcloud auth application-default]: https://cloud.google.com/sdk/gcloud/reference/auth/application-default
/// [gke-link]: https://cloud.google.com/kubernetes-engine
pub async fn create_access_token_credential() -> Result<Credential> {
let adc_path = adc_path().ok_or_else(||
// TODO(#442) - This should (successfully) fall back to MDS Credentials. We will temporarily return an error.
CredentialError::new(false, Box::from("Unable to find Application Default Credentials.")))?;

let contents = std::fs::read_to_string(adc_path).map_err(|e| {
match e.kind() {
std::io::ErrorKind::NotFound => {
// TODO(#442) - This should (successfully) fall back to MDS Credentials. We will temporarily return an error.
CredentialError::new(
false,
Box::from("Unable to find Application Default Credentials."),
)
}
_ => CredentialError::new(false, e.into()),
}
})?;
let contents = match load_adc()? {
AdcContents::Contents(contents) => contents,
AdcContents::FallbackToMds => return Ok(mds_credential::new()),
};
let js: serde_json::Value =
serde_json::from_str(&contents).map_err(|e| CredentialError::new(false, e.into()))?;
let cred_type = js
Expand All @@ -253,13 +241,50 @@ pub async fn create_access_token_credential() -> Result<Credential> {
}
}

#[derive(Debug, PartialEq)]
enum AdcPath {
FromEnv(String),
WellKnown(String),
}

#[derive(Debug, PartialEq)]
enum AdcContents {
Contents(String),
FallbackToMds,
}

fn path_not_found(path: String) -> CredentialError {
CredentialError::new(
false,
Box::from(format!(
"Failed to load Application Default Credentials (ADC) from {path}. Check that the `GOOGLE_APPLICATION_CREDENTIALS` environment variable points to a valid file."
)))
}

fn load_adc() -> Result<AdcContents> {
match adc_path() {
None => Ok(AdcContents::FallbackToMds),
Some(AdcPath::FromEnv(path)) => match std::fs::read_to_string(&path) {
Ok(contents) => Ok(AdcContents::Contents(contents)),
Err(e) if e.kind() == std::io::ErrorKind::NotFound => Err(path_not_found(path)),
Err(e) => Err(CredentialError::new(false, e.into())),
},
Some(AdcPath::WellKnown(path)) => match std::fs::read_to_string(path) {
Ok(contents) => Ok(AdcContents::Contents(contents)),
Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(AdcContents::FallbackToMds),
Err(e) => Err(CredentialError::new(false, e.into())),
},
}
}

/// The path to Application Default Credentials (ADC), as specified in [AIP-4110].
///
/// [AIP-4110]: https://google.aip.dev/auth/4110
fn adc_path() -> Option<String> {
std::env::var("GOOGLE_APPLICATION_CREDENTIALS")
.ok()
.or_else(adc_well_known_path)
fn adc_path() -> Option<AdcPath> {
if let Ok(path) = std::env::var("GOOGLE_APPLICATION_CREDENTIALS") {
return Some(AdcPath::FromEnv(path));
}
Some(AdcPath::WellKnown(adc_well_known_path()?))
}

/// The well-known path to ADC on Windows, as specified in [AIP-4113].
Expand All @@ -286,6 +311,7 @@ fn adc_well_known_path() -> Option<String> {
mod test {
use super::*;
use scoped_env::ScopedEnv;
use std::error::Error;

#[cfg(target_os = "windows")]
#[test]
Expand All @@ -299,7 +325,9 @@ mod test {
);
assert_eq!(
adc_path(),
Some("C:/Users/foo/gcloud/application_default_credentials.json".to_string())
Some(AdcPath::WellKnown(
"C:/Users/foo/gcloud/application_default_credentials.json".to_string()
))
);
}

Expand All @@ -325,7 +353,9 @@ mod test {
);
assert_eq!(
adc_path(),
Some("/home/foo/.config/gcloud/application_default_credentials.json".to_string())
Some(AdcPath::WellKnown(
"/home/foo/.config/gcloud/application_default_credentials.json".to_string()
))
);
}

Expand All @@ -348,7 +378,55 @@ mod test {
);
assert_eq!(
adc_path(),
Some("/usr/bar/application_default_credentials.json".to_string())
Some(AdcPath::FromEnv(
"/usr/bar/application_default_credentials.json".to_string()
))
);
}

#[test]
#[serial_test::serial]
fn load_adc_no_well_known_path_fallback_to_mds() {
let _e1 = ScopedEnv::remove("GOOGLE_APPLICATION_CREDENTIALS");
let _e2 = ScopedEnv::remove("HOME"); // For posix
let _e3 = ScopedEnv::remove("APPDATA"); // For windows
assert_eq!(load_adc().unwrap(), AdcContents::FallbackToMds);
}

#[test]
#[serial_test::serial]
fn load_adc_no_file_at_well_known_path_fallback_to_mds() {
// Create a new temp directory. There is not an ADC file in here.
let dir = tempfile::TempDir::new().unwrap();
let path = dir.path().to_str().unwrap();
let _e1 = ScopedEnv::remove("GOOGLE_APPLICATION_CREDENTIALS");
let _e2 = ScopedEnv::set("HOME", path); // For posix
let _e3 = ScopedEnv::set("APPDATA", path); // For windows
assert_eq!(load_adc().unwrap(), AdcContents::FallbackToMds);
}

#[test]
#[serial_test::serial]
fn load_adc_no_file_at_env_is_error() {
let _e = ScopedEnv::set("GOOGLE_APPLICATION_CREDENTIALS", "file-does-not-exist.json");
let err = load_adc().err().unwrap();
let msg = err.source().unwrap().to_string();
assert!(msg.contains("Failed to load Application Default Credentials"));
assert!(msg.contains("file-does-not-exist.json"));
assert!(msg.contains("GOOGLE_APPLICATION_CREDENTIALS"));
}

#[test]
#[serial_test::serial]
fn load_adc_success() {
let file = tempfile::NamedTempFile::new().unwrap();
let path = file.into_temp_path();
std::fs::write(&path, "contents").expect("Unable to write to temporary file.");
let _e = ScopedEnv::set("GOOGLE_APPLICATION_CREDENTIALS", path.to_str().unwrap());

assert_eq!(
load_adc().unwrap(),
AdcContents::Contents("contents".to_string())
);
}
}
26 changes: 15 additions & 11 deletions src/auth/src/credentials/mds_credential.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use crate::credentials::traits::dynamic::Credential;
use crate::credentials::Result;
use crate::credentials::traits::dynamic::Credential as CredentialTrait;
use crate::credentials::{Credential, Result};
use crate::errors::{is_retryable, CredentialError};
use crate::token::{Token, TokenProvider};
use async_trait::async_trait;
Expand All @@ -26,19 +26,26 @@ use time::OffsetDateTime;

const METADATA_FLAVOR_VALUE: &str = "Google";
const METADATA_FLAVOR: &str = "metadata-flavor";
#[allow(dead_code)] // TODO(#442) - implementation in progress
const METADATA_ROOT: &str = "http://metadata.google.internal/computeMetadata/v1";

#[allow(dead_code)] // TODO(#442) - implementation in progress
pub(crate) struct MDSCredential<T>
pub(crate) fn new() -> Credential {
let token_provider = MDSAccessTokenProvider {
endpoint: METADATA_ROOT.to_string(),
};
Credential {
inner: Box::new(MDSCredential { token_provider }),
}
}

struct MDSCredential<T>
where
T: TokenProvider,
{
token_provider: T,
}

#[async_trait::async_trait]
impl<T> Credential for MDSCredential<T>
impl<T> CredentialTrait for MDSCredential<T>
where
T: TokenProvider,
{
Expand All @@ -59,8 +66,7 @@ where
}
}

#[allow(dead_code)] // TODO(#442) - implementation in progress
#[derive(serde::Deserialize, serde::Serialize, PartialEq, Debug, Clone)]
#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
struct ServiceAccountInfo {
email: String,
scopes: Option<Vec<String>>,
Expand All @@ -75,13 +81,12 @@ struct MDSTokenResponse {
token_type: String,
}

#[allow(dead_code)] // TODO(#442) - implementation in progress
struct MDSAccessTokenProvider {
endpoint: String,
}

#[allow(dead_code)]
impl MDSAccessTokenProvider {
#[allow(dead_code)]
async fn get_service_account_info(
request: &Client,
metadata_service_endpoint: String,
Expand Down Expand Up @@ -118,7 +123,6 @@ impl MDSAccessTokenProvider {
}

#[async_trait]
#[allow(dead_code)]
impl TokenProvider for MDSAccessTokenProvider {
async fn get_token(&mut self) -> Result<Token> {
let client = Client::new();
Expand Down
16 changes: 8 additions & 8 deletions src/auth/tests/credentials.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,24 +22,24 @@ mod test {

#[tokio::test]
#[serial_test::serial]
async fn create_access_token_credential_no_adc_filepath_fallback() {
// TODO(#442) - We should (successfully) fall back to MDS Credentials. But for now we are temporarily returning an error.
async fn create_access_token_credential_fallback_to_mds() {
let _e1 = ScopedEnv::remove("GOOGLE_APPLICATION_CREDENTIALS");
let _e2 = ScopedEnv::remove("HOME"); // For posix
let _e3 = ScopedEnv::remove("APPDATA"); // For windows
let err = create_access_token_credential().await.err().unwrap();
let msg = err.source().unwrap().to_string();
assert!(msg.contains("Unable to find Application Default Credentials"));

// We will assume that if credentials were created successfully, they are MDS Credentials.
create_access_token_credential().await.unwrap();
}

#[tokio::test]
#[serial_test::serial]
async fn create_access_token_credential_no_adc_file_fallback() {
// TODO(#442) - We should (successfully) fall back to MDS Credentials. But for now we are temporarily returning an error.
async fn create_access_token_credential_errors_if_adc_env_is_not_a_file() {
let _e = ScopedEnv::set("GOOGLE_APPLICATION_CREDENTIALS", "file-does-not-exist.json");
let err = create_access_token_credential().await.err().unwrap();
let msg = err.source().unwrap().to_string();
assert!(msg.contains("Unable to find Application Default Credentials"));
assert!(msg.contains("Failed to load Application Default Credentials"));
assert!(msg.contains("file-does-not-exist.json"));
assert!(msg.contains("GOOGLE_APPLICATION_CREDENTIALS"));
}

#[tokio::test]
Expand Down

0 comments on commit e758f27

Please sign in to comment.