diff --git a/core/src/main/java/com/alibaba/druid/sql/ast/statement/SQLMergeStatement.java b/core/src/main/java/com/alibaba/druid/sql/ast/statement/SQLMergeStatement.java index 4059c49b11..3e8a873277 100644 --- a/core/src/main/java/com/alibaba/druid/sql/ast/statement/SQLMergeStatement.java +++ b/core/src/main/java/com/alibaba/druid/sql/ast/statement/SQLMergeStatement.java @@ -138,6 +138,21 @@ public WhenUpdate cloneTo() { cloneTo(x); return x; } + + @Override + public boolean replace(SQLExpr expr, SQLExpr target) { + boolean isSuccess = false; + if (expr instanceof SQLUpdateSetItem && target instanceof SQLUpdateSetItem) { + for (int i = 0; i < items.size(); i++) { + if (items.get(i) == expr) { + target.setParent(this); + items.set(i, (SQLUpdateSetItem) target); + isSuccess = true; + } + } + } + return isSuccess || super.replace(expr, target); + } } public static class WhenInsert extends When { @@ -203,6 +218,27 @@ public WhenInsert cloneTo() { cloneTo(x); return x; } + + @Override + public boolean replace(SQLExpr expr, SQLExpr target) { + boolean isSuccess = false; + for (int i = 0; i < columns.size(); i++) { + if (columns.get(i) == expr) { + target.setParent(this); + columns.set(i, target); + isSuccess = true; + } + } + + for (int i = 0; i < values.size(); i++) { + if (values.get(i) == expr) { + target.setParent(this); + values.set(i, target); + isSuccess = true; + } + } + return isSuccess || super.replace(expr, target); + } } public static class WhenDelete extends When { @@ -228,7 +264,7 @@ public WhenDelete cloneTo() { } } - public abstract static class When extends SQLObjectImpl { + public abstract static class When extends SQLObjectImpl implements SQLReplaceable { protected boolean not; protected SQLName by; protected SQLExpr where; @@ -281,5 +317,14 @@ public void setWhere(SQLExpr x) { } this.where = x; } + + public boolean replace(SQLExpr expr, SQLExpr target) { + if (this.where == expr) { + target.setParent(this); + this.where = target; + return true; + } + return false; + } } } diff --git a/core/src/test/java/com/alibaba/druid/bvt/sql/ReplaceTest.java b/core/src/test/java/com/alibaba/druid/bvt/sql/ReplaceTest.java new file mode 100644 index 0000000000..c08e50aff3 --- /dev/null +++ b/core/src/test/java/com/alibaba/druid/bvt/sql/ReplaceTest.java @@ -0,0 +1,28 @@ +package com.alibaba.druid.bvt.sql; + +import com.alibaba.druid.sql.SQLUtils; +import com.alibaba.druid.sql.ast.expr.SQLIdentifierExpr; +import com.alibaba.druid.sql.ast.expr.SQLMethodInvokeExpr; +import com.alibaba.druid.sql.ast.statement.SQLMergeStatement; +import org.junit.Assert; +import org.junit.Test; + + +public class ReplaceTest { + @Test + public void test_when() { + SQLMergeStatement.WhenInsert when = new SQLMergeStatement.WhenInsert(); + when.addColumn(new SQLIdentifierExpr("id")); + when.addColumn(new SQLIdentifierExpr("id2")); + SQLIdentifierExpr current_timestamp_identifier1 = new SQLIdentifierExpr("current_timestamp"); + SQLIdentifierExpr current_timestamp_identifier2 = new SQLIdentifierExpr("current_timestamp"); + when.addValue(current_timestamp_identifier1); + when.addValue(current_timestamp_identifier2); + SQLMethodInvokeExpr current_timestamp_method1 = new SQLMethodInvokeExpr("current_timestamp"); + SQLMethodInvokeExpr current_timestamp_method2 = new SQLMethodInvokeExpr("current_timestamp"); + SQLUtils.replaceInParent(current_timestamp_identifier1, current_timestamp_method1); + SQLUtils.replaceInParent(current_timestamp_identifier2, current_timestamp_method2); + Assert.assertEquals(current_timestamp_method1, when.getValues().get(0)); + Assert.assertEquals(current_timestamp_method2, when.getValues().get(1)); + } +}