Skip to content

Commit

Permalink
perf: 优化 Failure 的使用
Browse files Browse the repository at this point in the history
  • Loading branch information
caixw committed Feb 27, 2024
1 parent cd3a79a commit 9785f22
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 25 deletions.
40 changes: 25 additions & 15 deletions assert.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,18 @@ import (
"fmt"
"sort"
"strings"
"sync"
)

var failureSprint FailureSprintFunc = DefaultFailureSprint

var failurePool = &sync.Pool{New: func() interface{} { return &Failure{} }}

// Failure 在断言出错时输出的错误信息
type Failure struct {
Action string // 操作名称,比如 Equal,NotEqual 等方法名称。
Values map[string]interface{} // 断言出错时返回的一些额外参数
User string // 断言出错时用户反馈的额外信息
user []interface{} // 断言出错时用户反馈的额外信息
}

// FailureSprintFunc 将 [Failure] 转换成文本的函数
Expand All @@ -49,6 +52,7 @@ type FailureSprintFunc = func(*Failure) string
// [New] 方法在默认情况下继承由此方法设置的值。
func SetFailureSprintFunc(f FailureSprintFunc) { failureSprint = f }

// GetFailureSprintFunc 获取当前的 [FailureSprintFunc] 方法
func GetFailureSprintFunc() FailureSprintFunc { return failureSprint }

// DefaultFailureSprint 默认的 [FailureSprintFunc] 实现
Expand All @@ -73,9 +77,9 @@ func DefaultFailureSprint(f *Failure) string {
}
}

if f.User != "" {
if u := f.User(); u != "" {
s.WriteString("用户反馈信息:")
s.WriteString(f.User)
s.WriteString(u)
}

return s.String()
Expand All @@ -87,19 +91,25 @@ func DefaultFailureSprint(f *Failure) string {
// 对数据进行格式化,否则采用 fmt.Sprint(user...) 格式化数据;
// kv 表示当前错误返回的数据;
func NewFailure(action string, user []interface{}, kv map[string]interface{}) *Failure {
var u string
if len(user) > 0 {
switch v := user[0].(type) {
case string:
u = fmt.Sprintf(v, user[1:]...)
default:
u = fmt.Sprint(user...)
}
f := failurePool.Get().(*Failure)
f.Action = action
f.user = user
f.Values = kv
return f
}

// User 返回用户提交的返馈信息
func (f *Failure) User() string {
// NOTE: 通过函数的方式返回字符串,而不是直接在 [NewFailure] 直接处理完,可以确保在未使用的情况下无需初始化。

if len(f.user) == 0 {
return ""
}

return &Failure{
Action: action,
User: u,
Values: kv,
switch v := f.user[0].(type) {
case string:
return fmt.Sprintf(v, f.user[1:]...)
default:
return fmt.Sprint(f.user...)
}
}
8 changes: 4 additions & 4 deletions assert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import "testing"

func TestDefaultFailureSprint(t *testing.T) {
f := NewFailure("A", nil, nil)
if f.Action != "A" || f.User != "" || len(f.Values) != 0 {
if f.Action != "A" || f.User() != "" || len(f.Values) != 0 {
t.Error("err1")
}
if s := DefaultFailureSprint(f); s != "A 断言失败!" {
Expand All @@ -17,7 +17,7 @@ func TestDefaultFailureSprint(t *testing.T) {

// 带 user
f = NewFailure("AB", []interface{}{1, 2}, nil)
if f.Action != "AB" || f.User != "1 2" || len(f.Values) != 0 {
if f.Action != "AB" || f.User() != "1 2" || len(f.Values) != 0 {
t.Error("err3")
}
if s := DefaultFailureSprint(f); s != "AB 断言失败!用户反馈信息:1 2" {
Expand All @@ -26,7 +26,7 @@ func TestDefaultFailureSprint(t *testing.T) {

// 带 values
f = NewFailure("AB", nil, map[string]interface{}{"k1": "v1", "k2": 2})
if f.Action != "AB" || f.User != "" || len(f.Values) != 2 {
if f.Action != "AB" || f.User() != "" || len(f.Values) != 2 {
t.Error("err5")
}
if s := DefaultFailureSprint(f); s != "AB 断言失败!反馈以下参数:\nk1=v1\nk2=2\n" {
Expand All @@ -35,7 +35,7 @@ func TestDefaultFailureSprint(t *testing.T) {

// 带 user,values
f = NewFailure("AB", []interface{}{1, 2}, map[string]interface{}{"k1": "v1", "k2": 2})
if f.Action != "AB" || f.User == "" || len(f.Values) != 2 {
if f.Action != "AB" || f.User() == "" || len(f.Values) != 2 {
t.Error("err7")
}
if s := DefaultFailureSprint(f); s != "AB 断言失败!反馈以下参数:\nk1=v1\nk2=2\n用户反馈信息:1 2" {
Expand Down
7 changes: 4 additions & 3 deletions assertion.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@ import (
"time"
)

// Assertion 是对 testing 包的一些简单包装
// Assertion 是对 [testing.TB] 的二次包装
type Assertion struct {
tb testing.TB
print func(...interface{})
f FailureSprintFunc
}

// New 返回 Assertion 对象
// New 返回 [Assertion] 对象
//
// fatal 决定在出错时是调用 tb.Error 还是 tb.Fatal;
func New(tb testing.TB, fatal bool) *Assertion {
Expand All @@ -39,7 +39,7 @@ func New(tb testing.TB, fatal bool) *Assertion {
}
}

// NewWithEnv 以指定的环境变量初始化 *Assertion 对象
// NewWithEnv 以指定的环境变量初始化 [Assertion] 对象
//
// env 是以 [testing.TB.Setenv] 的形式调用。
func NewWithEnv(tb testing.TB, fatal bool, env map[string]string) *Assertion {
Expand All @@ -59,6 +59,7 @@ func (a *Assertion) Assert(expr bool, f *Failure) *Assertion {
a.TB().Helper()
a.print(a.f(f))
}
failurePool.Put(f)
return a
}

Expand Down
6 changes: 3 additions & 3 deletions assertion_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,9 +187,9 @@ func TestAssertion_TypeEqual(t *testing.T) {
func TestAssertion_Same(t *testing.T) {
a := New(t, false)

a.NotSame(5, 5)
a.NotSame(struct{}{}, struct{}{})
a.NotSame(func() {}, func() {})
a.NotSame(5, 5).
NotSame(struct{}{}, struct{}{}).
NotSame(func() {}, func() {})

i := 5
a.NotSame(i, i)
Expand Down

0 comments on commit 9785f22

Please sign in to comment.