Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: distinguish count(*) and row_number() from different tables #861

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions src/binder/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,13 @@ impl Binder {
}

let node = match func.name.to_string().to_lowercase().as_str() {
"count" if args.is_empty() => Node::RowCount,
"count" if args.is_empty() => {
let num: usize = self.context().from.unwrap().into();
let id = self
.egraph
.add(Node::Constant(DataValue::Int32(num as i32)));
Node::CountStar(id)
}
"count" if func.distinct => Node::CountDistinct(args[0]),
"count" => Node::Count(args[0]),
"max" => Node::Max(args[0]),
Expand All @@ -418,7 +424,13 @@ impl Binder {
"first" => Node::First(args[0]),
"last" => Node::Last(args[0]),
"replace" => Node::Replace([args[0], args[1], args[2]]),
"row_number" => Node::RowNumber,
"row_number" => {
let num: usize = self.context().from.unwrap().into();
let id = self
.egraph
.add(Node::Constant(DataValue::Int32(num as i32)));
Node::RowNumber(id)
}
name => todo!("Unsupported function: {}", name),
};
let mut id = self.egraph.add(node);
Expand Down
7 changes: 7 additions & 0 deletions src/binder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,8 @@ struct Context {
/// Column aliases that can be accessed from the outside query.
/// `column_alias` -> id
output_aliases: HashMap<String, Id>,
/// The plan of `FROM` clause.
from: Option<Id>,
}

impl Binder {
Expand Down Expand Up @@ -339,6 +341,11 @@ impl Binder {
}
}

/// Get the current context.
fn context(&self) -> &Context {
self.contexts.last().unwrap()
}

/// Add an column alias to the current context.
fn add_alias(&mut self, column_name: String, table_name: String, id: Id) {
let context = self.contexts.last_mut().unwrap();
Expand Down
7 changes: 4 additions & 3 deletions src/binder/select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,10 @@ impl Binder {
}

fn bind_select(&mut self, select: Select, order_by: Vec<OrderByExpr>) -> Result {
let from = self.bind_from(select.from)?;
let projection = self.bind_projection(select.projection, from)?;
let mut plan = self.bind_from(select.from)?;
self.contexts.last_mut().unwrap().from = Some(plan);

let projection = self.bind_projection(select.projection, plan)?;
let mut where_ = self.bind_where(select.selection)?;
let groupby = match select.group_by {
GroupByExpr::All => return Err(BindError::Todo("group by all".into())),
Expand All @@ -91,7 +93,6 @@ impl Binder {
Some(Distinct::On(exprs)) => self.bind_exprs(exprs)?,
};

let mut plan = from;
self.plan_apply(&mut where_, &mut plan);
plan = self.egraph.add(Node::Filter([where_, plan]));
let mut to_rewrite = [projection, distinct, having, orderby];
Expand Down
8 changes: 4 additions & 4 deletions src/executor/evaluator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ impl<'a> Evaluator<'a> {
}
Desc(a) | Ref(a) => self.next(*a).eval(chunk),
// for aggs, evaluate its children
RowCount => Ok(ArrayImpl::new_null(
CountStar(_) => Ok(ArrayImpl::new_null(
(0..chunk.cardinality()).map(|_| ()).collect(),
)),
Count(a) | Sum(a) | Min(a) | Max(a) | First(a) | Last(a) | CountDistinct(a) => {
Expand Down Expand Up @@ -160,7 +160,7 @@ impl<'a> Evaluator<'a> {
match self.node() {
Over([window, _, _]) => self.next(*window).init_agg_state(),
CountDistinct(_) => AggState::DistinctValue(HashSet::default()),
RowCount | RowNumber | Count(_) => AggState::Value(DataValue::Int32(0)),
CountStar(_) | RowNumber(_) | Count(_) => AggState::Value(DataValue::Int32(0)),
Sum(_) | Min(_) | Max(_) | First(_) | Last(_) => AggState::Value(DataValue::Null),
t => panic!("not aggregation: {t}"),
}
Expand Down Expand Up @@ -214,7 +214,7 @@ impl<'a> Evaluator<'a> {
use Expr::*;
Ok(match state {
AggState::Value(state) => AggState::Value(match self.node() {
RowCount => state.add(DataValue::Int32(chunk.cardinality() as _)),
CountStar(_) => state.add(DataValue::Int32(chunk.cardinality() as _)),
Count(a) => state.add(DataValue::Int32(self.next(*a).eval(chunk)?.count() as _)),
Sum(a) => state.add(self.next(*a).eval(chunk)?.sum()),
Min(a) => state.min(self.next(*a).eval(chunk)?.min_()),
Expand Down Expand Up @@ -244,7 +244,7 @@ impl<'a> Evaluator<'a> {
}
match state {
AggState::Value(state) => AggState::Value(match self.node() {
RowCount | RowNumber => state.add(DataValue::Int32(1)),
CountStar(_) | RowNumber(_) => state.add(DataValue::Int32(1)),
Count(_) => state.add(DataValue::Int32(!value.is_null() as _)),
Sum(_) => state.add(value),
Min(_) => state.min(value),
Expand Down
3 changes: 2 additions & 1 deletion src/planner/explain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,8 @@ impl<'a> Explain<'a> {
),

// aggregations
RowCount | RowNumber => enode.to_string().into(),
CountStar(num) => format!("count(*)_{}", self.expr[*num].as_const()).into(),
RowNumber(num) => format!("row_number()_{}", self.expr[*num].as_const()).into(),
Max(a) | Min(a) | Sum(a) | Avg(a) | Count(a) | First(a) | Last(a)
| CountDistinct(a) => {
let name = enode.to_string();
Expand Down
16 changes: 11 additions & 5 deletions src/planner/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,15 +75,21 @@ define_language! {
"sum" = Sum(Id),
"avg" = Avg(Id),
"count" = Count(Id),
"count-distinct" = CountDistinct(Id),
"rowcount" = RowCount,
"count_distinct" = CountDistinct(Id),
"count_star" = CountStar(Id), // (count_star num)
// count the number of rows of the input table
// `num` is a integer literal, which is only used to
// distinguish count(*) from different tables
"first" = First(Id),
"last" = Last(Id),
// window functions
"over" = Over([Id; 3]), // (over window_function [partition_key..] [order_key..])
// TODO: support frame clause
// "range" = Range([Id; 2]), // (range start end)
"row_number" = RowNumber,
"row_number" = RowNumber(Id), // (row_number num)
// generate a sequence of numbers starting from 1
// `num` is a integer literal, which is only used to
// distinguish row_number() from different tables

// subquery related
"exists" = Exists(Id), // (exists plan)
Expand Down Expand Up @@ -240,7 +246,7 @@ impl Expr {
use Expr::*;
matches!(
self,
RowCount
CountStar(_)
| Max(_)
| Min(_)
| Sum(_)
Expand All @@ -254,7 +260,7 @@ impl Expr {

pub const fn is_window_function(&self) -> bool {
use Expr::*;
matches!(self, RowNumber) || self.is_aggregate_function()
matches!(self, RowNumber(_)) || self.is_aggregate_function()
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/planner/rules/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,7 @@ pub fn analyze_columns(egraph: &EGraph, enode: &Expr) -> ColumnSet {
use Expr::*;
let columns = |i: &Id| &egraph[*i].data.columns;
match enode {
Column(_) | Ref(_) => [enode.clone()].into_iter().collect(),
Column(_) | Ref(_) => [enode.clone()].into(),
// others: merge from all children
_ => (enode.children().iter())
.flat_map(|id| columns(id).iter().cloned())
Expand Down
2 changes: 1 addition & 1 deletion src/planner/rules/type_.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ pub fn analyze_type(
Avg(a) => check(enode, x(a)?, |a| a.is_number()),

// agg
RowCount | RowNumber | Count(_) | CountDistinct(_) => Ok(DataType::Int32),
CountStar(_) | RowNumber(_) | Count(_) | CountDistinct(_) => Ok(DataType::Int32),
First(a) | Last(a) => x(a),
Over([f, _, _]) => x(f),

Expand Down
12 changes: 6 additions & 6 deletions tests/planner_test/count.planner.sql
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@ explain select count(*) from t
/*
Projection
├── exprs:ref
│ └── rowcount
├── cost: 1.13
│ └── count(*)_4
├── cost: 1.23
├── rows: 1
└── Agg { aggs: [ rowcount ], cost: 1.11, rows: 1 }
└── Agg { aggs: [ count(*)_4 ], cost: 1.21, rows: 1 }
└── Scan { table: t, list: [], filter: true, cost: 0, rows: 1 }
*/

Expand All @@ -18,12 +18,12 @@ explain select count(*) + 1 from t
Projection
├── exprs:+
│ ├── lhs:ref
│ │ └── rowcount
│ │ └── count(*)_4
│ ├── rhs: 1

├── cost: 1.33
├── cost: 1.4300001
├── rows: 1
└── Agg { aggs: [ rowcount ], cost: 1.11, rows: 1 }
└── Agg { aggs: [ count(*)_4 ], cost: 1.21, rows: 1 }
└── Scan { table: t, list: [], filter: true, cost: 0, rows: 1 }
*/

56 changes: 28 additions & 28 deletions tests/planner_test/tpch.planner.sql
Original file line number Diff line number Diff line change
Expand Up @@ -159,10 +159,10 @@ Projection
│ │ │ └── l_discount

│ └── ref
│ └── rowcount
├── cost: 70266880
│ └── count(*)_19
├── cost: 70566940
├── rows: 100
└── Order { by: [ l_returnflag, l_linestatus ], cost: 70266840, rows: 100 }
└── Order { by: [ l_returnflag, l_linestatus ], cost: 70566904, rows: 100 }
└── HashAgg
├── keys: [ l_returnflag, l_linestatus ]
├── aggs:
Expand All @@ -184,8 +184,8 @@ Projection
│ │ └── l_discount
│ ├── count
│ │ └── l_discount
│ └── rowcount
├── cost: 70265070
│ └── count(*)_19
├── cost: 70565140
├── rows: 100
└── Projection
├── exprs: [ l_quantity, l_extendedprice, l_discount, l_tax, l_returnflag, l_linestatus ]
Expand Down Expand Up @@ -682,11 +682,11 @@ Projection
├── exprs:
│ ┌── o_orderpriority
│ └── ref
│ └── rowcount
├── cost: 35742960
│ └── count(*)_12
├── cost: 35761710
├── rows: 10
└── Order { by: [ o_orderpriority ], cost: 35742960, rows: 10 }
└── HashAgg { keys: [ o_orderpriority ], aggs: [ rowcount ], cost: 35742904, rows: 10 }
└── Order { by: [ o_orderpriority ], cost: 35761710, rows: 10 }
└── HashAgg { keys: [ o_orderpriority ], aggs: [ count(*)_12 ], cost: 35761656, rows: 10 }
└── Projection { exprs: [ o_orderpriority ], cost: 35712024, rows: 187500 }
└── HashJoin
├── type: semi
Expand Down Expand Up @@ -1952,26 +1952,26 @@ Projection
│ │ └── count
│ │ └── o_orderkey
│ └── ref
│ └── rowcount
├── cost: 10053795
│ └── count(*)_46
├── cost: 10053796
├── rows: 10
└── Order
├── by:
│ ┌── desc
│ │ └── ref
│ │ └── rowcount
│ │ └── count(*)_46
│ └── desc
│ └── ref
│ └── count
│ └── o_orderkey
├── cost: 10053795
├── cost: 10053796
├── rows: 10
└── HashAgg
├── keys:ref
│ └── count
│ └── o_orderkey
├── aggs: [ rowcount ]
├── cost: 10053740
├── aggs: [ count(*)_46 ]
├── cost: 10053741
├── rows: 10
└── Projection
├── exprs:
Expand Down Expand Up @@ -2204,15 +2204,15 @@ Projection
│ ├── p_type
│ ├── p_size
│ └── ref
│ └── count-distinct
│ └── count_distinct
│ └── ps_suppkey
├── cost: 9952286
├── rows: 1000
└── Order
├── by:
│ ┌── desc
│ │ └── ref
│ │ └── count-distinct
│ │ └── count_distinct
│ │ └── ps_suppkey
│ ├── p_brand
│ ├── p_type
Expand All @@ -2221,7 +2221,7 @@ Projection
├── rows: 1000
└── HashAgg
├── keys: [ p_brand, p_type, p_size ]
├── aggs:count-distinct
├── aggs:count_distinct
│ └── ps_suppkey
├── cost: 9938269
├── rows: 1000
Expand Down Expand Up @@ -2961,20 +2961,20 @@ Projection
├── exprs:
│ ┌── s_name
│ └── ref
│ └── rowcount
├── cost: 124247200
│ └── count(*)_52
├── cost: 124265950
├── rows: 10
└── TopN
├── limit: 100
├── offset: 0
├── order_by:
│ ┌── desc
│ │ └── ref
│ │ └── rowcount
│ │ └── count(*)_52
│ └── s_name
├── cost: 124247200
├── cost: 124265950
├── rows: 10
└── HashAgg { keys: [ s_name ], aggs: [ rowcount ], cost: 124247144, rows: 10 }
└── HashAgg { keys: [ s_name ], aggs: [ count(*)_52 ], cost: 124265896, rows: 10 }
└── Projection { exprs: [ s_name ], cost: 124216260, rows: 187537.97 }
└── HashJoin
├── type: semi
Expand Down Expand Up @@ -3120,25 +3120,25 @@ Projection
│ ┌── ref
│ │ └── Substring { str: c_phone, start: 1, length: 2 }
│ ├── ref
│ │ └── rowcount
│ │ └── count(*)_95
│ └── ref
│ └── sum
│ └── c_acctbal
├── cost: 4399655
├── cost: 4403405
├── rows: 10
└── Order
├── by:ref
│ └── Substring { str: c_phone, start: 1, length: 2 }
├── cost: 4399654.5
├── cost: 4403404.5
├── rows: 10
└── HashAgg
├── keys:ref
│ └── Substring { str: c_phone, start: 1, length: 2 }
├── aggs:
│ ┌── rowcount
│ ┌── count(*)_95
│ └── sum
│ └── c_acctbal
├── cost: 4399590
├── cost: 4403340
├── rows: 10
└── Projection
├── exprs: [ Substring { str: c_phone, start: 1, length: 2 }, c_acctbal ]
Expand Down
8 changes: 8 additions & 0 deletions tests/sql/count.slt
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,14 @@ select count(*) as 'total' from t where v > 5
----
3

# count(*) from different tables should be distinct
query II
select c2, c1 from
(select count(*) as c1 from t where v <= 1),
(select count(*) as c2 from t);
----
8 1

statement ok
delete from t where v = 7

Expand Down
9 changes: 9 additions & 0 deletions tests/sql/window_function.slt
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,12 @@ SELECT a FROM t HAVING sum(a) OVER () > 0;

statement error window function calls cannot be nested
SELECT sum(sum(a) over ()) over () FROM t;

# row_number() from different tables should be distinct
query II rowsort
SELECT c2, c1 FROM
(SELECT row_number() OVER () as c1 FROM t WHERE a = 1),
(SELECT row_number() OVER () as c2 FROM t WHERE a > 1);
----
1 1
2 1
Loading