Skip to content

Commit

Permalink
Merge pull request smallnest#886 from rekyyang/fix/inform-chan
Browse files Browse the repository at this point in the history
fix: done channel for broadcast, fork and inform
  • Loading branch information
smallnest authored Jan 10, 2025
2 parents d9e2dbe + 5bf2c1a commit a09a362
Showing 1 changed file with 40 additions and 20 deletions.
60 changes: 40 additions & 20 deletions client/xclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@ import (
"time"

"github.com/juju/ratelimit"
"golang.org/x/sync/singleflight"

ex "github.com/smallnest/rpcx/errors"
"github.com/smallnest/rpcx/log"
"github.com/smallnest/rpcx/protocol"
"github.com/smallnest/rpcx/share"
"golang.org/x/sync/singleflight"
)

const (
Expand Down Expand Up @@ -944,6 +945,9 @@ func (c *xClient) Broadcast(ctx context.Context, serviceMethod string, args inte
var replyOnce sync.Once

ctx = setServerTimeout(ctx)
// add timeout after set server timeout, only prevent client hanging
ctx, cancel := context.WithTimeout(ctx, time.Minute)
defer cancel()
callPlugins := make([]RPCClient, 0, len(c.servers))
clients := make(map[string]RPCClient)
c.mu.Lock()
Expand Down Expand Up @@ -982,7 +986,9 @@ func (c *xClient) Broadcast(ctx context.Context, serviceMethod string, args inte
}

e := c.wrapCall(ctx, client, serviceMethod, args, clonedReply)
done <- (e == nil)
defer func() {
done <- (e == nil)
}()
if e != nil {
if uncoverError(e) {
c.removeClient(k, c.servicePath, serviceMethod, client)
Expand All @@ -998,7 +1004,6 @@ func (c *xClient) Broadcast(ctx context.Context, serviceMethod string, args inte
}()
}

timeout := time.NewTimer(time.Minute)
check:
for {
select {
Expand All @@ -1007,12 +1012,14 @@ check:
if l == 0 || !result { // all returns or some one returns an error
break check
}
case <-timeout.C:
err.Append(errors.New(("timeout")))
break check
}
}
timeout.Stop()

select {
case <-ctx.Done():
err.Append(errors.New(("timeout")))
default:
}

return err.ErrorOrNil()
}
Expand All @@ -1035,6 +1042,10 @@ func (c *xClient) Fork(ctx context.Context, serviceMethod string, args interface
}

ctx = setServerTimeout(ctx)

// add timeout after set server timeout, only prevent client hanging
ctx, cancel := context.WithTimeout(ctx, time.Minute)
defer cancel()
callPlugins := make([]RPCClient, 0, len(c.servers))
clients := make(map[string]RPCClient)
c.mu.Lock()
Expand Down Expand Up @@ -1080,7 +1091,9 @@ func (c *xClient) Fork(ctx context.Context, serviceMethod string, args interface
reflect.ValueOf(reply).Elem().Set(reflect.ValueOf(clonedReply).Elem())
})
}
done <- (e == nil)
defer func() {
done <- (e == nil)
}()
if e != nil {
if uncoverError(e) {
c.removeClient(k, c.servicePath, serviceMethod, client)
Expand All @@ -1090,7 +1103,6 @@ func (c *xClient) Fork(ctx context.Context, serviceMethod string, args interface
}()
}

timeout := time.NewTimer(time.Minute)
check:
for {
select {
Expand All @@ -1102,13 +1114,14 @@ check:
if l == 0 { // all returns or some one returns an error
break check
}

case <-timeout.C:
err.Append(errors.New(("timeout")))
break check
}
}
timeout.Stop()

select {
case <-ctx.Done():
err.Append(errors.New(("timeout")))
default:
}

return err.ErrorOrNil()
}
Expand All @@ -1132,6 +1145,10 @@ func (c *xClient) Inform(ctx context.Context, serviceMethod string, args interfa
}

ctx = setServerTimeout(ctx)

// add timeout after set server timeout, only prevent client hanging
ctx, cancel := context.WithTimeout(ctx, time.Minute)
defer cancel()
callPlugins := make([]RPCClient, 0, len(c.servers))
clients := make(map[string]RPCClient)
c.mu.Lock()
Expand Down Expand Up @@ -1175,7 +1192,9 @@ func (c *xClient) Inform(ctx context.Context, serviceMethod string, args interfa
}

e := c.wrapCall(ctx, client, serviceMethod, args, clonedReply)
done <- (e == nil)
defer func() {
done <- (e == nil)
}()
if e != nil {
if uncoverError(e) {
c.removeClient(k, c.servicePath, serviceMethod, client)
Expand Down Expand Up @@ -1204,7 +1223,6 @@ func (c *xClient) Inform(ctx context.Context, serviceMethod string, args interfa
}()
}

timeout := time.NewTimer(time.Minute)
check:
for {
select {
Expand All @@ -1213,12 +1231,14 @@ check:
if l == 0 { // all returns or some one returns an error
break check
}
case <-timeout.C:
err.Append(errors.New(("timeout")))
break check
}
}
timeout.Stop()

select {
case <-ctx.Done():
err.Append(errors.New(("timeout")))
default:
}

return receipts, err.ErrorOrNil()
}
Expand Down

0 comments on commit a09a362

Please sign in to comment.