Skip to content

Commit

Permalink
fix(mc2mc): udfs inside merge load method (#68)
Browse files Browse the repository at this point in the history
* fix: udfs inside merge load method

* fix: udf and vars in single separator function
  • Loading branch information
deryrahman authored Feb 5, 2025
1 parent dcc06d8 commit c4f5ec2
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 20 deletions.
10 changes: 6 additions & 4 deletions mc2mc/internal/query/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,12 @@ func (b *Builder) Build() (string, error) {
if b.method == MERGE {
query := RemoveComments(b.query)
hr, query := SeparateHeadersAndQuery(query)
vars, query := SeparateVariablesAndQuery(query)
varsAndUDFs, query := SeparateVariablesUDFsAndQuery(query)
queries := semicolonPattern.Split(query, -1)
if len(queries) <= 1 {
return b.query, nil
}
query = b.constructMergeQuery(hr, vars, queries)
query = b.constructMergeQuery(hr, varsAndUDFs, queries)
return query, nil
}

Expand Down Expand Up @@ -180,7 +180,7 @@ func (b *Builder) constructOverridedValues(query string) (string, error) {
}

// constructMergeQueries constructs merge queries with headers and variables
func (b *Builder) constructMergeQuery(hr, vars string, queries []string) string {
func (b *Builder) constructMergeQuery(hr, varsAndUDFs string, queries []string) string {
builder := strings.Builder{}
for i, q := range queries {
q = strings.TrimSpace(q)
Expand All @@ -189,7 +189,9 @@ func (b *Builder) constructMergeQuery(hr, vars string, queries []string) string
}
builder.WriteString(fmt.Sprintf("%s\n", hr))
if !IsDDL(q) {
builder.WriteString(fmt.Sprintf("%s\n", vars))
if varsAndUDFs != "" {
builder.WriteString(fmt.Sprintf("%s\n", varsAndUDFs))
}
}
builder.WriteString(fmt.Sprintf("%s;", q))
if i < len(queries)-1 {
Expand Down
69 changes: 69 additions & 0 deletions mc2mc/internal/query/builder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,75 @@ MERGE INTO append_test
USING (SELECT * FROM @src) source
on append_test.id = source.id
WHEN MATCHED THEN UPDATE
SET append_test.id = 2;`, query)
})
t.Run("returns query for merge load method with multiple dml and ddl and contains function", func(t *testing.T) {
queryToExecute := `SET odps.table.append2.enable=true;
CREATE TABLE IF NOT EXISTS append_test (id bigint)
TBLPROPERTIES('table.format.version'='2');
FUNCTION castStringToBoolean (@field STRING) AS CASE
WHEN TOLOWER(@field) = '1.0' THEN true
WHEN TOLOWER(@field) = '0.0' THEN false
WHEN TOLOWER(@field) = '1' THEN true
WHEN TOLOWER(@field) = '0' THEN false
WHEN TOLOWER(@field) = 'true' THEN true
WHEN TOLOWER(@field) = 'false' THEN false
END;
function my_add(@a BIGINT) as @a + 1;
INSERT OVERWRITE TABLE append_test VALUES(0),(1);
@src := SELECT my_add(1) id;
MERGE INTO append_test
USING (SELECT castStringToBoolean(id) FROM @src) source
on append_test.id = source.id
WHEN MATCHED THEN UPDATE
SET append_test.id = 2;`
odspClient := &mockOdpsClient{}

query, err := query.NewBuilder(
logger.NewDefaultLogger(),
odspClient,
query.WithQuery(queryToExecute),
query.WithMethod(query.MERGE),
).Build()
assert.NoError(t, err)
assert.Equal(t, `SET odps.table.append2.enable=true;
CREATE TABLE IF NOT EXISTS append_test (id bigint)
TBLPROPERTIES('table.format.version'='2');
--*--optimus-break-marker--*--
SET odps.table.append2.enable=true;
FUNCTION castStringToBoolean (@field STRING) AS CASE
WHEN TOLOWER(@field) = '1.0' THEN true
WHEN TOLOWER(@field) = '0.0' THEN false
WHEN TOLOWER(@field) = '1' THEN true
WHEN TOLOWER(@field) = '0' THEN false
WHEN TOLOWER(@field) = 'true' THEN true
WHEN TOLOWER(@field) = 'false' THEN false
END;
function my_add(@a BIGINT) as @a + 1;
@src := SELECT my_add(1) id;
INSERT OVERWRITE TABLE append_test VALUES(0),(1);
--*--optimus-break-marker--*--
SET odps.table.append2.enable=true;
FUNCTION castStringToBoolean (@field STRING) AS CASE
WHEN TOLOWER(@field) = '1.0' THEN true
WHEN TOLOWER(@field) = '0.0' THEN false
WHEN TOLOWER(@field) = '1' THEN true
WHEN TOLOWER(@field) = '0' THEN false
WHEN TOLOWER(@field) = 'true' THEN true
WHEN TOLOWER(@field) = 'false' THEN false
END;
function my_add(@a BIGINT) as @a + 1;
@src := SELECT my_add(1) id;
MERGE INTO append_test
USING (SELECT castStringToBoolean(id) FROM @src) source
on append_test.id = source.id
WHEN MATCHED THEN UPDATE
SET append_test.id = 2;`, query)
})
}
Expand Down
23 changes: 12 additions & 11 deletions mc2mc/internal/query/helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ var (
multiCommentPattern = regexp.MustCompile(`(?s)\s*/\*.*?\*/\s*\n?`) // regex to match multi-line comments
headerPattern = regexp.MustCompile(`(?i)^set`) // regex to match header statements
variablePattern = regexp.MustCompile(`(?i)^@`) // regex to match variable statements
udfPattern = regexp.MustCompile(`(?i)^function\s+`) // regex to match UDF statements
ddlPattern = regexp.MustCompile(`(?i)^CREATE\s+`) // regex to match DDL statements
)

Expand Down Expand Up @@ -53,8 +54,8 @@ func SeparateHeadersAndQuery(query string) (string, string) {
return headerStr, queryStr
}

func SeparateVariablesAndQuery(query string) (string, string) {
variables := []string{}
func SeparateVariablesUDFsAndQuery(query string) (string, string) {
variablesAndUDFs := []string{}
query = strings.TrimSpace(query)
remainingQueries := []string{}

Expand All @@ -66,26 +67,26 @@ func SeparateVariablesAndQuery(query string) (string, string) {
continue
}
stmtWithoutComment := commentPattern.ReplaceAllString(stmt, "")
if variablePattern.MatchString(stmtWithoutComment) {
variables = append(variables, stmt)
if variablePattern.MatchString(stmtWithoutComment) || udfPattern.MatchString(stmtWithoutComment) {
variablesAndUDFs = append(variablesAndUDFs, stmt)
} else {
remainingQueries = append(remainingQueries, stmt)
}
}

variableStr := ""
if len(variables) > 0 {
for i, variable := range variables {
variables[i] = strings.TrimSpace(variable)
variableUDFStr := ""
if len(variablesAndUDFs) > 0 {
for i, variable := range variablesAndUDFs {
variablesAndUDFs[i] = strings.TrimSpace(variable)
}
variableStr = strings.Join(variables, ";\n")
variableStr += ";"
variableUDFStr = strings.Join(variablesAndUDFs, ";\n")
variableUDFStr += ";"
}

// join the remaining queries back together
queryStr := strings.Join(remainingQueries, ";\n")

return variableStr, queryStr
return variableUDFStr, queryStr
}

func RemoveComments(query string) string {
Expand Down
47 changes: 42 additions & 5 deletions mc2mc/internal/query/helper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,14 +107,14 @@ where CAST(event_timestamp as DATE) = '{{ .DSTART | Date }}'
})
}

func TestSeparateVariablesAndQuery(t *testing.T) {
func TestSeparateVariablesUDFsAndQuery(t *testing.T) {
t.Run("returns query without variables", func(t *testing.T) {
q1 := `MERGE INTO append_test
USING (SELECT * FROM @src) source
on append_test.id = source.id
WHEN MATCHED THEN UPDATE
SET append_test.id = 2;`
variables, query := query.SeparateVariablesAndQuery(q1)
variables, query := query.SeparateVariablesUDFsAndQuery(q1)
assert.Empty(t, variables)
assert.Equal(t, `MERGE INTO append_test
USING (SELECT * FROM @src) source
Expand All @@ -130,7 +130,7 @@ on append_test.id = source.id
WHEN MATCHED THEN UPDATE
SET append_test.id = 2;`

variables, query := query.SeparateVariablesAndQuery(q1)
variables, query := query.SeparateVariablesUDFsAndQuery(q1)
assert.Empty(t, variables)
assert.Equal(t, `MERGE INTO append_test
USING (SELECT * FROM @src) source
Expand All @@ -145,7 +145,7 @@ USING (SELECT * FROM @src) source
on append_test.id = source.id
WHEN MATCHED THEN UPDATE
SET append_test.id = 2;`
variables, query := query.SeparateVariablesAndQuery(q1)
variables, query := query.SeparateVariablesUDFsAndQuery(q1)
assert.Equal(t, "@src := SELECT 1 id;", variables)
assert.Equal(t, `MERGE INTO append_test
USING (SELECT * FROM @src) source
Expand All @@ -170,7 +170,7 @@ USING (SELECT * FROM @src2) source
on append_test.id = source.id
WHEN MATCHED THEN UPDATE
SET append_test.id = 3;`
variables, query := query.SeparateVariablesAndQuery(q1)
variables, query := query.SeparateVariablesUDFsAndQuery(q1)
assert.Equal(t, `@src := SELECT id
FROM src_table
WHERE id = 1;
Expand All @@ -186,6 +186,43 @@ MERGE INTO append_test
USING (SELECT * FROM @src2) source
on append_test.id = source.id
WHEN MATCHED THEN UPDATE
SET append_test.id = 3`, query)
})
t.Run("splits multiline variables + udfs and queries", func(t *testing.T) {
q1 := `@src := SELECT id
FROM src_table
WHERE id = 1;
function my_add(@a BIGINT) as @a + 1;
@src2 := SELECT id
FROM src_table
WHERE id = 2;
MERGE INTO append_test
USING (SELECT * FROM @src) source
on append_test.id = source.id
WHEN MATCHED THEN UPDATE
SET append_test.id = 2;
MERGE INTO append_test
USING (SELECT * FROM @src2) source
on append_test.id = source.id
WHEN MATCHED THEN UPDATE
SET append_test.id = 3;`
variables, query := query.SeparateVariablesUDFsAndQuery(q1)
assert.Equal(t, `@src := SELECT id
FROM src_table
WHERE id = 1;
function my_add(@a BIGINT) as @a + 1;
@src2 := SELECT id
FROM src_table
WHERE id = 2;`, variables)
assert.Equal(t, `MERGE INTO append_test
USING (SELECT * FROM @src) source
on append_test.id = source.id
WHEN MATCHED THEN UPDATE
SET append_test.id = 2;
MERGE INTO append_test
USING (SELECT * FROM @src2) source
on append_test.id = source.id
WHEN MATCHED THEN UPDATE
SET append_test.id = 3`, query)
})
}
Expand Down

0 comments on commit c4f5ec2

Please sign in to comment.