diff --git a/crates/frontend/src/api.rs b/crates/frontend/src/api.rs index f68bada8..6fef73e9 100644 --- a/crates/frontend/src/api.rs +++ b/crates/frontend/src/api.rs @@ -23,7 +23,7 @@ use crate::{ pub async fn fetch_dimensions( filters: &PaginationParams, tenant: String, - org_id: String + org_id: String, ) -> Result, ServerFnError> { let client = reqwest::Client::new(); let host = use_host_server(); diff --git a/crates/frontend/src/app.rs b/crates/frontend/src/app.rs index 6f019fd1..5d39b154 100755 --- a/crates/frontend/src/app.rs +++ b/crates/frontend/src/app.rs @@ -99,7 +99,9 @@ pub fn app(app_envs: Envs) -> impl IntoView { path="/admin/organisations" view=move || { view! { - + + + } } /> diff --git a/crates/frontend/src/components/default_config_form/utils.rs b/crates/frontend/src/components/default_config_form/utils.rs index 1795cab0..163fc44a 100644 --- a/crates/frontend/src/components/default_config_form/utils.rs +++ b/crates/frontend/src/components/default_config_form/utils.rs @@ -24,7 +24,7 @@ pub async fn update_default_config( key: String, tenant: String, payload: DefaultConfigUpdateReq, - org_id: String + org_id: String, ) -> Result { let host = get_host(); let url = format!("{host}/default-config/{key}"); diff --git a/crates/frontend/src/components/experiment_conclude_form/utils.rs b/crates/frontend/src/components/experiment_conclude_form/utils.rs index c4e8bdae..8414680d 100644 --- a/crates/frontend/src/components/experiment_conclude_form/utils.rs +++ b/crates/frontend/src/components/experiment_conclude_form/utils.rs @@ -7,7 +7,7 @@ pub async fn conclude_experiment( exp_id: String, variant_id: String, tenant: &String, - org_id: &String + org_id: &String, ) -> Result { let client = reqwest::Client::new(); let host = get_host(); diff --git a/crates/frontend/src/components/function_form/utils.rs b/crates/frontend/src/components/function_form/utils.rs index 870914db..8023121c 100644 --- a/crates/frontend/src/components/function_form/utils.rs +++ b/crates/frontend/src/components/function_form/utils.rs @@ -14,7 +14,7 @@ pub async fn create_function( runtime_version: String, description: String, tenant: String, - org_id: String + org_id: String, ) -> Result { let payload = FunctionCreateRequest { function_name, @@ -29,7 +29,7 @@ pub async fn create_function( url, reqwest::Method::POST, Some(payload), - construct_request_headers(&[("x-tenant", &tenant),("x-org-id", &org_id)])?, + construct_request_headers(&[("x-tenant", &tenant), ("x-org-id", &org_id)])?, ) .await?; diff --git a/crates/frontend/src/components/workspace_form.rs b/crates/frontend/src/components/workspace_form.rs index 6925379c..c6114461 100644 --- a/crates/frontend/src/components/workspace_form.rs +++ b/crates/frontend/src/components/workspace_form.rs @@ -61,7 +61,12 @@ where let handle_submit = handle_submit_clone; async move { let result = if is_edit { - update_workspace(workspace_name_rs.get(), org_id.get().to_string(), update_payload).await + update_workspace( + workspace_name_rs.get(), + org_id.get().to_string(), + update_payload, + ) + .await } else { create_workspace(org_id.get().to_string(), create_payload).await }; diff --git a/crates/frontend/src/pages/context_override.rs b/crates/frontend/src/pages/context_override.rs index 169e1464..de431992 100644 --- a/crates/frontend/src/pages/context_override.rs +++ b/crates/frontend/src/pages/context_override.rs @@ -152,27 +152,36 @@ pub fn context_override() -> impl IntoView { let (modal_visible, set_modal_visible) = create_signal(false); let (delete_id, set_delete_id) = create_signal::>(None); - let page_resource: Resource<(String, String), PageResource> = create_blocking_resource( - move || (tenant_rws.get().0, org_rws.get().0), - |(current_tenant, org_id)| async move { - let empty_list_filters = PaginationParams::all_entries(); - let (config_result, dimensions_result, default_config_result) = join!( - fetch_config(current_tenant.to_string(), None, org_id.clone()), - fetch_dimensions(&empty_list_filters, current_tenant.to_string(), org_id.clone()), - fetch_default_config(&empty_list_filters, current_tenant.to_string(), org_id.clone()) - ); - PageResource { - config: config_result.unwrap_or_default(), - dimensions: dimensions_result - .unwrap_or_default() - .data - .into_iter() - .filter(|d| d.dimension != "variantIds") - .collect(), - default_config: default_config_result.unwrap_or_default().data, - } - }, - ); + let page_resource: Resource<(String, String), PageResource> = + create_blocking_resource( + move || (tenant_rws.get().0, org_rws.get().0), + |(current_tenant, org_id)| async move { + let empty_list_filters = PaginationParams::all_entries(); + let (config_result, dimensions_result, default_config_result) = join!( + fetch_config(current_tenant.to_string(), None, org_id.clone()), + fetch_dimensions( + &empty_list_filters, + current_tenant.to_string(), + org_id.clone() + ), + fetch_default_config( + &empty_list_filters, + current_tenant.to_string(), + org_id.clone() + ) + ); + PageResource { + config: config_result.unwrap_or_default(), + dimensions: dimensions_result + .unwrap_or_default() + .data + .into_iter() + .filter(|d| d.dimension != "variantIds") + .collect(), + default_config: default_config_result.unwrap_or_default().data, + } + }, + ); let handle_context_create = Callback::new(move |_| { set_form_mode.set(Some(FormMode::Create)); @@ -241,7 +250,8 @@ pub fn context_override() -> impl IntoView { let confirm_delete = Callback::new(move |_| { if let Some(id) = delete_id.get().clone() { spawn_local(async move { - let result = delete_context(tenant_rws.get().0, id, org_rws.get().0).await; + let result = + delete_context(tenant_rws.get().0, id, org_rws.get().0).await; match result { Ok(_) => { diff --git a/crates/frontend/src/pages/custom_types.rs b/crates/frontend/src/pages/custom_types.rs index 757633f6..e4d8e3cb 100644 --- a/crates/frontend/src/pages/custom_types.rs +++ b/crates/frontend/src/pages/custom_types.rs @@ -74,7 +74,8 @@ pub fn types_page() -> impl IntoView { let tenant = tenant_rws.get().0; let org = org_rws.get().0; let _ = - delete_type(tenant, row_data.clone().type_name, org).await; + delete_type(tenant, row_data.clone().type_name, org) + .await; types_resource.refetch(); } }); diff --git a/crates/frontend/src/pages/function.rs b/crates/frontend/src/pages/function.rs index b615adef..d093c16f 100644 --- a/crates/frontend/src/pages/function.rs +++ b/crates/frontend/src/pages/function.rs @@ -59,7 +59,8 @@ pub fn function_page() -> impl IntoView { let combined_resource: Resource<(String, String, String), CombinedResource> = create_blocking_resource(source, |(function_name, tenant, org_id)| async move { let function_result = - fetch_function(function_name.to_string(), tenant.to_string(), org_id).await; + fetch_function(function_name.to_string(), tenant.to_string(), org_id) + .await; CombinedResource { function: function_result.ok(), diff --git a/crates/frontend/src/pages/function/function_create.rs b/crates/frontend/src/pages/function/function_create.rs index f29c434f..969a122e 100644 --- a/crates/frontend/src/pages/function/function_create.rs +++ b/crates/frontend/src/pages/function/function_create.rs @@ -3,7 +3,10 @@ use leptos_router::use_navigate; use serde::{Deserialize, Serialize}; use superposition_types::database::models::cac::Function; -use crate::{components::function_form::FunctionEditor, types::{OrganisationId, Tenant}}; +use crate::{ + components::function_form::FunctionEditor, + types::{OrganisationId, Tenant}, +}; #[derive(Serialize, Deserialize, Clone, Debug)] struct CombinedResource { diff --git a/crates/frontend/src/pages/organisations.rs b/crates/frontend/src/pages/organisations.rs index edcaf987..ac8b331d 100644 --- a/crates/frontend/src/pages/organisations.rs +++ b/crates/frontend/src/pages/organisations.rs @@ -1,58 +1,86 @@ use leptos::*; +use serde_json::{Map, Value}; use crate::api::fetch_organisations; -use crate::components::skeleton::Skeleton; +use crate::components::{ + skeleton::Skeleton, + stat::Stat, + table::{ + types::{Column, ColumnSortable}, + Table, + }, +}; use crate::utils::use_host_server; #[component] pub fn organisations() -> impl IntoView { - let host = StoredValue::new(use_host_server()); - let (organisation_rs, organisation_ws) = create_signal::>(None); - + let host = use_host_server(); let organisation_resource = create_local_resource( || (), |_| async { fetch_organisations().await.unwrap_or_default() }, ); + let table_columns = create_memo(move |_| { + let host = host.clone(); + let navigate = move |_: &str, row: &Map| { + let organisation_id = row["organisation_id"] + .as_str() + .clone() + .unwrap_or_default() + .to_string(); + view! { + + } + .into_view() + }; + + vec![Column::new( + "organisation_id".to_string(), + None, + navigate, + ColumnSortable::No, + )] + }); + view! { -
+
} }> {move || { + let organisations = organisation_resource.get().unwrap_or_default(); + let table_rows = organisations.clone() + .into_iter() + .map(|organisation| { + let mut map = Map::new(); + map.insert(String::from("organisation_id"), Value::String(organisation)); + map + }) + .collect::>>(); + view! { - -
Select Organisation
- - - +
+ +
+ } }} - - + } } diff --git a/crates/service_utils/src/middlewares/tenant.rs b/crates/service_utils/src/middlewares/tenant.rs index 4605e17d..e5310664 100644 --- a/crates/service_utils/src/middlewares/tenant.rs +++ b/crates/service_utils/src/middlewares/tenant.rs @@ -15,15 +15,15 @@ use std::rc::Rc; use superposition_types::TenantConfig; pub struct OrgWorkspaceMiddlewareFactory { - enable_org_id_header: bool, - enable_workspace_id_header: bool, + enable_org_id: bool, + enable_workspace_id: bool, } impl OrgWorkspaceMiddlewareFactory { - pub fn new(enable_org_id_header: bool, enable_workspace_id_header: bool) -> Self { + pub fn new(enable_org_id: bool, enable_workspace_id: bool) -> Self { Self { - enable_org_id_header, - enable_workspace_id_header, + enable_org_id, + enable_workspace_id, } } } @@ -43,16 +43,16 @@ where fn new_transform(&self, service: S) -> Self::Future { ready(Ok(OrgWorkspaceMiddleware { service: Rc::new(service), - enable_org_id_header: self.enable_org_id_header, - enable_workspace_id_header: self.enable_workspace_id_header, + enable_org_id: self.enable_org_id, + enable_workspace_id: self.enable_workspace_id, })) } } pub struct OrgWorkspaceMiddleware { service: Rc, - enable_org_id_header: bool, - enable_workspace_id_header: bool, + enable_org_id: bool, + enable_workspace_id: bool, } fn extract_org_workspace_from_header( @@ -106,8 +106,8 @@ where fn call(&self, req: ServiceRequest) -> Self::Future { let srv = self.service.clone(); - let enable_org_id = self.enable_org_id_header; - let enable_workspace_id = self.enable_workspace_id_header; + let enable_org_id = self.enable_org_id; + let enable_workspace_id = self.enable_workspace_id; Box::pin(async move { let app_state = match req.app_data::>() { @@ -124,8 +124,10 @@ where }; let request_path = req.uri().path().replace(&base, ""); - let request_pattern = - req.match_pattern().unwrap_or_else(|| request_path.clone()); + let request_pattern = req + .match_pattern() + .map(|a| a.replace(&base, "")) + .unwrap_or_else(|| request_path.clone()); let pkg_regex = Regex::new(".*/pkg/.+") .map_err(|err| error::ErrorInternalServerError(err.to_string()))?; let assets_regex = Regex::new(".*/assets/.+") @@ -196,21 +198,29 @@ where ) }); - // TODO: validate the tenant - let (validated_tenant, tenant_config) = match (org, enable_org_id, workspace, enable_workspace_id) { - (None, true, None, true) => return Err(error::ErrorBadRequest("The parameters org id and workspace id are required, and must be passed through headers/url params/query params. Consult the documentation to know which to use for this endpoint")), - (None, true, _, _) => return Err(error::ErrorBadRequest("The parameter org id is required, and must be passed through headers/url params/query params. Consult the documentation to know which to use for this endpoint")), - (_, _, None, true) => return Err(error::ErrorBadRequest("The parameter workspace id is required, and must be passed through headers/url params/query params. Consult the documentation to know which to use for this endpoint")), - (Some(org_id), _, Some(workspace_id), _) => { + let workspace_id = match (enable_workspace_id, workspace) { + (true, None) => return Err(error::ErrorBadRequest("The parameter workspace id is required, and must be passed through headers/url params/query params. Consult the documentation to know which to use for this endpoint")), + (true, Some(workspace_id)) => workspace_id, + (false, _) => "public", + }; + + // TODO: validate the tenant, get correct TenantConfig + let (validated_tenant, tenant_config) = match (enable_org_id, org) { + (true, None) => return Err(error::ErrorBadRequest("The parameter org id is required, and must be passed through headers/url params/query params. Consult the documentation to know which to use for this endpoint")), + (true, Some(org_id)) => { let tenant = format!("{org_id}_{workspace_id}"); (Tenant(tenant), TenantConfig::default()) - } - (_, _, _, _) => (Tenant("public".into()), TenantConfig::default()), + }, + (false, _) => (Tenant("public".into()), TenantConfig::default()), }; + let organisation = org + .map(String::from) + .map(OrganisationId) + .unwrap_or_default(); + req.extensions_mut().insert(validated_tenant); - req.extensions_mut() - .insert(OrganisationId(org.unwrap_or("juspay").into())); + req.extensions_mut().insert(organisation); req.extensions_mut().insert(tenant_config); } diff --git a/crates/service_utils/src/service/types.rs b/crates/service_utils/src/service/types.rs index 6ca09f9e..abd70ed6 100644 --- a/crates/service_utils/src/service/types.rs +++ b/crates/service_utils/src/service/types.rs @@ -192,7 +192,6 @@ impl FromRequest for DbConnection { req: &actix_web::HttpRequest, _: &mut actix_web::dev::Payload, ) -> Self::Future { - let app_state = match req.app_data::>() { Some(state) => state, None => { @@ -238,6 +237,13 @@ impl FromRequest for CustomHeaders { #[derive(Deref, DerefMut, Clone, Debug)] pub struct OrganisationId(pub String); + +impl Default for OrganisationId { + fn default() -> Self { + Self(String::from("superposition")) + } +} + impl FromRequest for OrganisationId { type Error = Error; type Future = Ready>; @@ -264,7 +270,7 @@ impl FromRequest for OrganisationId { "message": "x-org-id was not set. Please ensure you are passing in the x-tenant header" }))) } else { - Ok(OrganisationId("juspay".into())) + Ok(OrganisationId::default()) } } }; diff --git a/crates/superposition/src/auth.rs b/crates/superposition/src/auth.rs index 254d2275..c375f2ee 100644 --- a/crates/superposition/src/auth.rs +++ b/crates/superposition/src/auth.rs @@ -3,6 +3,7 @@ mod no_auth; mod oidc; use std::{ + collections::HashSet, future::{ready, Ready}, sync::Arc, }; @@ -15,7 +16,7 @@ use actix_web::{ web::{self, Data, Path}, Error, HttpMessage, HttpRequest, HttpResponse, Scope, }; -use authenticator::{Authenticator, SwitchOrgParams}; +use authenticator::{Authenticator, Login, SwitchOrgParams}; use aws_sdk_kms::Client; use futures_util::future::LocalBoxFuture; use no_auth::DisabledAuthenticator; @@ -32,6 +33,31 @@ pub struct AuthMiddleware { auth_handler: AuthHandler, } +impl AuthMiddleware { + fn get_login_type( + &self, + request: &ServiceRequest, + exception: &HashSet, + ) -> Login { + let path_prefix = self.auth_handler.0.get_path_prefix(); + let request_pattern = request + .match_pattern() + .map(|a| a.replace(&path_prefix, "")) + .unwrap_or_else(|| request.uri().path().replace(&path_prefix, "")); + + let excep = exception.contains(&request_pattern) + // Implies it's a local/un-forwarded request. + || !request.headers().contains_key(header::USER_AGENT); + let org_request = request.path().matches("/organisations").count() > 0; + + match (excep, org_request) { + (true, false) => Login::None, + (_, true) => Login::Global, + (false, false) => Login::Org, + } + } +} + impl Service for AuthMiddleware where S: Service, Error = Error>, @@ -69,7 +95,12 @@ where (_, _) => None, } }) - .unwrap_or_else(|| self.auth_handler.0.authenticate(&request)); + .unwrap_or_else(|| { + let login_type = self + .get_login_type(&request, &state.tenant_middleware_exclusion_list); + + self.auth_handler.0.authenticate(&request, &login_type) + }); match result { Ok(user) => { @@ -99,7 +130,11 @@ impl AuthHandler { routes(self.clone()) } - pub async fn init(kms_client: &Option, app_env: &AppEnv) -> Self { + pub async fn init( + kms_client: &Option, + app_env: &AppEnv, + path_prefix: String, + ) -> Self { let auth_provider: String = get_from_env_unsafe("AUTH_PROVIDER").unwrap(); let mut auth = auth_provider.split('+'); @@ -110,6 +145,7 @@ impl AuthHandler { .split(",") .map(String::from) .collect(), + path_prefix, )), Some("OIDC") => { let url = Url::parse(auth.next().unwrap()) @@ -119,9 +155,15 @@ impl AuthHandler { let cid = get_from_env_unsafe("OIDC_CLIENT_ID").unwrap(); let csecret = get_oidc_client_secret(kms_client, app_env).await; Arc::new( - oidc::OIDCAuthenticator::new(url, base_url, cid, csecret) - .await - .unwrap(), + oidc::OIDCAuthenticator::new( + url, + base_url, + path_prefix, + cid, + csecret, + ) + .await + .unwrap(), ) } _ => panic!("Missing/Unknown authenticator."), diff --git a/crates/superposition/src/auth/authenticator.rs b/crates/superposition/src/auth/authenticator.rs index 0e8391c7..018d9270 100644 --- a/crates/superposition/src/auth/authenticator.rs +++ b/crates/superposition/src/auth/authenticator.rs @@ -1,6 +1,14 @@ -use actix_web::{dev::ServiceRequest, web::Path, HttpRequest, HttpResponse, Scope}; +use std::fmt::Display; + +use actix_web::{ + dev::ServiceRequest, + http::header::{HeaderMap, HeaderValue}, + web::Path, + HttpRequest, HttpResponse, Scope, +}; use futures_util::future::LocalBoxFuture; use serde::Deserialize; +use service_utils::service::types::OrganisationId; use superposition_types::User; #[derive(Deserialize)] @@ -8,8 +16,55 @@ pub(super) struct SwitchOrgParams { pub(super) organisation_id: String, } +fn extract_org_from_header(headers: &HeaderMap) -> Option<&str> { + headers + .get("x-org-id") + .and_then(|header_value: &HeaderValue| header_value.to_str().ok()) +} + +fn extract_org_from_url(path: &str, match_pattern: Option) -> Option<&str> { + match_pattern.and_then(move |pattern| { + let pattern_segments = pattern.split('/'); + let path_segments = path.split('/').collect::>(); + + std::iter::zip(path_segments, pattern_segments) + .find(|(_, pattern_seg)| *pattern_seg == "{org_id}") + .map(|(path_seg, _)| path_seg) + }) +} + +fn extract_org_from_query_params(query_str: &str) -> Option<&str> { + query_str + .split('&') + .find(|segment| segment.contains("org=")) + .and_then(|tenant_query_param| tenant_query_param.split('=').nth(1)) +} + +#[derive(Debug)] +pub enum Login { + None, + Global, + Org, +} + +impl Display for Login { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::None => write!(f, "none"), + Self::Global => write!(f, "user"), + Self::Org => write!(f, "org_user"), + } + } +} + pub trait Authenticator: Sync + Send { - fn authenticate(&self, request: &ServiceRequest) -> Result; + fn get_path_prefix(&self) -> String; + + fn authenticate( + &self, + request: &ServiceRequest, + login_type: &Login, + ) -> Result; fn routes(&self) -> Scope; fn get_organisations(&self, req: &HttpRequest) -> HttpResponse; @@ -19,4 +74,13 @@ pub trait Authenticator: Sync + Send { req: &HttpRequest, path: &Path, ) -> LocalBoxFuture<'static, actix_web::Result>; + + fn get_org_id(&self, request: &ServiceRequest) -> OrganisationId { + extract_org_from_header(request.headers()) + .or_else(|| extract_org_from_url(request.path(), request.match_pattern())) + .or_else(|| extract_org_from_query_params(request.query_string())) + .map(String::from) + .map(OrganisationId) + .unwrap_or_default() + } } diff --git a/crates/superposition/src/auth/no_auth.rs b/crates/superposition/src/auth/no_auth.rs index 8121394e..c11b2b3e 100644 --- a/crates/superposition/src/auth/no_auth.rs +++ b/crates/superposition/src/auth/no_auth.rs @@ -7,18 +7,32 @@ use actix_web::{ use futures_util::future::LocalBoxFuture; use superposition_types::User; -use super::authenticator::{Authenticator, SwitchOrgParams}; +use super::authenticator::{Authenticator, Login, SwitchOrgParams}; -pub struct DisabledAuthenticator(Vec); +pub struct DisabledAuthenticator { + organisations: Vec, + path_prefix: String, +} impl DisabledAuthenticator { - pub fn new(organisations: Vec) -> Self { - Self(organisations) + pub fn new(organisations: Vec, path_prefix: String) -> Self { + Self { + organisations, + path_prefix, + } } } impl Authenticator for DisabledAuthenticator { - fn authenticate(&self, _: &ServiceRequest) -> Result { + fn get_path_prefix(&self) -> String { + self.path_prefix.clone() + } + + fn authenticate( + &self, + _: &ServiceRequest, + _: &Login, + ) -> Result { Ok(User::default()) } @@ -27,26 +41,35 @@ impl Authenticator for DisabledAuthenticator { } fn get_organisations(&self, _: &actix_web::HttpRequest) -> HttpResponse { - HttpResponse::Ok().json(serde_json::json!(self.0)) + HttpResponse::Ok().json(serde_json::json!(self.organisations)) } fn switch_organisation( &self, _: &HttpRequest, - path: &Path, + path_params: &Path, ) -> LocalBoxFuture<'static, actix_web::Result> { - let cookie = Cookie::build("org_user", "org_token") - .path("/") + let path = if self.path_prefix.as_str() == "" { + String::from("/") + } else { + self.path_prefix.clone() + }; + let cookie = Cookie::build(Login::Org.to_string(), "org_token") + .path(path) .http_only(true) .max_age(Duration::days(1)) .finish(); - let org_id = path.organisation_id.clone(); + let org_id = path_params.organisation_id.clone(); + let path_prefix = self.path_prefix.clone(); Box::pin(async move { Ok(HttpResponse::Found() .cookie(cookie) - .insert_header(("Location", format!("/admin/{org_id}/workspaces"))) + .insert_header(( + "Location", + format!("{path_prefix}/admin/{org_id}/workspaces"), + )) .finish()) }) } diff --git a/crates/superposition/src/auth/oidc.rs b/crates/superposition/src/auth/oidc.rs index 78b2f179..9049ea47 100644 --- a/crates/superposition/src/auth/oidc.rs +++ b/crates/superposition/src/auth/oidc.rs @@ -19,12 +19,15 @@ use openidconnect::{ }; use service_utils::helpers::get_from_env_unsafe; use superposition_types::User; -use types::{LoginParams, ProtectionCookie, UserClaims, UserTokenResponse}; +use types::{ + GlobalUserClaims, GlobalUserTokenResponse, LoginParams, OrgUserClaims, + OrgUserTokenResponse, ProtectionCookie, +}; use utils::{presence_no_check, try_user_from, verify_presence}; use crate::auth::authenticator::SwitchOrgParams; -use super::authenticator::Authenticator; +use super::authenticator::{Authenticator, Login}; #[derive(Clone)] pub struct OIDCAuthenticator { @@ -33,6 +36,7 @@ pub struct OIDCAuthenticator { client_id: String, client_secret: String, base_url: String, + path_prefix: String, issuer_endpoint_format: String, token_endpoint_format: String, } @@ -41,6 +45,7 @@ impl OIDCAuthenticator { pub async fn new( idp_url: url::Url, base_url: String, + path_prefix: String, client_id: String, client_secret: String, ) -> Result> { @@ -64,7 +69,9 @@ impl OIDCAuthenticator { ClientId::new(client_id.clone()), Some(ClientSecret::new(client_secret.clone())), ) - .set_redirect_uri(RedirectUrl::new(format!("{}/oidc/login", base_url))?); + .set_redirect_uri(RedirectUrl::new(format!( + "{base_url}{path_prefix}/oidc/login" + ))?); Ok(Self { client, @@ -72,11 +79,46 @@ impl OIDCAuthenticator { client_id, client_secret, base_url, + path_prefix, issuer_endpoint_format, token_endpoint_format, }) } + fn get_org_client(&self, org_id: &String) -> Result { + let issuer_url = match self.get_issuer_url(org_id) { + Ok(issuer_url) => issuer_url, + Err(e) => return Err(format!("Unable to create issuer url: {e}")), + }; + + let token_url = match self.get_token_url(org_id) { + Ok(token_url) => token_url, + Err(e) => return Err(format!("Unable to create token url: {e}")), + }; + + let redirect_url = match RedirectUrl::new(format!( + "{}{}/", + self.base_url.clone(), + self.path_prefix + )) { + Ok(redirect_url) => redirect_url, + Err(e) => return Err(format!("Unable to create redirect url: {e}")), + }; + + let provider = self + .provider_metadata + .clone() + .set_issuer(issuer_url) + .set_token_endpoint(Some(token_url)); + + Ok(CoreClient::from_provider_metadata( + provider, + ClientId::new(self.client_id.clone()), + Some(ClientSecret::new(self.client_secret.clone())), + ) + .set_redirect_uri(redirect_url)) + } + fn get_issuer_url( &self, organisation_id: &String, @@ -97,7 +139,15 @@ impl OIDCAuthenticator { TokenUrl::new(token_endpoint) } - fn new_redirect(&self) -> HttpResponse { + fn get_cookie_path(&self) -> String { + if self.path_prefix.as_str() == "" { + String::from('/') + } else { + self.path_prefix.clone() + } + } + + fn new_redirect(&self, cookie_name: &str) -> HttpResponse { let (auth_url, csrf_token, nonce) = self .client .authorize_url( @@ -115,8 +165,8 @@ impl OIDCAuthenticator { }; let cookie_result = serde_json::to_string(&protection) - .map_err(|err| { - log::error!("Unable to stringify data: {err}"); + .map_err(|e| { + log::error!("Unable to stringify data: {e}"); ErrorInternalServerError(format!("Unable to stringify data")) }) .map(|cookie| { @@ -124,7 +174,7 @@ impl OIDCAuthenticator { .max_age(Duration::days(7)) // .http_only(true) // .same_site(SameSite::Strict) - .path("/") + .path(self.get_cookie_path()) .finish() }); @@ -134,7 +184,7 @@ impl OIDCAuthenticator { .cookie(p_cookie) // Deletes the cookie. .cookie( - Cookie::build("user", "") + Cookie::build(cookie_name, "") .max_age(Duration::seconds(0)) .finish(), ) @@ -143,40 +193,88 @@ impl OIDCAuthenticator { } } - fn decode_token(&self, cookie: &str) -> Result { - let ctr = serde_json::from_str::(cookie) - .map_err(|e| format!("Error while decoding token {e}"))?; + fn decode_global_token(&self, cookie: &str) -> Result { + let ctr = serde_json::from_str::(cookie) + .map_err(|e| format!("Error while decoding token: {e}"))?; ctr.id_token() .ok_or(String::from("Id Token not found"))? .claims(&self.client.id_token_verifier(), verify_presence) - .map_err(|err| format!("Error in claims verification {err}")) + .map_err(|e| format!("Error in claims verification: {e}")) .cloned() } -} -impl Authenticator for OIDCAuthenticator { - fn authenticate(&self, request: &ServiceRequest) -> Result { - let token = request.cookie("user").and_then(|c| { - self.decode_token(c.value()) + fn get_global_user(&self, request: &ServiceRequest) -> Result { + let token = request.cookie(&Login::Global.to_string()).and_then(|c| { + self.decode_global_token(c.value()) .map_err(|e| log::error!("Error in decoding user : {e}")) .ok() }); - let path = &request.path(); - let excep = path.matches("login").count() > 0 - // Implies it's a local/un-forwarded request. - || !request.headers().contains_key(header::USER_AGENT) - || path.matches("health").count() > 0 - || path.matches("ready").count() > 0; - - if excep { - Ok(User::default()) - } else if let Some(token_response) = token { - Ok(try_user_from(&token_response).map_err(|err| { - log::error!("Unable to get user {err}"); + if let Some(token_response) = token { + Ok(try_user_from(&token_response).map_err(|e| { + log::error!("Unable to get user: {e}"); ErrorBadRequest(String::from("Unable to get user")) })?) } else { - Err(self.new_redirect()) + log::error!("Error user not found in cookies"); + Err(self.new_redirect(&Login::Global.to_string())) + } + } + + fn decode_org_token( + &self, + org_id: &String, + cookie: &str, + ) -> Result { + let client = self + .get_org_client(org_id) + .map_err(|e| format!("Error in getting Org specific client: {e}"))?; + let id_token_verifier = client.id_token_verifier(); + + let ctr = serde_json::from_str::(cookie) + .map_err(|e| format!("Error while decoding token: {e}"))?; + ctr.id_token() + .ok_or(String::from("Id Token not found"))? + .claims(&id_token_verifier, presence_no_check) + .map_err(|e| format!("Error in claims verification: {e}")) + .cloned() + } + + fn get_org_user(&self, request: &ServiceRequest) -> Result { + let org_id = self.get_org_id(request); + let token = request.cookie(&Login::Org.to_string()).and_then(|c| { + self.decode_org_token(&org_id.0, c.value()) + .map_err(|e| log::error!("Error in decoding org_user : {e}")) + .ok() + }); + if let Some(token_response) = token { + Ok(try_user_from(&token_response).map_err(|e| { + log::error!("Unable to get org_user: {e}"); + ErrorBadRequest(String::from("Unable to get user")) + })?) + } else { + log::error!("Error org_user not found in cookies"); + Err(self.new_redirect(&Login::Org.to_string())) + } + } +} + +impl Authenticator for OIDCAuthenticator { + fn get_path_prefix(&self) -> String { + self.path_prefix.clone() + } + + fn authenticate( + &self, + request: &ServiceRequest, + login_type: &Login, + ) -> Result { + match login_type { + Login::None => Ok(User::default()), + Login::Global => self.get_global_user(request), + Login::Org => match self.get_global_user(request) { + Err(e) => Err(e), + Ok(_) => self.get_org_user(request), + }, } } @@ -188,9 +286,9 @@ impl Authenticator for OIDCAuthenticator { fn get_organisations(&self, req: &HttpRequest) -> HttpResponse { let organisations = req - .cookie("user") + .cookie(&Login::Global.to_string()) .and_then(|user_cookie| { - self.decode_token(user_cookie.value()) + self.decode_global_token(user_cookie.value()) .map_err(|e| log::error!("Error in decoding user : {e}")) .ok() }) @@ -200,7 +298,7 @@ impl Authenticator for OIDCAuthenticator { Some(organisations) => { HttpResponse::Ok().json(serde_json::json!(organisations)) } - None => self.new_redirect(), + None => self.new_redirect(&Login::Global.to_string()), } } @@ -209,53 +307,22 @@ impl Authenticator for OIDCAuthenticator { req: &HttpRequest, path: &Path, ) -> LocalBoxFuture<'static, actix_web::Result> { - let issuer_url = match self.get_issuer_url(&path.organisation_id) { - Ok(issuer_url) => issuer_url, - Err(e) => { - log::error!("Unable to create issuer url {e}"); - return Box::pin(async move { - Err(ErrorBadRequest(String::from("Unable to create issuer url"))) - }); - } - }; - - let token_url = match self.get_token_url(&path.organisation_id) { - Ok(token_url) => token_url, + let client = match self.get_org_client(&path.organisation_id) { + Ok(client) => client, Err(e) => { - log::error!("Unable to create token url {e}"); + log::error!("Error in getting Org specific client: {e}"); return Box::pin(async move { - Err(ErrorInternalServerError("Unable to create token url")) + Err(ErrorInternalServerError(String::from( + "Error in getting Org specific client", + ))) }); } }; - let redirect_url = match RedirectUrl::new(format!("{}/", self.base_url.clone())) { - Ok(redirect_url) => redirect_url, - Err(e) => { - log::error!("Unable to create redirect url {e}"); - return Box::pin(async move { - Err(ErrorInternalServerError("Unable to create redirect url")) - }); - } - }; - - let provider = self - .provider_metadata - .clone() - .set_issuer(issuer_url) - .set_token_endpoint(Some(token_url)); - - let client = CoreClient::from_provider_metadata( - provider, - ClientId::new(self.client_id.clone()), - Some(ClientSecret::new(self.client_secret.clone())), - ) - .set_redirect_uri(redirect_url); - let user = req - .cookie("user") + .cookie(&Login::Global.to_string()) .and_then(|user_cookie| { - self.decode_token(user_cookie.value()) + self.decode_global_token(user_cookie.value()) .map_err(|e| log::error!("Error in decoding user : {e}")) .ok() }) @@ -280,7 +347,9 @@ impl Authenticator for OIDCAuthenticator { let org_id = path.organisation_id.clone(); let user = ResourceOwnerUsername::new(username.to_string()); let pass = ResourceOwnerPassword::new(switch_pass); - let redirect = self.new_redirect(); + let redirect = self.new_redirect(&Login::Org.to_string()); + let prefix = self.path_prefix.clone(); + let cookie_path = self.get_cookie_path(); Box::pin(async move { let response = client @@ -291,22 +360,20 @@ impl Authenticator for OIDCAuthenticator { .map_err(|e| log::error!("Failed to switch organisation for token: {e}")) .and_then(|tr| { tr.id_token() - .ok_or_else(|| eprintln!("No identity-token!")) - .and_then(|t| { - t.claims(&client.id_token_verifier(), presence_no_check) - .map_err(|e| log::error!("Couldn't verify claims: {e}")) - })?; + .ok_or_else(|| log::error!("No identity-token!"))? + .claims(&client.id_token_verifier(), presence_no_check) + .map_err(|e| log::error!("Couldn't verify claims: {e}"))?; Ok(tr) }); match response { Ok(r) => { - let token = serde_json::to_string(&r).map_err(|err| { - log::error!("Unable to stringify data: {err}"); + let token = serde_json::to_string(&r).map_err(|e| { + log::error!("Unable to stringify data: {e}"); ErrorInternalServerError(format!("Unable to stringify data")) })?; - let cookie = Cookie::build("org_user", token) - .path("/") + let cookie = Cookie::build(Login::Org.to_string(), token) + .path(cookie_path) .http_only(true) .max_age(Duration::days(1)) .finish(); @@ -314,7 +381,7 @@ impl Authenticator for OIDCAuthenticator { .cookie(cookie) .insert_header(( "Location", - format!("/admin/{org_id}/workspaces"), + format!("{prefix}/admin/{org_id}/workspaces"), )) .finish()) } @@ -334,12 +401,12 @@ async fn login( p_cookie } else { log::error!("OIDC: Missing/Bad protection-cookie, redirecting..."); - return Ok(data.new_redirect()); + return Ok(data.new_redirect(&Login::Global.to_string())); }; if params.state.secret() != p_cookie.csrf.secret() { log::error!("OIDC: Bad csrf",); - return Ok(data.new_redirect()); + return Ok(data.new_redirect(&Login::Global.to_string())); } // Exchange the code with a token. @@ -361,20 +428,23 @@ async fn login( match response { Ok(r) => { - let token = serde_json::to_string(&r).map_err(|err| { - log::error!("Unable to stringify data: {err}"); + let token = serde_json::to_string(&r).map_err(|e| { + log::error!("Unable to stringify data: {e}"); ErrorInternalServerError(format!("Unable to stringify data")) })?; - let cookie = Cookie::build("user", token) - .path("/") + let cookie = Cookie::build(Login::Global.to_string(), token) + .path(data.get_cookie_path()) .http_only(true) .max_age(Duration::days(1)) .finish(); Ok(HttpResponse::Found() .cookie(cookie) - .insert_header(("Location", "/admin/organisations")) + .insert_header(( + "Location", + format!("{}/admin/organisations", data.path_prefix.clone()), + )) .finish()) } - Err(()) => Ok(data.new_redirect()), + Err(()) => Ok(data.new_redirect(&Login::Global.to_string())), } } diff --git a/crates/superposition/src/auth/oidc/types.rs b/crates/superposition/src/auth/oidc/types.rs index 002d77b8..db701e8c 100644 --- a/crates/superposition/src/auth/oidc/types.rs +++ b/crates/superposition/src/auth/oidc/types.rs @@ -1,8 +1,9 @@ use actix_web::HttpRequest; use openidconnect::{ core::{ - CoreGenderClaim, CoreJsonWebKeyType, CoreJweContentEncryptionAlgorithm, - CoreJwsSigningAlgorithm, CoreTokenType, + CoreGenderClaim, CoreIdTokenClaims, CoreJsonWebKeyType, + CoreJweContentEncryptionAlgorithm, CoreJwsSigningAlgorithm, CoreTokenResponse, + CoreTokenType, }, AdditionalClaims, AuthorizationCode, CsrfToken, EmptyExtraTokenFields, IdTokenClaims, IdTokenFields, Nonce, StandardTokenResponse, @@ -10,15 +11,15 @@ use openidconnect::{ use serde::{Deserialize, Serialize}; #[derive(Serialize, Debug, Deserialize, Clone)] -pub(super) struct ExtraClaims { +pub(super) struct GlobalUserExtraClaims { pub(super) organisations: Vec, pub(super) switch_pass: String, } -impl AdditionalClaims for ExtraClaims {} +impl AdditionalClaims for GlobalUserExtraClaims {} -pub(super) type CoreIdTokenFields = IdTokenFields< - ExtraClaims, +pub(super) type GlobalUserCoreIdTokenFields = IdTokenFields< + GlobalUserExtraClaims, EmptyExtraTokenFields, CoreGenderClaim, CoreJweContentEncryptionAlgorithm, @@ -26,10 +27,12 @@ pub(super) type CoreIdTokenFields = IdTokenFields< CoreJsonWebKeyType, >; -pub(super) type UserTokenResponse = - StandardTokenResponse; +pub(super) type GlobalUserTokenResponse = + StandardTokenResponse; +pub(super) type GlobalUserClaims = IdTokenClaims; -pub(super) type UserClaims = IdTokenClaims; +pub(super) type OrgUserTokenResponse = CoreTokenResponse; +pub(super) type OrgUserClaims = CoreIdTokenClaims; #[derive(Deserialize, Serialize)] pub(super) struct ProtectionCookie { diff --git a/crates/superposition/src/auth/oidc/utils.rs b/crates/superposition/src/auth/oidc/utils.rs index a08bebe6..00ab795f 100644 --- a/crates/superposition/src/auth/oidc/utils.rs +++ b/crates/superposition/src/auth/oidc/utils.rs @@ -1,8 +1,6 @@ -use openidconnect::Nonce; +use openidconnect::{AdditionalClaims, GenderClaim, IdTokenClaims, Nonce}; use superposition_types::User; -use super::types::UserClaims; - pub(super) fn verify_presence(n: Option<&Nonce>) -> Result<(), String> { if n.is_some() { Ok(()) @@ -15,7 +13,9 @@ pub(super) fn presence_no_check(_: Option<&Nonce>) -> Result<(), String> { Ok(()) } -pub(super) fn try_user_from(claims: &UserClaims) -> Result { +pub(super) fn try_user_from( + claims: &IdTokenClaims, +) -> Result { let user = User { email: claims .email() diff --git a/crates/superposition/src/main.rs b/crates/superposition/src/main.rs index 35755275..a19b4b67 100644 --- a/crates/superposition/src/main.rs +++ b/crates/superposition/src/main.rs @@ -119,7 +119,7 @@ async fn main() -> Result<()> { .await, ); - let auth = AuthHandler::init(&kms_client, &app_env).await; + let auth = AuthHandler::init(&kms_client, &app_env, base.clone()).await; HttpServer::new(move || { let leptos_options = &conf.leptos_options; @@ -139,8 +139,6 @@ async fn main() -> Result<()> { .service(web::redirect("/", ui_redirect_path.to_string())) .service(web::redirect("/admin", ui_redirect_path.to_string())) .service(web::redirect("/admin/{tenant}/", "default-config")) - .service(auth.routes()) - .service(auth.org_routes()) .leptos_routes( leptos_options.to_owned(), routes.to_owned(), @@ -152,6 +150,8 @@ async fn main() -> Result<()> { "/health", get().to(|| async { HttpResponse::Ok().body("Health is good :D") }), ) + .service(auth.routes()) + .service(auth.org_routes()) /***************************** V1 Routes *****************************/ .service( scope("/context")