diff --git a/generator/mysql/mysql_generator.go b/generator/mysql/mysql_generator.go index 7495bec4..fe2fa2ac 100644 --- a/generator/mysql/mysql_generator.go +++ b/generator/mysql/mysql_generator.go @@ -38,7 +38,7 @@ func Generate(destDir string, dbConn DBConnection, generatorTemplate ...template } defer db.Close() - err = generate(db, dbConn.DBName, destDir, generatorTemplate...) + err = GenerateDB(db, dbConn.DBName, destDir, generatorTemplate...) if err != nil { return err } @@ -70,7 +70,7 @@ func GenerateDSN(dsn, destDir string, templates ...template.Template) error { } defer db.Close() - err = generate(db, cfg.DBName, destDir, templates...) + err = GenerateDB(db, cfg.DBName, destDir, templates...) if err != nil { return fmt.Errorf("failed to generate: %w", err) } @@ -96,7 +96,8 @@ func openConnection(connectionString string) (*sql.DB, error) { return db, nil } -func generate(db *sql.DB, dbName, destDir string, templates ...template.Template) error { +// GenerateDB generates jet files using the provided *sql.DB +func GenerateDB(db *sql.DB, dbName, destDir string, templates ...template.Template) error { fmt.Println("Retrieving database information...") // No schemas in MySQL schemaMetaData, err := metadata.GetSchema(db, &mySqlQuerySet{}, dbName) diff --git a/generator/postgres/postgres_generator.go b/generator/postgres/postgres_generator.go index 0b503ef8..83469037 100644 --- a/generator/postgres/postgres_generator.go +++ b/generator/postgres/postgres_generator.go @@ -56,6 +56,14 @@ func GenerateDSN(dsn, schema, destDir string, templates ...template.Template) er defer db.Close() fmt.Println("Retrieving schema information...") + return GenerateDB(db, schema, cfg.Database, destDir, templates...) +} + +// GenerateDB generates jet files using the provided *sql.DB +func GenerateDB(db *sql.DB, dbName, schema, destDir string, templates ...template.Template) error { + if dbName == "" { + return fmt.Errorf("database name is required") + } generatorTemplate := template.Default(postgres.Dialect) if len(templates) > 0 { generatorTemplate = templates[0] @@ -66,7 +74,7 @@ func GenerateDSN(dsn, schema, destDir string, templates ...template.Template) er return fmt.Errorf("failed to get '%s' schema metadata: %w", schema, err) } - dirPath := filepath.Join(destDir, cfg.Database) + dirPath := filepath.Join(destDir, dbName) err = template.ProcessSchema(dirPath, schemaMetadata, generatorTemplate) if err != nil { diff --git a/generator/sqlite/sqlite_generator.go b/generator/sqlite/sqlite_generator.go index eadc72b1..a2a5cf8a 100644 --- a/generator/sqlite/sqlite_generator.go +++ b/generator/sqlite/sqlite_generator.go @@ -3,6 +3,7 @@ package sqlite import ( "database/sql" "fmt" + "github.com/go-jet/jet/v2/generator/metadata" "github.com/go-jet/jet/v2/generator/template" "github.com/go-jet/jet/v2/sqlite" @@ -17,7 +18,11 @@ func GenerateDSN(dsn, destDir string, templates ...template.Template) error { defer db.Close() fmt.Println("Retrieving schema information...") + return GenerateDB(db, destDir, templates...) +} +// GenerateDB generates jet files using the provided *sql.DB +func GenerateDB(db *sql.DB, destDir string, templates ...template.Template) error { generatorTemplate := template.Default(sqlite.Dialect) if len(templates) > 0 { generatorTemplate = templates[0]