diff --git a/client.go b/client.go index 751a5bde..a7da60a9 100644 --- a/client.go +++ b/client.go @@ -66,7 +66,7 @@ type Client struct { responseChannels *cache.Cache lock sync.RWMutex - closed bool + isClosed bool conn *websocket.Conn node *Node sharedKeys map[string]*[sharedKeySize]byte @@ -161,20 +161,28 @@ func (c *Client) Address() string { func (c *Client) IsClosed() bool { c.lock.RLock() defer c.lock.RUnlock() - return c.closed + return c.isClosed } // Close closes the client. -func (c *Client) Close() { +func (c *Client) Close() error { c.lock.Lock() defer c.lock.Unlock() - if !c.closed { - c.closed = true - close(c.OnConnect.C) - close(c.OnMessage.C) - close(c.reconnectChan) - c.conn.Close() + + if c.isClosed { + return nil } + + c.isClosed = true + + c.OnConnect.close() + c.OnMessage.close() + + close(c.reconnectChan) + + c.conn.Close() + + return nil } // GetNode returns the node that client is currently connected to. @@ -395,7 +403,7 @@ func (c *Client) handleMessage(msgType int, data []byte) error { c.lock.RLock() defer c.lock.RUnlock() - if c.closed { + if c.isClosed { return nil } @@ -500,7 +508,7 @@ func (c *Client) handleMessage(msgType int, data []byte) error { c.lock.RLock() defer c.lock.RUnlock() - if c.closed { + if c.isClosed { return nil } diff --git a/multiclient.go b/multiclient.go index 503c45d1..6e628dbf 100644 --- a/multiclient.go +++ b/multiclient.go @@ -159,7 +159,14 @@ func NewMultiClient(account *Account, baseIdentifier string, numSubClients int, if !ok { return } + + m.lock.RLock() + if m.isClosed { + m.lock.RUnlock() + return + } m.OnConnect.receive(node) + m.lock.RUnlock() for { select { @@ -202,7 +209,14 @@ func NewMultiClient(account *Account, baseIdentifier string, numSubClients int, return nil } } + + m.lock.RLock() + if m.isClosed { + m.lock.RUnlock() + return + } m.OnMessage.receive(msg, true) + m.lock.RUnlock() } case <-m.onClose: return @@ -679,12 +693,19 @@ func (m *MultiClient) Close() error { time.AfterFunc(time.Duration(m.config.SessionConfig.Linger)*time.Millisecond, func() { for _, client := range m.GetClients() { - client.Close() + err := client.Close() + if err != nil { + log.Println(err) + continue + } } }) m.isClosed = true + m.OnConnect.close() + m.OnMessage.close() + close(m.onClose) return nil diff --git a/util.go b/util.go index 7b360835..a119b777 100644 --- a/util.go +++ b/util.go @@ -212,6 +212,10 @@ func (c *OnConnect) receive(node *Node) { } } +func (c *OnConnect) close() { + close(c.C) +} + // OnMessageFunc is a wrapper type for gomobile compatibility. type OnMessageFunc interface{ OnMessage(*Message) } @@ -249,6 +253,10 @@ func (c *OnMessage) receive(msg *Message, verbose bool) { } } +func (c *OnMessage) close() { + close(c.C) +} + // OnErrorFunc is a wrapper type for gomobile compatibility. type OnErrorFunc interface{ OnError(error) } @@ -284,6 +292,10 @@ func (c *OnError) receive(err error) { } } +func (c *OnError) close() { + close(c.C) +} + // ClientAddr represents NKN client address. It implements net.Addr interface. type ClientAddr struct { addr string