diff --git a/assert.go b/assert.go index 17916c4..72c9903 100644 --- a/assert.go +++ b/assert.go @@ -28,15 +28,18 @@ import ( "fmt" "sort" "strings" + "sync" ) var failureSprint FailureSprintFunc = DefaultFailureSprint +var failurePool = &sync.Pool{New: func() interface{} { return &Failure{} }} + // Failure 在断言出错时输出的错误信息 type Failure struct { Action string // 操作名称,比如 Equal,NotEqual 等方法名称。 Values map[string]interface{} // 断言出错时返回的一些额外参数 - User string // 断言出错时用户反馈的额外信息 + user []interface{} // 断言出错时用户反馈的额外信息 } // FailureSprintFunc 将 [Failure] 转换成文本的函数 @@ -49,6 +52,7 @@ type FailureSprintFunc = func(*Failure) string // [New] 方法在默认情况下继承由此方法设置的值。 func SetFailureSprintFunc(f FailureSprintFunc) { failureSprint = f } +// GetFailureSprintFunc 获取当前的 [FailureSprintFunc] 方法 func GetFailureSprintFunc() FailureSprintFunc { return failureSprint } // DefaultFailureSprint 默认的 [FailureSprintFunc] 实现 @@ -73,9 +77,9 @@ func DefaultFailureSprint(f *Failure) string { } } - if f.User != "" { + if u := f.User(); u != "" { s.WriteString("用户反馈信息:") - s.WriteString(f.User) + s.WriteString(u) } return s.String() @@ -87,19 +91,25 @@ func DefaultFailureSprint(f *Failure) string { // 对数据进行格式化,否则采用 fmt.Sprint(user...) 格式化数据; // kv 表示当前错误返回的数据; func NewFailure(action string, user []interface{}, kv map[string]interface{}) *Failure { - var u string - if len(user) > 0 { - switch v := user[0].(type) { - case string: - u = fmt.Sprintf(v, user[1:]...) - default: - u = fmt.Sprint(user...) - } + f := failurePool.Get().(*Failure) + f.Action = action + f.user = user + f.Values = kv + return f +} + +// User 返回用户提交的返馈信息 +func (f *Failure) User() string { + // NOTE: 通过函数的方式返回字符串,而不是直接在 [NewFailure] 直接处理完,可以确保在未使用的情况下无需初始化。 + + if len(f.user) == 0 { + return "" } - return &Failure{ - Action: action, - User: u, - Values: kv, + switch v := f.user[0].(type) { + case string: + return fmt.Sprintf(v, f.user[1:]...) + default: + return fmt.Sprint(f.user...) } } diff --git a/assert_test.go b/assert_test.go index 66f3815..eacd13f 100644 --- a/assert_test.go +++ b/assert_test.go @@ -8,7 +8,7 @@ import "testing" func TestDefaultFailureSprint(t *testing.T) { f := NewFailure("A", nil, nil) - if f.Action != "A" || f.User != "" || len(f.Values) != 0 { + if f.Action != "A" || f.User() != "" || len(f.Values) != 0 { t.Error("err1") } if s := DefaultFailureSprint(f); s != "A 断言失败!" { @@ -17,7 +17,7 @@ func TestDefaultFailureSprint(t *testing.T) { // 带 user f = NewFailure("AB", []interface{}{1, 2}, nil) - if f.Action != "AB" || f.User != "1 2" || len(f.Values) != 0 { + if f.Action != "AB" || f.User() != "1 2" || len(f.Values) != 0 { t.Error("err3") } if s := DefaultFailureSprint(f); s != "AB 断言失败!用户反馈信息:1 2" { @@ -26,7 +26,7 @@ func TestDefaultFailureSprint(t *testing.T) { // 带 values f = NewFailure("AB", nil, map[string]interface{}{"k1": "v1", "k2": 2}) - if f.Action != "AB" || f.User != "" || len(f.Values) != 2 { + if f.Action != "AB" || f.User() != "" || len(f.Values) != 2 { t.Error("err5") } if s := DefaultFailureSprint(f); s != "AB 断言失败!反馈以下参数:\nk1=v1\nk2=2\n" { @@ -35,7 +35,7 @@ func TestDefaultFailureSprint(t *testing.T) { // 带 user,values f = NewFailure("AB", []interface{}{1, 2}, map[string]interface{}{"k1": "v1", "k2": 2}) - if f.Action != "AB" || f.User == "" || len(f.Values) != 2 { + if f.Action != "AB" || f.User() == "" || len(f.Values) != 2 { t.Error("err7") } if s := DefaultFailureSprint(f); s != "AB 断言失败!反馈以下参数:\nk1=v1\nk2=2\n用户反馈信息:1 2" { diff --git a/assertion.go b/assertion.go index 1157971..b89612b 100644 --- a/assertion.go +++ b/assertion.go @@ -16,14 +16,14 @@ import ( "time" ) -// Assertion 是对 testing 包的一些简单包装 +// Assertion 是对 [testing.TB] 的二次包装 type Assertion struct { tb testing.TB print func(...interface{}) f FailureSprintFunc } -// New 返回 Assertion 对象 +// New 返回 [Assertion] 对象 // // fatal 决定在出错时是调用 tb.Error 还是 tb.Fatal; func New(tb testing.TB, fatal bool) *Assertion { @@ -39,7 +39,7 @@ func New(tb testing.TB, fatal bool) *Assertion { } } -// NewWithEnv 以指定的环境变量初始化 *Assertion 对象 +// NewWithEnv 以指定的环境变量初始化 [Assertion] 对象 // // env 是以 [testing.TB.Setenv] 的形式调用。 func NewWithEnv(tb testing.TB, fatal bool, env map[string]string) *Assertion { @@ -59,6 +59,7 @@ func (a *Assertion) Assert(expr bool, f *Failure) *Assertion { a.TB().Helper() a.print(a.f(f)) } + failurePool.Put(f) return a } diff --git a/assertion_test.go b/assertion_test.go index 1be9306..e550f72 100644 --- a/assertion_test.go +++ b/assertion_test.go @@ -187,9 +187,9 @@ func TestAssertion_TypeEqual(t *testing.T) { func TestAssertion_Same(t *testing.T) { a := New(t, false) - a.NotSame(5, 5) - a.NotSame(struct{}{}, struct{}{}) - a.NotSame(func() {}, func() {}) + a.NotSame(5, 5). + NotSame(struct{}{}, struct{}{}). + NotSame(func() {}, func() {}) i := 5 a.NotSame(i, i)