From 468f052c8988bab54a1ad77cd191f256a0d7d0a7 Mon Sep 17 00:00:00 2001 From: Yilun Date: Wed, 1 Apr 2020 03:32:36 -0700 Subject: [PATCH] NewMultiClient will return immediately when at least one client is created Signed-off-by: Yilun --- examples/client/main.go | 2 + multiclient.go | 261 +++++++++++++++++++++------------------- 2 files changed, 138 insertions(+), 125 deletions(-) diff --git a/examples/client/main.go b/examples/client/main.go index 921d040e..89c6749a 100644 --- a/examples/client/main.go +++ b/examples/client/main.go @@ -42,6 +42,8 @@ func main() { defer toClient.Close() <-toClient.OnConnect.C + time.Sleep(time.Second) + timeSent := time.Now().UnixNano() / int64(time.Millisecond) var timeReceived int64 go func() { diff --git a/multiclient.go b/multiclient.go index 3726d21e..4a722f1e 100644 --- a/multiclient.go +++ b/multiclient.go @@ -6,7 +6,6 @@ import ( "fmt" "log" "net" - "reflect" "regexp" "sort" "strconv" @@ -36,18 +35,21 @@ var ( type MultiClient struct { config *ClientConfig offset int - Clients map[int]*Client - DefaultClient *Client addr *ClientAddr OnConnect *OnConnect OnMessage *OnMessage acceptSession chan *ncp.Session onClose chan struct{} + msgCache *cache.Cache sync.RWMutex - acceptAddrs []*regexp.Regexp + clients map[int]*Client + defaultClient *Client + acceptAddrs []*regexp.Regexp + isClosed bool + + sessionLock sync.Mutex sessions map[string]*ncp.Session - isClosed bool } func NewMultiClient(account *Account, baseIdentifier string, numSubClients int, originalClient bool, config *ClientConfig) (*MultiClient, error) { @@ -63,138 +65,144 @@ func NewMultiClient(account *Account, baseIdentifier string, numSubClients int, offset = 1 } - clients := make(map[int]*Client, numClients) - - var wg sync.WaitGroup - var lock sync.Mutex - success := false - for i := -offset; i < numSubClients; i++ { - wg.Add(1) - go func(i int) { - defer wg.Done() - client, err := NewClient(account, addIdentifier(baseIdentifier, i), config) - if err != nil { - log.Println(err) - return - } - lock.Lock() - clients[i] = client - success = true - lock.Unlock() - }(i) - } - wg.Wait() - if !success { - return nil, errors.New("failed to create any client") - } - - var defaultClient *Client - if originalClient { - defaultClient = clients[-1] - } else { - defaultClient = clients[0] - } - addr := address.MakeAddressString(account.PublicKey.EncodePoint(), baseIdentifier) - onConnect := NewOnConnect(1, nil) - go func() { - cases := make([]reflect.SelectCase, numClients) - for i := 0; i < numClients; i++ { - if clients[i-offset] != nil { - cases[i] = reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(clients[i-offset].OnConnect.C)} - } else { - cases[i] = reflect.SelectCase{Dir: reflect.SelectRecv} - } - } - if _, value, ok := reflect.Select(cases); ok { - nodeInfo := value.Interface().(*NodeInfo) - onConnect.receive(nodeInfo) - } - }() - m := &MultiClient{ config: config, offset: offset, - Clients: clients, - DefaultClient: defaultClient, addr: NewClientAddr(addr), - OnConnect: onConnect, + OnConnect: NewOnConnect(1, nil), OnMessage: NewOnMessage(int(config.MsgChanLen), nil), acceptSession: make(chan *ncp.Session, acceptSessionBufSize), - sessions: make(map[string]*ncp.Session, 0), onClose: make(chan struct{}, 0), + msgCache: cache.New(time.Duration(config.MsgCacheExpiration)*time.Millisecond, time.Duration(config.MsgCacheExpiration)*time.Millisecond), + clients: make(map[int]*Client, numClients), + defaultClient: nil, + sessions: make(map[string]*ncp.Session, 0), } - msgCache := cache.New(time.Duration(config.MsgCacheExpiration)*time.Millisecond, time.Duration(config.MsgCacheExpiration)*time.Millisecond) + var wg sync.WaitGroup + defaultClientIdx := numSubClients + success := make(chan struct{}, 1) + fail := make(chan struct{}, 1) - go func() { - cases := make([]reflect.SelectCase, numClients) - for i := 0; i < numClients; i++ { - if clients[i-offset] != nil { - cases[i] = reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(clients[i-offset].OnMessage.C)} - } else { - cases[i] = reflect.SelectCase{Dir: reflect.SelectRecv} + for i := -offset; i < numSubClients; i++ { + wg.Add(1) + go func(i int) { + client, err := NewClient(account, addIdentifier(baseIdentifier, i), config) + if err != nil { + log.Println(err) + wg.Done() + return } - } - for { + + m.Lock() + m.clients[i] = client + if i < defaultClientIdx { + m.defaultClient = client + defaultClientIdx = i + } + m.Unlock() + select { - case _, ok := <-m.onClose: - if !ok { - return - } + case success <- struct{}{}: default: } - if i, value, ok := reflect.Select(cases); ok { - msg := value.Interface().(*Message) - if msg.Type == SessionType { - if !msg.Encrypted { - continue - } - err := m.handleSessionMsg(addIdentifier("", i-offset), msg.Src, msg.MessageId, msg.Data) - if err != nil { - if err != ncp.ErrSessionClosed && err != errAddrNotAllowed { - log.Println(err) - } - continue - } - } else { - cacheKey := string(msg.MessageId) - if _, ok := msgCache.Get(cacheKey); ok { - continue - } - msgCache.Set(cacheKey, struct{}{}, cache.DefaultExpiration) - msg.Src, _ = removeIdentifier(msg.Src) - if msg.NoReply { - msg.reply = func(data interface{}) error { - return nil + wg.Done() + + nodeInfo := <-client.OnConnect.C + m.OnConnect.receive(nodeInfo) + + for { + select { + case msg := <-client.OnMessage.C: + if msg.Type == SessionType { + if !msg.Encrypted { + continue + } + err := m.handleSessionMsg(addIdentifier("", i-offset), msg.Src, msg.MessageId, msg.Data) + if err != nil { + if err != ncp.ErrSessionClosed && err != errAddrNotAllowed { + log.Println(err) + } + continue } } else { - msg.reply = func(data interface{}) error { - payload, err := newReplyPayload(data, msg.MessageId) - if err != nil { - return err + cacheKey := string(msg.MessageId) + if _, ok := m.msgCache.Get(cacheKey); ok { + continue + } + m.msgCache.Set(cacheKey, nil, cache.DefaultExpiration) + + msg.Src, _ = removeIdentifier(msg.Src) + if msg.NoReply { + msg.reply = func(data interface{}) error { + return nil } - if err := m.send([]string{msg.Src}, payload, msg.Encrypted, 0); err != nil { - return err + } else { + msg.reply = func(data interface{}) error { + payload, err := newReplyPayload(data, msg.MessageId) + if err != nil { + return err + } + if err := m.send([]string{msg.Src}, payload, msg.Encrypted, 0); err != nil { + return err + } + return nil } - return nil } + m.OnMessage.receive(msg, true) } - m.OnMessage.receive(msg, true) + case <-m.onClose: + return } } + }(i) + } + + go func() { + wg.Wait() + select { + case fail <- struct{}{}: + default: } }() - return m, nil + select { + case <-success: + return m, nil + case <-fail: + return nil, errors.New("failed to create any client") + } +} + +func (m *MultiClient) GetClients() map[int]*Client { + m.RLock() + defer m.RUnlock() + clients := make(map[int]*Client, len(m.clients)) + for i, client := range m.clients { + clients[i] = client + } + return clients +} + +func (m *MultiClient) GetClient(i int) *Client { + m.RLock() + defer m.RUnlock() + return m.clients[i] +} + +func (m *MultiClient) GetDefaultClient() *Client { + m.RLock() + defer m.RUnlock() + return m.defaultClient } func (m *MultiClient) SendWithClient(clientID int, dests *StringArray, data interface{}, config *MessageConfig) (*OnMessage, error) { - client, ok := m.Clients[clientID] - if !ok { - return nil, fmt.Errorf("clientID %d not found", clientID) + client := m.GetClient(clientID) + if client == nil { + return nil, fmt.Errorf("client %d is not created or not ready", clientID) } config, err := MergeMessageConfig(m.config.MessageConfig, config) @@ -238,9 +246,9 @@ func addMultiClientPrefix(dest []string, clientID int) []string { } func (m *MultiClient) sendWithClient(clientID int, dests []string, payload *payloads.Payload, encrypted bool, maxHoldingSeconds int32) error { - client, ok := m.Clients[clientID] - if !ok { - return fmt.Errorf("clientID %d not found", clientID) + client := m.GetClient(clientID) + if client == nil { + return fmt.Errorf("client %d is not created or not ready", clientID) } return client.send(addMultiClientPrefix(dests, clientID), payload, encrypted, maxHoldingSeconds) } @@ -260,11 +268,12 @@ func (m *MultiClient) Send(dests *StringArray, data interface{}, config *Message var errMsg []string var onRawReply *OnMessage onReply := NewOnMessage(1, nil) + clients := m.GetClients() if !config.NoReply { onRawReply = NewOnMessage(1, nil) // response channel is added first to prevent some client fail to handle response if send finish before receive response - for _, client := range m.Clients { + for _, client := range clients { client.responseChannels.Add(string(payload.MessageId), onRawReply, cache.DefaultExpiration) } } @@ -274,7 +283,7 @@ func (m *MultiClient) Send(dests *StringArray, data interface{}, config *Message go func() { sent := 0 - for clientID := range m.Clients { + for clientID := range clients { err := m.sendWithClient(clientID, dests.Elems, payload, !config.Unencrypted, config.MaxHoldingSeconds) if err == nil { select { @@ -327,7 +336,7 @@ func (m *MultiClient) send(dests []string, payload *payloads.Payload, encrypted fail := make(chan struct{}, 1) go func() { sent := 0 - for clientID := range m.Clients { + for clientID := range m.GetClients() { err := m.sendWithClient(clientID, dests, payload, encrypted, maxHoldingSeconds) if err == nil { select { @@ -372,9 +381,10 @@ func (m *MultiClient) PublishText(topic string, data string, config *MessageConf } func (m *MultiClient) newSession(remoteAddr string, sessionID []byte, config *ncp.Config) (*ncp.Session, error) { - clientIDs := make([]string, 0, len(m.Clients)) - clients := make(map[string]*Client, len(m.Clients)) - for id, client := range m.Clients { + rawClients := m.GetClients() + clientIDs := make([]string, 0, len(rawClients)) + clients := make(map[string]*Client, len(rawClients)) + for id, client := range rawClients { clientID := addIdentifier("", id) clientIDs = append(clientIDs, clientID) clients[clientID] = client @@ -409,21 +419,22 @@ func (m *MultiClient) handleSessionMsg(localClientID, src string, sessionID, dat sessionKey := sessionKey(remoteAddr, sessionID) var err error - m.Lock() + m.sessionLock.Lock() session, ok := m.sessions[sessionKey] if !ok { if !m.shouldAcceptAddr(remoteAddr) { + m.sessionLock.Unlock() return errAddrNotAllowed } session, err = m.newSession(remoteAddr, sessionID, m.config.SessionConfig) if err != nil { - m.Unlock() + m.sessionLock.Unlock() return err } m.sessions[sessionKey] = session } - m.Unlock() + m.sessionLock.Unlock() err = session.ReceiveWith(localClientID, remoteClientID, data) if err != nil { @@ -497,9 +508,9 @@ func (m *MultiClient) DialWithConfig(remoteAddr string, config *DialConfig) (*nc return nil, err } - m.Lock() + m.sessionLock.Lock() m.sessions[sessionKey] = session - m.Unlock() + m.sessionLock.Unlock() ctx := context.Background() var cancel context.CancelFunc @@ -526,10 +537,8 @@ func (m *MultiClient) AcceptSession() (*ncp.Session, error) { continue } return session, nil - case _, ok := <-m.onClose: - if !ok { - return nil, ErrClosed - } + case <-m.onClose: + return nil, ErrClosed } } } @@ -546,6 +555,7 @@ func (m *MultiClient) Close() error { return nil } + m.sessionLock.Lock() for _, session := range m.sessions { err := session.Close() if err != nil { @@ -553,9 +563,10 @@ func (m *MultiClient) Close() error { continue } } + m.sessionLock.Unlock() time.AfterFunc(time.Duration(m.config.SessionConfig.Linger)*time.Millisecond, func() { - for _, client := range m.Clients { + for _, client := range m.GetClients() { client.Close() } })