diff --git a/lib/sql_tools.rb b/lib/sql_tools.rb index ea783c6..cca485b 100644 --- a/lib/sql_tools.rb +++ b/lib/sql_tools.rb @@ -17,5 +17,8 @@ def parser = @parser ||= TreeStand::Parser.new("sql") end Table = Data.define(:name, :alias) - Column = Data.define(:table, :name) + Column = Data.define(:table, :name) do + def to_s = "#{table.name}.#{name}" + def inspect = to_s + end end diff --git a/lib/sql_tools/inner_join.rb b/lib/sql_tools/inner_join.rb new file mode 100644 index 0000000..7e32c67 --- /dev/null +++ b/lib/sql_tools/inner_join.rb @@ -0,0 +1,3 @@ +module SqlTools + InnerJoin = Data.define(:object, :predicate) +end diff --git a/lib/sql_tools/left_join.rb b/lib/sql_tools/left_join.rb new file mode 100644 index 0000000..82f9007 --- /dev/null +++ b/lib/sql_tools/left_join.rb @@ -0,0 +1,3 @@ +module SqlTools + LeftJoin = Data.define(:object, :predicate) +end diff --git a/lib/sql_tools/predicate.rb b/lib/sql_tools/predicate.rb index 72b3f8f..29af960 100644 --- a/lib/sql_tools/predicate.rb +++ b/lib/sql_tools/predicate.rb @@ -1,6 +1,9 @@ module SqlTools class Predicate - Binary = Struct.new(:left, :operator, :right) + Binary = Struct.new(:left, :operator, :right) do + def to_s = "(#{left} #{operator} #{right})" + def inspect = to_s + end class Builder def initialize(query) diff --git a/lib/sql_tools/predicate_filter.rb b/lib/sql_tools/predicate_filter.rb new file mode 100644 index 0000000..77a6d02 --- /dev/null +++ b/lib/sql_tools/predicate_filter.rb @@ -0,0 +1,35 @@ +module SqlTools + class PredicateFilter + def initialize(query) + @query = query + end + + def filter(object) + @stack = [] + filter_recursive(object, @query.predicate) + + right = @stack.pop + + while left = @stack.pop + predicate = Predicate::Binary.new(left, "AND", right) + right = predicate + end + + right + end + + private + + def filter_recursive(object, predicate) + case predicate + when Predicate::Binary + @stack << predicate if filter_recursive(object, predicate.left) || filter_recursive(object, predicate.right) + false + when SqlTools::Column + predicate.table == object + else + raise "Unknown predicate type: #{predicate}" + end + end + end +end diff --git a/lib/sql_tools/query.rb b/lib/sql_tools/query.rb index 510db3d..ecea578 100644 --- a/lib/sql_tools/query.rb +++ b/lib/sql_tools/query.rb @@ -1,10 +1,11 @@ module SqlTools class Query - attr_accessor :select, :from + attr_accessor :select, :from, :join_nodes attr_reader :common_table_expressions def initialize @common_table_expressions = {} + @join_nodes = [] end def selections @@ -46,6 +47,26 @@ def selections end end + def joins + join_nodes.map do |join_node| + object_name = join_node.find_node(<<~QUERY).text + (join + (relation + (object_reference name: (identifier) @object_name))) + QUERY + + filter = PredicateFilter.new(self) + object = object_alias_map[object_name] + predicate = filter.filter(object) + + if join_node.children.any? { |child| child.type == :keyword_left } + LeftJoin.new(object, predicate) + else + InnerJoin.new(object, predicate) + end + end + end + def relations objects.each_with_object({}) do |object, map| map[object] ||= [] @@ -56,8 +77,8 @@ def relations end end - def predicates - @predicates ||= begin + def predicate + @predicate ||= begin nodes = from.query(<<~QUERY).map { |captures| captures["predicate"] } (from (join @@ -71,7 +92,18 @@ def predicates visitor = PredicateVisitor.new(predicate).visit binding.b unless visitor.stack.size == 1 builder.build(visitor.stack.last) - end.to_set + end + + right = predicates.pop + + # This needs to pluck the left & right from binary expressions & rebuild the tree. + # TODO: maybe this is the rotate algorithm, TBD + while left = predicates.pop + predicate = Predicate::Binary.new(left, "AND", right) + right = predicate + end + + right end end diff --git a/lib/sql_tools/query_visitor.rb b/lib/sql_tools/query_visitor.rb index 7eaaf85..abc9bc1 100644 --- a/lib/sql_tools/query_visitor.rb +++ b/lib/sql_tools/query_visitor.rb @@ -20,5 +20,9 @@ def on_select(node) def on_from(node) @query.from = node end + + def on_join(node) + @query.join_nodes << node + end end end diff --git a/test/unit/join_test.rb b/test/unit/join_test.rb new file mode 100644 index 0000000..b506a11 --- /dev/null +++ b/test/unit/join_test.rb @@ -0,0 +1,148 @@ +# frozen_string_literal: true + +require "test_helper" + +module SqlTools + class JoinTest < Minitest::Test + def test_join_predicates + query = query_from_sql(<<~SQL) + SELECT * + FROM table_a a + JOIN table_b b + ON a.id = b.a_id + SQL + + expected = Predicate::Binary.new( + left: Column.new(table: Table.new("table_a", "a"), name: "id"), + operator: "=", + right: Column.new(table: Table.new("table_b", "b"), name: "a_id"), + ) + + assert_equal(1, query.joins.size) + assert_equal(expected, query.predicate) + assert_equal(expected, query.joins.first.predicate) + end + + def test_join_predicates_in_the_where + skip("bug in the parser on empty join") + query = query_from_sql(<<~SQL) + SELECT * + FROM table_a a + JOIN table_b b + WHERE a.id = b.a_id + SQL + + expected = Predicate::Binary.new( + left: Column.new(table: Table.new("table_a", "a"), name: "id"), + operator: "=", + right: Column.new(table: Table.new("table_b", "b"), name: "a_id"), + ) + + assert_equal(1, query.joins.size) + assert_equal(expected, query.predicate) + assert_equal(expected, query.joins.first.predicate) + end + + def test_join_predicate_spread_out + query = query_from_sql(<<~SQL) + SELECT * + FROM table_a a + JOIN table_b b + ON a.id = b.a_id + WHERE a.shard_id = b.shard_id + SQL + + expected = Predicate::Binary.new( + left: Predicate::Binary.new( + left: Column.new(table: Table.new("table_a", "a"), name: "id"), + operator: "=", + right: Column.new(table: Table.new("table_b", "b"), name: "a_id"), + ), + operator: "AND", + right: Predicate::Binary.new( + left: Column.new(table: Table.new("table_a", "a"), name: "shard_id"), + operator: "=", + right: Column.new(table: Table.new("table_b", "b"), name: "shard_id"), + ), + ) + + assert_equal(1, query.joins.size) + assert_equal(expected, query.predicate) + assert_equal(expected, query.joins.first.predicate) + end + + def test_many_join_predicates_spread_out + skip("predicate precedence is messed up here") + query = query_from_sql(<<~SQL) + SELECT * + FROM table_a a + JOIN table_b b + ON a.id = b.a_id + AND a.shard_id = b.shard_id + WHERE a.user_id = b.user_id + AND a.id = 4 + SQL + + assert_equal( + Predicate::Binary.new( + left: Predicate::Binary.new( + left: Predicate::Binary.new( + left: Predicate::Binary.new( + left: Column.new(table: Table.new("table_a", "a"), name: "id"), + operator: "=", + right: Column.new(table: Table.new("table_b", "b"), name: "a_id"), + ), + operator: "AND", + right: Predicate::Binary.new( + left: Column.new(table: Table.new("table_a", "a"), name: "shard_id"), + operator: "=", + right: Column.new(table: Table.new("table_b", "b"), name: "shard_id"), + ), + ), + operator: "AND", + right: Predicate::Binary.new( + left: Column.new(table: Table.new("table_a", "a"), name: "user_id"), + operator: "=", + right: Column.new(table: Table.new("table_b", "b"), name: "user_id"), + ), + ), + operator: "AND", + right: Predicate::Binary.new( + left: Column.new(table: Table.new("table_a", "a"), name: "id"), + operator: "=", + right: "4", + ), + ), + query.predicate, + ) + + assert_equal(1, query.joins.size) + assert_equal( + Predicate::Binary.new( + left: Predicate::Binary.new( + left: Predicate::Binary.new( + left: Column.new(table: Table.new("table_a", "a"), name: "id"), + operator: "=", + right: Column.new(table: Table.new("table_b", "b"), name: "a_id"), + ), + operator: "AND", + right: Predicate::Binary.new( + left: Column.new(table: Table.new("table_a", "a"), name: "shard_id"), + operator: "=", + right: Column.new(table: Table.new("table_b", "b"), name: "shard_id"), + ), + ), + operator: "AND", + right: Predicate::Binary.new( + left: Column.new(table: Table.new("table_a", "a"), name: "user_id"), + operator: "=", + right: Column.new(table: Table.new("table_b", "b"), name: "user_id"), + ), + ), + query.joins.first.predicate, + ) + end + end +end + + diff --git a/test/unit/predicate_test.rb b/test/unit/predicate_test.rb index 58fdad0..f4af4e2 100644 --- a/test/unit/predicate_test.rb +++ b/test/unit/predicate_test.rb @@ -11,14 +11,13 @@ def test_simple_predicate WHERE id = 1 SQL - assert_equal(1, query.predicates.size) assert_equal( Predicate::Binary.new( left: Column.new(table: Table.new("table", "table"), name: "id"), operator: "=", right: "1", ), - query.predicates.first, + query.predicate, ) end @@ -30,7 +29,6 @@ def test_multiple_predicates AND name = "derek" SQL - assert_equal(1, query.predicates.size) assert_equal( Predicate::Binary.new( left: Predicate::Binary.new( @@ -45,7 +43,60 @@ def test_multiple_predicates right: "\"derek\"", ), ), - query.predicates.first, + query.predicate, + ) + end + + def test_many_predicates_with_precedence + query = query_from_sql(<<~SQL) + SELECT * + FROM table + WHERE id = 1 + AND name = "derek" + AND phone = "555-0000" + OR id >= 2 + AND name = "stride" + SQL + + assert_equal( + Predicate::Binary.new( + left: Predicate::Binary.new( + left: Predicate::Binary.new( + left: Predicate::Binary.new( + left: Column.new(table: Table.new("table", "table"), name: "id"), + operator: "=", + right: "1", + ), + operator: "AND", + right: Predicate::Binary.new( + left: Column.new(table: Table.new("table", "table"), name: "name"), + operator: "=", + right: "\"derek\"", + ), + ), + operator: "AND", + right: Predicate::Binary.new( + left: Column.new(table: Table.new("table", "table"), name: "phone"), + operator: "=", + right: "\"555-0000\"", + ), + ), + operator: "OR", + right: Predicate::Binary.new( + left: Predicate::Binary.new( + left: Column.new(table: Table.new("table", "table"), name: "id"), + operator: ">=", + right: "2", + ), + operator: "AND", + right: Predicate::Binary.new( + left: Column.new(table: Table.new("table", "table"), name: "name"), + operator: "=", + right: "\"stride\"", + ), + ), + ), + query.predicate, ) end @@ -59,7 +110,6 @@ def test_multiple_predicates_with_precedence AND name = "stride" SQL - assert_equal(1, query.predicates.size) assert_equal( Predicate::Binary.new( left: Predicate::Binary.new( @@ -90,7 +140,7 @@ def test_multiple_predicates_with_precedence ), ), ), - query.predicates.first, + query.predicate, ) end end