Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add context to graceful shutdown #535

Open
wants to merge 1 commit into
base: v3
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 12 additions & 8 deletions chain.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package cron

import (
"context"
"fmt"
"runtime"
"sync"
Expand All @@ -24,9 +25,12 @@ func NewChain(c ...JobWrapper) Chain {
// Then decorates the given job with all JobWrappers in the chain.
//
// This:
// NewChain(m1, m2, m3).Then(job)
//
// NewChain(m1, m2, m3).Then(job)
//
// is equivalent to:
// m1(m2(m3(job)))
//
// m1(m2(m3(job)))
func (c Chain) Then(j Job) Job {
for i := range c.wrappers {
j = c.wrappers[len(c.wrappers)-i-1](j)
Expand All @@ -37,7 +41,7 @@ func (c Chain) Then(j Job) Job {
// Recover panics in wrapped jobs and log them with the provided logger.
func Recover(logger Logger) JobWrapper {
return func(j Job) Job {
return FuncJob(func() {
return FuncJob(func(ctx context.Context) {
defer func() {
if r := recover(); r != nil {
const size = 64 << 10
Expand All @@ -50,7 +54,7 @@ func Recover(logger Logger) JobWrapper {
logger.Error(err, "panic", "stack", "...\n"+string(buf))
}
}()
j.Run()
j.Run(ctx)
})
}
}
Expand All @@ -61,14 +65,14 @@ func Recover(logger Logger) JobWrapper {
func DelayIfStillRunning(logger Logger) JobWrapper {
return func(j Job) Job {
var mu sync.Mutex
return FuncJob(func() {
return FuncJob(func(ctx context.Context) {
start := time.Now()
mu.Lock()
defer mu.Unlock()
if dur := time.Since(start); dur > time.Minute {
logger.Info("delay", "duration", dur)
}
j.Run()
j.Run(ctx)
})
}
}
Expand All @@ -79,10 +83,10 @@ func SkipIfStillRunning(logger Logger) JobWrapper {
var ch = make(chan struct{}, 1)
ch <- struct{}{}
return func(j Job) Job {
return FuncJob(func() {
return FuncJob(func(ctx context.Context) {
select {
case v := <-ch:
j.Run()
j.Run(ctx)
ch <- v
default:
logger.Info("skip")
Expand Down
43 changes: 22 additions & 21 deletions chain_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package cron

import (
"context"
"io/ioutil"
"log"
"reflect"
Expand All @@ -11,7 +12,7 @@ import (

func appendingJob(slice *[]int, value int) Job {
var m sync.Mutex
return FuncJob(func() {
return FuncJob(func(ctx context.Context) {
m.Lock()
*slice = append(*slice, value)
m.Unlock()
Expand All @@ -20,9 +21,9 @@ func appendingJob(slice *[]int, value int) Job {

func appendingWrapper(slice *[]int, value int) JobWrapper {
return func(j Job) Job {
return FuncJob(func() {
appendingJob(slice, value).Run()
j.Run()
return FuncJob(func(ctx context.Context) {
appendingJob(slice, value).Run(ctx)
j.Run(ctx)
})
}
}
Expand All @@ -35,14 +36,14 @@ func TestChain(t *testing.T) {
append3 = appendingWrapper(&nums, 3)
append4 = appendingJob(&nums, 4)
)
NewChain(append1, append2, append3).Then(append4).Run()
NewChain(append1, append2, append3).Then(append4).Run(context.Background())
if !reflect.DeepEqual(nums, []int{1, 2, 3, 4}) {
t.Error("unexpected order of calls:", nums)
}
}

func TestChainRecover(t *testing.T) {
panickingJob := FuncJob(func() {
panickingJob := FuncJob(func(ctx context.Context) {
panic("panickingJob panics")
})

Expand All @@ -53,19 +54,19 @@ func TestChainRecover(t *testing.T) {
}
}()
NewChain().Then(panickingJob).
Run()
Run(context.Background())
})

t.Run("Recovering JobWrapper recovers", func(t *testing.T) {
NewChain(Recover(PrintfLogger(log.New(ioutil.Discard, "", 0)))).
Then(panickingJob).
Run()
Run(context.Background())
})

t.Run("composed with the *IfStillRunning wrappers", func(t *testing.T) {
NewChain(Recover(PrintfLogger(log.New(ioutil.Discard, "", 0)))).
Then(panickingJob).
Run()
Run(context.Background())
})
}

Expand All @@ -76,7 +77,7 @@ type countJob struct {
delay time.Duration
}

func (j *countJob) Run() {
func (j *countJob) Run(context.Context) {
j.m.Lock()
j.started++
j.m.Unlock()
Expand All @@ -103,7 +104,7 @@ func TestChainDelayIfStillRunning(t *testing.T) {
t.Run("runs immediately", func(t *testing.T) {
var j countJob
wrappedJob := NewChain(DelayIfStillRunning(DiscardLogger)).Then(&j)
go wrappedJob.Run()
go wrappedJob.Run(context.Background())
time.Sleep(2 * time.Millisecond) // Give the job 2ms to complete.
if c := j.Done(); c != 1 {
t.Errorf("expected job run once, immediately, got %d", c)
Expand All @@ -114,9 +115,9 @@ func TestChainDelayIfStillRunning(t *testing.T) {
var j countJob
wrappedJob := NewChain(DelayIfStillRunning(DiscardLogger)).Then(&j)
go func() {
go wrappedJob.Run()
go wrappedJob.Run(context.Background())
time.Sleep(time.Millisecond)
go wrappedJob.Run()
go wrappedJob.Run(context.Background())
}()
time.Sleep(3 * time.Millisecond) // Give both jobs 3ms to complete.
if c := j.Done(); c != 2 {
Expand All @@ -129,9 +130,9 @@ func TestChainDelayIfStillRunning(t *testing.T) {
j.delay = 10 * time.Millisecond
wrappedJob := NewChain(DelayIfStillRunning(DiscardLogger)).Then(&j)
go func() {
go wrappedJob.Run()
go wrappedJob.Run(context.Background())
time.Sleep(time.Millisecond)
go wrappedJob.Run()
go wrappedJob.Run(context.Background())
}()

// After 5ms, the first job is still in progress, and the second job was
Expand All @@ -157,7 +158,7 @@ func TestChainSkipIfStillRunning(t *testing.T) {
t.Run("runs immediately", func(t *testing.T) {
var j countJob
wrappedJob := NewChain(SkipIfStillRunning(DiscardLogger)).Then(&j)
go wrappedJob.Run()
go wrappedJob.Run(context.Background())
time.Sleep(2 * time.Millisecond) // Give the job 2ms to complete.
if c := j.Done(); c != 1 {
t.Errorf("expected job run once, immediately, got %d", c)
Expand All @@ -168,9 +169,9 @@ func TestChainSkipIfStillRunning(t *testing.T) {
var j countJob
wrappedJob := NewChain(SkipIfStillRunning(DiscardLogger)).Then(&j)
go func() {
go wrappedJob.Run()
go wrappedJob.Run(context.Background())
time.Sleep(time.Millisecond)
go wrappedJob.Run()
go wrappedJob.Run(context.Background())
}()
time.Sleep(3 * time.Millisecond) // Give both jobs 3ms to complete.
if c := j.Done(); c != 2 {
Expand All @@ -183,9 +184,9 @@ func TestChainSkipIfStillRunning(t *testing.T) {
j.delay = 10 * time.Millisecond
wrappedJob := NewChain(SkipIfStillRunning(DiscardLogger)).Then(&j)
go func() {
go wrappedJob.Run()
go wrappedJob.Run(context.Background())
time.Sleep(time.Millisecond)
go wrappedJob.Run()
go wrappedJob.Run(context.Background())
}()

// After 5ms, the first job is still in progress, and the second job was
Expand All @@ -209,7 +210,7 @@ func TestChainSkipIfStillRunning(t *testing.T) {
j.delay = 10 * time.Millisecond
wrappedJob := NewChain(SkipIfStillRunning(DiscardLogger)).Then(&j)
for i := 0; i < 11; i++ {
go wrappedJob.Run()
go wrappedJob.Run(context.Background())
}
time.Sleep(200 * time.Millisecond)
done := j.Done()
Expand Down
34 changes: 20 additions & 14 deletions cron.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,13 @@ type Cron struct {
parser Parser
nextID EntryID
jobWaiter sync.WaitGroup
ctx context.Context
cancel context.CancelFunc
}

// Job is an interface for submitted cron jobs.
type Job interface {
Run()
Run(ctx context.Context)
}

// Schedule describes a job's duty cycle.
Expand Down Expand Up @@ -92,20 +94,21 @@ func (s byTime) Less(i, j int) bool {
//
// Available Settings
//
// Time Zone
// Description: The time zone in which schedules are interpreted
// Default: time.Local
// Time Zone
// Description: The time zone in which schedules are interpreted
// Default: time.Local
//
// Parser
// Description: Parser converts cron spec strings into cron.Schedules.
// Default: Accepts this spec: https://en.wikipedia.org/wiki/Cron
// Parser
// Description: Parser converts cron spec strings into cron.Schedules.
// Default: Accepts this spec: https://en.wikipedia.org/wiki/Cron
//
// Chain
// Description: Wrap submitted jobs to customize behavior.
// Default: A chain that recovers panics and logs them to stderr.
// Chain
// Description: Wrap submitted jobs to customize behavior.
// Default: A chain that recovers panics and logs them to stderr.
//
// See "cron.With*" to modify the default behavior.
func New(opts ...Option) *Cron {
ctx, cancel := context.WithCancel(context.Background())
c := &Cron{
entries: nil,
chain: NewChain(),
Expand All @@ -118,6 +121,8 @@ func New(opts ...Option) *Cron {
logger: DefaultLogger,
location: time.Local,
parser: standardParser,
ctx: ctx,
cancel: cancel,
}
for _, opt := range opts {
opt(c)
Expand All @@ -126,14 +131,14 @@ func New(opts ...Option) *Cron {
}

// FuncJob is a wrapper that turns a func() into a cron.Job
type FuncJob func()
type FuncJob func(ctx context.Context)

func (f FuncJob) Run() { f() }
func (f FuncJob) Run(ctx context.Context) { f(ctx) }

// AddFunc adds a func to the Cron to be run on the given schedule.
// The spec is parsed using the time zone of this Cron instance as the default.
// An opaque ID is returned that can be used to later remove it.
func (c *Cron) AddFunc(spec string, cmd func()) (EntryID, error) {
func (c *Cron) AddFunc(spec string, cmd func(ctx context.Context)) (EntryID, error) {
return c.AddJob(spec, FuncJob(cmd))
}

Expand Down Expand Up @@ -304,7 +309,7 @@ func (c *Cron) startJob(j Job) {
c.jobWaiter.Add(1)
go func() {
defer c.jobWaiter.Done()
j.Run()
j.Run(c.ctx)
}()
}

Expand All @@ -319,6 +324,7 @@ func (c *Cron) Stop() context.Context {
c.runningMu.Lock()
defer c.runningMu.Unlock()
if c.running {
c.cancel()
c.stop <- struct{}{}
c.running = false
}
Expand Down
Loading