Skip to content

Commit

Permalink
feat: auth path prefix support and org_user authentication
Browse files Browse the repository at this point in the history
  • Loading branch information
ayushjain17 committed Dec 31, 2024
1 parent 094665f commit e69f0f3
Show file tree
Hide file tree
Showing 11 changed files with 435 additions and 187 deletions.
4 changes: 3 additions & 1 deletion crates/frontend/src/app.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,9 @@ pub fn app(app_envs: Envs) -> impl IntoView {
path="/admin/organisations"
view=move || {
view! {
<Organisations/>
<Layout show_side_nav=false>
<Organisations/>
</Layout>
}
}
/>
Expand Down
98 changes: 63 additions & 35 deletions crates/frontend/src/pages/organisations.rs
Original file line number Diff line number Diff line change
@@ -1,58 +1,86 @@
use leptos::*;
use serde_json::{Map, Value};

use crate::api::fetch_organisations;
use crate::components::skeleton::Skeleton;
use crate::components::{
skeleton::Skeleton,
stat::Stat,
table::{
types::{Column, ColumnSortable},
Table,
},
};
use crate::utils::use_host_server;

#[component]
pub fn organisations() -> impl IntoView {
let host = StoredValue::new(use_host_server());
let (organisation_rs, organisation_ws) = create_signal::<Option<String>>(None);

let host = use_host_server();
let organisation_resource = create_local_resource(
|| (),
|_| async { fetch_organisations().await.unwrap_or_default() },
);

let table_columns = create_memo(move |_| {
let host = host.clone();
let navigate = move |_: &str, row: &Map<String, Value>| {
let organisation_id = row["organisation_id"]
.as_str()
.clone()
.unwrap_or_default()
.to_string();
view! {
<button
formaction=format!("{host}/organisations/switch/{organisation_id}")
class="cursor-pointer text-blue-500"
>
{organisation_id}
</button>
}
.into_view()
};

vec![Column::new(
"organisation_id".to_string(),
None,
navigate,
ColumnSortable::No,
)]
});

view! {
<div class="h-screen w-full flex flex-col items-center justify-center">
<form class="p-8">
<Suspense fallback=move || {
view! { <Skeleton /> }
}>
{move || {
let organisations = organisation_resource.get().unwrap_or_default();
let table_rows = organisations.clone()
.into_iter()
.map(|organisation| {
let mut map = Map::new();
map.insert(String::from("organisation_id"), Value::String(organisation));
map
})
.collect::<Vec<Map<String, Value>>>();

view! {
<form action=format!(
"{}/organisations/switch/{}",
host.get_value(),
organisation_rs.get().unwrap_or_default(),
)>
<div>Select Organisation</div>
<select
class="w-[300px] border border-black"
value=organisation_rs.get().unwrap_or_default()
on:change=move |event| {
let organisation = event_target_value(&event);
organisation_ws
.set(
if organisation.as_str() != "" { Some(organisation) } else { None },
);
}
>
<option value=String::from("")>Select Organisation</option>
<For
each=move || organisation_resource.get().clone().unwrap_or_default()
key=|organisation| organisation.clone()
children=move |organisation| {
view! { <option value=organisation.clone() selected={organisation == organisation_rs.get().unwrap_or_default()}>{organisation}</option> }
}
/>
</select>
<button disabled=organisation_rs.get().is_none()>Submit</button>
</form>
<div class="pb-4">
<Stat
heading="Oraganisations"
icon="ri-building-fill"
number={organisations.len().to_string()}
/>
</div>
<Table
class="card-body card rounded-lg w-full bg-base-100 shadow"
cell_class="min-w-48 font-mono".to_string()
rows=table_rows
key_column="id".to_string()
columns=table_columns.get()
/>
}
}}

</Suspense>
</div>
</form>
}
}
56 changes: 33 additions & 23 deletions crates/service_utils/src/middlewares/tenant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@ use std::rc::Rc;
use superposition_types::TenantConfig;

pub struct OrgWorkspaceMiddlewareFactory {
enable_org_id_header: bool,
enable_workspace_id_header: bool,
enable_org_id: bool,
enable_workspace_id: bool,
}

impl OrgWorkspaceMiddlewareFactory {
pub fn new(enable_org_id_header: bool, enable_workspace_id_header: bool) -> Self {
pub fn new(enable_org_id: bool, enable_workspace_id: bool) -> Self {
Self {
enable_org_id_header,
enable_workspace_id_header,
enable_org_id,
enable_workspace_id,
}
}
}
Expand All @@ -43,16 +43,16 @@ where
fn new_transform(&self, service: S) -> Self::Future {
ready(Ok(OrgWorkspaceMiddleware {
service: Rc::new(service),
enable_org_id_header: self.enable_org_id_header,
enable_workspace_id_header: self.enable_workspace_id_header,
enable_org_id: self.enable_org_id,
enable_workspace_id: self.enable_workspace_id,
}))
}
}

pub struct OrgWorkspaceMiddleware<S> {
service: Rc<S>,
enable_org_id_header: bool,
enable_workspace_id_header: bool,
enable_org_id: bool,
enable_workspace_id: bool,
}

fn extract_org_workspace_from_header(
Expand Down Expand Up @@ -106,8 +106,8 @@ where

fn call(&self, req: ServiceRequest) -> Self::Future {
let srv = self.service.clone();
let enable_org_id = self.enable_org_id_header;
let enable_workspace_id = self.enable_workspace_id_header;
let enable_org_id = self.enable_org_id;
let enable_workspace_id = self.enable_workspace_id;

Box::pin(async move {
let app_state = match req.app_data::<Data<AppState>>() {
Expand All @@ -124,8 +124,10 @@ where
};

let request_path = req.uri().path().replace(&base, "");
let request_pattern =
req.match_pattern().unwrap_or_else(|| request_path.clone());
let request_pattern = req
.match_pattern()
.map(|a| a.replace(&base, ""))
.unwrap_or_else(|| request_path.clone());
let pkg_regex = Regex::new(".*/pkg/.+")
.map_err(|err| error::ErrorInternalServerError(err.to_string()))?;
let assets_regex = Regex::new(".*/assets/.+")
Expand Down Expand Up @@ -196,21 +198,29 @@ where
)
});

// TODO: validate the tenant
let (validated_tenant, tenant_config) = match (org, enable_org_id, workspace, enable_workspace_id) {
(None, true, None, true) => return Err(error::ErrorBadRequest("The parameters org id and workspace id are required, and must be passed through headers/url params/query params. Consult the documentation to know which to use for this endpoint")),
(None, true, _, _) => return Err(error::ErrorBadRequest("The parameter org id is required, and must be passed through headers/url params/query params. Consult the documentation to know which to use for this endpoint")),
(_, _, None, true) => return Err(error::ErrorBadRequest("The parameter workspace id is required, and must be passed through headers/url params/query params. Consult the documentation to know which to use for this endpoint")),
(Some(org_id), _, Some(workspace_id), _) => {
let workspace_id = match (enable_workspace_id, workspace) {
(true, None) => return Err(error::ErrorBadRequest("The parameter workspace id is required, and must be passed through headers/url params/query params. Consult the documentation to know which to use for this endpoint")),
(true, Some(workspace_id)) => workspace_id,
(false, _) => "public",
};

// TODO: validate the tenant, get correct TenantConfig
let (validated_tenant, tenant_config) = match (enable_org_id, org) {
(true, None) => return Err(error::ErrorBadRequest("The parameter org id is required, and must be passed through headers/url params/query params. Consult the documentation to know which to use for this endpoint")),
(true, Some(org_id)) => {
let tenant = format!("{org_id}_{workspace_id}");
(Tenant(tenant), TenantConfig::default())
}
(_, _, _, _) => (Tenant("public".into()), TenantConfig::default()),
},
(false, _) => (Tenant("public".into()), TenantConfig::default()),
};

let organisation = org
.map(String::from)
.map(OrganisationId)
.unwrap_or_default();

req.extensions_mut().insert(validated_tenant);
req.extensions_mut()
.insert(OrganisationId(org.unwrap_or("juspay").into()));
req.extensions_mut().insert(organisation);
req.extensions_mut().insert(tenant_config);
}

Expand Down
10 changes: 8 additions & 2 deletions crates/service_utils/src/service/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,6 @@ impl FromRequest for DbConnection {
req: &actix_web::HttpRequest,
_: &mut actix_web::dev::Payload,
) -> Self::Future {

let app_state = match req.app_data::<Data<AppState>>() {
Some(state) => state,
None => {
Expand Down Expand Up @@ -238,6 +237,13 @@ impl FromRequest for CustomHeaders {

#[derive(Deref, DerefMut, Clone, Debug)]
pub struct OrganisationId(pub String);

impl Default for OrganisationId {
fn default() -> Self {
Self(String::from("superposition"))
}
}

impl FromRequest for OrganisationId {
type Error = Error;
type Future = Ready<Result<Self, Self::Error>>;
Expand All @@ -264,7 +270,7 @@ impl FromRequest for OrganisationId {
"message": "x-org-id was not set. Please ensure you are passing in the x-tenant header"
})))
} else {
Ok(OrganisationId("juspay".into()))
Ok(OrganisationId::default())
}
}
};
Expand Down
54 changes: 48 additions & 6 deletions crates/superposition/src/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ mod no_auth;
mod oidc;

use std::{
collections::HashSet,
future::{ready, Ready},
sync::Arc,
};
Expand All @@ -15,7 +16,7 @@ use actix_web::{
web::{self, Data, Path},
Error, HttpMessage, HttpRequest, HttpResponse, Scope,
};
use authenticator::{Authenticator, SwitchOrgParams};
use authenticator::{Authenticator, Login, SwitchOrgParams};
use aws_sdk_kms::Client;
use futures_util::future::LocalBoxFuture;
use no_auth::DisabledAuthenticator;
Expand All @@ -32,6 +33,31 @@ pub struct AuthMiddleware<S> {
auth_handler: AuthHandler,
}

impl<S> AuthMiddleware<S> {
fn get_login_type(
&self,
request: &ServiceRequest,
exception: &HashSet<String>,
) -> Login {
let path_prefix = self.auth_handler.0.get_path_prefix();
let request_pattern = request
.match_pattern()
.map(|a| a.replace(&path_prefix, ""))
.unwrap_or_else(|| request.uri().path().replace(&path_prefix, ""));

let excep = exception.contains(&request_pattern)
// Implies it's a local/un-forwarded request.
|| !request.headers().contains_key(header::USER_AGENT);
let org_request = request.path().matches("/organisations").count() > 0;

match (excep, org_request) {
(true, false) => Login::None,
(_, true) => Login::Global,
(false, false) => Login::Org,
}
}
}

impl<S, B> Service<ServiceRequest> for AuthMiddleware<S>
where
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
Expand Down Expand Up @@ -69,7 +95,12 @@ where
(_, _) => None,
}
})
.unwrap_or_else(|| self.auth_handler.0.authenticate(&request));
.unwrap_or_else(|| {
let login_type = self
.get_login_type(&request, &state.tenant_middleware_exclusion_list);

self.auth_handler.0.authenticate(&request, &login_type)
});

match result {
Ok(user) => {
Expand Down Expand Up @@ -99,7 +130,11 @@ impl AuthHandler {
routes(self.clone())
}

pub async fn init(kms_client: &Option<Client>, app_env: &AppEnv) -> Self {
pub async fn init(
kms_client: &Option<Client>,
app_env: &AppEnv,
path_prefix: String,
) -> Self {
let auth_provider: String = get_from_env_unsafe("AUTH_PROVIDER").unwrap();
let mut auth = auth_provider.split('+');

Expand All @@ -110,6 +145,7 @@ impl AuthHandler {
.split(",")
.map(String::from)
.collect(),
path_prefix,
)),
Some("OIDC") => {
let url = Url::parse(auth.next().unwrap())
Expand All @@ -119,9 +155,15 @@ impl AuthHandler {
let cid = get_from_env_unsafe("OIDC_CLIENT_ID").unwrap();
let csecret = get_oidc_client_secret(kms_client, app_env).await;
Arc::new(
oidc::OIDCAuthenticator::new(url, base_url, cid, csecret)
.await
.unwrap(),
oidc::OIDCAuthenticator::new(
url,
base_url,
path_prefix,
cid,
csecret,
)
.await
.unwrap(),
)
}
_ => panic!("Missing/Unknown authenticator."),
Expand Down
Loading

0 comments on commit e69f0f3

Please sign in to comment.