From a8a2c382ee521da362ae65cb37c2a6bf00a36d9d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Putra?= Date: Wed, 14 Sep 2022 11:21:21 +0200 Subject: [PATCH] frame: cqlvalue, added missing length checks in AsX methods Fixes #277 --- frame/cqlvalue.go | 46 ++++++++++++++++++++++++++++++++++------------ 1 file changed, 34 insertions(+), 12 deletions(-) diff --git a/frame/cqlvalue.go b/frame/cqlvalue.go index 16c81dc2..9bd8eb8d 100644 --- a/frame/cqlvalue.go +++ b/frame/cqlvalue.go @@ -199,6 +199,17 @@ func (c CqlValue) AsFloat64() (float64, error) { uint64(c.Value[7])), nil } +func readString(raw []byte) (newRaw []byte, read string, err error) { + if len(raw) < 4 { + return nil, "", fmt.Errorf("expected at least 4 bytes, got %d", len(raw)) + } + size := binary.BigEndian.Uint32(raw) + if len(raw) < int(size+4) { + return nil, "", fmt.Errorf("expected at least %d bytes, got %d", size+4, len(raw)) + } + return raw[size+4:], string(raw[4 : size+4]), nil +} + func (c CqlValue) AsStringSlice() ([]string, error) { if c.Type.ID != SetID && c.Type.ID != ListID { return nil, fmt.Errorf("%v can't be interpreted as a slice", c) @@ -224,12 +235,15 @@ func (c CqlValue) AsStringSlice() ([]string, error) { raw = raw[4:] for i := range res { - if len(raw) < 4 { - return nil, fmt.Errorf("expected at least 4 bytes, got %d", len(raw)) + var err error + raw, res[i], err = readString(raw) + if err != nil { + return nil, err } - size := binary.BigEndian.Uint32(raw) - res[i] = string(raw[4 : size+4]) - raw = raw[size+4:] + } + + if len(raw) > 0 { + return nil, fmt.Errorf("got %d unexpected trailing bytes", len(raw)) } return res, nil @@ -252,16 +266,24 @@ func (c CqlValue) AsStringMap() (map[string]string, error) { res := make(map[string]string, n) for i := 0; i < n; i++ { - keyN := binary.BigEndian.Uint32(raw) - key := string(raw[4 : keyN+4]) - raw = raw[keyN+4:] - - valueN := binary.BigEndian.Uint32(raw) - value := string(raw[4 : valueN+4]) - raw = raw[valueN+4:] + var key, value string + var err error + raw, key, err = readString(raw) + if err != nil { + return nil, err + } + raw, value, err = readString(raw) + if err != nil { + return nil, err + } res[key] = value } + + if len(raw) > 0 { + return nil, fmt.Errorf("got %d unexpected trailing bytes", len(raw)) + } + return res, nil }