diff --git a/connection.go b/connection.go index df116fb3..42098199 100644 --- a/connection.go +++ b/connection.go @@ -15,13 +15,14 @@ var Connections = map[string]*Connection{} // Connection represents all necessary details to talk with a datastore type Connection struct { - ID string - Store store - Dialect dialect - Elapsed int64 - TX *Tx - eager bool - eagerFields []string + ID string + Store store + Dialect dialect + Elapsed int64 + TX *Tx + eager bool + eagerFields []string + OptimizeCount bool } func (c *Connection) String() string { diff --git a/finders.go b/finders.go index 9026d14b..504bb44f 100644 --- a/finders.go +++ b/finders.go @@ -348,10 +348,32 @@ func (q Query) CountByField(model interface{}, field string) (int, error) { tmpQuery.Paginator = nil tmpQuery.orderClauses = clauses{} tmpQuery.limitResults = 0 - query, args := tmpQuery.ToSQL(&Model{Value: model}) - // when query contains custom selected fields / executed using RawQuery, - // sql may already contains limit and offset + var query, countQuery string + var args []interface{} + var isRaw bool + + if tmpQuery.RawSQL != nil && tmpQuery.RawSQL.Fragment != "" { + isRaw = true + } + + // Count can't be optimized if the query contains raw SQL due to ToSQL internals + if tmpQuery.OptimizeCount && !isRaw { + tmpQuery.addColumns = []string{} // Optimizing Count means giving up selecting any distinct columns. + // This can be changed in the future but will also have to address the issue of + // table aliasing in column names -- AKA reevaluating how Model.ignoreTableName works. + query, args = tmpQuery.ToSQL(&Model{Value: model, ignoreTableName: true}, + fmt.Sprintf("COUNT(%s) as row_count", field)) + } else { + if tmpQuery.OptimizeCount && isRaw { + log(logging.Warn, "Query contains raw SQL; COUNT cannot be optimized") + } + + query, args = tmpQuery.ToSQL(&Model{Value: model}) + } + + //when query contains custom selected fields / executed using RawQuery, + // sql may already contains limit and offset if rLimitOffset.MatchString(query) { foundLimit := rLimitOffset.FindString(query) query = query[0 : len(query)-len(foundLimit)] @@ -360,7 +382,12 @@ func (q Query) CountByField(model interface{}, field string) (int, error) { query = query[0 : len(query)-len(foundLimit)] } - countQuery := fmt.Sprintf("SELECT COUNT(%s) AS row_count FROM (%s) a", field, query) + if tmpQuery.OptimizeCount { + countQuery = query + } else { + countQuery = fmt.Sprintf("SELECT COUNT(%s) AS row_count FROM (%s) a", field, query) + } + log(logging.SQL, countQuery, args...) return q.Connection.Store.Get(res, countQuery, args...) }) diff --git a/finders_test.go b/finders_test.go index 389dc80a..7ff3e51d 100644 --- a/finders_test.go +++ b/finders_test.go @@ -727,6 +727,40 @@ func Test_Count(t *testing.T) { }) } +func Test_Count_Optimized(t *testing.T) { + if PDB == nil { + t.Skip("skipping integration tests") + } + transaction(func(tx *Connection) { + r := require.New(t) + + tx.OptimizeCount = true + + user := User{Name: nulls.NewString("Dylan")} + err := tx.Create(&user) + r.NoError(err) + c, err := tx.Count(&user) + r.NoError(err) + r.Equal(c, 1) + + c, err = tx.Where("1=1").CountByField(&user, "distinct id") + r.NoError(err) + r.Equal(c, 1) + // should ignore order in count + + c, err = tx.Order("id desc").Count(&user) + r.NoError(err) + r.Equal(c, 1) + + var uAQ []UsersAddressQuery + _, err = Q(tx).Select("users_addresses.*").LeftJoin("users", "users.id=users_addresses.user_id").Count(&uAQ) + r.NoError(err) + + _, err = Q(tx).Select("users_addresses.*", "users.name", "users.email").LeftJoin("users", "users.id=users_addresses.user_id").Count(&uAQ) + r.NoError(err) + }) +} + func Test_Count_Disregards_Pagination(t *testing.T) { if PDB == nil { t.Skip("skipping integration tests") diff --git a/model.go b/model.go index 58019f73..e5512675 100644 --- a/model.go +++ b/model.go @@ -27,8 +27,9 @@ type modelIterable func(*Model) error // that is passed in to many functions. type Model struct { Value - tableName string - As string + tableName string + As string + ignoreTableName bool } // ID returns the ID of the Model. All models must have an `ID` field this is diff --git a/query.go b/query.go index b9af3998..64937388 100644 --- a/query.go +++ b/query.go @@ -25,6 +25,7 @@ type Query struct { havingClauses havingClauses Paginator *Paginator Connection *Connection + OptimizeCount bool } // Clone will fill targetQ query with the connection used in q, if @@ -42,6 +43,7 @@ func (q *Query) Clone(targetQ *Query) { targetQ.groupClauses = q.groupClauses targetQ.havingClauses = q.havingClauses targetQ.addColumns = q.addColumns + targetQ.OptimizeCount = q.OptimizeCount if q.Paginator != nil { paginator := *q.Paginator @@ -196,6 +198,7 @@ func Q(c *Connection) *Query { eager: c.eager, eagerFields: c.eagerFields, eagerMode: eagerModeNil, + OptimizeCount: c.OptimizeCount, } } diff --git a/sql_builder.go b/sql_builder.go index edd05c46..103a33ed 100644 --- a/sql_builder.go +++ b/sql_builder.go @@ -216,10 +216,13 @@ var columnCacheMutex = sync.RWMutex{} func (sq *sqlBuilder) buildColumns() columns.Columns { tableName := sq.Model.TableName() + asName := sq.Model.As - if asName == "" { + // If asName is not explicitly set and ignoreTableName is set, then don't us an AS name (alias) + if asName == "" && !sq.Model.ignoreTableName { asName = strings.Replace(tableName, ".", "_", -1) } + acl := len(sq.AddColumns) if acl == 0 { columnCacheMutex.RLock()