diff --git a/client/xclient.go b/client/xclient.go index 2b0676ef..cf19af4b 100644 --- a/client/xclient.go +++ b/client/xclient.go @@ -16,11 +16,12 @@ import ( "time" "github.com/juju/ratelimit" + "golang.org/x/sync/singleflight" + ex "github.com/smallnest/rpcx/errors" "github.com/smallnest/rpcx/log" "github.com/smallnest/rpcx/protocol" "github.com/smallnest/rpcx/share" - "golang.org/x/sync/singleflight" ) const ( @@ -944,6 +945,9 @@ func (c *xClient) Broadcast(ctx context.Context, serviceMethod string, args inte var replyOnce sync.Once ctx = setServerTimeout(ctx) + // add timeout after set server timeout, only prevent client hanging + ctx, cancel := context.WithTimeout(ctx, time.Minute) + defer cancel() callPlugins := make([]RPCClient, 0, len(c.servers)) clients := make(map[string]RPCClient) c.mu.Lock() @@ -982,7 +986,9 @@ func (c *xClient) Broadcast(ctx context.Context, serviceMethod string, args inte } e := c.wrapCall(ctx, client, serviceMethod, args, clonedReply) - done <- (e == nil) + defer func() { + done <- (e == nil) + }() if e != nil { if uncoverError(e) { c.removeClient(k, c.servicePath, serviceMethod, client) @@ -998,7 +1004,6 @@ func (c *xClient) Broadcast(ctx context.Context, serviceMethod string, args inte }() } - timeout := time.NewTimer(time.Minute) check: for { select { @@ -1007,12 +1012,14 @@ check: if l == 0 || !result { // all returns or some one returns an error break check } - case <-timeout.C: - err.Append(errors.New(("timeout"))) - break check } } - timeout.Stop() + + select { + case <-ctx.Done(): + err.Append(errors.New(("timeout"))) + default: + } return err.ErrorOrNil() } @@ -1035,6 +1042,10 @@ func (c *xClient) Fork(ctx context.Context, serviceMethod string, args interface } ctx = setServerTimeout(ctx) + + // add timeout after set server timeout, only prevent client hanging + ctx, cancel := context.WithTimeout(ctx, time.Minute) + defer cancel() callPlugins := make([]RPCClient, 0, len(c.servers)) clients := make(map[string]RPCClient) c.mu.Lock() @@ -1080,7 +1091,9 @@ func (c *xClient) Fork(ctx context.Context, serviceMethod string, args interface reflect.ValueOf(reply).Elem().Set(reflect.ValueOf(clonedReply).Elem()) }) } - done <- (e == nil) + defer func() { + done <- (e == nil) + }() if e != nil { if uncoverError(e) { c.removeClient(k, c.servicePath, serviceMethod, client) @@ -1090,7 +1103,6 @@ func (c *xClient) Fork(ctx context.Context, serviceMethod string, args interface }() } - timeout := time.NewTimer(time.Minute) check: for { select { @@ -1102,13 +1114,14 @@ check: if l == 0 { // all returns or some one returns an error break check } - - case <-timeout.C: - err.Append(errors.New(("timeout"))) - break check } } - timeout.Stop() + + select { + case <-ctx.Done(): + err.Append(errors.New(("timeout"))) + default: + } return err.ErrorOrNil() } @@ -1132,6 +1145,10 @@ func (c *xClient) Inform(ctx context.Context, serviceMethod string, args interfa } ctx = setServerTimeout(ctx) + + // add timeout after set server timeout, only prevent client hanging + ctx, cancel := context.WithTimeout(ctx, time.Minute) + defer cancel() callPlugins := make([]RPCClient, 0, len(c.servers)) clients := make(map[string]RPCClient) c.mu.Lock() @@ -1175,7 +1192,9 @@ func (c *xClient) Inform(ctx context.Context, serviceMethod string, args interfa } e := c.wrapCall(ctx, client, serviceMethod, args, clonedReply) - done <- (e == nil) + defer func() { + done <- (e == nil) + }() if e != nil { if uncoverError(e) { c.removeClient(k, c.servicePath, serviceMethod, client) @@ -1204,7 +1223,6 @@ func (c *xClient) Inform(ctx context.Context, serviceMethod string, args interfa }() } - timeout := time.NewTimer(time.Minute) check: for { select { @@ -1213,12 +1231,14 @@ check: if l == 0 { // all returns or some one returns an error break check } - case <-timeout.C: - err.Append(errors.New(("timeout"))) - break check } } - timeout.Stop() + + select { + case <-ctx.Done(): + err.Append(errors.New(("timeout"))) + default: + } return receipts, err.ErrorOrNil() }