Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[relay] Code cleaning in message marshalling #3074

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions relay/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ func (c *Client) handShake() error {
return fmt.Errorf("validate version: %w", err)
}

msgType, err := messages.DetermineServerMessageType(buf[messages.SizeOfVersionByte:n])
msgType, err := messages.DetermineServerMessageType(buf[:n])
if err != nil {
c.log.Errorf("failed to determine message type: %s", err)
return err
Expand All @@ -317,7 +317,7 @@ func (c *Client) handShake() error {
return fmt.Errorf("unexpected message type")
}

addr, err := messages.UnmarshalAuthResponse(buf[messages.SizeOfProtoHeader:n])
addr, err := messages.UnmarshalAuthResponse(buf[:n])
if err != nil {
return err
}
Expand Down Expand Up @@ -348,24 +348,27 @@ func (c *Client) readLoop(relayConn net.Conn) {
c.log.Debugf("failed to read message from relay server: %s", errExit)
}
c.mu.Unlock()
c.bufPool.Put(bufPtr)
break
}

_, err := messages.ValidateVersion(buf[:n])
buf = buf[:n]

_, err := messages.ValidateVersion(buf)
if err != nil {
c.log.Errorf("failed to validate protocol version: %s", err)
c.bufPool.Put(bufPtr)
continue
}

msgType, err := messages.DetermineServerMessageType(buf[messages.SizeOfVersionByte:n])
msgType, err := messages.DetermineServerMessageType(buf)
if err != nil {
c.log.Errorf("failed to determine message type: %s", err)
c.bufPool.Put(bufPtr)
continue
}

if !c.handleMsg(msgType, buf[messages.SizeOfProtoHeader:n], bufPtr, hc, internallyStoppedFlag) {
if !c.handleMsg(msgType, buf, bufPtr, hc, internallyStoppedFlag) {
break
}
}
Expand Down
99 changes: 50 additions & 49 deletions relay/messages/message.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,26 @@ const (
MsgTypeAuth = 6
MsgTypeAuthResponse = 7

SizeOfVersionByte = 1
SizeOfMsgType = 1

SizeOfProtoHeader = SizeOfVersionByte + SizeOfMsgType

sizeOfMagicByte = 4

headerSizeTransport = IDSize

// base size of the message
sizeOfVersionByte = 1
sizeOfMsgType = 1
sizeOfProtoHeader = sizeOfVersionByte + sizeOfMsgType

// auth message
sizeOfMagicByte = 4
headerSizeAuth = sizeOfMagicByte + IDSize
offsetMagicByte = sizeOfProtoHeader
offsetAuthPeerID = sizeOfProtoHeader + sizeOfMagicByte
headerTotalSizeAuth = sizeOfProtoHeader + headerSizeAuth

// hello message
headerSizeHello = sizeOfMagicByte + IDSize
headerSizeHelloResp = 0

headerSizeAuth = sizeOfMagicByte + IDSize
headerSizeAuthResp = 0
// transport
headerSizeTransport = IDSize
offsetTransportID = sizeOfProtoHeader
headerTotalSizeTransport = sizeOfProtoHeader + headerSizeTransport
)

var (
Expand Down Expand Up @@ -73,7 +79,7 @@ func (m MsgType) String() string {

// ValidateVersion checks if the given version is supported by the protocol
func ValidateVersion(msg []byte) (int, error) {
if len(msg) < SizeOfVersionByte {
if len(msg) < sizeOfProtoHeader {
return 0, ErrInvalidMessageLength
}
version := int(msg[0])
Expand All @@ -85,11 +91,11 @@ func ValidateVersion(msg []byte) (int, error) {

// DetermineClientMessageType determines the message type from the first the message
func DetermineClientMessageType(msg []byte) (MsgType, error) {
if len(msg) < SizeOfMsgType {
if len(msg) < sizeOfProtoHeader {
return 0, ErrInvalidMessageLength
}

msgType := MsgType(msg[0])
msgType := MsgType(msg[1])
switch msgType {
case
MsgTypeHello,
Expand All @@ -105,11 +111,11 @@ func DetermineClientMessageType(msg []byte) (MsgType, error) {

// DetermineServerMessageType determines the message type from the first the message
func DetermineServerMessageType(msg []byte) (MsgType, error) {
if len(msg) < SizeOfMsgType {
if len(msg) < sizeOfProtoHeader {
return 0, ErrInvalidMessageLength
}

msgType := MsgType(msg[0])
msgType := MsgType(msg[1])
switch msgType {
case
MsgTypeHelloResponse,
Expand All @@ -134,12 +140,12 @@ func MarshalHelloMsg(peerID []byte, additions []byte) ([]byte, error) {
return nil, fmt.Errorf("invalid peerID length: %d", len(peerID))
}

msg := make([]byte, SizeOfProtoHeader+sizeOfMagicByte, SizeOfProtoHeader+headerSizeHello+len(additions))
msg := make([]byte, sizeOfProtoHeader+sizeOfMagicByte, sizeOfProtoHeader+headerSizeHello+len(additions))

msg[0] = byte(CurrentProtocolVersion)
msg[1] = byte(MsgTypeHello)

copy(msg[SizeOfProtoHeader:SizeOfProtoHeader+sizeOfMagicByte], magicHeader)
copy(msg[sizeOfProtoHeader:sizeOfProtoHeader+sizeOfMagicByte], magicHeader)

msg = append(msg, peerID...)
msg = append(msg, additions...)
Expand All @@ -151,14 +157,14 @@ func MarshalHelloMsg(peerID []byte, additions []byte) ([]byte, error) {
// UnmarshalHelloMsg extracts peerID and the additional data from the hello message. The Additional data is used to
// authenticate the client with the server.
func UnmarshalHelloMsg(msg []byte) ([]byte, []byte, error) {
if len(msg) < headerSizeHello {
if len(msg) < sizeOfProtoHeader+headerSizeHello {
return nil, nil, ErrInvalidMessageLength
}
if !bytes.Equal(msg[:sizeOfMagicByte], magicHeader) {
if !bytes.Equal(msg[sizeOfProtoHeader:sizeOfProtoHeader+sizeOfMagicByte], magicHeader) {
return nil, nil, errors.New("invalid magic header")
}

return msg[sizeOfMagicByte:headerSizeHello], msg[headerSizeHello:], nil
return msg[sizeOfProtoHeader+sizeOfMagicByte : sizeOfProtoHeader+headerSizeHello], msg[headerSizeHello:], nil
}

// Deprecated: Use MarshalAuthResponse instead.
Expand All @@ -167,7 +173,7 @@ func UnmarshalHelloMsg(msg []byte) ([]byte, []byte, error) {
// instance URL. This URL will be used by choose the common Relay server in case if the peers are in different Relay
// servers.
func MarshalHelloResponse(additionalData []byte) ([]byte, error) {
msg := make([]byte, SizeOfProtoHeader, SizeOfProtoHeader+headerSizeHelloResp+len(additionalData))
msg := make([]byte, sizeOfProtoHeader, sizeOfProtoHeader+headerSizeHelloResp+len(additionalData))

msg[0] = byte(CurrentProtocolVersion)
msg[1] = byte(MsgTypeHelloResponse)
Expand All @@ -180,7 +186,7 @@ func MarshalHelloResponse(additionalData []byte) ([]byte, error) {
// Deprecated: Use UnmarshalAuthResponse instead.
// UnmarshalHelloResponse extracts the additional data from the hello response message.
func UnmarshalHelloResponse(msg []byte) ([]byte, error) {
if len(msg) < headerSizeHelloResp {
if len(msg) < sizeOfProtoHeader+headerSizeHelloResp {
return nil, ErrInvalidMessageLength
}
return msg, nil
Expand All @@ -196,12 +202,12 @@ func MarshalAuthMsg(peerID []byte, authPayload []byte) ([]byte, error) {
return nil, fmt.Errorf("invalid peerID length: %d", len(peerID))
}

msg := make([]byte, SizeOfProtoHeader+sizeOfMagicByte, SizeOfProtoHeader+headerSizeAuth+len(authPayload))
msg := make([]byte, sizeOfProtoHeader+sizeOfMagicByte, headerTotalSizeAuth+len(authPayload))

msg[0] = byte(CurrentProtocolVersion)
msg[1] = byte(MsgTypeAuth)

copy(msg[SizeOfProtoHeader:SizeOfProtoHeader+sizeOfMagicByte], magicHeader)
copy(msg[sizeOfProtoHeader:], magicHeader)

msg = append(msg, peerID...)
msg = append(msg, authPayload...)
Expand All @@ -211,14 +217,14 @@ func MarshalAuthMsg(peerID []byte, authPayload []byte) ([]byte, error) {

// UnmarshalAuthMsg extracts peerID and the auth payload from the message
func UnmarshalAuthMsg(msg []byte) ([]byte, []byte, error) {
if len(msg) < headerSizeAuth {
if len(msg) < headerTotalSizeAuth {
return nil, nil, ErrInvalidMessageLength
}
if !bytes.Equal(msg[:sizeOfMagicByte], magicHeader) {
if !bytes.Equal(msg[offsetMagicByte:offsetMagicByte+sizeOfMagicByte], magicHeader) {
return nil, nil, errors.New("invalid magic header")
}

return msg[sizeOfMagicByte:headerSizeAuth], msg[headerSizeAuth:], nil
return msg[offsetAuthPeerID:headerTotalSizeAuth], msg[headerTotalSizeAuth:], nil
}

// MarshalAuthResponse creates a response message to the auth.
Expand All @@ -227,7 +233,7 @@ func UnmarshalAuthMsg(msg []byte) ([]byte, []byte, error) {
// servers.
func MarshalAuthResponse(address string) ([]byte, error) {
ab := []byte(address)
msg := make([]byte, SizeOfProtoHeader, SizeOfProtoHeader+headerSizeAuthResp+len(ab))
msg := make([]byte, sizeOfProtoHeader, sizeOfProtoHeader+len(ab))

msg[0] = byte(CurrentProtocolVersion)
msg[1] = byte(MsgTypeAuthResponse)
Expand All @@ -243,69 +249,64 @@ func MarshalAuthResponse(address string) ([]byte, error) {

// UnmarshalAuthResponse it is a confirmation message to auth success
func UnmarshalAuthResponse(msg []byte) (string, error) {
if len(msg) < headerSizeAuthResp+1 {
if len(msg) < sizeOfProtoHeader+1 {
return "", ErrInvalidMessageLength
}
return string(msg), nil
return string(msg[sizeOfProtoHeader:]), nil
}

// MarshalCloseMsg creates a close message.
// The close message is used to close the connection gracefully between the client and the server. The server and the
// client can send this message. After receiving this message, the server or client will close the connection.
func MarshalCloseMsg() []byte {
msg := make([]byte, SizeOfProtoHeader)

msg[0] = byte(CurrentProtocolVersion)
msg[1] = byte(MsgTypeClose)

return msg
return []byte{
byte(CurrentProtocolVersion),
byte(MsgTypeClose),
}
}

// MarshalTransportMsg creates a transport message.
// The transport message is used to exchange data between peers. The message contains the data to be exchanged and the
// destination peer hashed ID.
func MarshalTransportMsg(peerID []byte, payload []byte) ([]byte, error) {
func MarshalTransportMsg(peerID, payload []byte) ([]byte, error) {
if len(peerID) != IDSize {
return nil, fmt.Errorf("invalid peerID length: %d", len(peerID))
}

msg := make([]byte, SizeOfProtoHeader+headerSizeTransport, SizeOfProtoHeader+headerSizeTransport+len(payload))

msg := make([]byte, headerTotalSizeTransport, headerTotalSizeTransport+len(payload))
msg[0] = byte(CurrentProtocolVersion)
msg[1] = byte(MsgTypeTransport)

copy(msg[SizeOfProtoHeader:], peerID)

copy(msg[sizeOfProtoHeader:], peerID)
msg = append(msg, payload...)

return msg, nil
}

// UnmarshalTransportMsg extracts the peerID and the payload from the transport message.
func UnmarshalTransportMsg(buf []byte) ([]byte, []byte, error) {
if len(buf) < headerSizeTransport {
if len(buf) < headerTotalSizeTransport {
return nil, nil, ErrInvalidMessageLength
}

return buf[:headerSizeTransport], buf[headerSizeTransport:], nil
return buf[offsetTransportID:headerTotalSizeTransport], buf[headerTotalSizeTransport:], nil
}

// UnmarshalTransportID extracts the peerID from the transport message.
func UnmarshalTransportID(buf []byte) ([]byte, error) {
if len(buf) < headerSizeTransport {
if len(buf) < headerTotalSizeTransport {
return nil, ErrInvalidMessageLength
}
return buf[:headerSizeTransport], nil
return buf[offsetTransportID:headerTotalSizeTransport], nil
}

// UpdateTransportMsg updates the peerID in the transport message.
// With this function the server can reuse the given byte slice to update the peerID in the transport message. So do
// need to allocate a new byte slice.
func UpdateTransportMsg(msg []byte, peerID []byte) error {
if len(msg) < len(peerID) {
if len(msg) < offsetTransportID+len(peerID) {
return ErrInvalidMessageLength
}
copy(msg, peerID)
copy(msg[offsetTransportID:], peerID)
return nil
}

Expand Down
Loading
Loading