Skip to content

Commit

Permalink
feat: implement mapping for HashBucketFunctionTransformer
Browse files Browse the repository at this point in the history
  • Loading branch information
colin-sentry committed Dec 30, 2024
1 parent cc646b9 commit cc41bc3
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,9 @@ query_processors:
- quantileTDigestWeighted
- processor: HashBucketFunctionTransformer
args:
hash_bucket_names:
- attr_str
- attr_num
hash_bucket_name_mapping:
attr_str: attr_str
attr_num: attr_num

validate_data_model: do_nothing # in order to reference aliased columns, we shouldn't validate columns purely based on the entity schema
validators:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,12 @@ query_processors:
time_parse_columns:
- start_timestamp
- end_timestamp
- processor: HashBucketFunctionTransformer
args:
hash_bucket_name_mapping:
attr_str: attr_str
attr_f64: attr_num
attr_i64: attr_num
- processor: OptionalAttributeAggregationTransformer
args:
attribute_column_names:
Expand All @@ -108,11 +114,6 @@ query_processors:
curried_aggregation_names:
- quantile
- quantileTDigestWeighted
- processor: HashBucketFunctionTransformer
args:
hash_bucket_names:
- attr_str
- attr_num

validate_data_model: do_nothing # in order to reference aliased columns, we shouldn't validate columns purely based on the entity schema
validators:
Expand Down
22 changes: 12 additions & 10 deletions snuba/query/processors/logical/hash_bucket_functions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Sequence
from typing import Mapping

from snuba.query.expressions import Column, Expression, FunctionCall, Literal
from snuba.query.logical import Query
Expand All @@ -22,11 +22,9 @@ class HashBucketFunctionTransformer(LogicalQueryProcessor):
It converts mapExists(attr_str, 'blah') to mapExists(attr_str_{hash('blah')%20}, 'blah')
"""

def __init__(
self,
hash_bucket_names: Sequence[str],
):
self.hash_bucket_names = hash_bucket_names
def __init__(self, hash_bucket_name_mapping: Mapping[str, str]):
super().__init__()
self.hash_bucket_name_mapping = hash_bucket_name_mapping

def process_query(self, query: Query, query_settings: QuerySettings) -> None:
def transform_map_keys_and_values_expression(exp: Expression) -> Expression:
Expand All @@ -40,7 +38,7 @@ def transform_map_keys_and_values_expression(exp: Expression) -> Expression:
if not isinstance(param, Column):
return exp

if param.column_name not in self.hash_bucket_names:
if param.column_name not in self.hash_bucket_name_mapping:
return exp

if exp.function_name not in ("mapKeys", "mapValues"):
Expand All @@ -56,7 +54,7 @@ def transform_map_keys_and_values_expression(exp: Expression) -> Expression:
parameters=(
Column(
None,
column_name=f"{param.column_name}_{i}",
column_name=f"{self.hash_bucket_name_mapping[param.column_name]}_{i}",
table_name=param.table_name,
),
),
Expand All @@ -76,7 +74,7 @@ def transform_map_contains_expression(exp: Expression) -> Expression:
if not isinstance(column, Column):
return exp

if column.column_name not in self.hash_bucket_names:
if column.column_name not in self.hash_bucket_name_mapping:
return exp

if exp.function_name != "mapContains":
Expand All @@ -91,7 +89,11 @@ def transform_map_contains_expression(exp: Expression) -> Expression:
alias=exp.alias,
function_name=exp.function_name,
parameters=(
Column(None, None, f"{column.column_name}_{bucket_idx}"),
Column(
None,
None,
f"{self.hash_bucket_name_mapping[column.column_name]}_{bucket_idx}",
),
key,
),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@
condition=binary_condition(
"or",
f.mapContains(column("attr_str"), literal("blah"), alias="x"),
f.mapContains(column("attr_i64"), literal("blah"), alias="y"),
f.mapContains(column("attr_strz"), literal("blah"), alias="z"),
),
),
Expand All @@ -210,6 +211,7 @@
condition=binary_condition(
"or",
f.mapContains(column("attr_str_2"), literal("blah"), alias="x"),
f.mapContains(column("attr_num_2"), literal("blah"), alias="y"),
f.mapContains(column("attr_strz"), literal("blah"), alias="z"),
),
),
Expand All @@ -220,7 +222,9 @@
@pytest.mark.parametrize("pre_format, expected_query", test_data)
def test_format_expressions(pre_format: Query, expected_query: Query) -> None:
copy = deepcopy(pre_format)
HashBucketFunctionTransformer("attr_str").process_query(copy, HTTPQuerySettings())
HashBucketFunctionTransformer(
{"attr_str": "attr_str", "attr_i64": "attr_num"}
).process_query(copy, HTTPQuerySettings())
assert copy.get_selected_columns() == expected_query.get_selected_columns()
assert copy.get_groupby() == expected_query.get_groupby()
assert copy.get_condition() == expected_query.get_condition()

0 comments on commit cc41bc3

Please sign in to comment.