diff --git a/groups/group.go b/groups/group.go index 1d6f3b8..53bf698 100644 --- a/groups/group.go +++ b/groups/group.go @@ -5,6 +5,7 @@ import ( "errors" "github.com/lxzan/concurrency/internal" "sync" + "sync/atomic" "time" ) @@ -13,48 +14,51 @@ const ( defaultWaitTimeout = 60 * time.Second // 默认线程同步等待超时 ) +var defaultCaller Caller = func(args any, f func(any) error) error { return f(args) } + type ( Caller func(args any, f func(any) error) error Group[T any] struct { - options *options - mu *sync.Mutex // 锁 - errs []error // 错误 - done chan bool // 信号 - q []T // 任务队列 - taskDone int64 // 已完成任务数量 - taskTotal int64 // 总任务数量 - OnMessage func(args T) error // 任务处理 - OnError func(err error) // 错误处理 + options *options // 配置 + mu sync.Mutex // 锁 + ctx context.Context // 上下文 + cancelFunc context.CancelFunc // 取消函数 + canceled atomic.Uint32 // 是否已取消 + errs []error // 错误 + done chan bool // 完成信号 + q []T // 任务队列 + taskDone int64 // 已完成任务数量 + taskTotal int64 // 总任务数量 + OnMessage func(args T) error // 任务处理 + OnError func(err error) // 错误处理 } ) // New 新建一个任务集 func New[T any](opts ...Option) *Group[T] { - o := &options{ - timeout: defaultWaitTimeout, - concurrency: defaultConcurrency, - caller: func(args any, f func(any) error) error { return f(args) }, - } + o := new(options) + opts = append(opts, withInitialize()) for _, f := range opts { f(o) } c := &Group[T]{ options: o, - mu: &sync.Mutex{}, q: make([]T, 0), taskDone: 0, done: make(chan bool), } + c.ctx, c.cancelFunc = context.WithTimeout(context.Background(), o.timeout) c.OnMessage = func(args T) error { return nil } c.OnError = func(err error) {} + return c } -func (c *Group[T]) clear() { +func (c *Group[T]) clearJob() { c.mu.Lock() c.q = c.q[:0] c.mu.Unlock() @@ -82,19 +86,21 @@ func (c *Group[T]) incrAndIsDone() bool { return ok } -func (c *Group[T]) hasError() bool { +func (c *Group[T]) getError() error { c.mu.Lock() defer c.mu.Unlock() - return len(c.errs) > 0 + return errors.Join(c.errs...) +} + +func (c *Group[T]) jobFunc(v any) error { + if c.canceled.Load() == 1 { + return nil + } + return c.OnMessage(v.(T)) } func (c *Group[T]) do(args T) { - if err := c.options.caller(args, func(v any) error { - if c.options.cancel && c.hasError() { - return nil - } - return c.OnMessage(v.(T)) - }); err != nil { + if err := c.options.caller(args, c.jobFunc); err != nil { c.mu.Lock() c.errs = append(c.errs, err) c.mu.Unlock() @@ -119,6 +125,13 @@ func (c *Group[T]) Len() int { return x } +// Cancel 取消队列中剩余任务的执行 +func (c *Group[T]) Cancel() { + if c.canceled.CompareAndSwap(0, 1) { + c.cancelFunc() + } +} + // Push 往任务队列中追加任务 func (c *Group[T]) Push(eles ...T) { c.mu.Lock() @@ -148,13 +161,13 @@ func (c *Group[T]) Start() error { } } - ctx, cancel := context.WithTimeout(context.Background(), c.options.timeout) - defer cancel() + defer c.cancelFunc() + select { case <-c.done: - return errors.Join(c.errs...) - case <-ctx.Done(): - c.clear() - return ctx.Err() + return c.getError() + case <-c.ctx.Done(): + c.clearJob() + return c.ctx.Err() } } diff --git a/groups/group_test.go b/groups/group_test.go index 76c0512..b890b59 100644 --- a/groups/group_test.go +++ b/groups/group_test.go @@ -82,7 +82,7 @@ func TestNewTaskGroup(t *testing.T) { }) t.Run("cancel", func(t *testing.T) { - ctl := New[int](WithCancel(), WithConcurrency(1)) + ctl := New[int](WithConcurrency(1)) ctl.Push(1, 3, 5) arr := make([]int, 0) ctl.OnMessage = func(args int) error { @@ -96,6 +96,9 @@ func TestNewTaskGroup(t *testing.T) { return nil } } + ctl.OnError = func(err error) { + ctl.Cancel() + } err := ctl.Start() as.Error(err) as.ElementsMatch(arr, []int{1, 3}) diff --git a/groups/options.go b/groups/options.go index 96c9e4c..ae2bec3 100644 --- a/groups/options.go +++ b/groups/options.go @@ -1,6 +1,7 @@ package groups import ( + "github.com/lxzan/concurrency/internal" "github.com/pkg/errors" "runtime" "time" @@ -11,7 +12,6 @@ type options struct { timeout time.Duration concurrency int64 caller Caller - cancel bool } type Option func(o *options) @@ -24,16 +24,9 @@ func WithTimeout(t time.Duration) Option { } // WithConcurrency 设置最大并发 -func WithConcurrency(n int64) Option { +func WithConcurrency(n uint32) Option { return func(o *options) { - o.concurrency = n - } -} - -// WithCancel 设置遇到错误放弃执行剩余任务 -func WithCancel() Option { - return func(o *options) { - o.cancel = true + o.concurrency = int64(n) } } @@ -55,3 +48,11 @@ func WithRecovery() Option { } } } + +func withInitialize() Option { + return func(o *options) { + o.timeout = internal.SelectValue(o.timeout <= 0, defaultWaitTimeout, o.timeout) + o.concurrency = internal.SelectValue(o.concurrency <= 0, defaultConcurrency, o.concurrency) + o.caller = internal.SelectValue(o.caller == nil, defaultCaller, o.caller) + } +}