Skip to content

Commit

Permalink
Merge pull request #445 from dolthub/zachmu/issue-439
Browse files Browse the repository at this point in the history
Allow limits and offsets to use value args (? in prepared statements)
  • Loading branch information
zachmu authored Jun 2, 2021
2 parents 643401f + ca21092 commit 9dbddef
Show file tree
Hide file tree
Showing 17 changed files with 263 additions and 62 deletions.
12 changes: 6 additions & 6 deletions engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,18 +152,18 @@ func (e *Engine) QueryNodeWithBindings(
return nil, nil, err
}

analyzed, err = e.Analyzer.Analyze(ctx, parsed, nil)
if err != nil {
return nil, nil, err
}

if len(bindings) > 0 {
analyzed, err = plan.ApplyBindings(analyzed, bindings)
parsed, err = plan.ApplyBindings(parsed, bindings)
if err != nil {
return nil, nil, err
}
}

analyzed, err = e.Analyzer.Analyze(ctx, parsed, nil)
if err != nil {
return nil, nil, err
}

transactionDatabase, err := e.beginTransaction(ctx, parsed)
if err != nil {
return nil, nil, err
Expand Down
21 changes: 20 additions & 1 deletion enginetest/enginetests.go
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ func TestQueryErrors(t *testing.T, harness Harness) {
t.Skipf("skipping query %s", tt.Query)
}
}
AssertErr(t, engine, harness, tt.Query, tt.ExpectedErr)
AssertErrWithBindings(t, engine, harness, tt.Query, tt.Bindings, tt.ExpectedErr)
})
}
}
Expand Down Expand Up @@ -2864,6 +2864,25 @@ func AssertErr(t *testing.T, e *sqle.Engine, harness Harness, query string, expe
AssertErrWithCtx(t, e, NewContext(harness), query, expectedErrKind, errStrs...)
}

// AssertErrWithBindings asserts that the given query returns an error during its execution, optionally specifying a
// type of error.
func AssertErrWithBindings(t *testing.T, e *sqle.Engine, harness Harness, query string, bindings map[string]sql.Expression, expectedErrKind *errors.Kind, errStrs ...string) {
ctx := NewContext(harness)
_, iter, err := e.QueryWithBindings(ctx, query, bindings)
if err == nil {
_, err = sql.RowIterToRows(ctx, iter)
}
require.Error(t, err)
if expectedErrKind != nil {
require.True(t, expectedErrKind.Is(err), "Expected error of type %s but got %s", expectedErrKind, err)
}
// If there are multiple error strings then we only match against the first
if len(errStrs) >= 1 {
require.Equal(t, errStrs[0], err.Error())
}

}

// AssertErrWithCtx is the same as AssertErr, but uses the context given instead of creating one from a harness
func AssertErrWithCtx(t *testing.T, e *sqle.Engine, ctx *sql.Context, query string, expectedErrKind *errors.Kind, errStrs ...string) {
_, iter, err := e.Query(ctx, query)
Expand Down
48 changes: 48 additions & 0 deletions enginetest/queries.go
Original file line number Diff line number Diff line change
Expand Up @@ -1221,10 +1221,29 @@ var QueryTests = []QueryTest{
Query: "SELECT i FROM mytable WHERE s = 'first row' ORDER BY i DESC LIMIT 1;",
Expected: []sql.Row{{int64(1)}},
},
{
Query: "SELECT i FROM mytable WHERE s = 'first row' ORDER BY i DESC LIMIT 0;",
Expected: []sql.Row{},
},
{
Query: "SELECT i FROM mytable ORDER BY i LIMIT 1 OFFSET 1;",
Expected: []sql.Row{{int64(2)}},
},
{
Query: "SELECT i FROM mytable WHERE s = 'first row' ORDER BY i DESC LIMIT ?;",
Bindings: map[string]sql.Expression{
"v1": expression.NewLiteral(1, sql.Int8),
},
Expected: []sql.Row{{int64(1)}},
},
{
Query: "SELECT i FROM mytable ORDER BY i LIMIT ? OFFSET 2;",
Bindings: map[string]sql.Expression{
"v1": expression.NewLiteral(1, sql.Int8),
"v2": expression.NewLiteral(1, sql.Int8),
},
Expected: []sql.Row{{int64(3)}},
},
{
Query: "SELECT i FROM mytable WHERE i NOT IN (SELECT i FROM (SELECT * FROM (SELECT i as i, s as s FROM mytable) f) s)",
Expected: []sql.Row{},
Expand Down Expand Up @@ -5792,6 +5811,7 @@ var ExplodeQueries = []QueryTest{

type QueryErrorTest struct {
Query string
Bindings map[string]sql.Expression
ExpectedErr *errors.Kind
}

Expand Down Expand Up @@ -5995,6 +6015,34 @@ var errorQueries = []QueryErrorTest{
Query: "SELECT a FROM (select i,s FROM mytable) mt (a,b,c) order by a desc;",
ExpectedErr: sql.ErrColumnCountMismatch,
},
{
Query: "SELECT i FROM mytable limit ?",
ExpectedErr: sql.ErrInvalidSyntax,
Bindings: map[string]sql.Expression{
"v1": expression.NewLiteral(-100, sql.Int8),
},
},
{
Query: "SELECT i FROM mytable limit ?",
ExpectedErr: sql.ErrInvalidType,
Bindings: map[string]sql.Expression{
"v1": expression.NewLiteral("100", sql.LongText),
},
},
{
Query: "SELECT i FROM mytable limit 10, ?",
ExpectedErr: sql.ErrInvalidSyntax,
Bindings: map[string]sql.Expression{
"v1": expression.NewLiteral(-100, sql.Int8),
},
},
{
Query: "SELECT i FROM mytable limit 10, ?",
ExpectedErr: sql.ErrInvalidType,
Bindings: map[string]sql.Expression{
"v1": expression.NewLiteral("100", sql.LongText),
},
},
}

// WriteQueryTest is a query test for INSERT, UPDATE, etc. statements. It has a query to run and a select query to
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ require (
github.com/VividCortex/gohistogram v1.0.0 // indirect
github.com/cespare/xxhash v1.1.0
github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81
github.com/dolthub/vitess v0.0.0-20210524220733-7b048c544267
github.com/dolthub/vitess v0.0.0-20210530214338-7755381e6501
github.com/fastly/go-utils v0.0.0-20180712184237-d95a45783239 // indirect
github.com/go-kit/kit v0.9.0
github.com/go-sql-driver/mysql v1.6.0
Expand Down
6 changes: 2 additions & 4 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,14 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81 h1:7/v8q9XGFa6q5Ap4Z/OhNkAMBaK5YeuEzwJt+NZdhiE=
github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81/go.mod h1:siLfyv2c92W1eN/R4QqG/+RjjX5W2+gCTRjZxBjI3TY=
github.com/dolthub/vitess v0.0.0-20210524220733-7b048c544267 h1:g3KWBmLtSWlEbwUF4NV4a4jzE5aec8n2ZHWwSDy9IGY=
github.com/dolthub/vitess v0.0.0-20210524220733-7b048c544267/go.mod h1:hUE8oSk2H5JZnvtlLBhJPYC8WZCA5AoSntdLTcBvdBM=
github.com/dolthub/vitess v0.0.0-20210530214338-7755381e6501 h1:QO+maZZoP4PUwS5Clk/lo5AvZ8J5jHevbC/tTAfLe70=
github.com/dolthub/vitess v0.0.0-20210530214338-7755381e6501/go.mod h1:hUE8oSk2H5JZnvtlLBhJPYC8WZCA5AoSntdLTcBvdBM=
github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c=
github.com/fastly/go-utils v0.0.0-20180712184237-d95a45783239 h1:Ghm4eQYC0nEPnSJdVkTrXpu9KtoVCSo1hg7mtI7G9KU=
github.com/fastly/go-utils v0.0.0-20180712184237-d95a45783239/go.mod h1:Gdwt2ce0yfBxPvZrHkprdPPTTS3N5rwmLE8T22KBXlw=
github.com/go-kit/kit v0.9.0 h1:wDJmvq38kDhkVxi50ni9ykkdUr1PKgqKOoi01fa0Mdk=
github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
github.com/go-sql-driver/mysql v1.4.1 h1:g24URVg0OFbNUTx9qqY1IRZ9D9z3iPyi5zKhQZpNwpA=
github.com/go-sql-driver/mysql v1.4.1/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w=
github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE=
github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg=
Expand Down Expand Up @@ -132,7 +131,6 @@ golang.org/x/tools v0.0.0-20190830154057-c17b040389b9/go.mod h1:b+2E5dAYhXwXZwtn
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM=
google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
google.golang.org/appengine v1.6.7 h1:FZR1q0exgwxzPzp/aF+VccGrSfxfPpkBqjIIEq3ru6c=
google.golang.org/appengine v1.6.7/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc=
google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc=
google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc=
Expand Down
4 changes: 2 additions & 2 deletions sql/analyzer/parallelize_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,15 +198,15 @@ func TestIsParallelizable(t *testing.T) {
{
"limit",
plan.NewLimit(
5,
expression.NewLiteral(5, sql.Int8),
plan.NewResolvedTable(nil, nil, nil),
),
false,
},
{
"offset",
plan.NewOffset(
5,
expression.NewLiteral(5, sql.Int8),
plan.NewResolvedTable(nil, nil, nil),
),
false,
Expand Down
1 change: 1 addition & 0 deletions sql/analyzer/rules.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
// OnceBeforeDefault contains the rules to be applied just once before the
// DefaultRules.
var OnceBeforeDefault = []Rule{
{"validate_offset_and_limit", validateLimitAndOffset},
{"load_stored_procedures", loadStoredProcedures},
{"resolve_views", resolveViews},
{"lift_common_table_expressions", liftCommonTableExpressions},
Expand Down
54 changes: 54 additions & 0 deletions sql/analyzer/validation_rules.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,60 @@ var DefaultValidationRules = []Rule{
{validateUnionSchemasMatchRule, validateUnionSchemasMatch},
}

// validateLimitAndOffset ensures that only integer literals are used for limit and offset values
func validateLimitAndOffset(ctx *sql.Context, a *Analyzer, n sql.Node, scope *Scope) (sql.Node, error) {
return plan.TransformUp(n, func(n sql.Node) (sql.Node, error) {
switch n := n.(type) {
case *plan.Limit:
switch e := n.Limit.(type) {
case *expression.Literal:
if !sql.IsInteger(e.Type()) {
return nil, sql.ErrInvalidType.New(e.Type().String())
}
i, err := e.Eval(ctx, nil)
if err != nil {
return nil, err
}

i64, err := sql.Int64.Convert(i)
if err != nil {
return nil, err
}
if i64.(int64) < 0 {
return nil, sql.ErrInvalidSyntax.New("negative limit")
}
default:
return nil, sql.ErrInvalidType.New(e.Type().String())
}
return n, nil
case *plan.Offset:
switch e := n.Offset.(type) {
case *expression.Literal:
if !sql.IsInteger(e.Type()) {
return nil, sql.ErrInvalidType.New(e.Type().String())
}
i, err := e.Eval(ctx, nil)
if err != nil {
return nil, err
}

i64, err := sql.Int64.Convert(i)
if err != nil {
return nil, err
}
if i64.(int64) < 0 {
return nil, sql.ErrInvalidSyntax.New("negative offset")
}
default:
return nil, sql.ErrInvalidType.New(e.Type().String())
}
return n, nil
default:
return n, nil
}
})
}

func validateIsResolved(ctx *sql.Context, a *Analyzer, n sql.Node, scope *Scope) (sql.Node, error) {
span, _ := ctx.Span("validate_is_resolved")
defer span.Finish()
Expand Down
4 changes: 4 additions & 0 deletions sql/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,10 @@ var (
// ErrSavepointDoesNotExist is returned when a RELEASE SAVEPOINT or ROLLBACK TO SAVEPOINT statement references a
// non-existent savepoint identifier
ErrSavepointDoesNotExist = errors.NewKind("SAVEPOINT %s does not exist")

// ErrInvalidSyntax is returned for syntax errors that aren't picked up by the parser, e.g. the wrong type of
// expression used in part of statement.
ErrInvalidSyntax = errors.NewKind("Invalid syntax: %s")
)

func CastSQLError(err error) (*mysql.SQLError, bool) {
Expand Down
2 changes: 2 additions & 0 deletions sql/expression/literal.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ func (p *Literal) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {

func (p *Literal) String() string {
switch v := p.value.(type) {
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
return fmt.Sprintf("%d", v)
case string:
return fmt.Sprintf("%q", v)
case []byte:
Expand Down
22 changes: 10 additions & 12 deletions sql/parse/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -627,7 +627,7 @@ func convertSelect(ctx *sql.Context, s *sqlparser.Select) (sql.Node, error) {
}
} else if ok, val := sql.HasDefaultValue(ctx, ctx.Session, "sql_select_limit"); !ok {
limit := mustCastNumToInt64(val)
node = plan.NewLimit(limit, node)
node = plan.NewLimit(expression.NewLiteral(limit, sql.Int64), node)
}

// Finally, if common table expressions were provided, wrap the top-level node in a With node to capture them
Expand Down Expand Up @@ -1869,15 +1869,11 @@ func limitToLimit(
limit sqlparser.Expr,
child sql.Node,
) (*plan.Limit, error) {
rowCount, err := getInt64Value(ctx, limit, "LIMIT with non-integer literal")
rowCount, err := ExprToExpression(ctx, limit)
if err != nil {
return nil, err
}

if rowCount < 0 {
return nil, ErrUnsupportedSyntax.New("LIMIT must be >= 0")
}

return plan.NewLimit(rowCount, child), nil
}

Expand All @@ -1895,16 +1891,12 @@ func offsetToOffset(
offset sqlparser.Expr,
child sql.Node,
) (*plan.Offset, error) {
o, err := getInt64Value(ctx, offset, "OFFSET with non-integer literal")
rowCount, err := ExprToExpression(ctx, offset)
if err != nil {
return nil, err
}

if o < 0 {
return nil, ErrUnsupportedSyntax.New("OFFSET must be >= 0")
}

return plan.NewOffset(o, child), nil
return plan.NewOffset(rowCount, child), nil
}

// getInt64Literal returns an int64 *expression.Literal for the value given, or an unsupported error with the string
Expand All @@ -1915,6 +1907,12 @@ func getInt64Literal(ctx *sql.Context, expr sqlparser.Expr, errStr string) (*exp
return nil, err
}

switch e := e.(type) {
case *expression.Literal:
if !sql.IsInteger(e.Type()) {
return nil, ErrUnsupportedFeature.New(errStr)
}
}
nl, ok := e.(*expression.Literal)
if !ok || !sql.IsInteger(nl.Type()) {
return nil, ErrUnsupportedFeature.New(errStr)
Expand Down
Loading

0 comments on commit 9dbddef

Please sign in to comment.