Skip to content

Commit

Permalink
refactor: workerpool logic (#40)
Browse files Browse the repository at this point in the history
  • Loading branch information
ramzeng authored Jul 29, 2024
1 parent 111e331 commit 8695261
Show file tree
Hide file tree
Showing 15 changed files with 116 additions and 104 deletions.
5 changes: 1 addition & 4 deletions README-CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,9 @@ func main() {

server := ramix.NewServer(
ramix.WithPort(8899),
// 协程池模式
// 如果你不想使用协程池,可以删除以下两行
ramix.UseWorkerPool(),
ramix.WithWorkersCount(100),
)

server.UseWorkerPool(ramix.NewRoundRobinWorkerPool(100, 1024))
server.Use(ramix.Recovery(), ramix.Logger())

server.RegisterRoute(0, func(context *ramix.Context) {
Expand Down
5 changes: 1 addition & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,9 @@ func main() {

server := ramix.NewServer(
ramix.WithPort(8899),
// worker pool mode
// if you don't want to use worker pool, you can remove the following two lines
ramix.UseWorkerPool(),
ramix.WithWorkersCount(100),
)

server.UseWorkerPool(ramix.NewRoundRobinWorkerPool(100, 1024))
server.Use(ramix.Recovery(), ramix.Logger())

server.RegisterRoute(0, func(context *ramix.Context) {
Expand Down
4 changes: 2 additions & 2 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ type Connection interface {
close(syncConnectionManager bool)
refreshLastActiveTime()
isAlive() bool
pushTask(ctx *Context)
submitTask(ctx *Context)
}

type netConnection struct {
Expand All @@ -43,7 +43,7 @@ func (c *netConnection) isAlive() bool {
return !c.isClosed && c.lastActiveTime.Add(c.server.HeartbeatTimeout).After(time.Now())
}

func (c *netConnection) pushTask(ctx *Context) {
func (c *netConnection) submitTask(ctx *Context) {
c.worker.tasks <- ctx
}

Expand Down
8 changes: 8 additions & 0 deletions context.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,11 @@ func (c *Context) Get(key string) any {

return c.keys[key]
}

func newContext(connection Connection, request *Request) *Context {
return &Context{
Connection: connection,
Request: request,
step: -1,
}
}
3 changes: 1 addition & 2 deletions examples/barrage/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,9 @@ func main() {

server := ramix.NewServer(
ramix.WithPort(8899),
ramix.UseWorkerPool(),
ramix.WithWorkersCount(100),
)

server.UseWorkerPool(ramix.NewRoundRobinWorkerPool(100, 1024))
server.Use(ramix.Recovery(), ramix.Logger())

server.OnConnectionOpen(func(connection ramix.Connection) {
Expand Down
3 changes: 1 addition & 2 deletions examples/startup/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,9 @@ func main() {

server := ramix.NewServer(
ramix.WithPort(8899),
ramix.UseWorkerPool(),
ramix.WithWorkersCount(100),
)

server.UseWorkerPool(ramix.NewRoundRobinWorkerPool(1000, 1024))
server.Use(ramix.Recovery(), ramix.Logger())

server.RegisterRoute(0, func(context *ramix.Context) {
Expand Down
3 changes: 1 addition & 2 deletions examples/tls/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,11 @@ func main() {

server := ramix.NewServer(
ramix.WithPort(8899),
ramix.UseWorkerPool(),
ramix.WithWorkersCount(100),
ramix.WithCertFile("./public_certificate.pem"),
ramix.WithPrivateKeyFile("./private_key.pem"),
)

server.UseWorkerPool(ramix.NewRoundRobinWorkerPool(100, 1024))
server.Use(ramix.Recovery(), ramix.Logger())

server.RegisterRoute(0, func(context *ramix.Context) {
Expand Down
3 changes: 1 addition & 2 deletions examples/websocket/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,9 @@ func main() {

server := ramix.NewServer(
ramix.WithPort(8899),
ramix.UseWorkerPool(),
ramix.WithWorkersCount(100),
)

server.UseWorkerPool(ramix.NewRoundRobinWorkerPool(100, 1024))
server.Use(ramix.Recovery(), ramix.Logger())

server.RegisterRoute(0, func(context *ramix.Context) {
Expand Down
16 changes: 0 additions & 16 deletions option.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@ var defaultServerOptions = ServerOptions{
ConnectionGroupsCount: 10,
ConnectionReadBufferSize: 1024,
ConnectionWriteBufferSize: 1024,
UseWorkerPool: false,
WorkersCount: 10,
MaxWorkerTasksCount: 1024,
HeartbeatInterval: 5 * time.Second,
HeartbeatTimeout: 60 * time.Second,
Expand All @@ -37,8 +35,6 @@ type ServerOptions struct {
ConnectionGroupsCount int
ConnectionReadBufferSize uint32
ConnectionWriteBufferSize uint32
UseWorkerPool bool // true: all connections share a worker pool, false: each connection has a worker
WorkersCount uint32
MaxWorkerTasksCount uint32
HeartbeatInterval time.Duration
HeartbeatTimeout time.Duration
Expand Down Expand Up @@ -130,18 +126,6 @@ func WithConnectionWriteBufferSize(connectionWriteBufferSize uint32) ServerOptio
}
}

func UseWorkerPool() ServerOption {
return func(o *ServerOptions) {
o.UseWorkerPool = true
}
}

func WithWorkersCount(workersCount uint32) ServerOption {
return func(o *ServerOptions) {
o.WorkersCount = workersCount
}
}

func WithMaxWorkerTasksCount(maxTasksCount uint32) ServerOption {
return func(o *ServerOptions) {
o.MaxWorkerTasksCount = maxTasksCount
Expand Down
9 changes: 0 additions & 9 deletions option_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,6 @@ func TestWithMaxTasksCount(t *testing.T) {
}
}

func TestWithWorkersCount(t *testing.T) {
serverOptions := defaultServerOptions
serverOption := WithWorkersCount(2048)
serverOption(&serverOptions)
if serverOptions.WorkersCount != 2048 {
t.Error("serverOptions.WorkersCount should be 2048")
}
}

func TestWithHeartbeatInterval(t *testing.T) {
serverOptions := defaultServerOptions
serverOption := WithHeartbeatInterval(10 * time.Second)
Expand Down
98 changes: 45 additions & 53 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,23 @@ import (
type Server struct {
ServerOptions
*routeGroup
upgrader *websocket.Upgrader
connectionID uint64
ctx context.Context
cancel context.CancelFunc
router *router
workers []*worker
decoder DecoderInterface
encoder EncoderInterface
heartbeatChecker *heartbeatChecker
connectionManager *connectionManager
connectionOpen func(connection Connection)
connectionClose func(connection Connection)
upgrader *websocket.Upgrader
currentConnectionID uint64
ctx context.Context
cancel context.CancelFunc
router *router
workerPool WorkerPool
decoder DecoderInterface
encoder EncoderInterface
heartbeatChecker *heartbeatChecker
connectionManager *connectionManager
connectionOpen func(connection Connection)
connectionClose func(connection Connection)
}

func (s *Server) Serve() {
s.selfCheck()

if s.UseWorkerPool {
s.startWorkers()
if s.UsingWorkerPool() {
s.startWorkerPool()
}

switch {
Expand Down Expand Up @@ -66,9 +64,9 @@ func (s *Server) listenWebSocket() {
return
}

atomic.AddUint64(&s.connectionID, 1)
atomic.AddUint64(&s.currentConnectionID, 1)

go s.openWebSocketConnection(socket, s.connectionID)
go s.openWebSocketConnection(socket, s.currentConnectionID)
})

debug("WebSocket server is starting on %s:%d", s.IP, s.WebSocketPort)
Expand Down Expand Up @@ -136,9 +134,9 @@ func (s *Server) listenTCP() {
continue
}

atomic.AddUint64(&s.connectionID, 1)
atomic.AddUint64(&s.currentConnectionID, 1)

go s.openTCPConnection(socket, s.connectionID)
go s.openTCPConnection(socket, s.currentConnectionID)
}
}
}
Expand All @@ -157,8 +155,8 @@ func (s *Server) stop() <-chan struct{} {

s.cancel()

s.stopWorkers()
s.connectionManager.clearConnections()
s.stopWorkerPool()
s.clearConnections()

debug("Server stopped")

Expand All @@ -185,20 +183,17 @@ func (s *Server) monitor() {
}
}

func (s *Server) startWorkers() {
s.workers = make([]*worker, s.WorkersCount)
func (s *Server) startWorkerPool() {
s.workerPool.init()
s.workerPool.start()
}

for i := 0; i < int(s.WorkersCount); i++ {
w := newWorker(i, s.MaxWorkerTasksCount, s.ctx)
w.start()
s.workers[i] = w
}
func (s *Server) stopWorkerPool() {
s.workerPool.stop()
}

func (s *Server) stopWorkers() {
for _, w := range s.workers {
w.stop()
}
func (s *Server) clearConnections() {
s.connectionManager.clearConnections()
}

func (s *Server) openWebSocketConnection(socket *websocket.Conn, connectionID uint64) {
Expand All @@ -211,10 +206,8 @@ func (s *Server) openWebSocketConnection(socket *websocket.Conn, connectionID ui

connection.ctx, connection.cancel = context.WithCancel(context.Background())

if s.UseWorkerPool {
c.worker = s.workers[connectionID%uint64(s.WorkersCount)]
} else {
w := newWorker(int(connectionID), s.MaxWorkerTasksCount, connection.ctx)
if !s.UsingWorkerPool() {
w := newWorker(int(connectionID), s.MaxWorkerTasksCount)
w.start()
c.worker = w
}
Expand All @@ -240,10 +233,8 @@ func (s *Server) openTCPConnection(socket net.Conn, connectionID uint64) {

connection.ctx, connection.cancel = context.WithCancel(context.Background())

if s.UseWorkerPool {
c.worker = s.workers[connectionID%uint64(s.WorkersCount)]
} else {
w := newWorker(int(connectionID), s.MaxWorkerTasksCount, connection.ctx)
if !s.UsingWorkerPool() {
w := newWorker(int(connectionID), s.MaxWorkerTasksCount)
w.start()
c.worker = w
}
Expand All @@ -256,15 +247,11 @@ func (s *Server) openTCPConnection(socket net.Conn, connectionID uint64) {

connection.open()

debug("TCPConnection %d opened, worker %d assigned", connection.ID(), connection.worker.id)
debug("TCPConnection %d opened", connection.ID())
}

func (s *Server) handleRequest(connection Connection, request *Request) {
ctx := &Context{
Connection: connection,
Request: request,
step: -1,
}
ctx := newContext(connection, request)

if handlers, ok := s.router.routes[ctx.Request.Message.Event]; ok {
ctx.handlers = append(ctx.handlers, handlers...)
Expand All @@ -274,8 +261,11 @@ func (s *Server) handleRequest(connection Connection, request *Request) {
})
}

// push task to logic worker
connection.pushTask(ctx)
if s.UsingWorkerPool() {
s.workerPool.submitTask(ctx)
} else {
connection.submitTask(ctx)
}
}

func (s *Server) OnConnectionOpen(callback func(connection Connection)) {
Expand All @@ -286,10 +276,12 @@ func (s *Server) OnConnectionClose(callback func(connection Connection)) {
s.connectionClose = callback
}

func (s *Server) selfCheck() {
if s.UseWorkerPool && s.WorkersCount <= 0 {
panic("Workers count must be greater than 0")
}
func (s *Server) UseWorkerPool(workerPool WorkerPool) {
s.workerPool = workerPool
}

func (s *Server) UsingWorkerPool() bool {
return s.workerPool != nil
}

func NewServer(serverOptions ...ServerOption) *Server {
Expand Down
2 changes: 1 addition & 1 deletion tcp_connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ func (c *TCPConnection) close(syncConnectionManager bool) {
}

// If the worker pool is not used, need to stop the worker by self
if !c.server.UseWorkerPool {
if !c.server.UsingWorkerPool() {
c.worker.stop()
}

Expand Down
17 changes: 11 additions & 6 deletions worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@ package ramix
import "context"

type worker struct {
id int
tasks chan *Context
ctx context.Context
id int
tasks chan *Context
ctx context.Context
cancel context.CancelFunc
}

func (w *worker) start() {
Expand All @@ -32,13 +33,17 @@ func (w *worker) start() {
}

func (w *worker) stop() {
w.cancel()
close(w.tasks)
}

func newWorker(workerID int, maxTasksCount uint32, ctx context.Context) *worker {
return &worker{
func newWorker(workerID int, maxTasksCount uint32) *worker {
w := &worker{
id: workerID,
tasks: make(chan *Context, maxTasksCount),
ctx: ctx,
}

w.ctx, w.cancel = context.WithCancel(context.Background())

return w
}
Loading

0 comments on commit 8695261

Please sign in to comment.