diff --git a/README.md b/README.md index 0d575874de..d2049d6173 100644 --- a/README.md +++ b/README.md @@ -274,64 +274,60 @@ database implementation: package main import ( + "fmt" "time" + + "github.com/dolthub/go-mysql-server/sql/information_schema" + sqle "github.com/dolthub/go-mysql-server" "github.com/dolthub/go-mysql-server/memory" "github.com/dolthub/go-mysql-server/server" "github.com/dolthub/go-mysql-server/sql" - "github.com/dolthub/go-mysql-server/sql/information_schema" ) -// Example of how to implement a MySQL server based on a Engine: -// -// ``` -// > mysql --host=127.0.0.1 --port=3306 -u root mydb -e "SELECT * FROM mytable" -// +----------+-------------------+-------------------------------+---------------------+ -// | name | email | phone_numbers | created_at | -// +----------+-------------------+-------------------------------+---------------------+ -// | John Doe | john@doe.com | ["555-555-555"] | 2018-04-18 09:41:13 | -// | John Doe | johnalt@doe.com | [] | 2018-04-18 09:41:13 | -// | Jane Doe | jane@doe.com | [] | 2018-04-18 09:41:13 | -// | Evil Bob | evilbob@gmail.com | ["555-666-555","666-666-666"] | 2018-04-18 09:41:13 | -// +----------+-------------------+-------------------------------+---------------------+ -// ``` +var ( + dbName = "mydb" + tableName = "mytable" + address = "localhost" + port = 3306 +) + func main() { + ctx := sql.NewEmptyContext() engine := sqle.NewDefault( sql.NewDatabaseProvider( - createTestDatabase(), + createTestDatabase(ctx), information_schema.NewInformationSchemaDatabase(), )) - engine.Analyzer.Catalog.MySQLDb.AddRootAccount() + config := server.Config{ Protocol: "tcp", - Address: "localhost:3306", + Address: fmt.Sprintf("%s:%d", address, port), } s, err := server.NewDefaultServer(config, engine) if err != nil { panic(err) } - s.Start() + if err = s.Start(); err != nil { + panic(err) + } } -func createTestDatabase() *memory.Database { - const ( - dbName = "mydb" - tableName = "mytable" - ) +func createTestDatabase(ctx *sql.Context) *memory.Database { db := memory.NewDatabase(dbName) table := memory.NewTable(tableName, sql.NewPrimaryKeySchema(sql.Schema{ - {Name: "name", Type: sql.Text, Nullable: false, Source: tableName}, - {Name: "email", Type: sql.Text, Nullable: false, Source: tableName}, + {Name: "name", Type: sql.Text, Nullable: false, Source: tableName, PrimaryKey: true}, + {Name: "email", Type: sql.Text, Nullable: false, Source: tableName, PrimaryKey: true}, {Name: "phone_numbers", Type: sql.JSON, Nullable: false, Source: tableName}, {Name: "created_at", Type: sql.Datetime, Nullable: false, Source: tableName}, - }), nil) - + }), db.GetForeignKeyCollection()) db.AddTable(tableName, table) - ctx := sql.NewEmptyContext() - _ = table.Insert(ctx, sql.NewRow("John Doe", "john@doe.com", sql.MustJSON(`["555-555-555"]`), time.Now())) - _ = table.Insert(ctx, sql.NewRow("John Doe", "johnalt@doe.com", sql.MustJSON(`[]`), time.Now())) - _ = table.Insert(ctx, sql.NewRow("Jane Doe", "jane@doe.com", sql.MustJSON(`[]`), time.Now())) - _ = table.Insert(ctx, sql.NewRow("Jane Deo", "janedeo@gmail.com", sql.MustJSON(`["556-565-566", "777-777-777"]`), time.Now())) + + creationTime := time.Unix(0, 1667304000000001000).UTC() + _ = table.Insert(ctx, sql.NewRow("Jane Deo", "janedeo@gmail.com", sql.MustJSON(`["556-565-566", "777-777-777"]`), creationTime)) + _ = table.Insert(ctx, sql.NewRow("Jane Doe", "jane@doe.com", sql.MustJSON(`[]`), creationTime)) + _ = table.Insert(ctx, sql.NewRow("John Doe", "john@doe.com", sql.MustJSON(`["555-555-555"]`), creationTime)) + _ = table.Insert(ctx, sql.NewRow("John Doe", "johnalt@doe.com", sql.MustJSON(`[]`), creationTime)) return db } ``` @@ -339,15 +335,15 @@ func createTestDatabase() *memory.Database { Then, you can connect to the server with any MySQL client: ```bash -> mysql --host=127.0.0.1 --port=3306 -u root mydb -e "SELECT * FROM mytable" -+----------+-------------------+-------------------------------+---------------------+ -| name | email | phone_numbers | created_at | -+----------+-------------------+-------------------------------+---------------------+ -| John Doe | john@doe.com | ["555-555-555"] | 2018-04-18 10:42:58 | -| John Doe | johnalt@doe.com | [] | 2018-04-18 10:42:58 | -| Jane Doe | jane@doe.com | [] | 2018-04-18 10:42:58 | -| Jane Doe | janedeo@gmail.com | ["556-565-566","777-777-777"] | 2018-04-18 10:42:58 | -+----------+-------------------+-------------------------------+---------------------+ +> mysql --host=localhost --port=3306 --user=root mydb --execute="SELECT * FROM mytable;" ++----------+-------------------+-------------------------------+----------------------------+ +| name | email | phone_numbers | created_at | ++----------+-------------------+-------------------------------+----------------------------+ +| Jane Deo | janedeo@gmail.com | ["556-565-566","777-777-777"] | 2022-11-01 12:00:00.000001 | +| Jane Doe | jane@doe.com | [] | 2022-11-01 12:00:00.000001 | +| John Doe | john@doe.com | ["555-555-555"] | 2022-11-01 12:00:00.000001 | +| John Doe | johnalt@doe.com | [] | 2022-11-01 12:00:00.000001 | ++----------+-------------------+-------------------------------+----------------------------+ ``` See the complete example [here](_example/main.go). diff --git a/_example/main.go b/_example/main.go index 8e83e01865..d31cae2a9f 100644 --- a/_example/main.go +++ b/_example/main.go @@ -1,4 +1,4 @@ -// Copyright 2020-2021 Dolthub, Inc. +// Copyright 2020-2022 Dolthub, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -15,69 +15,84 @@ package main import ( + "fmt" "time" + "github.com/dolthub/go-mysql-server/sql/information_schema" + sqle "github.com/dolthub/go-mysql-server" "github.com/dolthub/go-mysql-server/memory" "github.com/dolthub/go-mysql-server/server" "github.com/dolthub/go-mysql-server/sql" - "github.com/dolthub/go-mysql-server/sql/information_schema" ) -// Example of how to implement a MySQL server based on a Engine: +// This is an example of how to implement a MySQL server. +// After running the example, you may connect to it using the following: // -// ``` -// > mysql --host=127.0.0.1 --port=3306 -u root mydb -e "SELECT * FROM mytable" -// +----------+-------------------+-------------------------------+---------------------+ -// | name | email | phone_numbers | created_at | -// +----------+-------------------+-------------------------------+---------------------+ -// | John Doe | john@doe.com | ["555-555-555"] | 2018-04-18 09:41:13 | -// | John Doe | johnalt@doe.com | [] | 2018-04-18 09:41:13 | -// | Jane Doe | jane@doe.com | [] | 2018-04-18 09:41:13 | -// | Evil Bob | evilbob@gmail.com | ["555-666-555","666-666-666"] | 2018-04-18 09:41:13 | -// +----------+-------------------+-------------------------------+---------------------+ -// ``` +// > mysql --host=localhost --port=3306 --user=root mydb --execute="SELECT * FROM mytable;" +// +----------+-------------------+-------------------------------+----------------------------+ +// | name | email | phone_numbers | created_at | +// +----------+-------------------+-------------------------------+----------------------------+ +// | Jane Deo | janedeo@gmail.com | ["556-565-566","777-777-777"] | 2022-11-01 12:00:00.000001 | +// | Jane Doe | jane@doe.com | [] | 2022-11-01 12:00:00.000001 | +// | John Doe | john@doe.com | ["555-555-555"] | 2022-11-01 12:00:00.000001 | +// | John Doe | johnalt@doe.com | [] | 2022-11-01 12:00:00.000001 | +// +----------+-------------------+-------------------------------+----------------------------+ +// +// The included MySQL client is used in this example, however any MySQL-compatible client will work. + +var ( + dbName = "mydb" + tableName = "mytable" + address = "localhost" + port = 3306 +) + +// For go-mysql-server developers: Remember to update the snippet in the README when this file changes. + func main() { + ctx := sql.NewEmptyContext() engine := sqle.NewDefault( sql.NewDatabaseProvider( - createTestDatabase(), + createTestDatabase(ctx), information_schema.NewInformationSchemaDatabase(), )) - engine.Analyzer.Catalog.MySQLDb.AddRootAccount() + // This variable may be found in the "users_example.go" file. Please refer to that file for a walkthrough on how to + // set up the "mysql" database to allow user creation and user checking when establishing connections. This is set + // to false for this example, but feel free to play around with it and see how it works. + if enableUsers { + if err := enableUserAccounts(ctx, engine); err != nil { + panic(err) + } + } config := server.Config{ Protocol: "tcp", - Address: "localhost:3306", + Address: fmt.Sprintf("%s:%d", address, port), } - s, err := server.NewDefaultServer(config, engine) if err != nil { panic(err) } - - s.Start() + if err = s.Start(); err != nil { + panic(err) + } } -func createTestDatabase() *memory.Database { - const ( - dbName = "mydb" - tableName = "mytable" - ) - +func createTestDatabase(ctx *sql.Context) *memory.Database { db := memory.NewDatabase(dbName) table := memory.NewTable(tableName, sql.NewPrimaryKeySchema(sql.Schema{ {Name: "name", Type: sql.Text, Nullable: false, Source: tableName, PrimaryKey: true}, {Name: "email", Type: sql.Text, Nullable: false, Source: tableName, PrimaryKey: true}, {Name: "phone_numbers", Type: sql.JSON, Nullable: false, Source: tableName}, - {Name: "created_at", Type: sql.Timestamp, Nullable: false, Source: tableName}, + {Name: "created_at", Type: sql.Datetime, Nullable: false, Source: tableName}, }), db.GetForeignKeyCollection()) - - creationTime := time.Unix(1524044473, 0).UTC() db.AddTable(tableName, table) - ctx := sql.NewEmptyContext() - table.Insert(ctx, sql.NewRow("John Doe", "john@doe.com", sql.JSONDocument{Val: []string{"555-555-555"}}, creationTime)) - table.Insert(ctx, sql.NewRow("John Doe", "johnalt@doe.com", sql.JSONDocument{Val: []string{}}, creationTime)) - table.Insert(ctx, sql.NewRow("Jane Doe", "jane@doe.com", sql.JSONDocument{Val: []string{}}, creationTime)) - table.Insert(ctx, sql.NewRow("Evil Bob", "evilbob@gmail.com", sql.JSONDocument{Val: []string{"555-666-555", "666-666-666"}}, creationTime)) + + creationTime := time.Unix(0, 1667304000000001000).UTC() + _ = table.Insert(ctx, sql.NewRow("Jane Deo", "janedeo@gmail.com", sql.MustJSON(`["556-565-566", "777-777-777"]`), creationTime)) + _ = table.Insert(ctx, sql.NewRow("Jane Doe", "jane@doe.com", sql.MustJSON(`[]`), creationTime)) + _ = table.Insert(ctx, sql.NewRow("John Doe", "john@doe.com", sql.MustJSON(`["555-555-555"]`), creationTime)) + _ = table.Insert(ctx, sql.NewRow("John Doe", "johnalt@doe.com", sql.MustJSON(`[]`), creationTime)) return db } diff --git a/_example/main_test.go b/_example/main_test.go new file mode 100644 index 0000000000..cf19afe17e --- /dev/null +++ b/_example/main_test.go @@ -0,0 +1,129 @@ +// Copyright 2022 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "database/sql" + "fmt" + "net" + "testing" + + _ "github.com/go-sql-driver/mysql" + "github.com/gocraft/dbr/v2" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +var expectedResults = [][]string{ + {"Jane Deo", "janedeo@gmail.com", `["556-565-566","777-777-777"]`, "2022-11-01 12:00:00.000001"}, + {"Jane Doe", "jane@doe.com", `[]`, "2022-11-01 12:00:00.000001"}, + {"John Doe", "john@doe.com", `["555-555-555"]`, "2022-11-01 12:00:00.000001"}, + {"John Doe", "johnalt@doe.com", `[]`, "2022-11-01 12:00:00.000001"}, +} + +func TestExampleUsersDisabled(t *testing.T) { + enableUsers = false + useUnusedPort(t) + go func() { + main() + }() + + conn, err := dbr.Open("mysql", fmt.Sprintf("no_user:@tcp(%s:%d)/%s", address, port, dbName), nil) + require.NoError(t, err) + require.NoError(t, conn.Ping()) + + rows, err := conn.Query(fmt.Sprintf("SELECT * FROM %s;", tableName)) + require.NoError(t, err) + checkRows(t, expectedResults, rows) + require.NoError(t, conn.Close()) +} + +func TestExampleRootUserEnabled(t *testing.T) { + enableUsers = true + pretendThatFileExists = false + useUnusedPort(t) + go func() { + main() + }() + + conn, err := dbr.Open("mysql", fmt.Sprintf("no_user:@tcp(%s:%d)/%s", address, port, dbName), nil) + require.NoError(t, err) + require.ErrorContains(t, conn.Ping(), "User not found") + conn, err = dbr.Open("mysql", fmt.Sprintf("root:@tcp(%s:%d)/%s", address, port, dbName), nil) + require.NoError(t, err) + require.NoError(t, conn.Ping()) + + rows, err := conn.Query(fmt.Sprintf("SELECT * FROM %s;", tableName)) + require.NoError(t, err) + checkRows(t, expectedResults, rows) + require.NoError(t, conn.Close()) +} + +func TestExampleLoadedUser(t *testing.T) { + enableUsers = true + pretendThatFileExists = true + useUnusedPort(t) + go func() { + main() + }() + + conn, err := dbr.Open("mysql", fmt.Sprintf("no_user:@tcp(%s:%d)/%s", address, port, dbName), nil) + require.NoError(t, err) + require.ErrorContains(t, conn.Ping(), "User not found") + conn, err = dbr.Open("mysql", fmt.Sprintf("root:@tcp(%s:%d)/%s", address, port, dbName), nil) + require.NoError(t, err) + require.ErrorContains(t, conn.Ping(), "User not found") + conn, err = dbr.Open("mysql", + fmt.Sprintf("gms_user:123456@tcp(%s:%d)/%s?allowCleartextPasswords=true", address, port, dbName), nil) + require.NoError(t, err) + require.NoError(t, conn.Ping()) + + rows, err := conn.Query(fmt.Sprintf("SELECT * FROM %s;", tableName)) + require.NoError(t, err) + checkRows(t, expectedResults, rows) + require.NoError(t, conn.Close()) +} + +func checkRows(t *testing.T, expectedRows [][]string, actualRows *sql.Rows) { + rowIdx := -1 + for actualRows.Next() { + rowIdx++ + + if assert.Less(t, rowIdx, len(expectedRows)) { + compareRow := make([]string, len(expectedRows[rowIdx])) + connRow := make([]*string, len(compareRow)) + interfaceRow := make([]any, len(compareRow)) + for i := range connRow { + interfaceRow[i] = &connRow[i] + } + assert.NoError(t, actualRows.Scan(interfaceRow...)) + for i := range connRow { + if assert.NotNil(t, connRow[i]) { + compareRow[i] = *connRow[i] + } + } + assert.Equal(t, expectedRows[rowIdx], compareRow) + } + } + assert.NoError(t, actualRows.Close()) +} + +func useUnusedPort(t *testing.T) { + // Tests should grab an open port, otherwise they'll fail if some hardcoded port is already in use + listener, err := net.Listen("tcp", ":0") + require.NoError(t, err) + port = listener.Addr().(*net.TCPAddr).Port + require.NoError(t, listener.Close()) +} diff --git a/_example/users_example.go b/_example/users_example.go new file mode 100644 index 0000000000..459db906fe --- /dev/null +++ b/_example/users_example.go @@ -0,0 +1,104 @@ +// Copyright 2022 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + sqle "github.com/dolthub/go-mysql-server" + "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/mysql_db" +) + +// This file walks through how an application might add support for users through the inclusion of the "mysql" database. +// By default, this database is disabled, and must be either explicitly enabled, or implicitly enabled by loading into +// it. When disabled, all users are accepted, meaning that anyone may connect to your application. This is great for +// testing, but not so great for an exposed application that wants to protect itself from malicious actors. + +var ( + // When set to true, calls "enableUserAccounts". + enableUsers = false + + // When set to false, we create an account named "root" that has all privileges and does not have a password set. + // When set to true, we create an account named "gms_user" that has all privileges and has the password "123456". + // This is a stand-in for checking whether a file exists that contains the "mysql" database's data. + pretendThatFileExists = false +) + +// MySQLPersister is an example struct which handles the persistence of the data in the "mysql" database. +type MySQLPersister struct { + Data []byte +} + +var _ mysql_db.MySQLDbPersistence = (*MySQLPersister)(nil) + +// Persist implements the interface mysql_db.MySQLDbPersistence. This function is simple, in that it simply stores +// the given data inside itself. A real application would persist to the file system. +func (m *MySQLPersister) Persist(ctx *sql.Context, data []byte) error { + m.Data = data + return nil +} + +func enableUserAccounts(ctx *sql.Context, engine *sqle.Engine) error { + mysqlDb := engine.Analyzer.Catalog.MySQLDb + + // The functions "AddRootAccount" and "LoadData" both automatically enable the "mysql" database, but this is just + // to explicitly show how one can manually enable (or disable) the database. + mysqlDb.Enabled = true + // The persister here simply stands-in for your provided persistence function. The database calls this whenever it + // needs to save any changes to any of the "mysql" database's tables. + persister := &MySQLPersister{} + mysqlDb.SetPersister(persister) + + // Here we show how a real application may choose to bootstrap their users. If we've previously created a file + // (generated by calling the above persister), then we may check that the file exists, and load the file if it does. + // If the file does not exist, then we create a "root" account, so that we may create our default users. Do remember + // to either remove the "root" account when done, or give it a password, otherwise you'll have an all-powerful + // account without any protection. + if pretendThatFileExists { + dataLoadedFromPretendFile := createLoadedData() + if err := mysqlDb.LoadData(ctx, dataLoadedFromPretendFile); err != nil { + return err + } + } else { + // AddRootAccount creates a password-less account named "root" that has all privileges. This is intended for use + // with testing, and also to set up the initial user accounts. A real application may want to check that a + // persisted file exists, and call "LoadData" if one does. If a file does not exist, it would call + // "AddRootAccount". + mysqlDb.AddRootAccount() + } + + return nil +} + +// createLoadedData returns data that would match what a file would contain if you had a single user named "gms_user" +// (on "localhost") that had all privileges, along with the password "123456". +func createLoadedData() []byte { + // As we continue development on go-mysql-server, our file format will continue to evolve. We guarantee backward + // compatibility for persisted "mysql" databases, so the below data will always work. This is only here for example + // purposes. + return []byte{ + 16, 0, 0, 0, 0, 0, 0, 0, 8, 0, 12, 0, 8, 0, 4, 0, 8, 0, 0, 0, 8, 0, 0, 0, 8, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, + 28, 0, 0, 0, 0, 0, 22, 0, 40, 0, 36, 0, 32, 0, 28, 0, 24, 0, 20, 0, 8, 0, 0, 0, 0, 0, 4, 0, 22, 0, 0, 0, 36, 0, + 0, 0, 168, 231, 96, 99, 0, 0, 0, 0, 0, 0, 0, 0, 28, 0, 0, 0, 72, 0, 0, 0, 104, 0, 0, 0, 248, 0, 0, 0, 4, 1, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 41, 0, 0, 0, 42, 54, 66, 66, 52, 56, 51, 55, 69, 66, 55, 52, 51, 50, 57, 49, 48, 53, + 69, 69, 52, 53, 54, 56, 68, 68, 65, 55, 68, 67, 54, 55, 69, 68, 50, 67, 65, 50, 65, 68, 57, 0, 0, 0, 21, 0, 0, + 0, 109, 121, 115, 113, 108, 95, 110, 97, 116, 105, 118, 101, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0, 10, + 0, 16, 0, 12, 0, 8, 0, 4, 0, 10, 0, 0, 0, 12, 0, 0, 0, 12, 0, 0, 0, 12, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 30, 0, + 0, 0, 23, 0, 0, 0, 7, 0, 0, 0, 6, 0, 0, 0, 3, 0, 0, 0, 24, 0, 0, 0, 16, 0, 0, 0, 8, 0, 0, 0, 22, 0, 0, 0, 21, 0, + 0, 0, 18, 0, 0, 0, 15, 0, 0, 0, 29, 0, 0, 0, 27, 0, 0, 0, 11, 0, 0, 0, 28, 0, 0, 0, 26, 0, 0, 0, 5, 0, 0, 0, 4, + 0, 0, 0, 2, 0, 0, 0, 30, 0, 0, 0, 25, 0, 0, 0, 20, 0, 0, 0, 19, 0, 0, 0, 14, 0, 0, 0, 13, 0, 0, 0, 17, 0, 0, 0, + 9, 0, 0, 0, 1, 0, 0, 0, 12, 0, 0, 0, 0, 0, 0, 0, 9, 0, 0, 0, 108, 111, 99, 97, 108, 104, 111, 115, 116, 0, 0, 0, + 8, 0, 0, 0, 103, 109, 115, 95, 117, 115, 101, 114, 0, 0, 0, 0, + } +} diff --git a/sql/mysql_db/mysql_db.go b/sql/mysql_db/mysql_db.go index 2e1befcad0..d96fc7a81e 100644 --- a/sql/mysql_db/mysql_db.go +++ b/sql/mysql_db/mysql_db.go @@ -318,6 +318,9 @@ func (db *MySQLDb) UserActivePrivilegeSet(ctx *sql.Context) PrivilegeSet { // privileged operation. This takes into account the active roles, which are set in the context, therefore the user is // also pulled from the context. func (db *MySQLDb) UserHasPrivileges(ctx *sql.Context, operations ...sql.PrivilegedOperation) bool { + if !db.Enabled { + return true + } privSet := db.UserActivePrivilegeSet(ctx) for _, operation := range operations { for _, operationPriv := range operation.Privileges { @@ -379,6 +382,9 @@ func (db *MySQLDb) GetTableNames(ctx *sql.Context) ([]string, error) { // AuthMethod implements the interface mysql.AuthServer. func (db *MySQLDb) AuthMethod(user, addr string) (string, error) { + if !db.Enabled { + return "mysql_native_password", nil + } var host string // TODO : need to check for network type instead of addr string if it's unix socket network, // macOS passes empty addr, but ubuntu returns "@" as addr for `localhost`