Skip to content

Commit

Permalink
Impl policy action direct
Browse files Browse the repository at this point in the history
Related to #13
  • Loading branch information
sorz committed Mar 15, 2023
1 parent 64f7680 commit 95e66ed
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 42 deletions.
2 changes: 1 addition & 1 deletion conf/policy.rules
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
#
# Supported actions:
# - require <cap1> [or <cap2>|...]
# - (TODO) direct
# - direct

# Connection to TCP 8001 requires "cap1" on proxy's capabilities,
# TCP 8002 requires "cap1" or "cap2".
Expand Down
125 changes: 95 additions & 30 deletions src/policy/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,25 +14,72 @@ use flexstr::{SharedStr, ToSharedStr};
use capabilities::CapSet;
use tracing::info;

#[derive(Debug)]
pub enum Action {
Require(HashSet<CapSet>),
Direct,
}

impl Default for Action {
fn default() -> Self {
Self::Require(Default::default())
}
}

impl Action {
fn add_require(&mut self, caps: CapSet) {
match self {
Self::Direct => false,
Self::Require(set) => set.insert(caps),
};
}

fn set_direct(&mut self) {
*self = Self::Direct
}

fn len(&self) -> usize {
match self {
Self::Direct => 1,
Self::Require(set) => set.len(),
}
}

fn extend(&mut self, other: &Self) {
match other {
Self::Direct => self.set_direct(),
Self::Require(new_caps) => {
if let Self::Require(caps) = self {
caps.extend(new_caps.iter().cloned())
}
}
}
}
}

#[derive(Default)]
struct RuleSet<K: Eq + Hash>(HashMap<K, HashSet<CapSet>>);
struct RuleSet<K: Eq + Hash>(HashMap<K, Action>);

type ListenPortRuleSet = RuleSet<u16>;
type DstDomainRuleSet = RuleSet<SharedStr>;

impl<K: Eq + Hash> RuleSet<K> {
fn add(&mut self, key: K, caps: CapSet) {
fn add(&mut self, key: K, action: parser::RuleAction) {
// TODO: warning duplicated rules
self.0.entry(key).or_default().insert(caps);
let value = self.0.entry(key).or_default();
match action {
parser::RuleAction::Require(caps) => value.add_require(caps),
parser::RuleAction::Direct => value.set_direct(),
}
}

fn get<'a>(&'a self, key: &'a K) -> impl Iterator<Item = &'a CapSet> {
self.0.get(key).into_iter().flatten()
fn get<'a>(&'a self, key: &'a K) -> impl Iterator<Item = &'a Action> {
self.0.get(key).into_iter()
}
}

impl DstDomainRuleSet {
fn get_recursive<'a>(&'a self, name: &'a str) -> impl Iterator<Item = &'a CapSet> {
fn get_recursive<'a>(&'a self, name: &'a str) -> impl Iterator<Item = &'a Action> {
let mut skip = 0usize;
name.split_terminator('.')
.map(move |part| {
Expand All @@ -42,7 +89,6 @@ impl DstDomainRuleSet {
})
.chain(["."])
.filter_map(|key| self.0.get(key))
.flatten()
}
}

Expand Down Expand Up @@ -75,13 +121,12 @@ impl Policy {

fn add_rule(&mut self, rule: parser::Rule) {
let parser::Rule { filter, action } = rule;
let parser::RuleAction::Require(caps) = action;
match filter {
parser::RuleFilter::ListenPort(port) => {
self.listen_port_ruleset.add(port, caps);
self.listen_port_ruleset.add(port, action);
}
parser::RuleFilter::Sni(parts) => {
self.dst_domain_ruleset.add(parts.to_shared_str(), caps);
self.dst_domain_ruleset.add(parts.to_shared_str(), action);
}
}
}
Expand All @@ -94,35 +139,41 @@ impl Policy {
.fold(0, |acc, v| acc + v.len())
}

pub fn matches(
&self,
listen_port: Option<u16>,
dst_domain: Option<SharedStr>,
) -> HashSet<CapSet> {
let mut rules = HashSet::new();
pub fn matches(&self, listen_port: Option<u16>, dst_domain: Option<SharedStr>) -> Action {
let mut action: Action = Default::default();
if let Some(port) = listen_port {
rules.extend(self.listen_port_ruleset.get(&port).cloned());
self.listen_port_ruleset
.get(&port)
.for_each(|a| action.extend(a))
}
if let Some(name) = dst_domain {
rules.extend(self.dst_domain_ruleset.get_recursive(&name).cloned());
self.dst_domain_ruleset
.get_recursive(&name)
.for_each(|a| action.extend(a));
}
rules
action
}
}

#[test]
fn test_router_listen_port() {
fn test_policy_listen_port() {
use capabilities::CheckAllCapsMeet;

let rules = "
listen-port 1 require a
listen-port 2 require b
listen-port 2 require c or d
";
let router = Policy::load(rules.as_bytes()).unwrap();
assert_eq!(3, router.rule_count());
let p1 = router.matches(Some(1), None);
let p2 = router.matches(Some(2), None);
let policy = Policy::load(rules.as_bytes()).unwrap();
assert_eq!(3, policy.rule_count());
let p1 = match policy.matches(Some(1), None) {
Action::Require(a) => a,
_ => panic!(),
};
let p2 = match policy.matches(Some(2), None) {
Action::Require(a) => a,
_ => panic!(),
};
let abc = CapSet::new(["a", "b", "c"].into_iter());
let bc = CapSet::new(["b", "c"].into_iter());
let c = CapSet::new(["c"].into_iter());
Expand All @@ -135,8 +186,8 @@ fn test_router_listen_port() {
}

#[test]
fn test_router_get_domain_caps_requirements() {
let router = Policy::load(
fn test_policy_get_domain_caps_requirements() {
let policy = Policy::load(
"
dst domain . require root
dst domain com. require com
Expand All @@ -147,18 +198,32 @@ fn test_router_get_domain_caps_requirements() {
.unwrap();
assert_eq!(
3,
router
policy
.dst_domain_ruleset
.get_recursive("test.example.com")
.count()
);
assert_eq!(
3,
router
policy
.dst_domain_ruleset
.get_recursive("example.com")
.count()
);
assert_eq!(2, router.dst_domain_ruleset.get_recursive("com").count());
assert_eq!(1, router.dst_domain_ruleset.get_recursive("net").count());
assert_eq!(2, policy.dst_domain_ruleset.get_recursive("com").count());
assert_eq!(1, policy.dst_domain_ruleset.get_recursive("net").count());
}

#[test]
fn test_policy_action_direct() {
let rules = "
listen-port 1 require a
listen-port 1 direct
dst domain test require c
";
let policy = Policy::load(rules.as_bytes()).unwrap();
let direct = policy.matches(Some(1), Some("test".into()));
let require = policy.matches(Some(2), Some("test".into()));
assert!(matches!(direct, Action::Direct));
assert!(matches!(require, Action::Require(_)));
}
17 changes: 13 additions & 4 deletions src/policy/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ pub enum RuleFilter {
#[derive(Debug, PartialEq, Eq)]
pub enum RuleAction {
Require(CapSet),
Direct,
}

#[derive(Debug, PartialEq, Eq)]
Expand Down Expand Up @@ -89,8 +90,14 @@ fn action_require(input: &str) -> IResult<&str, RuleAction> {
.parse(input)
}

fn action_direct(input: &str) -> IResult<&str, RuleAction> {
tag_no_case("direct")
.map(|_| RuleAction::Direct)
.parse(input)
}

fn rule_action(input: &str) -> IResult<&str, RuleAction> {
action_require(input)
alt((action_require, action_direct)).parse(input)
}

fn rule(input: &str) -> IResult<&str, Rule> {
Expand Down Expand Up @@ -153,13 +160,15 @@ fn test_dst_domain_filter() {
}

#[test]
fn test_action_require() {
let (rem, caps) = action_require("require a or b\n").unwrap();
fn test_action() {
let (rem, action) = rule_action("require a or b\n").unwrap();
assert_eq!("\n", rem);
assert_eq!(
RuleAction::Require(CapSet::new(["a", "b"].into_iter())),
caps
action
);
let (_, action) = rule_action("direct\n").unwrap();
assert_eq!(RuleAction::Direct, action);
}

#[test]
Expand Down
19 changes: 12 additions & 7 deletions src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use moproxy::{
client::{Connectable, NewClient},
futures_stream::TcpListenerStream,
monitor::Monitor,
policy::{parser, Policy},
policy::{parser, Action, Policy},
proxy::{ProxyProto, ProxyServer, UserPassAuthCredential},
web::{self, AutoRemoveFile},
};
Expand Down Expand Up @@ -153,15 +153,19 @@ impl MoProxy {

fn servers_with_policy(&self, client: &NewClient) -> Vec<Arc<ProxyServer>> {
let from_port = client.from_port;
let caps = self
let action = self
.policy
.read()
.matches(Some(from_port), client.dest.host.domain());
self.monitor
.servers()
.into_iter()
.filter(|s| caps.iter().all(|c| s.capable_anyof(c)))
.collect()
match action {
Action::Direct => vec![self.direct_server.clone()],
Action::Require(caps) => self
.monitor
.servers()
.into_iter()
.filter(|s| caps.iter().all(|c| s.capable_anyof(c)))
.collect(),
}
}

#[instrument(level = "error", skip_all, fields(on_port=sock.local_addr()?.port(), peer=?sock.peer_addr()?))]
Expand All @@ -184,6 +188,7 @@ impl MoProxy {
match client {
Ok(client) => client.serve().await?,
Err(client) if args.allow_direct => {
// FIXME: skip this if it's already a direct connection
client
.direct_connect(self.direct_server.clone())
.await?
Expand Down

0 comments on commit 95e66ed

Please sign in to comment.