diff --git a/client/client.go b/client/client.go index 5e0baaa..9bf3603 100644 --- a/client/client.go +++ b/client/client.go @@ -15,7 +15,6 @@ import ( "time" "unsafe" - "github.com/appscode/go/log" rt "github.com/quantcast/g2/pkg/runtime" ) @@ -32,53 +31,148 @@ type connection struct { // of `atomic.SwapPointer` and `atomic.CompareAndSwapPointer` much more // convenient. net.Conn + connVersion int } +type channels struct { + outbound chan *request + expected chan *Response +} + +type ConnCloseHandler func(conn net.Conn) (err error) +type ConnOpenHandler func() (conn net.Conn, err error) + // One client connect to one server. // Use Pool for multi-connections. type Client struct { - net, addr string - handlers sync.Map - expected chan *Response - outbound chan *request - conn *connection + reconnectState uint32 + net, addr string + handlers sync.Map + conn *connection //rw *bufio.ReadWriter - + chans *channels responsePool *sync.Pool requestPool *sync.Pool ResponseTimeout time.Duration - ErrorHandler ErrorHandler + ErrorHandler ErrorHandler + connCloseHandler ConnCloseHandler + connOpenHandler ConnOpenHandler + logHandler LogHandler +} + +type LogLevel int + +const ( + Error LogLevel = 0 + Warning LogLevel = 1 + Info LogLevel = 2 + Debug LogLevel = 3 +) + +type LogHandler func(level LogLevel, message ...string) + +func (client *Client) Log(level LogLevel, message ...string) { + if client.logHandler != nil { + client.logHandler(level, message...) + } +} + +func NewConnected(conn net.Conn) (client *Client) { + + existingConnection := &connection{conn, 0} + + connOpenHandler := func() (conn net.Conn, err error) { + if existingConnection != nil { + conn = existingConnection.Conn + existingConnection = nil + } else { + err = errors.New("Connection supplied to NewConnected() failed") + } + return + } + + client, _ = NewClient(nil, connOpenHandler, nil) + + return client } // Return a client. -func New(network, addr string) (client *Client, err error) { - conn, err := net.Dial(network, addr) +func New(network, addr string, logHandler LogHandler) (client *Client, err error) { + + if logHandler == nil { + logHandler = func(level LogLevel, message ...string) {} + } + + retryPeriod := 3 * time.Second + + connOpenHandler := func() (conn net.Conn, err error) { + logHandler(Info, fmt.Sprintf("Trying to connect to server %v ...", addr)) + for { + for numTries := 1; ; numTries++ { + if numTries%100 == 0 { + logHandler(Info, fmt.Sprintf("Still trying to connect to server %v, attempt# %v ...", addr, numTries)) + } + conn, err = net.Dial(network, addr) + if err != nil { + time.Sleep(retryPeriod) + continue + } + break + } + // at this point the server is back online, we will disconnect and reconnect again to make sure that we don't have + // one of those dud connections which could happen if we've reconnected to gearman too quickly after it started + _ = conn.Close() + time.Sleep(retryPeriod) + + // todo: come up with a more reliable way to determine if we have a working connection to gearman, pehaps by performing a test + conn, err = net.Dial(network, addr) + if err != nil { + // looks like there is another problem, go back to the main loop + time.Sleep(retryPeriod) + continue + } + + break + } + logHandler(Info, fmt.Sprintf("Connected to server %v", addr)) - if err != nil { return } - client = NewConnected(conn) + client, err = NewClient(nil, connOpenHandler, logHandler) return } -// Return a new client from an established connection. Largely used for -// testing, though other use-cases can be imagined. -func NewConnected(conn net.Conn) (client *Client) { +/// handler_conn_close: optional +func NewClient(connCloseHandler ConnCloseHandler, + connOpenHandler ConnOpenHandler, + logHandler LogHandler) (client *Client, err error) { + + conn, err := connOpenHandler() + if err != nil { + // if we're emitting errors we wont log them, they can be logged by the codebase that's using this client + err = errors.New(fmt.Sprintf("Failed to create new client: %v", err)) + return + } + addr := conn.RemoteAddr() client = &Client{ - net: addr.Network(), - addr: addr.String(), - conn: &connection{conn}, - outbound: make(chan *request), - expected: make(chan *Response), - ResponseTimeout: DefaultTimeout, - responsePool: &sync.Pool{New: func() interface{} { return &Response{} }}, - requestPool: &sync.Pool{New: func() interface{} { return &request{} }}, + net: addr.Network(), + addr: addr.String(), + conn: &connection{Conn: conn}, + chans: &channels{ + expected: make(chan *Response), + outbound: make(chan *request)}, + ResponseTimeout: DefaultTimeout, + responsePool: &sync.Pool{New: func() interface{} { return &Response{} }}, + requestPool: &sync.Pool{New: func() interface{} { return &request{} }}, + connCloseHandler: connCloseHandler, + connOpenHandler: connOpenHandler, + logHandler: logHandler, } go client.readLoop() @@ -87,6 +181,25 @@ func NewConnected(conn net.Conn) (client *Client) { return } +func (client *Client) IsConnectionSet() bool { + return client.loadConn() != nil +} + +func (client *Client) writeReconnectCleanup(conn *connection, req *request, ibufs ...[]byte) bool { + for _, ibuf := range ibufs { + if _, err := conn.Write(ibuf); err != nil { + client.requestPool.Put(req) + go client.reconnect(err) + return true // return true will cause writeLoop to exit, it will be restarted upon successful reconnect + } + } + return false +} + +func (client *Client) loadChans() *channels { + return (*channels)(atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&client.chans)))) +} + func (client *Client) writeLoop() { ibuf := make([]byte, 4) length := uint32(0) @@ -95,14 +208,23 @@ func (client *Client) writeLoop() { // Pipeline requests; but only write them one at a time. To allow multiple // goroutines to all write as quickly as possible, uses a channel and the // writeLoop lives in a separate goroutine. - for req := range client.outbound { - client.conn.Write([]byte(rt.ReqStr)) + for req := range client.loadChans().outbound { - // todo handle errors. + conn := client.loadConn() + if conn == nil { + client.requestPool.Put(req) + return + } + + if exit := client.writeReconnectCleanup(conn, req, []byte(rt.ReqStr)); exit { + return + } binary.BigEndian.PutUint32(ibuf, req.pt.Uint32()) - client.conn.Write(ibuf) + if exit := client.writeReconnectCleanup(conn, req, ibuf); exit { + return + } length = 0 @@ -115,13 +237,14 @@ func (client *Client) writeLoop() { binary.BigEndian.PutUint32(ibuf, length) - client.conn.Write(ibuf) - - client.conn.Write(req.data[0]) + if client.writeReconnectCleanup(conn, req, ibuf, req.data[0]) { + return + } for i = 1; i < len(req.data); i++ { - client.conn.Write(NullBytes) - client.conn.Write(req.data[i]) + if exit := client.writeReconnectCleanup(conn, req, NullBytes, req.data[i]); exit { + return + } } client.requestPool.Put(req) @@ -136,77 +259,108 @@ func decodeHeader(header []byte) (code []byte, pt uint32, length int) { return } +func (client *Client) lockReconnect() (success bool) { + return atomic.CompareAndSwapUint32(&client.reconnectState, 0, 1) +} + +// called by owner of reconnect state to tell that it has finished reconnecting +func (client *Client) resetReconnectState() { + atomic.StoreUint32(&client.reconnectState, 0) +} + func (client *Client) reconnect(err error) error { - if client.conn != nil { - return nil - } - // TODO I doubt this error handling is right because it looks - // really complicated. + // not actioning on error if it's deemed Temporary + // we might want to take note of timestamp and eventually recycle this connection + // if it persists too long (even though classified as Temporary here) if opErr, ok := err.(*net.OpError); ok { - if opErr.Timeout() { - client.err(err) - } if opErr.Temporary() { return nil } - - return err } - if err != nil { - client.err(err) + ownReconnect := client.lockReconnect() + + if !ownReconnect { + //Reconnect collision, this thread will exit and wait on next client.Lock() for other to complete reconnection + return nil } - // If it is unexpected error and the connection wasn't - // closed by Gearmand, the client should close the conection - // and reconnect to job server. - client.Close() + defer client.resetReconnectState() // before releasing client lock we will reset reconnection state + + connVersion := client.loadConn().connVersion - conn, err := net.Dial(client.net, client.addr) + client.Log(Error, fmt.Sprintf("Closing connection to %v due to error %v, will reconnect...", client.addr, err)) + if closeErr := client.Close(); closeErr != nil { + client.Log(Warning, fmt.Sprintf("Non-fatal error %v, while closing connection to %v", closeErr, client.addr)) + } + + oldChans := client.loadChans() + close(oldChans.expected) + close(oldChans.outbound) + conn, err := client.connOpenHandler() if err != nil { client.err(err) return err } - swapped := atomic.CompareAndSwapPointer( - (*unsafe.Pointer)(unsafe.Pointer(&client.conn)), - unsafe.Pointer(nil), - unsafe.Pointer(&connection{conn})) + newConn := &connection{conn, connVersion + 1} - if !swapped { - conn.Close() + if swapped := atomic.CompareAndSwapPointer( + (*unsafe.Pointer)(unsafe.Pointer(&client.conn)), + unsafe.Pointer(nil), unsafe.Pointer(newConn)); !swapped { + return errors.New("Was expecting nil when replacing with new connection") } + // replace closed channels with new ones + _ = (*channels)(atomic.SwapPointer( + (*unsafe.Pointer)(unsafe.Pointer(&client.chans)), + unsafe.Pointer(&channels{expected: make(chan *Response), outbound: make(chan *request)}))) + + go client.readLoop() + go client.writeLoop() + return nil } +func (client *Client) loadConn() *connection { + return (*connection)(atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&client.conn)))) +} + +func (client *Client) readReconnect(startConn *connection, buf []byte) (n int, exit bool) { + conn := client.loadConn() + if startConn != conn { + return 0, true + } + var err error + if n, err = io.ReadFull(conn, buf); err != nil { + go client.reconnect(err) + return n, true + } else { + return n, false + } +} + func (client *Client) readLoop() { header := make([]byte, rt.HeaderSize) var err error var resp *Response + startConn := client.loadConn() - for client.conn != nil { - if _, err = io.ReadFull(client.conn, header); err != nil { - if err = client.reconnect(err); err != nil { - break - } + for startConn == client.loadConn() { - continue + if _, exit := client.readReconnect(startConn, header); exit { + return } _, pt, length := decodeHeader(header) contents := make([]byte, length) - if _, err = io.ReadFull(client.conn, contents); err != nil { - if err = client.reconnect(err); err != nil { - break - } - - continue + if _, exit := client.readReconnect(startConn, contents); exit { + return } resp = client.responsePool.Get().(*Response) @@ -245,22 +399,22 @@ func (client *Client) readLoop() { client.process(resp) } + } func (client *Client) process(resp *Response) { - // NOTE Any waiting goroutine which reads from `expected` should return the + // NOTE Any waiting goroutine which reads from `channels` should return the // response object to the pool; but the conditions which handle it // terminally should return it here. switch resp.DataType { case rt.PT_Error: - log.Errorln("Received error", resp.Data) client.err(getError(resp.Data)) - client.expected <- resp + client.loadChans().expected <- resp case rt.PT_StatusRes, rt.PT_JobCreated, rt.PT_EchoRes: - client.expected <- resp + client.loadChans().expected <- resp case rt.PT_WorkComplete, rt.PT_WorkFail, rt.PT_WorkException: defer client.handlers.Delete(resp.Handle) fallthrough @@ -270,9 +424,11 @@ func (client *Client) process(resp *Response) { if handler, ok := client.handlers.Load(resp.Handle); ok { if h, ok := handler.(ResponseHandler); ok { h(resp) + } else { + client.err(errors.New(fmt.Sprintf("Could not cast handler to ResponseHandler for %v", resp.Handle))) } } else { - client.err(fmt.Errorf("unexpected %s response for \"%s\" with no handler", resp.DataType, resp.Handle)) + client.err(errors.New(fmt.Sprintf("unexpected %s response for \"%s\" with no handler", resp.DataType, resp.Handle))) } client.responsePool.Put(resp) @@ -282,6 +438,8 @@ func (client *Client) process(resp *Response) { func (client *Client) err(e error) { if client.ErrorHandler != nil { client.ErrorHandler(e) + } else { + client.Log(Error, e.Error()) // in case ErrorHandler is not supplied we try the Log, this might be important } } @@ -289,20 +447,27 @@ func (client *Client) request() *request { return client.requestPool.Get().(*request) } -func (client *Client) submit(pt rt.PT, funcname string, payload []byte) (string, error) { - var err error +func (client *Client) submit(pt rt.PT, funcname string, payload []byte) (handle string, err error) { - client.outbound <- client.request().submitJob(pt, funcname, IdGen.Id(), payload) + defer func() { + if e := safeCastError(recover(), "panic in submit()"); e != nil { + err = e + } + }() - res := <-client.expected + chans := client.loadChans() + chans.outbound <- client.request().submitJob(pt, funcname, IdGen.Id(), payload) - if res.DataType == rt.PT_Error { - err = getError(res.Data) + if res := <-chans.expected; res != nil { + var err error + if res.DataType == rt.PT_Error { + err = getError(res.Data) + } + defer client.responsePool.Put(res) + return res.Handle, err } - defer client.responsePool.Put(res) - - return res.Handle, err + return "", errors.New("Channels are closed, please resubmit your message") } // Call the function and get a response. @@ -380,7 +545,7 @@ func (client *Client) doCron(funcname string, cronExpr string, funcParam []byte) } func (client *Client) DoAt(funcname string, epoch int64, funcParam []byte) (handle string, err error) { - if client.conn == nil { + if client.loadConn() == nil { return "", ErrLostConn } @@ -393,14 +558,21 @@ func (client *Client) DoAt(funcname string, epoch int64, funcParam []byte) (hand // Get job status from job server. func (client *Client) Status(handle string) (status *Status, err error) { - if err = client.reconnect(nil); err != nil { - return - } - client.outbound <- client.request().status(handle) + defer func() { + if e := safeCastError(recover(), "panic in Status"); e != nil { + err = e + } + }() + + chans := client.loadChans() + chans.outbound <- client.request().status(handle) - res := <-client.expected + res := <-chans.expected + if res == nil { + return nil, errors.New("Status response queue is empty, please resend") + } status, err = res.Status() client.responsePool.Put(res) @@ -410,13 +582,21 @@ func (client *Client) Status(handle string) (status *Status, err error) { // Echo. func (client *Client) Echo(data []byte) (echo []byte, err error) { - if err = client.reconnect(nil); err != nil { - return - } - client.outbound <- client.request().echo(data) + defer func() { + if e := safeCastError(recover(), "panic in Echo"); e != nil { + err = e + } + }() - res := <-client.expected + chans := client.loadChans() + chans.outbound <- client.request().echo(data) + + res := <-chans.expected + + if res == nil { + return nil, errors.New("Echo request got empty response, please resend") + } echo = res.Data @@ -432,8 +612,11 @@ func (client *Client) Close() (err error) { conn := (*connection)(ptr) if conn != nil { - err = conn.Close() - + if client.connCloseHandler != nil { + err = client.connCloseHandler(conn) + } else { + err = conn.Close() + } return } diff --git a/client/client_test.go b/client/client_test.go index fc18df8..bb9ec3f 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -6,14 +6,14 @@ import ( "encoding/hex" "errors" "fmt" + "github.com/appscode/go/log" + rt "github.com/quantcast/g2/pkg/runtime" "io" "net" "os" "reflect" "strings" "testing" - - rt "github.com/quantcast/g2/pkg/runtime" ) type snapshot struct { @@ -187,9 +187,18 @@ func drain(observed io.Reader) { } func TestClose(test *testing.T) { - client, _ := net.Pipe() - gearmanc := NewConnected(client) + handlerConnOpen := func() (conn net.Conn, err error) { + log.Infoln("Creating net.Pipe connection...") + conn, _ = net.Pipe() + return + } + + handlerConnClose := func(conn net.Conn) (err error) { + return conn.Close() + } + + gearmanc := NewClient(handlerConnClose, handlerConnOpen) if gearmanc.Close() != nil { test.Fatalf("expected no error in closing connected client") @@ -209,9 +218,17 @@ func TestSnapshot(test *testing.T) { test.Fatalf("error loading snapshot: %s\n", err) } + handlerConnOpen := func() (conn net.Conn, err error) { + return client, nil // return pre-created pipe client + } + + handlerConnClose := func(conn net.Conn) (err error) { + return conn.Close() + } + // This has to be done in another go-routine since all of the reads/writes // are synchronous - gearmanClient := NewConnected(client) + gearmanClient := NewClient(handlerConnClose, handlerConnOpen) if err = snapshot.replay(server, "server", "client"); err != nil { test.Fatalf("error loading snapshot: %s", err) diff --git a/client/error.go b/client/error.go index eabbf51..cc7fc1b 100644 --- a/client/error.go +++ b/client/error.go @@ -28,3 +28,13 @@ func getError(data []byte) (err error) { // Error handler type ErrorHandler func(error) + +func safeCastError(e interface{}, defaultMessage string) error { + if e == nil { + return nil + } + if err, ok := e.(error); ok { + return err + } + return errors.New(defaultMessage) +} diff --git a/client/pool.go b/client/pool.go index ec6a306..e380ff5 100644 --- a/client/pool.go +++ b/client/pool.go @@ -75,7 +75,7 @@ func (pool *Pool) Add(net, addr string, rate int) (err error) { item.Rate = rate } else { var client *Client - client, err = New(net, addr) + client, err = New(net, addr, nil) if err == nil { item = &PoolClient{Client: client, Rate: rate} pool.Clients[addr] = item diff --git a/client/response.go b/client/response.go index 6157e04..17e0cac 100644 --- a/client/response.go +++ b/client/response.go @@ -58,7 +58,7 @@ func (resp *Response) Update() (data []byte, err error) { func (resp *Response) Status() (status *Status, err error) { data := bytes.SplitN(resp.Data, []byte{'\x00'}, 2) if len(data) != 2 { - err = fmt.Errorf("Invalid data: %v", resp.Data) + err = fmt.Errorf("Invalid data: %v, split resulted in fewer than 2 elements", resp.Data) return } status = &Status{} diff --git a/example/client/client.go b/example/client/client.go index 5bb27cd..329f098 100644 --- a/example/client/client.go +++ b/example/client/client.go @@ -15,7 +15,7 @@ func main() { // by implementing IdGenerator interface. // client.IdGen = client.NewAutoIncId() - c, err := client.New(rt.Network, "127.0.0.1:4730") + c, err := client.NewNetClient(rt.Network, "127.0.0.1:4730") if err != nil { log.Fatalln(err) } diff --git a/example/client/persistent_client.go b/example/client/persistent_client.go new file mode 100644 index 0000000..bd934f6 --- /dev/null +++ b/example/client/persistent_client.go @@ -0,0 +1,124 @@ +package main + +import ( + "log" + "time" + + "github.com/quantcast/g2/client" + rt "github.com/quantcast/g2/pkg/runtime" +) + +func logHandler(level client.LogLevel, message ...string) { + switch level { + case client.Error: + log.Println("Error:", message) + case client.Warning: + log.Println("Warning", message) + case client.Info: + log.Println("Info:", message) + case client.Debug: + log.Println("Debug", message) + } +} + +func main() { + // Set the autoinc id generator + // You can write your own id generator + // by implementing IdGenerator interface. + // client.IdGen = client.NewAutoIncId() + + logs.InitLogs() + logs.FlushLogs() + c, err := client.NewNetClient(rt.Network, "127.0.0.1:4730") + if err != nil { + log.Fatalln(err) + } + defer c.Close() + c.ErrorHandler = func(e error) { + log.Println("ErrorHandler Received:", e) + } + echo := []byte("Hello\x00 world") + echomsg, err := c.Echo(echo) + if err != nil { + log.Printf("Error in Echo:", err) + } else { + log.Println("EchoMsg:", string(echomsg)) + } + + print_result := true + print_update := false + print_status := false + + jobHandler := func(resp *client.Response) { + switch resp.DataType { + case rt.PT_WorkException: + fallthrough + case rt.PT_WorkFail: + fallthrough + case rt.PT_WorkComplete: + if print_result { + if data, err := resp.Result(); err == nil { + log.Printf("RESULT: %v, string:%v\n", data, string(data)) + } else { + log.Printf("RESULT: %s\n", err) + } + } + case rt.PT_WorkWarning: + fallthrough + case rt.PT_WorkData: + if print_update { + if data, err := resp.Update(); err == nil { + log.Printf("UPDATE: %v\n", data) + } else { + log.Printf("UPDATE: %v, %s\n", data, err) + } + } + case rt.PT_WorkStatus: + if print_status { + if data, err := resp.Status(); err == nil { + log.Printf("STATUS: %v\n", data) + } else { + log.Printf("STATUS: %s\n", err) + } + } + default: + log.Printf("UNKNOWN: %v", resp.Data) + } + } + + log.Println("Press Ctrl-C to exit ...") + + for i := 0; ; i++ { + + if !c.IsConnectionSet() { + log.Printf("No active connection to server.. waiting...") + time.Sleep(5 * time.Second) + continue + } + + funcName := "ToUpper" + log.Println("Calling function", funcName, "with data:", echo) + handle, err := c.Do(funcName, echo, rt.JobNormal, jobHandler) + if err != nil { + log.Printf("Do %v ERROR:", funcName, err) + } + + log.Printf("Calling Status for handle %v", handle) + status, err := c.Status(handle) + if err != nil { + log.Printf("Status: %v, ERROR: %v", status, err) + } + + funcName = "Foobar" + log.Println("Calling function", funcName, "with data:", echo) + _, err = c.Do(funcName, echo, rt.JobNormal, jobHandler) + if err != nil { + log.Printf("Do %v ERROR:", funcName, err) + } + var sleep_seconds int = 0 + log.Printf("Finished Cycle %v, sleeping %v seconds", i, sleep_seconds) + time.Sleep(time.Duration(sleep_seconds) * time.Second) + + } + +} diff --git a/example/worker/worker.go b/example/worker/worker.go index fb0429d..3a46baf 100644 --- a/example/worker/worker.go +++ b/example/worker/worker.go @@ -7,8 +7,8 @@ import ( "strings" "time" - "github.com/quantcast/g2/worker" "github.com/mikespook/golib/signal" + "github.com/quantcast/g2/worker" ) func ToUpper(job worker.Job) ([]byte, error) { @@ -34,13 +34,30 @@ func Foobar(job worker.Job) ([]byte, error) { return job.Data(), nil } +func logHandler(level worker.LogLevel, message ...string) { + switch level { + case worker.Error: + log.Println("Error:", message) + case worker.Warning: + log.Println("Warning", message) + case worker.Info: + log.Println("Info:", message) + case worker.Debug: + log.Println("Debug", message) + } +} + func main() { log.Println("Starting ...") defer log.Println("Shutdown complete!") + w := worker.New(worker.Unlimited) + w.SetLogHandler(logHandler) + defer w.Close() w.ErrorHandler = func(e error) { - log.Println(e) + log.Println("ErrorHandler Received:", e) + if opErr, ok := e.(*net.OpError); ok { if !opErr.Temporary() { proc, err := os.FindProcess(os.Getpid()) @@ -53,6 +70,7 @@ func main() { } } } + w.JobHandler = func(job worker.Job) error { log.Printf("Data=%s\n", job.Data()) return nil @@ -69,6 +87,12 @@ func main() { return } go w.Work() + + ticker := time.Tick(10 * time.Second) + for _ = range ticker { + activeJobs := w.GetActiveJobCount() + log.Printf("Current job count: %v", activeJobs) + } signal.Bind(os.Interrupt, func() uint { return signal.BreakExit }) signal.Wait() } diff --git a/go.mod b/go.mod index c6c03aa..b787c9e 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,9 @@ require ( github.com/appscode/go v0.0.0-20180628092646-df3c57fca2be github.com/appscode/pat v0.0.0-20170521084856-48ff78925b79 github.com/beorn7/perks v0.0.0-20160229213445-3ac7bf7a47d1 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect github.com/golang/glog v0.0.0-20141105023935-44145f04b68c // indirect + github.com/golang/protobuf v1.3.1 // indirect github.com/golang/snappy v0.0.0-20160529050041-d9eb7a3d35ec // indirect github.com/google/uuid v0.0.0-20171113160352-8c31c18f31ed // indirect github.com/inconshreveable/mousetrap v1.0.0 // indirect @@ -21,7 +23,6 @@ require ( github.com/spf13/pflag v1.0.1 github.com/stretchr/testify v1.3.0 github.com/syndtr/goleveldb v0.0.0-20180815032940-ae2bd5eed72d - golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4 // indirect gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect gopkg.in/mgo.v2 v2.0.0-20180705113604-9856a29383ce // indirect gopkg.in/robfig/cron.v2 v2.0.0-20150107220207-be2e0b0deed5 diff --git a/worker/agent.go b/worker/agent.go index c26076a..4c127ec 100644 --- a/worker/agent.go +++ b/worker/agent.go @@ -4,9 +4,13 @@ import ( "bufio" "bytes" "encoding/binary" - "io" + "errors" + "fmt" "net" "sync" + "sync/atomic" + "time" + "unsafe" rt "github.com/quantcast/g2/pkg/runtime" ) @@ -14,11 +18,13 @@ import ( // The agent of job server. type agent struct { sync.Mutex - conn net.Conn - rw *bufio.ReadWriter - worker *Worker - in chan []byte - net, addr string + reconnectState uint32 + conn net.Conn + connectionVersion uint32 + rw *bufio.ReadWriter + worker *Worker + in chan []byte + net, addr string } // Create the agent of job server. @@ -32,23 +38,15 @@ func newAgent(net, addr string, worker *Worker) (a *agent, err error) { return } -func (a *agent) Connect() (err error) { - a.Lock() - defer a.Unlock() - a.conn, err = net.Dial(a.net, a.addr) - if err != nil { - return - } - a.rw = bufio.NewReadWriter(bufio.NewReader(a.conn), - bufio.NewWriter(a.conn)) - go a.work() - return +func (a *agent) loadRw() *bufio.ReadWriter { + return (*bufio.ReadWriter)(atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&a.rw)))) } func (a *agent) work() { + a.worker.Log(Info, "Starting agent Work For:", a.addr) defer func() { - if err := recover(); err != nil { - a.worker.err(err.(error)) + if err := safeCastError(recover(), "panic in work()"); err != nil { + a.reconnectError(err) } }() @@ -56,65 +54,56 @@ func (a *agent) work() { var l int var err error var data, leftdata []byte - for { - if !a.worker.isShuttingDown() { - if data, err = a.read(); err != nil { - if opErr, ok := err.(*net.OpError); ok { - if opErr.Temporary() { - continue - } else { - a.disconnect_error(err) - // else - we're probably dc'ing due to a Close() - - break - } - - } else if err == io.EOF { - a.disconnect_error(err) - break - } - a.worker.err(err) - // If it is unexpected error and the connection wasn't - // closed by Gearmand, the agent should close the conection - // and reconnect to job server. - a.Close() - a.conn, err = net.Dial(a.net, a.addr) - if err != nil { - a.worker.err(err) + startRw := a.loadRw() + + // exit the loop if connection has been replaced because reconnect will launch a new work() thread + for startRw == a.loadRw() && !a.worker.isShuttingDown() { + + if data, err = a.read(); err != nil { + if opErr, ok := err.(*net.OpError); ok { + if opErr.Temporary() { + a.worker.Log(Info, "opErr.Temporary():", a.addr) + continue + } else { + a.worker.Log(Info, "got permanent network error with server:", a.addr, "comm thread exiting.") + a.reconnectError(err) + // else - we're probably dc'ing due to a Close() break } - a.rw = bufio.NewReadWriter(bufio.NewReader(a.conn), - bufio.NewWriter(a.conn)) + } else { + a.worker.Log(Info, "got error", err.Error(), "with server:", a.addr, "comm thread exiting...") + a.reconnectError(err) + break } - if len(leftdata) > 0 { // some data left for processing - data = append(leftdata, data...) - } - if len(data) < rt.MinPacketLength { // not enough data - leftdata = data - continue - } - for { - if inpack, l, err = decodeInPack(data); err != nil { - a.worker.err(err) - leftdata = data + } + if len(leftdata) > 0 { // some data left for processing + data = append(leftdata, data...) + } + if len(data) < rt.MinPacketLength { // not enough data + leftdata = data + continue + } + for { + if inpack, l, err = decodeInPack(data); err != nil { + a.reconnectError(err) + break + } else { + leftdata = nil + inpack.a = a + a.worker.in <- inpack + if len(data) == l { break - } else { - leftdata = nil - inpack.a = a - a.worker.in <- inpack - if len(data) == l { - break - } - if len(data) > l { - data = data[l:] - } + } + if len(data) > l { + data = data[l:] } } } + } } -func (a *agent) disconnect_error(err error) { +func (a *agent) reconnectError(err error) { if a.conn != nil { err = &WorkerDisconnectError{ err: err, @@ -122,53 +111,139 @@ func (a *agent) disconnect_error(err error) { } a.worker.err(err) } + a.Connect() } func (a *agent) Close() { + if a.conn == nil { + return + } a.Lock() defer a.Unlock() if a.conn != nil { - a.conn.Close() + _ = a.conn.Close() a.conn = nil } } -func (a *agent) Grab() { - a.Lock() - defer a.Unlock() - a.grab() +func (a *agent) Grab() (err error) { + if a.conn == nil { + return errors.New("No connection") + } + return a.grab() } -func (a *agent) grab() { +func (a *agent) grab() error { outpack := getOutPack() outpack.dataType = rt.PT_GrabJobUniq - a.write(outpack) + return a.Write(outpack) } -func (a *agent) PreSleep() { - a.Lock() - defer a.Unlock() +func (a *agent) PreSleep() (err error) { + if a.conn == nil { + return errors.New("No connection") + } outpack := getOutPack() outpack.dataType = rt.PT_PreSleep - a.write(outpack) + return a.Write(outpack) } -func (a *agent) reconnect() error { - a.Lock() - defer a.Unlock() - conn, err := net.Dial(a.net, a.addr) - if err != nil { - return err +func (a *agent) lockReconnect() (success bool) { + return atomic.CompareAndSwapUint32(&a.reconnectState, 0, 1) +} + +// called by owner of reconnect state to tell that it has finished reconnecting +func (a *agent) resetReconnectState() { + atomic.StoreUint32(&a.reconnectState, 0) +} + +func (a *agent) Connect() { + + ownReconnect := a.lockReconnect() + + if !ownReconnect { + //Reconnect collision, this thread will exit and wait on next a.Lock() for other to complete reconnection + return } - a.conn = conn - a.rw = bufio.NewReadWriter(bufio.NewReader(a.conn), - bufio.NewWriter(a.conn)) + defer a.resetReconnectState() // before releasing client lock we will reset reconnection state + + a.worker.Log(Info, "Trying to Connect to server:", a.addr, "...") + + var conn net.Conn + var err error - a.worker.reRegisterFuncsForAgent(a) - a.grab() + for !a.worker.isShuttingDown() { + for numTries := 1; !a.worker.isShuttingDown(); numTries++ { - go a.work() - return nil + if a.conn != nil { + _ = a.conn.Close() + a.conn = nil + } + + // nil-out the rw pointer since it's no longer valid + _ = atomic.SwapPointer((*unsafe.Pointer)(unsafe.Pointer(&a.rw)), nil) + + if numTries%100 == 0 { + a.worker.Log(Info, fmt.Sprintf("Still trying to connect to server %v, attempt# %v ...", a.addr, numTries)) + } + conn, err = net.Dial(a.net, a.addr) + if err != nil { + time.Sleep(500 * time.Millisecond) + continue + } + + break + } + + if conn == nil { + // in case worker is shutting down + break + } + // at this point the server is back online, we will disconnect and reconnect again to make sure that we don't have + // one of those dud connections which could happen if we've reconnected to gearman too quickly after it started + _ = conn.Close() + time.Sleep(3 * time.Second) + + // todo: come up with a more reliable way to determine if we have a working connection to gearman, pehaps by performing a test + conn, err = net.Dial(a.net, a.addr) + if err != nil { + // looks like there is another problem, go back to the main loop + time.Sleep(time.Second) + continue + } + + a.conn = conn + a.connectionVersion++ + + a.worker.Log(Info, "Successfully Connected to:", a.addr) + + newRw := bufio.NewReadWriter(bufio.NewReader(a.conn), bufio.NewWriter(a.conn)) + + if swapped := atomic.CompareAndSwapPointer( + (*unsafe.Pointer)(unsafe.Pointer(&a.rw)), + unsafe.Pointer(nil), unsafe.Pointer(newRw)); !swapped { + a.worker.Log(Warning, fmt.Sprintf("Was expecting nil when replacing with new ReadWriter, server: %v", a.addr)) + } + + if err := a.worker.reRegisterFuncsForAgent(a); err != nil { + a.worker.Log(Error, fmt.Sprintf("Failed to register funcs for agent, error=%v, will reconnect...", err)) + continue + } + + if err := a.grab(); err != nil { + a.worker.Log(Error, fmt.Sprintf("Failed to request a new job assignment, error=%v, will reconnect", err)) + continue + } + + // only threads are a.work() and a.Work(), + // a.work() is exited when connectionVersion is incremented + // a.Work() does not exit because it uses an anonymous function to process writes + go a.work() + + break + } + + return } // read length bytes from the socket @@ -178,8 +253,9 @@ func (a *agent) read() (data []byte, err error) { tmp := rt.NewBuffer(rt.BufferSize) var buf bytes.Buffer + myRw := a.loadRw() // read the header so we can get the length of the data - if n, err = a.rw.Read(tmp); err != nil { + if n, err = myRw.Read(tmp); err != nil { return } dl := int(binary.BigEndian.Uint32(tmp[8:12])) @@ -189,10 +265,9 @@ func (a *agent) read() (data []byte, err error) { // read until we receive all the data for buf.Len() < dl+rt.MinPacketLength { - if n, err = a.rw.Read(tmp); err != nil { + if n, err = myRw.Read(tmp); err != nil { return buf.Bytes(), err } - buf.Write(tmp[:n]) } @@ -200,21 +275,22 @@ func (a *agent) read() (data []byte, err error) { } // Internal write the encoded job. -func (a *agent) write(outpack *outPack) (err error) { +func (a *agent) Write(outpack *outPack) (err error) { + + myRw := a.loadRw() + if myRw == nil { + return errors.New("Reconnect is active, discarding the response") + } + a.Lock() + defer a.Unlock() + var n int buf := outpack.Encode() for i := 0; i < len(buf); i += n { - n, err = a.rw.Write(buf[i:]) + n, err = myRw.Write(buf[i:]) if err != nil { return err } } - return a.rw.Flush() -} - -// Write with lock -func (a *agent) Write(outpack *outPack) (err error) { - a.Lock() - defer a.Unlock() - return a.write(outpack) + return myRw.Flush() } diff --git a/worker/error.go b/worker/error.go index b65a3cd..47012fd 100644 --- a/worker/error.go +++ b/worker/error.go @@ -26,3 +26,13 @@ func getError(data []byte) (err error) { // An error handler type ErrorHandler func(error) + +func safeCastError(e interface{}, defaultMessage string) error { + if e == nil { + return nil + } + if err, ok := e.(error); ok { + return err + } + return errors.New(defaultMessage) +} diff --git a/worker/inpack.go b/worker/inpack.go index 0306253..26d867f 100644 --- a/worker/inpack.go +++ b/worker/inpack.go @@ -55,7 +55,7 @@ func (inpack *inPack) SendData(data []byte) { outpack.data = rt.NewBuffer(l) copy(outpack.data, []byte(inpack.handle)) copy(outpack.data[hl+1:], data) - inpack.a.write(outpack) + inpack.a.Write(outpack) } func (inpack *inPack) SendWarning(data []byte) { @@ -66,7 +66,7 @@ func (inpack *inPack) SendWarning(data []byte) { outpack.data = rt.NewBuffer(l) copy(outpack.data, []byte(inpack.handle)) copy(outpack.data[hl+1:], data) - inpack.a.write(outpack) + inpack.a.Write(outpack) } // Update status. @@ -83,7 +83,7 @@ func (inpack *inPack) UpdateStatus(numerator, denominator int) { copy(outpack.data, []byte(inpack.handle)) copy(outpack.data[hl+1:], n) copy(outpack.data[hl+nl+2:], d) - inpack.a.write(outpack) + inpack.a.Write(outpack) } // Decode job from byte slice diff --git a/worker/worker.go b/worker/worker.go index e23ee4e..e929463 100644 --- a/worker/worker.go +++ b/worker/worker.go @@ -6,6 +6,7 @@ import ( "encoding/binary" "fmt" "sync" + "sync/atomic" "time" rt "github.com/quantcast/g2/pkg/runtime" @@ -16,6 +17,17 @@ const ( OneByOne ) +type LogLevel int + +const ( + Error LogLevel = 0 + Warning LogLevel = 1 + Info LogLevel = 2 + Debug LogLevel = 3 +) + +type LogHandler func(level LogLevel, message ...string) + // Worker is the only structure needed by worker side developing. // It can connect to multi-server and grab jobs. type Worker struct { @@ -28,15 +40,24 @@ type Worker struct { // The shuttingDown variable is protected by the Worker lock shuttingDown bool // Used during shutdown to wait for all active jobs to finish - activeJobs sync.WaitGroup - - // once protects registering jobs multiple times - once sync.Once + activeJobs sync.WaitGroup + activeJobsCount int32 Id string ErrorHandler ErrorHandler JobHandler JobHandler limit chan bool + logHandler LogHandler +} + +func (worker *Worker) Log(level LogLevel, message ...string) { + if worker.logHandler != nil { + worker.logHandler(level, message...) + } +} + +func (worker *Worker) GetActiveJobCount() int32 { + return atomic.LoadInt32(&worker.activeJobsCount) } // Return a worker. @@ -48,9 +69,10 @@ type Worker struct { // OneByOne(=1), there will be only one job executed in a time. func New(limit int) (worker *Worker) { worker = &Worker{ - agents: make([]*agent, 0, limit), - funcs: make(jobFuncs), - in: make(chan *inPack, rt.QueueSize), + agents: make([]*agent, 0, limit), + funcs: make(jobFuncs), + in: make(chan *inPack, rt.QueueSize), + logHandler: nil, } if limit != Unlimited { worker.limit = make(chan bool, limit-1) @@ -58,6 +80,10 @@ func New(limit int) (worker *Worker) { return } +func (worker *Worker) SetLogHandler(logHandler LogHandler) { + worker.logHandler = logHandler +} + // inner error handling func (worker *Worker) err(e error) { if worker.ErrorHandler != nil { @@ -81,7 +107,7 @@ func (worker *Worker) AddServer(net, addr string) (err error) { // Broadcast an outpack to all Gearman server. func (worker *Worker) broadcast(outpack *outPack) { for _, v := range worker.agents { - v.write(outpack) + v.Write(outpack) } } @@ -149,22 +175,23 @@ func (worker *Worker) removeFunc(funcname string) { func (worker *Worker) handleInPack(inpack *inPack) { switch inpack.dataType { case rt.PT_NoJob: - inpack.a.PreSleep() + _ = inpack.a.PreSleep() case rt.PT_Noop: if !worker.isShuttingDown() { - inpack.a.Grab() + _ = inpack.a.Grab() } case rt.PT_JobAssign, rt.PT_JobAssignUniq: go func() { if err := worker.exec(inpack); err != nil { - worker.err(err) + worker.Log(Error, fmt.Sprintf("ERROR %v in handleInPack(server: %v, job %v), discarding the results because cannot send them back to gearman", err, inpack.a.addr, inpack.handle)) + inpack.a.Connect() } }() if worker.limit != nil { worker.limit <- true } if !worker.isShuttingDown() { - inpack.a.Grab() + _ = inpack.a.Grab() } case rt.PT_Error: worker.err(inpack.Err()) @@ -186,17 +213,9 @@ func (worker *Worker) Ready() (err error) { return ErrNoneFuncs } for _, a := range worker.agents { - if err = a.Connect(); err != nil { - return - } + go a.Connect() } - // `once` protects registering worker functions multiple times. - worker.once.Do(func() { - for funcname, f := range worker.funcs { - worker.addFunc(funcname, f.timeout) - } - }) worker.ready = true return } @@ -213,9 +232,7 @@ func (worker *Worker) Work() { } worker.running = true - for _, a := range worker.agents { - a.Grab() - } + var inpack *inPack for inpack = range worker.in { worker.handleInPack(inpack) @@ -244,14 +261,12 @@ func (worker *Worker) Close() { } } -func (worker *Worker) Reconnect() error { +func (worker *Worker) ReconnectAllAgents() error { worker.Lock() defer worker.Unlock() if worker.running == true { for _, a := range worker.agents { - if err := a.reconnect(); err != nil { - return err - } + a.Connect() } } return nil @@ -287,9 +302,8 @@ func (worker *Worker) SetId(id string) { func (worker *Worker) exec(inpack *inPack) (err error) { defer func() { // decrement job counter in completion of this job - worker.Lock() worker.activeJobs.Done() - worker.Unlock() + atomic.AddInt32(&worker.activeJobsCount, -1) if worker.limit != nil { <-worker.limit } @@ -302,6 +316,7 @@ func (worker *Worker) exec(inpack *inPack) (err error) { } }() worker.activeJobs.Add(1) + atomic.AddInt32(&worker.activeJobsCount, 1) if worker.isShuttingDown() { return } @@ -329,21 +344,26 @@ func (worker *Worker) exec(inpack *inPack) (err error) { outpack.dataType = rt.PT_WorkException } err = r.err + if err != nil { + return + } } outpack.handle = inpack.handle outpack.data = r.data - inpack.a.Write(outpack) + err = inpack.a.Write(outpack) } return } -func (worker *Worker) reRegisterFuncsForAgent(a *agent) { +func (worker *Worker) reRegisterFuncsForAgent(a *agent) (err error) { worker.Lock() defer worker.Unlock() for funcname, f := range worker.funcs { outpack := prepFuncOutpack(funcname, f.timeout) - a.write(outpack) + if err := a.Write(outpack); err != nil { + return err + } } - + return } func (worker *Worker) Shutdown() { @@ -397,7 +417,8 @@ func (e *WorkerDisconnectError) Error() string { // Responds to the error by asking the worker to reconnect func (e *WorkerDisconnectError) Reconnect() (err error) { - return e.agent.reconnect() + e.agent.Connect() + return nil } // Which server was this for? diff --git a/worker/worker_disconnect_test.go b/worker/worker_disconnect_test.go index bb0f922..ff37637 100644 --- a/worker/worker_disconnect_test.go +++ b/worker/worker_disconnect_test.go @@ -225,7 +225,7 @@ func TestDcRc(t *testing.T) { } func send_client_request() { - c, err := client.New(rt.Network, "127.0.0.1:"+port) + c, err := client.New(rt.Network, "127.0.0.1:"+port, nil) if err == nil { _, err = c.DoBg("gearman-go-workertest", []byte{}, rt.JobHigh) if err != nil {