Skip to content

Commit

Permalink
feat: JOINs
Browse files Browse the repository at this point in the history
still some bugs with precedence that need to be worked out
  • Loading branch information
DerekStride committed Jul 1, 2024
1 parent 728d20f commit f6b8a7a
Show file tree
Hide file tree
Showing 9 changed files with 293 additions and 12 deletions.
5 changes: 4 additions & 1 deletion lib/sql_tools.rb
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,8 @@ def parser = @parser ||= TreeStand::Parser.new("sql")
end

Table = Data.define(:name, :alias)
Column = Data.define(:table, :name)
Column = Data.define(:table, :name) do
def to_s = "#{table.name}.#{name}"
def inspect = to_s
end
end
3 changes: 3 additions & 0 deletions lib/sql_tools/inner_join.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
module SqlTools
InnerJoin = Data.define(:object, :predicate)
end
3 changes: 3 additions & 0 deletions lib/sql_tools/left_join.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
module SqlTools
LeftJoin = Data.define(:object, :predicate)
end
5 changes: 4 additions & 1 deletion lib/sql_tools/predicate.rb
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
module SqlTools
class Predicate
Binary = Struct.new(:left, :operator, :right)
Binary = Struct.new(:left, :operator, :right) do
def to_s = "(#{left} #{operator} #{right})"
def inspect = to_s
end

class Builder
def initialize(query)
Expand Down
35 changes: 35 additions & 0 deletions lib/sql_tools/predicate_filter.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
module SqlTools
class PredicateFilter
def initialize(query)
@query = query
end

def filter(object)
@stack = []
filter_recursive(object, @query.predicate)

right = @stack.pop

while left = @stack.pop
predicate = Predicate::Binary.new(left, "AND", right)
right = predicate
end

right
end

private

def filter_recursive(object, predicate)
case predicate
when Predicate::Binary
@stack << predicate if filter_recursive(object, predicate.left) || filter_recursive(object, predicate.right)
false
when SqlTools::Column
predicate.table == object
else
raise "Unknown predicate type: #{predicate}"
end
end
end
end
40 changes: 36 additions & 4 deletions lib/sql_tools/query.rb
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
module SqlTools
class Query
attr_accessor :select, :from
attr_accessor :select, :from, :join_nodes
attr_reader :common_table_expressions

def initialize
@common_table_expressions = {}
@join_nodes = []
end

def selections
Expand Down Expand Up @@ -46,6 +47,26 @@ def selections
end
end

def joins
join_nodes.map do |join_node|
object_name = join_node.find_node(<<~QUERY).text
(join
(relation
(object_reference name: (identifier) @object_name)))
QUERY

filter = PredicateFilter.new(self)
object = object_alias_map[object_name]
predicate = filter.filter(object)

if join_node.children.any? { |child| child.type == :keyword_left }
LeftJoin.new(object, predicate)
else
InnerJoin.new(object, predicate)
end
end
end

def relations
objects.each_with_object({}) do |object, map|
map[object] ||= []
Expand All @@ -56,8 +77,8 @@ def relations
end
end

def predicates
@predicates ||= begin
def predicate
@predicate ||= begin
nodes = from.query(<<~QUERY).map { |captures| captures["predicate"] }
(from
(join
Expand All @@ -71,7 +92,18 @@ def predicates
visitor = PredicateVisitor.new(predicate).visit
binding.b unless visitor.stack.size == 1
builder.build(visitor.stack.last)
end.to_set
end

right = predicates.pop

# This needs to pluck the left & right from binary expressions & rebuild the tree.
# TODO: maybe this is the rotate algorithm, TBD
while left = predicates.pop
predicate = Predicate::Binary.new(left, "AND", right)
right = predicate
end

right
end
end

Expand Down
4 changes: 4 additions & 0 deletions lib/sql_tools/query_visitor.rb
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,9 @@ def on_select(node)
def on_from(node)
@query.from = node
end

def on_join(node)
@query.join_nodes << node
end
end
end
148 changes: 148 additions & 0 deletions test/unit/join_test.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
# frozen_string_literal: true

require "test_helper"

module SqlTools
class JoinTest < Minitest::Test
def test_join_predicates
query = query_from_sql(<<~SQL)
SELECT *
FROM table_a a
JOIN table_b b
ON a.id = b.a_id
SQL

expected = Predicate::Binary.new(
left: Column.new(table: Table.new("table_a", "a"), name: "id"),
operator: "=",
right: Column.new(table: Table.new("table_b", "b"), name: "a_id"),
)

assert_equal(1, query.joins.size)
assert_equal(expected, query.predicate)
assert_equal(expected, query.joins.first.predicate)
end

def test_join_predicates_in_the_where
skip("bug in the parser on empty join")
query = query_from_sql(<<~SQL)
SELECT *
FROM table_a a
JOIN table_b b
WHERE a.id = b.a_id
SQL

expected = Predicate::Binary.new(
left: Column.new(table: Table.new("table_a", "a"), name: "id"),
operator: "=",
right: Column.new(table: Table.new("table_b", "b"), name: "a_id"),
)

assert_equal(1, query.joins.size)
assert_equal(expected, query.predicate)
assert_equal(expected, query.joins.first.predicate)
end

def test_join_predicate_spread_out
query = query_from_sql(<<~SQL)
SELECT *
FROM table_a a
JOIN table_b b
ON a.id = b.a_id
WHERE a.shard_id = b.shard_id
SQL

expected = Predicate::Binary.new(
left: Predicate::Binary.new(
left: Column.new(table: Table.new("table_a", "a"), name: "id"),
operator: "=",
right: Column.new(table: Table.new("table_b", "b"), name: "a_id"),
),
operator: "AND",
right: Predicate::Binary.new(
left: Column.new(table: Table.new("table_a", "a"), name: "shard_id"),
operator: "=",
right: Column.new(table: Table.new("table_b", "b"), name: "shard_id"),
),
)

assert_equal(1, query.joins.size)
assert_equal(expected, query.predicate)
assert_equal(expected, query.joins.first.predicate)
end

def test_many_join_predicates_spread_out
skip("predicate precedence is messed up here")
query = query_from_sql(<<~SQL)
SELECT *
FROM table_a a
JOIN table_b b
ON a.id = b.a_id
AND a.shard_id = b.shard_id
WHERE a.user_id = b.user_id
AND a.id = 4
SQL

assert_equal(
Predicate::Binary.new(
left: Predicate::Binary.new(
left: Predicate::Binary.new(
left: Predicate::Binary.new(
left: Column.new(table: Table.new("table_a", "a"), name: "id"),
operator: "=",
right: Column.new(table: Table.new("table_b", "b"), name: "a_id"),
),
operator: "AND",
right: Predicate::Binary.new(
left: Column.new(table: Table.new("table_a", "a"), name: "shard_id"),
operator: "=",
right: Column.new(table: Table.new("table_b", "b"), name: "shard_id"),
),
),
operator: "AND",
right: Predicate::Binary.new(
left: Column.new(table: Table.new("table_a", "a"), name: "user_id"),
operator: "=",
right: Column.new(table: Table.new("table_b", "b"), name: "user_id"),
),
),
operator: "AND",
right: Predicate::Binary.new(
left: Column.new(table: Table.new("table_a", "a"), name: "id"),
operator: "=",
right: "4",
),
),
query.predicate,
)

assert_equal(1, query.joins.size)
assert_equal(
Predicate::Binary.new(
left: Predicate::Binary.new(
left: Predicate::Binary.new(
left: Column.new(table: Table.new("table_a", "a"), name: "id"),
operator: "=",
right: Column.new(table: Table.new("table_b", "b"), name: "a_id"),
),
operator: "AND",
right: Predicate::Binary.new(
left: Column.new(table: Table.new("table_a", "a"), name: "shard_id"),
operator: "=",
right: Column.new(table: Table.new("table_b", "b"), name: "shard_id"),
),
),
operator: "AND",
right: Predicate::Binary.new(
left: Column.new(table: Table.new("table_a", "a"), name: "user_id"),
operator: "=",
right: Column.new(table: Table.new("table_b", "b"), name: "user_id"),
),
),
query.joins.first.predicate,
)
end
end
end


62 changes: 56 additions & 6 deletions test/unit/predicate_test.rb
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,13 @@ def test_simple_predicate
WHERE id = 1
SQL

assert_equal(1, query.predicates.size)
assert_equal(
Predicate::Binary.new(
left: Column.new(table: Table.new("table", "table"), name: "id"),
operator: "=",
right: "1",
),
query.predicates.first,
query.predicate,
)
end

Expand All @@ -30,7 +29,6 @@ def test_multiple_predicates
AND name = "derek"
SQL

assert_equal(1, query.predicates.size)
assert_equal(
Predicate::Binary.new(
left: Predicate::Binary.new(
Expand All @@ -45,7 +43,60 @@ def test_multiple_predicates
right: "\"derek\"",
),
),
query.predicates.first,
query.predicate,
)
end

def test_many_predicates_with_precedence
query = query_from_sql(<<~SQL)
SELECT *
FROM table
WHERE id = 1
AND name = "derek"
AND phone = "555-0000"
OR id >= 2
AND name = "stride"
SQL

assert_equal(
Predicate::Binary.new(
left: Predicate::Binary.new(
left: Predicate::Binary.new(
left: Predicate::Binary.new(
left: Column.new(table: Table.new("table", "table"), name: "id"),
operator: "=",
right: "1",
),
operator: "AND",
right: Predicate::Binary.new(
left: Column.new(table: Table.new("table", "table"), name: "name"),
operator: "=",
right: "\"derek\"",
),
),
operator: "AND",
right: Predicate::Binary.new(
left: Column.new(table: Table.new("table", "table"), name: "phone"),
operator: "=",
right: "\"555-0000\"",
),
),
operator: "OR",
right: Predicate::Binary.new(
left: Predicate::Binary.new(
left: Column.new(table: Table.new("table", "table"), name: "id"),
operator: ">=",
right: "2",
),
operator: "AND",
right: Predicate::Binary.new(
left: Column.new(table: Table.new("table", "table"), name: "name"),
operator: "=",
right: "\"stride\"",
),
),
),
query.predicate,
)
end

Expand All @@ -59,7 +110,6 @@ def test_multiple_predicates_with_precedence
AND name = "stride"
SQL

assert_equal(1, query.predicates.size)
assert_equal(
Predicate::Binary.new(
left: Predicate::Binary.new(
Expand Down Expand Up @@ -90,7 +140,7 @@ def test_multiple_predicates_with_precedence
),
),
),
query.predicates.first,
query.predicate,
)
end
end
Expand Down

0 comments on commit f6b8a7a

Please sign in to comment.