diff --git a/conf/policy.rules b/conf/policy.rules index ff2e169..eb816f8 100644 --- a/conf/policy.rules +++ b/conf/policy.rules @@ -9,7 +9,7 @@ # # Supported actions: # - require [or |...] -# - (TODO) direct +# - direct # Connection to TCP 8001 requires "cap1" on proxy's capabilities, # TCP 8002 requires "cap1" or "cap2". diff --git a/src/policy/mod.rs b/src/policy/mod.rs index ee69a7b..fb8f3f2 100644 --- a/src/policy/mod.rs +++ b/src/policy/mod.rs @@ -14,25 +14,72 @@ use flexstr::{SharedStr, ToSharedStr}; use capabilities::CapSet; use tracing::info; +#[derive(Debug)] +pub enum Action { + Require(HashSet), + Direct, +} + +impl Default for Action { + fn default() -> Self { + Self::Require(Default::default()) + } +} + +impl Action { + fn add_require(&mut self, caps: CapSet) { + match self { + Self::Direct => false, + Self::Require(set) => set.insert(caps), + }; + } + + fn set_direct(&mut self) { + *self = Self::Direct + } + + fn len(&self) -> usize { + match self { + Self::Direct => 1, + Self::Require(set) => set.len(), + } + } + + fn extend(&mut self, other: &Self) { + match other { + Self::Direct => self.set_direct(), + Self::Require(new_caps) => { + if let Self::Require(caps) = self { + caps.extend(new_caps.iter().cloned()) + } + } + } + } +} + #[derive(Default)] -struct RuleSet(HashMap>); +struct RuleSet(HashMap); type ListenPortRuleSet = RuleSet; type DstDomainRuleSet = RuleSet; impl RuleSet { - fn add(&mut self, key: K, caps: CapSet) { + fn add(&mut self, key: K, action: parser::RuleAction) { // TODO: warning duplicated rules - self.0.entry(key).or_default().insert(caps); + let value = self.0.entry(key).or_default(); + match action { + parser::RuleAction::Require(caps) => value.add_require(caps), + parser::RuleAction::Direct => value.set_direct(), + } } - fn get<'a>(&'a self, key: &'a K) -> impl Iterator { - self.0.get(key).into_iter().flatten() + fn get<'a>(&'a self, key: &'a K) -> impl Iterator { + self.0.get(key).into_iter() } } impl DstDomainRuleSet { - fn get_recursive<'a>(&'a self, name: &'a str) -> impl Iterator { + fn get_recursive<'a>(&'a self, name: &'a str) -> impl Iterator { let mut skip = 0usize; name.split_terminator('.') .map(move |part| { @@ -42,7 +89,6 @@ impl DstDomainRuleSet { }) .chain(["."]) .filter_map(|key| self.0.get(key)) - .flatten() } } @@ -75,13 +121,12 @@ impl Policy { fn add_rule(&mut self, rule: parser::Rule) { let parser::Rule { filter, action } = rule; - let parser::RuleAction::Require(caps) = action; match filter { parser::RuleFilter::ListenPort(port) => { - self.listen_port_ruleset.add(port, caps); + self.listen_port_ruleset.add(port, action); } parser::RuleFilter::Sni(parts) => { - self.dst_domain_ruleset.add(parts.to_shared_str(), caps); + self.dst_domain_ruleset.add(parts.to_shared_str(), action); } } } @@ -94,24 +139,24 @@ impl Policy { .fold(0, |acc, v| acc + v.len()) } - pub fn matches( - &self, - listen_port: Option, - dst_domain: Option, - ) -> HashSet { - let mut rules = HashSet::new(); + pub fn matches(&self, listen_port: Option, dst_domain: Option) -> Action { + let mut action: Action = Default::default(); if let Some(port) = listen_port { - rules.extend(self.listen_port_ruleset.get(&port).cloned()); + self.listen_port_ruleset + .get(&port) + .for_each(|a| action.extend(a)) } if let Some(name) = dst_domain { - rules.extend(self.dst_domain_ruleset.get_recursive(&name).cloned()); + self.dst_domain_ruleset + .get_recursive(&name) + .for_each(|a| action.extend(a)); } - rules + action } } #[test] -fn test_router_listen_port() { +fn test_policy_listen_port() { use capabilities::CheckAllCapsMeet; let rules = " @@ -119,10 +164,16 @@ fn test_router_listen_port() { listen-port 2 require b listen-port 2 require c or d "; - let router = Policy::load(rules.as_bytes()).unwrap(); - assert_eq!(3, router.rule_count()); - let p1 = router.matches(Some(1), None); - let p2 = router.matches(Some(2), None); + let policy = Policy::load(rules.as_bytes()).unwrap(); + assert_eq!(3, policy.rule_count()); + let p1 = match policy.matches(Some(1), None) { + Action::Require(a) => a, + _ => panic!(), + }; + let p2 = match policy.matches(Some(2), None) { + Action::Require(a) => a, + _ => panic!(), + }; let abc = CapSet::new(["a", "b", "c"].into_iter()); let bc = CapSet::new(["b", "c"].into_iter()); let c = CapSet::new(["c"].into_iter()); @@ -135,8 +186,8 @@ fn test_router_listen_port() { } #[test] -fn test_router_get_domain_caps_requirements() { - let router = Policy::load( +fn test_policy_get_domain_caps_requirements() { + let policy = Policy::load( " dst domain . require root dst domain com. require com @@ -147,18 +198,32 @@ fn test_router_get_domain_caps_requirements() { .unwrap(); assert_eq!( 3, - router + policy .dst_domain_ruleset .get_recursive("test.example.com") .count() ); assert_eq!( 3, - router + policy .dst_domain_ruleset .get_recursive("example.com") .count() ); - assert_eq!(2, router.dst_domain_ruleset.get_recursive("com").count()); - assert_eq!(1, router.dst_domain_ruleset.get_recursive("net").count()); + assert_eq!(2, policy.dst_domain_ruleset.get_recursive("com").count()); + assert_eq!(1, policy.dst_domain_ruleset.get_recursive("net").count()); +} + +#[test] +fn test_policy_action_direct() { + let rules = " + listen-port 1 require a + listen-port 1 direct + dst domain test require c + "; + let policy = Policy::load(rules.as_bytes()).unwrap(); + let direct = policy.matches(Some(1), Some("test".into())); + let require = policy.matches(Some(2), Some("test".into())); + assert!(matches!(direct, Action::Direct)); + assert!(matches!(require, Action::Require(_))); } diff --git a/src/policy/parser.rs b/src/policy/parser.rs index 102e35b..b183073 100644 --- a/src/policy/parser.rs +++ b/src/policy/parser.rs @@ -20,6 +20,7 @@ pub enum RuleFilter { #[derive(Debug, PartialEq, Eq)] pub enum RuleAction { Require(CapSet), + Direct, } #[derive(Debug, PartialEq, Eq)] @@ -89,8 +90,14 @@ fn action_require(input: &str) -> IResult<&str, RuleAction> { .parse(input) } +fn action_direct(input: &str) -> IResult<&str, RuleAction> { + tag_no_case("direct") + .map(|_| RuleAction::Direct) + .parse(input) +} + fn rule_action(input: &str) -> IResult<&str, RuleAction> { - action_require(input) + alt((action_require, action_direct)).parse(input) } fn rule(input: &str) -> IResult<&str, Rule> { @@ -153,13 +160,15 @@ fn test_dst_domain_filter() { } #[test] -fn test_action_require() { - let (rem, caps) = action_require("require a or b\n").unwrap(); +fn test_action() { + let (rem, action) = rule_action("require a or b\n").unwrap(); assert_eq!("\n", rem); assert_eq!( RuleAction::Require(CapSet::new(["a", "b"].into_iter())), - caps + action ); + let (_, action) = rule_action("direct\n").unwrap(); + assert_eq!(RuleAction::Direct, action); } #[test] diff --git a/src/server.rs b/src/server.rs index ddb9559..54032aa 100644 --- a/src/server.rs +++ b/src/server.rs @@ -11,7 +11,7 @@ use moproxy::{ client::{Connectable, NewClient}, futures_stream::TcpListenerStream, monitor::Monitor, - policy::{parser, Policy}, + policy::{parser, Action, Policy}, proxy::{ProxyProto, ProxyServer, UserPassAuthCredential}, web::{self, AutoRemoveFile}, }; @@ -153,15 +153,19 @@ impl MoProxy { fn servers_with_policy(&self, client: &NewClient) -> Vec> { let from_port = client.from_port; - let caps = self + let action = self .policy .read() .matches(Some(from_port), client.dest.host.domain()); - self.monitor - .servers() - .into_iter() - .filter(|s| caps.iter().all(|c| s.capable_anyof(c))) - .collect() + match action { + Action::Direct => vec![self.direct_server.clone()], + Action::Require(caps) => self + .monitor + .servers() + .into_iter() + .filter(|s| caps.iter().all(|c| s.capable_anyof(c))) + .collect(), + } } #[instrument(level = "error", skip_all, fields(on_port=sock.local_addr()?.port(), peer=?sock.peer_addr()?))] @@ -184,6 +188,7 @@ impl MoProxy { match client { Ok(client) => client.serve().await?, Err(client) if args.allow_direct => { + // FIXME: skip this if it's already a direct connection client .direct_connect(self.direct_server.clone()) .await?