diff --git a/consumer.go b/consumer.go index da3e79d..56ec765 100644 --- a/consumer.go +++ b/consumer.go @@ -20,6 +20,7 @@ type Consumer struct { taskQueue chan core.QueuedMessage runFunc func(context.Context, core.QueuedMessage) error stop chan struct{} + exit chan struct{} logger Logger stopOnce sync.Once stopFlag int32 @@ -101,6 +102,9 @@ func (s *Consumer) Shutdown() error { s.stopOnce.Do(func() { close(s.stop) close(s.taskQueue) + if len(s.taskQueue) > 0 { + <-s.exit + } }) return nil } @@ -127,6 +131,10 @@ loop: select { case task, ok := <-s.taskQueue: if !ok { + select { + case s.exit <- struct{}{}: + default: + } return nil, ErrQueueHasBeenClosed } return task, nil @@ -147,6 +155,7 @@ func NewConsumer(opts ...Option) *Consumer { w := &Consumer{ taskQueue: make(chan core.QueuedMessage, o.queueSize), stop: make(chan struct{}), + exit: make(chan struct{}), logger: o.logger, runFunc: o.fn, } diff --git a/consumer_test.go b/consumer_test.go index ba013b4..7ea1fac 100644 --- a/consumer_test.go +++ b/consumer_test.go @@ -10,6 +10,8 @@ import ( "time" "github.com/golang-queue/queue/core" + "github.com/golang-queue/queue/mocks" + "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" ) @@ -237,8 +239,6 @@ func TestHandleTimeout(t *testing.T) { done <- w.handle(job) }() - assert.NoError(t, w.Shutdown()) - err = <-done assert.Error(t, err) assert.Equal(t, context.DeadlineExceeded, err) @@ -276,8 +276,6 @@ func TestJobComplete(t *testing.T) { done <- w.handle(job) }() - assert.NoError(t, w.Shutdown()) - err = <-done assert.Error(t, err) assert.Equal(t, errors.New("job completed"), err) @@ -308,7 +306,7 @@ func TestTaskJobComplete(t *testing.T) { go func() { done <- w.handle(job) }() - assert.NoError(t, w.Shutdown()) + err = <-done assert.NoError(t, err) @@ -385,3 +383,68 @@ func TestDecreaseWorkerCount(t *testing.T) { assert.Equal(t, 2, q.BusyWorkers()) q.Release() } + +func TestHandleAllJobBeforeShutdownConsumer(t *testing.T) { + controller := gomock.NewController(t) + defer controller.Finish() + + m := mocks.NewMockQueuedMessage(controller) + + w := NewConsumer( + WithFn(func(ctx context.Context, m core.QueuedMessage) error { + time.Sleep(10 * time.Millisecond) + return nil + }), + ) + + done := make(chan struct{}) + assert.NoError(t, w.Queue(m)) + assert.NoError(t, w.Queue(m)) + go func() { + assert.NoError(t, w.Shutdown()) + done <- struct{}{} + }() + + task, err := w.Request() + assert.NotNil(t, task) + assert.NoError(t, err) + task, err = w.Request() + assert.NotNil(t, task) + assert.NoError(t, err) + task, err = w.Request() + assert.Nil(t, task) + assert.True(t, errors.Is(err, ErrQueueHasBeenClosed)) + <-done +} + +func TestHandleAllJobBeforeShutdownConsumerInQueue(t *testing.T) { + controller := gomock.NewController(t) + defer controller.Finish() + + m := mocks.NewMockQueuedMessage(controller) + m.EXPECT().Bytes().Return([]byte("test")).AnyTimes() + + messages := make(chan string, 10) + + w := NewConsumer( + WithFn(func(ctx context.Context, m core.QueuedMessage) error { + time.Sleep(10 * time.Millisecond) + messages <- string(m.Bytes()) + return nil + }), + ) + + q, err := NewQueue( + WithLogger(NewLogger()), + WithWorker(w), + WithWorkerCount(1), + ) + assert.NoError(t, err) + + assert.NoError(t, q.Queue(m)) + assert.NoError(t, q.Queue(m)) + assert.Len(t, messages, 0) + q.Start() + q.Release() + assert.Len(t, messages, 2) +}