-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathvector.go
138 lines (118 loc) · 3.29 KB
/
vector.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
package libsqlvector
import (
"database/sql"
"database/sql/driver"
"encoding/binary"
"encoding/json"
"fmt"
"math"
"slices"
"strconv"
"strings"
)
// Vector is a wrapper for []float32 to implement sql.Scanner and driver.Valuer.
type Vector struct {
vec []float32
}
// NewVector creates a new Vector from a slice of float32.
func NewVector(vec []float32) Vector {
return Vector{vec: vec}
}
// Slice returns the underlying slice of float32.
func (v Vector) Slice() []float32 {
return v.vec
}
func (v Vector) FormatFloats() string {
buf := make([]byte, 0, 2+16*len(v.vec))
buf = append(buf, '[')
for i := 0; i < len(v.vec); i++ {
if i > 0 {
buf = append(buf, ',')
}
buf = strconv.AppendFloat(buf, float64(v.vec[i]), 'f', -1, 32)
}
buf = append(buf, ']')
return string(buf)
}
// String returns a string representation of the vector
func (v Vector) String() string {
buf := make([]byte, 0, 9+16*len(v.vec))
buf = append(buf, "vector('"...)
buf = append(buf, v.FormatFloats()...)
buf = append(buf, "')"...)
return string(buf)
}
func (v *Vector) Parse(s string) error {
vecs := s[9 : len(s)-3]
if len(vecs) == 0 {
v.vec = []float32{}
return nil
}
sp := strings.Split(vecs, ",")
v.vec = make([]float32, 0, len(sp))
for i := 0; i < len(sp); i++ {
n, err := strconv.ParseFloat(sp[i], 32)
if err != nil {
return err
}
v.vec = append(v.vec, float32(n))
}
return nil
}
// EncodeBinary encodes a binary representation of the vector.
func (v Vector) EncodeBinary(buf []byte) (newBuf []byte, err error) {
dim := len(v.vec)
buf = slices.Grow(buf, 4+4*dim)
buf = binary.BigEndian.AppendUint16(buf, uint16(dim))
buf = binary.BigEndian.AppendUint16(buf, 0)
for _, v := range v.vec {
buf = binary.BigEndian.AppendUint32(buf, math.Float32bits(v))
}
return buf, nil
}
// DecodeBinary decodes a binary representation of a vector.
func (v *Vector) DecodeBinary(buf []byte) error {
dim := int(binary.BigEndian.Uint16(buf[0:2]))
unused := binary.BigEndian.Uint16(buf[2:4])
if unused != 0 {
return fmt.Errorf("expected unused to be 0")
}
v.vec = make([]float32, 0, dim)
offset := 4
for i := 0; i < dim; i++ {
v.vec = append(v.vec, math.Float32frombits(binary.BigEndian.Uint32(buf[offset:offset+4])))
offset += 4
}
return nil
}
// statically assert that Vector implements sql.Scanner.
var _ sql.Scanner = (*Vector)(nil)
// Scan implements the sql.Scanner interface.
func (v *Vector) Scan(src interface{}) (err error) {
switch src := src.(type) {
case []byte:
return v.Parse(string(src))
case string:
return v.Parse(src)
default:
return fmt.Errorf("unsupported data type: %T", src)
}
}
// statically assert that Vector implements driver.Valuer.
var _ driver.Valuer = (*Vector)(nil)
// Value implements the driver.Valuer interface.
func (v Vector) Value() (driver.Value, error) {
return v.String(), nil
}
// statically assert that Vector implements json.Marshaler.
var _ json.Marshaler = (*Vector)(nil)
// MarshalJSON implements the json.Marshaler interface.
func (v Vector) MarshalJSON() ([]byte, error) {
return json.Marshal(v.vec)
}
// statically assert that Vector implements json.Unmarshaler.
var _ json.Unmarshaler = (*Vector)(nil)
// UnmarshalJSON implements the json.Unmarshaler interface.
func (v *Vector) UnmarshalJSON(data []byte) error {
return json.Unmarshal(data, &v.vec)
}