diff --git a/gocql/cluster.go b/gocql/cluster.go new file mode 100644 index 00000000..b867afbd --- /dev/null +++ b/gocql/cluster.go @@ -0,0 +1,268 @@ +package gocql + +import ( + "log" + "time" + + "github.com/scylladb/scylla-go-driver" + scyllalog "github.com/scylladb/scylla-go-driver/log" + "github.com/scylladb/scylla-go-driver/transport" +) + +type ClusterConfig struct { + // addresses for the initial connections. It is recommended to use the value set in + // the Cassandra config for broadcast_address or listen_address, an IP address not + // a domain name. This is because events from Cassandra will use the configured IP + // address, which is used to index connected hosts. If the domain name specified + // resolves to more than 1 IP address then the driver may connect multiple times to + // the same host, and will not mark the node being down or up from events. + Hosts []string + + // CQL version (default: 3.0.0) + CQLVersion string + + // ProtoVersion sets the version of the native protocol to use, this will + // enable features in the driver for specific protocol versions, generally this + // should be set to a known version (2,3,4) for the cluster being connected to. + // + // If it is 0 or unset (the default) then the driver will attempt to discover the + // highest supported protocol for the cluster. In clusters with nodes of different + // versions the protocol selected is not defined (ie, it can be any of the supported in the cluster) + ProtoVersion int + + // Connection timeout (default: 600ms) + Timeout time.Duration + + // Initial connection timeout, used during initial dial to server (default: 600ms) + // ConnectTimeout is used to set up the default dialer and is ignored if Dialer or HostDialer is provided. + ConnectTimeout time.Duration + + // Port used when dialing. + // Default: 9042 + Port int + + // Initial keyspace. Optional. + Keyspace string + + // Number of connections per host. + // Default: 2 + NumConns int + + // Default consistency level. + // Default: Quorum + Consistency Consistency + + // Compression algorithm. + // Default: nil + Compressor Compressor + + // Default: nil + Authenticator Authenticator + + // An Authenticator factory. Can be used to create alternative authenticators. + // Default: nil + // AuthProvider func(h *HostInfo) (Authenticator, error) + + // Default retry policy to use for queries. + // Default: no retries. + RetryPolicy RetryPolicy + + // ConvictionPolicy decides whether to mark host as down based on the error and host info. + // Default: SimpleConvictionPolicy + ConvictionPolicy ConvictionPolicy // TODO: use it? + + // Default reconnection policy to use for reconnecting before trying to mark host as down. + // ReconnectionPolicy ReconnectionPolicy + + // The keepalive period to use, enabled if > 0 (default: 0) + // SocketKeepalive is used to set up the default dialer and is ignored if Dialer or HostDialer is provided. + SocketKeepalive time.Duration + + // Maximum cache size for prepared statements globally for gocql. + // Default: 1000 + MaxPreparedStmts int + + // Maximum cache size for query info about statements for each session. + // Default: 1000 + MaxRoutingKeyInfo int + + // Default page size to use for created sessions. + // Default: 5000 + PageSize int + + // Consistency for the serial part of queries, values can be either SERIAL or LOCAL_SERIAL. + // Default: unset + // SerialConsistency SerialConsistency + + // SslOpts configures TLS use when HostDialer is not set. + // SslOpts is ignored if HostDialer is set. + SslOpts *SslOptions + + // Sends a client side timestamp for all requests which overrides the timestamp at which it arrives at the server. + // Default: true, only enabled for protocol 3 and above. + DefaultTimestamp bool + + // PoolConfig configures the underlying connection pool, allowing the + // configuration of host selection and connection selection policies. + PoolConfig PoolConfig + + // If not zero, gocql attempt to reconnect known DOWN nodes in every ReconnectInterval. + ReconnectInterval time.Duration // FIXME: unused + + // The maximum amount of time to wait for schema agreement in a cluster after + // receiving a schema change frame. (default: 60s) + MaxWaitSchemaAgreement time.Duration + + // HostFilter will filter all incoming events for host, any which don't pass + // the filter will be ignored. If set will take precedence over any options set + // via Discovery + // HostFilter HostFilter + + // AddressTranslator will translate addresses found on peer discovery and/or + // node change events. + // AddressTranslator AddressTranslator + + // If IgnorePeerAddr is true and the address in system.peers does not match + // the supplied host by either initial hosts or discovered via events then the + // host will be replaced with the supplied address. + // + // For example if an event comes in with host=10.0.0.1 but when looking up that + // address in system.local or system.peers returns 127.0.0.1, the peer will be + // set to 10.0.0.1 which is what will be used to connect to. + IgnorePeerAddr bool + + // If DisableInitialHostLookup then the driver will not attempt to get host info + // from the system.peers table, this will mean that the driver will connect to + // hosts supplied and will not attempt to lookup the hosts information, this will + // mean that data_centre, rack and token information will not be available and as + // such host filtering and token aware query routing will not be available. + DisableInitialHostLookup bool + + // Configure events the driver will register for + Events struct { + // disable registering for status events (node up/down) + DisableNodeStatusEvents bool + // disable registering for topology events (node added/removed/moved) + DisableTopologyEvents bool + // disable registering for schema events (keyspace/table/function removed/created/updated) + DisableSchemaEvents bool + } + + // DisableSkipMetadata will override the internal result metadata cache so that the driver does not + // send skip_metadata for queries, this means that the result will always contain + // the metadata to parse the rows and will not reuse the metadata from the prepared + // statement. + // + // See https://issues.apache.org/jira/browse/CASSANDRA-10786 + DisableSkipMetadata bool + + // QueryObserver will set the provided query observer on all queries created from this session. + // Use it to collect metrics / stats from queries by providing an implementation of QueryObserver. + // QueryObserver QueryObserver + + // BatchObserver will set the provided batch observer on all queries created from this session. + // Use it to collect metrics / stats from batch queries by providing an implementation of BatchObserver. + // BatchObserver BatchObserver + + // ConnectObserver will set the provided connect observer on all queries + // created from this session. + // ConnectObserver ConnectObserver + + // FrameHeaderObserver will set the provided frame header observer on all frames' headers created from this session. + // Use it to collect metrics / stats from frames by providing an implementation of FrameHeaderObserver. + // FrameHeaderObserver FrameHeaderObserver + + // Default idempotence for queries + DefaultIdempotence bool + + // The time to wait for frames before flushing the frames connection to Cassandra. + // Can help reduce syscall overhead by making less calls to write. Set to 0 to + // disable. + // + // (default: 200 microseconds) + WriteCoalesceWaitTime time.Duration + + // Dialer will be used to establish all connections created for this Cluster. + // If not provided, a default dialer configured with ConnectTimeout will be used. + // Dialer is ignored if HostDialer is provided. + // Dialer Dialer + + // HostDialer will be used to establish all connections for this Cluster. + // Unlike Dialer, HostDialer is responsible for setting up the entire connection, including the TLS session. + // To support shard-aware port, HostDialer should implement ShardDialer. + // If not provided, Dialer will be used instead. + // HostDialer HostDialer + + // DisableShardAwarePort will prevent the driver from connecting to Scylla's shard-aware port, + // even if there are nodes in the cluster that support it. + // + // It is generally recommended to leave this option turned off because gocql can use + // the shard-aware port to make the process of establishing more robust. + // However, if you have a cluster with nodes which expose shard-aware port + // but the port is unreachable due to network configuration issues, you can use + // this option to work around the issue. Set it to true only if you neither can fix + // your network nor disable shard-aware port on your nodes. + DisableShardAwarePort bool + + // Logger for this ClusterConfig. + // If not specified, defaults to the global gocql.Logger. + Logger StdLogger + + // internal config for testing + disableControlConn bool + disableInit bool +} + +func NewCluster(hosts ...string) *ClusterConfig { + cfg := ClusterConfig{Hosts: hosts, WriteCoalesceWaitTime: 200 * time.Microsecond} + return &cfg +} + +func sessionConfigFromGocql(cfg *ClusterConfig) (scylla.SessionConfig, error) { + scfg := scylla.DefaultSessionConfig(cfg.Keyspace, cfg.Hosts...) + scfg.Hosts = cfg.Hosts + scfg.WriteCoalesceWaitTime = cfg.WriteCoalesceWaitTime + if _, ok := cfg.Compressor.(SnappyCompressor); ok { + scfg.Compression = scylla.Snappy + } + + if auth, ok := cfg.Authenticator.(PasswordAuthenticator); ok { + scfg.Username = auth.Username + scfg.Password = auth.Password + } + + if policy, ok := cfg.PoolConfig.HostSelectionPolicy.(transport.HostSelectionPolicy); ok { + scfg.HostSelectionPolicy = policy + } + + if retryPolicy, ok := cfg.RetryPolicy.(transport.RetryPolicy); ok { + scfg.RetryPolicy = retryPolicy + } + + if cfg.Logger == nil { + if Logger == nil { + scfg.Logger = scyllalog.NewDefaultLogger() + } else { + scfg.Logger = stdLoggerWrapper{Logger} + } + } else { + if cfg.Logger == nil { + cfg.Logger = log.Default() + } + scfg.Logger = stdLoggerWrapper{cfg.Logger} + } + + if cfg.SslOpts != nil { + tlsConfig, err := setupTLSConfig(cfg.SslOpts) + if err != nil { + return scylla.SessionConfig{}, err + } + scfg.TLSConfig = tlsConfig + } + + return scfg, nil +} + +func (cfg *ClusterConfig) CreateSession() (*Session, error) { + return NewSession(*cfg) +} diff --git a/gocql/exec.go b/gocql/exec.go new file mode 100644 index 00000000..ca039d85 --- /dev/null +++ b/gocql/exec.go @@ -0,0 +1,191 @@ +package gocql + +import ( + "context" + "errors" + "fmt" + + "github.com/scylladb/scylla-go-driver" + "github.com/scylladb/scylla-go-driver/frame" + "github.com/scylladb/scylla-go-driver/transport" +) + +// SingleHostQueryExecutor allows to quickly execute diagnostic queries while +// connected to only a single node. +// The executor opens only a single connection to a node and does not use +// connection pools. +// Consistency level used is ONE. +// Retry policy is applied, attempts are visible in query metrics but query +// observer is not notified. +type SingleHostQueryExecutor struct { + conn *transport.Conn +} + +func bind(stmt *transport.Statement, values []interface{}) error { + if len(stmt.Values) != len(values) { + return fmt.Errorf("bind: expected %d columns, got %d", len(stmt.Values), len(values)) + } + + for i := range values { + v := anyWrapper{values[i]} + var err error + stmt.Values[i].N, stmt.Values[i].Bytes, err = v.Serialize(&stmt.Metadata.Columns[i].Type) + if err != nil { + return err + } + } + + return nil +} + +// Exec executes the query without returning any rows. +func (e SingleHostQueryExecutor) Exec(stmt string, values ...interface{}) error { + qStmt := transport.Statement{Content: stmt, Consistency: frame.ONE} + qStmt, err := e.conn.Prepare(context.Background(), qStmt) + if err != nil { + return err + } + + if err := bind(&qStmt, values); err != nil { + return err + } + _, err = e.conn.Query(context.Background(), qStmt, nil) + return err +} + +// Iter executes the query and returns an iterator capable of iterating +// over all results. +func (e SingleHostQueryExecutor) Iter(stmt string, values ...interface{}) *Iter { + qStmt := transport.Statement{Content: stmt, Consistency: frame.ONE} + qStmt, err := e.conn.Prepare(context.Background(), qStmt) + if err == nil { + err = bind(&qStmt, values) + } + it := newIter(newSingleHostIter(qStmt, e.conn)) + it.err = err + return it +} + +func (e SingleHostQueryExecutor) Close() { + if e.conn != nil { + e.conn.Close() + } +} + +// NewSingleHostQueryExecutor creates a SingleHostQueryExecutor by connecting +// to one of the hosts specified in the ClusterConfig. +// If ProtoVersion is not specified version 4 is used. +// Caller is responsible for closing the executor after use. +func NewSingleHostQueryExecutor(cfg *ClusterConfig) (e SingleHostQueryExecutor, err error) { + if len(cfg.Hosts) < 1 { + return + } + + var scfg scylla.SessionConfig + scfg, err = sessionConfigFromGocql(cfg) + if err != nil { + return + } + + host := cfg.Hosts[0] + var control *transport.Conn + control, err = transport.OpenConn(context.Background(), host, nil, scfg.ConnConfig) + if err != nil { + if control != nil { + control.Close() + } + return + } + e = SingleHostQueryExecutor{control} + return +} + +type singleHostIter struct { + conn *transport.Conn + result transport.QueryResult + pos int + rowCnt int + closed bool + err error + rd transport.RetryDecider + stmt transport.Statement +} + +func newSingleHostIter(stmt transport.Statement, conn *transport.Conn) *singleHostIter { + return &singleHostIter{ + conn: conn, + stmt: stmt, + rd: &transport.DefaultRetryDecider{}, + result: transport.QueryResult{HasMorePages: true}, + } +} + +func (it *singleHostIter) fetch() (transport.QueryResult, error) { + if !it.result.HasMorePages { + return transport.QueryResult{}, scylla.ErrNoMoreRows + } + for { + res, err := it.conn.Query(context.Background(), it.stmt, it.result.PagingState) + if err == nil { + return res, nil + } else if err != nil { + ri := transport.RetryInfo{ + Error: err, + Idempotent: it.stmt.Idempotent, + Consistency: 1, + } + if it.rd.Decide(ri) != transport.RetrySameNode { + return transport.QueryResult{}, err + } + } + } +} + +func (it *singleHostIter) Next() (frame.Row, error) { + if it.closed { + return nil, nil + } + + if it.pos >= it.rowCnt { + var err error + it.result, err = it.fetch() + if err != nil { + if !errors.Is(err, scylla.ErrNoMoreRows) { + it.err = err + } + return nil, it.Close() + } + + it.pos = 0 + it.rowCnt = len(it.result.Rows) + } + + // We probably got a zero-sized last page, retry to be sure + if it.rowCnt == 0 { + return it.Next() + } + + res := it.result.Rows[it.pos] + it.pos++ + return res, nil +} + +func (it *singleHostIter) Close() error { + if it.closed { + return it.err + } + it.closed = true + return it.err +} + +func (it *singleHostIter) Columns() []frame.ColumnSpec { + return it.stmt.Metadata.Columns +} + +func (it *singleHostIter) NumRows() int { + return it.rowCnt +} + +func (it *singleHostIter) PageState() []byte { + return it.result.PagingState +} diff --git a/gocql/go.mod b/gocql/go.mod new file mode 100644 index 00000000..91de4447 --- /dev/null +++ b/gocql/go.mod @@ -0,0 +1,12 @@ +module github.com/gocql/gocql + +go 1.18 + +require gopkg.in/inf.v0 v0.9.1 + +require ( + github.com/klauspost/compress v1.15.1 // indirect + github.com/pierrec/lz4/v4 v4.1.14 // indirect + github.com/scylladb/scylla-go-driver v0.0.0-20220926172207-fd9f96861a45 // indirect + go.uber.org/atomic v1.9.0 // indirect +) diff --git a/gocql/go.sum b/gocql/go.sum new file mode 100644 index 00000000..de79ab98 --- /dev/null +++ b/gocql/go.sum @@ -0,0 +1,28 @@ +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/google/go-cmp v0.5.6 h1:BKbKCqvP6I+rmFHt06ZmyQtvB8xAkWdhFyr0ZUNZcxQ= +github.com/klauspost/compress v1.15.1 h1:y9FcTHGyrebwfP0ZZqFiaxTaiDnUrGkJkI+f583BL1A= +github.com/klauspost/compress v1.15.1/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk= +github.com/pierrec/lz4/v4 v4.1.14 h1:+fL8AQEZtz/ijeNnpduH0bROTu0O3NZAlPjQxGn8LwE= +github.com/pierrec/lz4/v4 v4.1.14/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/scylladb/scylla-go-driver v0.0.0-20220923093004-5369a731b3dc h1:ZUJ82dVV+X3ZjX4y8KfjnZb5difEyWsbQueuesTmKZE= +github.com/scylladb/scylla-go-driver v0.0.0-20220923093004-5369a731b3dc/go.mod h1:4szg2aNdnU8KlAY9PmgAuu5bxqklsHXv8IAfyYskq5A= +github.com/scylladb/scylla-go-driver v0.0.0-20220926170552-3f45c1970f1b h1:n5xBIohQMWYWb7GQAkfDe/xQSyiiyb2IpymLujc+NCA= +github.com/scylladb/scylla-go-driver v0.0.0-20220926170552-3f45c1970f1b/go.mod h1:4szg2aNdnU8KlAY9PmgAuu5bxqklsHXv8IAfyYskq5A= +github.com/scylladb/scylla-go-driver v0.0.0-20220926171312-6e3b8f3bc98e h1:ndgMD8dHGcA2L8vG3ncB8hAU0fmD8mGSuMp6OTgQUrQ= +github.com/scylladb/scylla-go-driver v0.0.0-20220926171312-6e3b8f3bc98e/go.mod h1:4szg2aNdnU8KlAY9PmgAuu5bxqklsHXv8IAfyYskq5A= +github.com/scylladb/scylla-go-driver v0.0.0-20220926172207-fd9f96861a45 h1:5e1RWvo6/TrREaYjCALqFWETLLV3/988IOuH9WQhS3k= +github.com/scylladb/scylla-go-driver v0.0.0-20220926172207-fd9f96861a45/go.mod h1:4szg2aNdnU8KlAY9PmgAuu5bxqklsHXv8IAfyYskq5A= +github.com/scylladb/scylla-go-driver v0.1.12 h1:iEE4vlil0yHkMAoOmznnGiIT8U5IZ3684NBoXPzDJUs= +github.com/scylladb/scylla-go-driver v0.1.12/go.mod h1:MadciHUoixPASBKr0qXVtPYtl2WOx0bWpWheIdzmYxA= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE= +go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= +go.uber.org/goleak v1.1.12 h1:gZAh5/EyT/HQwlpkCy6wTpqfH9H8Lz8zbm3dZh+OyzA= +gopkg.in/inf.v0 v0.9.1 h1:73M5CoZyi3ZLMOyDlQh031Cx6N9NDJ2Vvfl76EDAgDc= +gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw= diff --git a/gocql/helpers.go b/gocql/helpers.go new file mode 100644 index 00000000..0870f85c --- /dev/null +++ b/gocql/helpers.go @@ -0,0 +1,84 @@ +package gocql + +import ( + "math/big" + "reflect" + "time" + + "gopkg.in/inf.v0" +) + +func copyBytes(p []byte) []byte { + b := make([]byte, len(p)) + copy(b, p) + return b +} + +func appendInt(p []byte, n int32) []byte { + return append(p, byte(n>>24), + byte(n>>16), + byte(n>>8), + byte(n)) +} + +func readInt(p []byte) int32 { + return int32(p[0])<<24 | int32(p[1])<<16 | int32(p[2])<<8 | int32(p[3]) +} + +func appendBytes(p []byte, d []byte) []byte { + if d == nil { + return appendInt(p, -1) + } + p = appendInt(p, int32(len(d))) + p = append(p, d...) + return p +} + +func goType(t TypeInfo) reflect.Type { + switch t.Type() { + case TypeVarchar, TypeAscii, TypeInet, TypeText: + return reflect.TypeOf(*new(string)) + case TypeBigInt, TypeCounter: + return reflect.TypeOf(*new(int64)) + case TypeTime: + return reflect.TypeOf(*new(time.Duration)) + case TypeTimestamp: + return reflect.TypeOf(*new(time.Time)) + case TypeBlob: + return reflect.TypeOf(*new([]byte)) + case TypeBoolean: + return reflect.TypeOf(*new(bool)) + case TypeFloat: + return reflect.TypeOf(*new(float32)) + case TypeDouble: + return reflect.TypeOf(*new(float64)) + case TypeInt: + return reflect.TypeOf(*new(int)) + case TypeSmallInt: + return reflect.TypeOf(*new(int16)) + case TypeTinyInt: + return reflect.TypeOf(*new(int8)) + case TypeDecimal: + return reflect.TypeOf(*new(*inf.Dec)) + case TypeUUID, TypeTimeUUID: + return reflect.TypeOf(*new(UUID)) + case TypeList, TypeSet: + return reflect.SliceOf(goType(t.(CollectionType).Elem)) + case TypeMap: + return reflect.MapOf(goType(t.(CollectionType).Key), goType(t.(CollectionType).Elem)) + case TypeVarint: + return reflect.TypeOf(*new(*big.Int)) + case TypeTuple: + // what can we do here? all there is to do is to make a list of interface{} + tuple := t.(TupleTypeInfo) + return reflect.TypeOf(make([]interface{}, len(tuple.Elems))) + case TypeUDT: + return reflect.TypeOf(make(map[string]interface{})) + case TypeDate: + return reflect.TypeOf(*new(time.Time)) + case TypeDuration: + return reflect.TypeOf(*new(Duration)) + default: + return nil + } +} diff --git a/gocql/iter.go b/gocql/iter.go new file mode 100644 index 00000000..513f7738 --- /dev/null +++ b/gocql/iter.go @@ -0,0 +1,176 @@ +package gocql + +import ( + "fmt" + "reflect" + + "github.com/scylladb/scylla-go-driver/frame" +) + +type iterr interface { + Columns() []frame.ColumnSpec + NumRows() int + Close() error + Next() (frame.Row, error) + PageState() []byte +} + +type Iter struct { + it iterr + err error + row frame.Row + // First Next call should be done on iter creation, rest on Scan + dontSkipNext bool +} + +func newIter(it iterr) *Iter { + ret := &Iter{it: it} + ret.row, ret.err = it.Next() + return ret +} + +func (it *Iter) Columns() []ColumnInfo { + if it.err != nil { + return nil + } + + c := it.it.Columns() + cols := make([]ColumnInfo, len(c)) + for i, v := range c { + typ := WrapOption(&v.Type) + cols[i] = ColumnInfo{ + Keyspace: v.Keyspace, + Table: v.Table, + Name: v.Name, + TypeInfo: typ, + } + } + + return cols +} + +func (it *Iter) NumRows() int { + return it.it.NumRows() +} + +func (it *Iter) Close() error { + if it.err != nil { + return it.err + } + return it.it.Close() +} + +func (it *Iter) Scan(dest ...interface{}) bool { + if it.err != nil { + return false + } + + if it.dontSkipNext { + it.row, it.err = it.it.Next() + } else { + it.dontSkipNext = true + } + if it.err != nil { + return false + } + + if len(it.row) == 0 { + return false + } + + if len(dest) != len(it.row) { + it.err = fmt.Errorf("expected %d columns, got %d", len(dest), len(it.row)) + return false + } + + for i := range dest { + if dest[i] == nil { + continue + } + it.err = Unmarshal(WrapOption(it.row[i].Type), it.row[i].Value, dest[i]) + if it.err != nil { + return false + } + } + + return true +} + +func (it *Iter) PageState() []byte { + return it.it.PageState() +} + +type RowData struct { + Columns []string + Values []interface{} +} + +// TupeColumnName will return the column name of a tuple value in a column named +// c at index n. It should be used if a specific element within a tuple is needed +// to be extracted from a map returned from SliceMap or MapScan. +func TupleColumnName(c string, n int) string { + return fmt.Sprintf("%s[%d]", c, n) +} + +func (iter *Iter) RowData() (RowData, error) { + if iter.err != nil { + return RowData{}, iter.err + } + + columns := make([]string, 0, len(iter.Columns())) + values := make([]interface{}, 0, len(iter.Columns())) + + for _, column := range iter.Columns() { + if c, ok := column.TypeInfo.(TupleTypeInfo); !ok { + val := column.TypeInfo.New() + columns = append(columns, column.Name) + values = append(values, val) + } else { + for i, elem := range c.Elems { + columns = append(columns, TupleColumnName(column.Name, i)) + values = append(values, elem.New()) + } + } + } + + return RowData{ + Columns: columns, + Values: values, + }, nil +} + +func (it *Iter) MapScan(m map[string]interface{}) bool { + if it.err != nil { + return false + } + + // Not checking for the error because we just did + rowData, _ := it.RowData() + + for i, col := range rowData.Columns { + if dest, ok := m[col]; ok { + rowData.Values[i] = dest + } + } + if it.Scan(rowData.Values...) { + rowData.rowMap(m) + return true + } + return false +} + +func dereference(i interface{}) interface{} { + return reflect.Indirect(reflect.ValueOf(i)).Interface() +} +func (r *RowData) rowMap(m map[string]interface{}) { + for i, column := range r.Columns { + val := dereference(r.Values[i]) + if valVal := reflect.ValueOf(val); valVal.Kind() == reflect.Slice { + valCopy := reflect.MakeSlice(valVal.Type(), valVal.Len(), valVal.Cap()) + reflect.Copy(valCopy, valVal) + m[column] = valCopy.Interface() + } else { + m[column] = val + } + } +} diff --git a/gocql/logger.go b/gocql/logger.go new file mode 100644 index 00000000..00a3b143 --- /dev/null +++ b/gocql/logger.go @@ -0,0 +1,39 @@ +package gocql + +import "log" + +type StdLogger interface { + Print(v ...interface{}) + Printf(format string, v ...interface{}) + Println(v ...interface{}) +} + +type stdLoggerWrapper struct { + StdLogger +} + +var Logger StdLogger = log.Default() + +func (s stdLoggerWrapper) Info(v ...interface{}) { + s.Print(v...) +} + +func (s stdLoggerWrapper) Infof(format string, v ...interface{}) { + s.Printf(format, v...) +} + +func (s stdLoggerWrapper) Infoln(v ...interface{}) { + s.Println(v...) +} + +func (s stdLoggerWrapper) Warn(v ...interface{}) { + s.Print(v...) +} + +func (s stdLoggerWrapper) Warnf(format string, v ...interface{}) { + s.Printf(format, v...) +} + +func (s stdLoggerWrapper) Warnln(v ...interface{}) { + s.Println(v...) +} diff --git a/gocql/logs b/gocql/logs new file mode 100644 index 00000000..e69de29b diff --git a/gocql/marshal.go b/gocql/marshal.go new file mode 100644 index 00000000..77bb7607 --- /dev/null +++ b/gocql/marshal.go @@ -0,0 +1,2592 @@ +// Copyright (c) 2012 The gocql Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package gocql + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "math" + "math/big" + "math/bits" + "net" + "reflect" + "strconv" + "strings" + "time" + + "gopkg.in/inf.v0" +) + +var ( + bigOne = big.NewInt(1) + emptyValue reflect.Value +) + +var ( + ErrorUDTUnavailable = errors.New("UDT are not available on protocols less than 3, please update config") +) + +// Marshaler is the interface implemented by objects that can marshal +// themselves into values understood by Cassandra. +type Marshaler interface { + MarshalCQL(info TypeInfo) ([]byte, error) +} + +// Unmarshaler is the interface implemented by objects that can unmarshal +// a Cassandra specific description of themselves. +type Unmarshaler interface { + UnmarshalCQL(info TypeInfo, data []byte) error +} + +// Marshal returns the CQL encoding of the value for the Cassandra +// internal type described by the info parameter. +// +// nil is serialized as CQL null. +// If value implements Marshaler, its MarshalCQL method is called to marshal the data. +// If value is a pointer, the pointed-to value is marshaled. +// +// Supported conversions are as follows, other type combinations may be added in the future: +// +// CQL type | Go type (value) | Note +// varchar, ascii, blob, text | string, []byte | +// boolean | bool | +// tinyint, smallint, int | integer types | +// tinyint, smallint, int | string | formatted as base 10 number +// bigint, counter | integer types | +// bigint, counter | big.Int | +// bigint, counter | string | formatted as base 10 number +// float | float32 | +// double | float64 | +// decimal | inf.Dec | +// time | int64 | nanoseconds since start of day +// time | time.Duration | duration since start of day +// timestamp | int64 | milliseconds since Unix epoch +// timestamp | time.Time | +// list, set | slice, array | +// list, set | map[X]struct{} | +// map | map[X]Y | +// uuid, timeuuid | gocql.UUID | +// uuid, timeuuid | [16]byte | raw UUID bytes +// uuid, timeuuid | []byte | raw UUID bytes, length must be 16 bytes +// uuid, timeuuid | string | hex representation, see ParseUUID +// varint | integer types | +// varint | big.Int | +// varint | string | value of number in decimal notation +// inet | net.IP | +// inet | string | IPv4 or IPv6 address string +// tuple | slice, array | +// tuple | struct | fields are marshaled in order of declaration +// user-defined type | gocql.UDTMarshaler | MarshalUDT is called +// user-defined type | map[string]interface{} | +// user-defined type | struct | struct fields' cql tags are used for column names +// date | int64 | milliseconds since Unix epoch to start of day (in UTC) +// date | time.Time | start of day (in UTC) +// date | string | parsed using "2006-01-02" format +// duration | int64 | duration in nanoseconds +// duration | time.Duration | +// duration | gocql.Duration | +// duration | string | parsed with time.ParseDuration +func Marshal(info TypeInfo, value interface{}) ([]byte, error) { + if info.Version() < protoVersion1 { + panic("protocol version not set") + } + + if valueRef := reflect.ValueOf(value); valueRef.Kind() == reflect.Ptr { + if valueRef.IsNil() { + return nil, nil + } else if v, ok := value.(Marshaler); ok { + return v.MarshalCQL(info) + } else { + return Marshal(info, valueRef.Elem().Interface()) + } + } + + if v, ok := value.(Marshaler); ok { + return v.MarshalCQL(info) + } + + switch info.Type() { + case TypeVarchar, TypeAscii, TypeBlob, TypeText: + return marshalVarchar(info, value) + case TypeBoolean: + return marshalBool(info, value) + case TypeTinyInt: + return marshalTinyInt(info, value) + case TypeSmallInt: + return marshalSmallInt(info, value) + case TypeInt: + return marshalInt(info, value) + case TypeBigInt, TypeCounter: + return marshalBigInt(info, value) + case TypeFloat: + return marshalFloat(info, value) + case TypeDouble: + return marshalDouble(info, value) + case TypeDecimal: + return marshalDecimal(info, value) + case TypeTime: + return marshalTime(info, value) + case TypeTimestamp: + return marshalTimestamp(info, value) + case TypeList, TypeSet: + return marshalList(info, value) + case TypeMap: + return marshalMap(info, value) + case TypeUUID, TypeTimeUUID: + return marshalUUID(info, value) + case TypeVarint: + return marshalVarint(info, value) + case TypeInet: + return marshalInet(info, value) + case TypeTuple: + return marshalTuple(info, value) + case TypeUDT: + return marshalUDT(info, value) + case TypeDate: + return marshalDate(info, value) + case TypeDuration: + return marshalDuration(info, value) + } + + // detect protocol 2 UDT + if strings.HasPrefix(info.Custom(), "org.apache.cassandra.db.marshal.UserType") && info.Version() < 3 { + return nil, ErrorUDTUnavailable + } + + // TODO(tux21b): add the remaining types + return nil, fmt.Errorf("can not marshal %T into %s", value, info) +} + +// Unmarshal parses the CQL encoded data based on the info parameter that +// describes the Cassandra internal data type and stores the result in the +// value pointed by value. +// +// If value implements Unmarshaler, it's UnmarshalCQL method is called to +// unmarshal the data. +// If value is a pointer to pointer, it is set to nil if the CQL value is +// null. Otherwise, nulls are unmarshalled as zero value. +// +// Supported conversions are as follows, other type combinations may be added in the future: +// +// CQL type | Go type (value) | Note +// varchar, ascii, blob, text | *string | +// varchar, ascii, blob, text | *[]byte | non-nil buffer is reused +// bool | *bool | +// tinyint, smallint, int, bigint, counter | *integer types | +// tinyint, smallint, int, bigint, counter | *big.Int | +// tinyint, smallint, int, bigint, counter | *string | formatted as base 10 number +// float | *float32 | +// double | *float64 | +// decimal | *inf.Dec | +// time | *int64 | nanoseconds since start of day +// time | *time.Duration | +// timestamp | *int64 | milliseconds since Unix epoch +// timestamp | *time.Time | +// list, set | *slice, *array | +// map | *map[X]Y | +// uuid, timeuuid | *string | see UUID.String +// uuid, timeuuid | *[]byte | raw UUID bytes +// uuid, timeuuid | *gocql.UUID | +// timeuuid | *time.Time | timestamp of the UUID +// inet | *net.IP | +// inet | *string | IPv4 or IPv6 address string +// tuple | *slice, *array | +// tuple | *struct | struct fields are set in order of declaration +// user-defined types | gocql.UDTUnmarshaler | UnmarshalUDT is called +// user-defined types | *map[string]interface{} | +// user-defined types | *struct | cql tag is used to determine field name +// date | *time.Time | time of beginning of the day (in UTC) +// date | *string | formatted with 2006-01-02 format +// duration | *gocql.Duration | +func Unmarshal(info TypeInfo, data []byte, value interface{}) error { + if v, ok := value.(Unmarshaler); ok { + return v.UnmarshalCQL(info, data) + } + + if isNullableValue(value) { + return unmarshalNullable(info, data, value) + } + + switch info.Type() { + case TypeVarchar, TypeAscii, TypeBlob, TypeText: + return unmarshalVarchar(info, data, value) + case TypeBoolean: + return unmarshalBool(info, data, value) + case TypeInt: + return unmarshalInt(info, data, value) + case TypeBigInt, TypeCounter: + return unmarshalBigInt(info, data, value) + case TypeVarint: + return unmarshalVarint(info, data, value) + case TypeSmallInt: + return unmarshalSmallInt(info, data, value) + case TypeTinyInt: + return unmarshalTinyInt(info, data, value) + case TypeFloat: + return unmarshalFloat(info, data, value) + case TypeDouble: + return unmarshalDouble(info, data, value) + case TypeDecimal: + return unmarshalDecimal(info, data, value) + case TypeTime: + return unmarshalTime(info, data, value) + case TypeTimestamp: + return unmarshalTimestamp(info, data, value) + case TypeList, TypeSet: + return unmarshalList(info, data, value) + case TypeMap: + return unmarshalMap(info, data, value) + case TypeTimeUUID: + return unmarshalTimeUUID(info, data, value) + case TypeUUID: + return unmarshalUUID(info, data, value) + case TypeInet: + return unmarshalInet(info, data, value) + case TypeTuple: + return unmarshalTuple(info, data, value) + case TypeUDT: + return unmarshalUDT(info, data, value) + case TypeDate: + return unmarshalDate(info, data, value) + case TypeDuration: + return unmarshalDuration(info, data, value) + } + + // detect protocol 2 UDT + if strings.HasPrefix(info.Custom(), "org.apache.cassandra.db.marshal.UserType") && info.Version() < 3 { + return ErrorUDTUnavailable + } + + // TODO(tux21b): add the remaining types + return fmt.Errorf("can not unmarshal %s into %T", info, value) +} + +func isNullableValue(value interface{}) bool { + v := reflect.ValueOf(value) + return v.Kind() == reflect.Ptr && v.Type().Elem().Kind() == reflect.Ptr +} + +func isNullData(info TypeInfo, data []byte) bool { + return data == nil +} + +func unmarshalNullable(info TypeInfo, data []byte, value interface{}) error { + valueRef := reflect.ValueOf(value) + + if isNullData(info, data) { + nilValue := reflect.Zero(valueRef.Type().Elem()) + valueRef.Elem().Set(nilValue) + return nil + } + + newValue := reflect.New(valueRef.Type().Elem().Elem()) + valueRef.Elem().Set(newValue) + return Unmarshal(info, data, newValue.Interface()) +} + +func marshalVarchar(info TypeInfo, value interface{}) ([]byte, error) { + switch v := value.(type) { + case Marshaler: + return v.MarshalCQL(info) + case unsetColumn: + return nil, nil + case string: + return []byte(v), nil + case []byte: + return v, nil + } + + if value == nil { + return nil, nil + } + + rv := reflect.ValueOf(value) + t := rv.Type() + k := t.Kind() + switch { + case k == reflect.String: + return []byte(rv.String()), nil + case k == reflect.Slice && t.Elem().Kind() == reflect.Uint8: + return rv.Bytes(), nil + } + return nil, marshalErrorf("can not marshal %T into %s", value, info) +} + +func unmarshalVarchar(info TypeInfo, data []byte, value interface{}) error { + switch v := value.(type) { + case Unmarshaler: + return v.UnmarshalCQL(info, data) + case *string: + *v = string(data) + return nil + case *[]byte: + if data != nil { + *v = append((*v)[:0], data...) + } else { + *v = nil + } + return nil + } + + rv := reflect.ValueOf(value) + if rv.Kind() != reflect.Ptr { + return unmarshalErrorf("can not unmarshal into non-pointer %T", value) + } + rv = rv.Elem() + t := rv.Type() + k := t.Kind() + switch { + case k == reflect.String: + rv.SetString(string(data)) + return nil + case k == reflect.Slice && t.Elem().Kind() == reflect.Uint8, k == reflect.Interface: + var dataCopy []byte + if data != nil { + dataCopy = make([]byte, len(data)) + copy(dataCopy, data) + } + if k == reflect.Slice { + rv.SetBytes(dataCopy) + } else { + rv.Set(reflect.ValueOf(dataCopy)) + } + return nil + } + + return unmarshalErrorf("can not unmarshal %s into %T", info, value) +} + +func marshalSmallInt(info TypeInfo, value interface{}) ([]byte, error) { + switch v := value.(type) { + case Marshaler: + return v.MarshalCQL(info) + case unsetColumn: + return nil, nil + case int16: + return encShort(v), nil + case uint16: + return encShort(int16(v)), nil + case int8: + return encShort(int16(v)), nil + case uint8: + return encShort(int16(v)), nil + case int: + if v > math.MaxInt16 || v < math.MinInt16 { + return nil, marshalErrorf("marshal smallint: value %d out of range", v) + } + return encShort(int16(v)), nil + case int32: + if v > math.MaxInt16 || v < math.MinInt16 { + return nil, marshalErrorf("marshal smallint: value %d out of range", v) + } + return encShort(int16(v)), nil + case int64: + if v > math.MaxInt16 || v < math.MinInt16 { + return nil, marshalErrorf("marshal smallint: value %d out of range", v) + } + return encShort(int16(v)), nil + case uint: + if v > math.MaxUint16 { + return nil, marshalErrorf("marshal smallint: value %d out of range", v) + } + return encShort(int16(v)), nil + case uint32: + if v > math.MaxUint16 { + return nil, marshalErrorf("marshal smallint: value %d out of range", v) + } + return encShort(int16(v)), nil + case uint64: + if v > math.MaxUint16 { + return nil, marshalErrorf("marshal smallint: value %d out of range", v) + } + return encShort(int16(v)), nil + case string: + n, err := strconv.ParseInt(v, 10, 16) + if err != nil { + return nil, marshalErrorf("can not marshal %T into %s: %v", value, info, err) + } + return encShort(int16(n)), nil + } + + if value == nil { + return nil, nil + } + + switch rv := reflect.ValueOf(value); rv.Type().Kind() { + case reflect.Int, reflect.Int64, reflect.Int32, reflect.Int16, reflect.Int8: + v := rv.Int() + if v > math.MaxInt16 || v < math.MinInt16 { + return nil, marshalErrorf("marshal smallint: value %d out of range", v) + } + return encShort(int16(v)), nil + case reflect.Uint, reflect.Uint64, reflect.Uint32, reflect.Uint16, reflect.Uint8: + v := rv.Uint() + if v > math.MaxUint16 { + return nil, marshalErrorf("marshal smallint: value %d out of range", v) + } + return encShort(int16(v)), nil + case reflect.Ptr: + if rv.IsNil() { + return nil, nil + } + } + + return nil, marshalErrorf("can not marshal %T into %s", value, info) +} + +func marshalTinyInt(info TypeInfo, value interface{}) ([]byte, error) { + switch v := value.(type) { + case Marshaler: + return v.MarshalCQL(info) + case unsetColumn: + return nil, nil + case int8: + return []byte{byte(v)}, nil + case uint8: + return []byte{byte(v)}, nil + case int16: + if v > math.MaxInt8 || v < math.MinInt8 { + return nil, marshalErrorf("marshal tinyint: value %d out of range", v) + } + return []byte{byte(v)}, nil + case uint16: + if v > math.MaxUint8 { + return nil, marshalErrorf("marshal tinyint: value %d out of range", v) + } + return []byte{byte(v)}, nil + case int: + if v > math.MaxInt8 || v < math.MinInt8 { + return nil, marshalErrorf("marshal tinyint: value %d out of range", v) + } + return []byte{byte(v)}, nil + case int32: + if v > math.MaxInt8 || v < math.MinInt8 { + return nil, marshalErrorf("marshal tinyint: value %d out of range", v) + } + return []byte{byte(v)}, nil + case int64: + if v > math.MaxInt8 || v < math.MinInt8 { + return nil, marshalErrorf("marshal tinyint: value %d out of range", v) + } + return []byte{byte(v)}, nil + case uint: + if v > math.MaxUint8 { + return nil, marshalErrorf("marshal tinyint: value %d out of range", v) + } + return []byte{byte(v)}, nil + case uint32: + if v > math.MaxUint8 { + return nil, marshalErrorf("marshal tinyint: value %d out of range", v) + } + return []byte{byte(v)}, nil + case uint64: + if v > math.MaxUint8 { + return nil, marshalErrorf("marshal tinyint: value %d out of range", v) + } + return []byte{byte(v)}, nil + case string: + n, err := strconv.ParseInt(v, 10, 8) + if err != nil { + return nil, marshalErrorf("can not marshal %T into %s: %v", value, info, err) + } + return []byte{byte(n)}, nil + } + + if value == nil { + return nil, nil + } + + switch rv := reflect.ValueOf(value); rv.Type().Kind() { + case reflect.Int, reflect.Int64, reflect.Int32, reflect.Int16, reflect.Int8: + v := rv.Int() + if v > math.MaxInt8 || v < math.MinInt8 { + return nil, marshalErrorf("marshal tinyint: value %d out of range", v) + } + return []byte{byte(v)}, nil + case reflect.Uint, reflect.Uint64, reflect.Uint32, reflect.Uint16, reflect.Uint8: + v := rv.Uint() + if v > math.MaxUint8 { + return nil, marshalErrorf("marshal tinyint: value %d out of range", v) + } + return []byte{byte(v)}, nil + case reflect.Ptr: + if rv.IsNil() { + return nil, nil + } + } + + return nil, marshalErrorf("can not marshal %T into %s", value, info) +} + +func marshalInt(info TypeInfo, value interface{}) ([]byte, error) { + switch v := value.(type) { + case Marshaler: + return v.MarshalCQL(info) + case unsetColumn: + return nil, nil + case int: + if v > math.MaxInt32 || v < math.MinInt32 { + return nil, marshalErrorf("marshal int: value %d out of range", v) + } + return encInt(int32(v)), nil + case uint: + if v > math.MaxUint32 { + return nil, marshalErrorf("marshal int: value %d out of range", v) + } + return encInt(int32(v)), nil + case int64: + if v > math.MaxInt32 || v < math.MinInt32 { + return nil, marshalErrorf("marshal int: value %d out of range", v) + } + return encInt(int32(v)), nil + case uint64: + if v > math.MaxUint32 { + return nil, marshalErrorf("marshal int: value %d out of range", v) + } + return encInt(int32(v)), nil + case int32: + return encInt(v), nil + case uint32: + return encInt(int32(v)), nil + case int16: + return encInt(int32(v)), nil + case uint16: + return encInt(int32(v)), nil + case int8: + return encInt(int32(v)), nil + case uint8: + return encInt(int32(v)), nil + case string: + i, err := strconv.ParseInt(v, 10, 32) + if err != nil { + return nil, marshalErrorf("can not marshal string to int: %s", err) + } + return encInt(int32(i)), nil + } + + if value == nil { + return nil, nil + } + + switch rv := reflect.ValueOf(value); rv.Type().Kind() { + case reflect.Int, reflect.Int64, reflect.Int32, reflect.Int16, reflect.Int8: + v := rv.Int() + if v > math.MaxInt32 || v < math.MinInt32 { + return nil, marshalErrorf("marshal int: value %d out of range", v) + } + return encInt(int32(v)), nil + case reflect.Uint, reflect.Uint64, reflect.Uint32, reflect.Uint16, reflect.Uint8: + v := rv.Uint() + if v > math.MaxInt32 { + return nil, marshalErrorf("marshal int: value %d out of range", v) + } + return encInt(int32(v)), nil + case reflect.Ptr: + if rv.IsNil() { + return nil, nil + } + } + + return nil, marshalErrorf("can not marshal %T into %s", value, info) +} + +func encInt(x int32) []byte { + return []byte{byte(x >> 24), byte(x >> 16), byte(x >> 8), byte(x)} +} + +func decInt(x []byte) int32 { + if len(x) != 4 { + return 0 + } + return int32(x[0])<<24 | int32(x[1])<<16 | int32(x[2])<<8 | int32(x[3]) +} + +func encShort(x int16) []byte { + p := make([]byte, 2) + p[0] = byte(x >> 8) + p[1] = byte(x) + return p +} + +func decShort(p []byte) int16 { + if len(p) != 2 { + return 0 + } + return int16(p[0])<<8 | int16(p[1]) +} + +func decTiny(p []byte) int8 { + if len(p) != 1 { + return 0 + } + return int8(p[0]) +} + +func marshalBigInt(info TypeInfo, value interface{}) ([]byte, error) { + switch v := value.(type) { + case Marshaler: + return v.MarshalCQL(info) + case unsetColumn: + return nil, nil + case int: + return encBigInt(int64(v)), nil + case uint: + if uint64(v) > math.MaxInt64 { + return nil, marshalErrorf("marshal bigint: value %d out of range", v) + } + return encBigInt(int64(v)), nil + case int64: + return encBigInt(v), nil + case uint64: + return encBigInt(int64(v)), nil + case int32: + return encBigInt(int64(v)), nil + case uint32: + return encBigInt(int64(v)), nil + case int16: + return encBigInt(int64(v)), nil + case uint16: + return encBigInt(int64(v)), nil + case int8: + return encBigInt(int64(v)), nil + case uint8: + return encBigInt(int64(v)), nil + case big.Int: + return encBigInt2C(&v), nil + case string: + i, err := strconv.ParseInt(value.(string), 10, 64) + if err != nil { + return nil, marshalErrorf("can not marshal string to bigint: %s", err) + } + return encBigInt(i), nil + } + + if value == nil { + return nil, nil + } + + rv := reflect.ValueOf(value) + switch rv.Type().Kind() { + case reflect.Int, reflect.Int64, reflect.Int32, reflect.Int16, reflect.Int8: + v := rv.Int() + return encBigInt(v), nil + case reflect.Uint, reflect.Uint64, reflect.Uint32, reflect.Uint16, reflect.Uint8: + v := rv.Uint() + if v > math.MaxInt64 { + return nil, marshalErrorf("marshal bigint: value %d out of range", v) + } + return encBigInt(int64(v)), nil + } + return nil, marshalErrorf("can not marshal %T into %s", value, info) +} + +func encBigInt(x int64) []byte { + return []byte{byte(x >> 56), byte(x >> 48), byte(x >> 40), byte(x >> 32), + byte(x >> 24), byte(x >> 16), byte(x >> 8), byte(x)} +} + +func bytesToInt64(data []byte) (ret int64) { + for i := range data { + ret |= int64(data[i]) << (8 * uint(len(data)-i-1)) + } + return ret +} + +func bytesToUint64(data []byte) (ret uint64) { + for i := range data { + ret |= uint64(data[i]) << (8 * uint(len(data)-i-1)) + } + return ret +} + +func unmarshalBigInt(info TypeInfo, data []byte, value interface{}) error { + return unmarshalIntlike(info, decBigInt(data), data, value) +} + +func unmarshalInt(info TypeInfo, data []byte, value interface{}) error { + return unmarshalIntlike(info, int64(decInt(data)), data, value) +} + +func unmarshalSmallInt(info TypeInfo, data []byte, value interface{}) error { + return unmarshalIntlike(info, int64(decShort(data)), data, value) +} + +func unmarshalTinyInt(info TypeInfo, data []byte, value interface{}) error { + return unmarshalIntlike(info, int64(decTiny(data)), data, value) +} + +func unmarshalVarint(info TypeInfo, data []byte, value interface{}) error { + switch v := value.(type) { + case *big.Int: + return unmarshalIntlike(info, 0, data, value) + case *uint64: + if len(data) == 9 && data[0] == 0 { + *v = bytesToUint64(data[1:]) + return nil + } + } + + if len(data) > 8 { + return unmarshalErrorf("unmarshal int: varint value %v out of range for %T (use big.Int)", data, value) + } + + int64Val := bytesToInt64(data) + if len(data) > 0 && len(data) < 8 && data[0]&0x80 > 0 { + int64Val -= (1 << uint(len(data)*8)) + } + return unmarshalIntlike(info, int64Val, data, value) +} + +func marshalVarint(info TypeInfo, value interface{}) ([]byte, error) { + var ( + retBytes []byte + err error + ) + + switch v := value.(type) { + case unsetColumn: + return nil, nil + case uint64: + if v > uint64(math.MaxInt64) { + retBytes = make([]byte, 9) + binary.BigEndian.PutUint64(retBytes[1:], v) + } else { + retBytes = make([]byte, 8) + binary.BigEndian.PutUint64(retBytes, v) + } + default: + retBytes, err = marshalBigInt(info, value) + } + + if err == nil { + // trim down to most significant byte + i := 0 + for ; i < len(retBytes)-1; i++ { + b0 := retBytes[i] + if b0 != 0 && b0 != 0xFF { + break + } + + b1 := retBytes[i+1] + if b0 == 0 && b1 != 0 { + if b1&0x80 == 0 { + i++ + } + break + } + + if b0 == 0xFF && b1 != 0xFF { + if b1&0x80 > 0 { + i++ + } + break + } + } + retBytes = retBytes[i:] + } + + return retBytes, err +} + +func unmarshalIntlike(info TypeInfo, int64Val int64, data []byte, value interface{}) error { + switch v := value.(type) { + case *int: + if ^uint(0) == math.MaxUint32 && (int64Val < math.MinInt32 || int64Val > math.MaxInt32) { + return unmarshalErrorf("unmarshal int: value %d out of range for %T", int64Val, *v) + } + *v = int(int64Val) + return nil + case *uint: + unitVal := uint64(int64Val) + switch info.Type() { + case TypeInt: + *v = uint(unitVal) & 0xFFFFFFFF + case TypeSmallInt: + *v = uint(unitVal) & 0xFFFF + case TypeTinyInt: + *v = uint(unitVal) & 0xFF + default: + if ^uint(0) == math.MaxUint32 && (int64Val < 0 || int64Val > math.MaxUint32) { + return unmarshalErrorf("unmarshal int: value %d out of range for %T", unitVal, *v) + } + *v = uint(unitVal) + } + return nil + case *int64: + *v = int64Val + return nil + case *uint64: + switch info.Type() { + case TypeInt: + *v = uint64(int64Val) & 0xFFFFFFFF + case TypeSmallInt: + *v = uint64(int64Val) & 0xFFFF + case TypeTinyInt: + *v = uint64(int64Val) & 0xFF + default: + *v = uint64(int64Val) + } + return nil + case *int32: + if int64Val < math.MinInt32 || int64Val > math.MaxInt32 { + return unmarshalErrorf("unmarshal int: value %d out of range for %T", int64Val, *v) + } + *v = int32(int64Val) + return nil + case *uint32: + switch info.Type() { + case TypeInt: + *v = uint32(int64Val) & 0xFFFFFFFF + case TypeSmallInt: + *v = uint32(int64Val) & 0xFFFF + case TypeTinyInt: + *v = uint32(int64Val) & 0xFF + default: + if int64Val < 0 || int64Val > math.MaxUint32 { + return unmarshalErrorf("unmarshal int: value %d out of range for %T", int64Val, *v) + } + *v = uint32(int64Val) & 0xFFFFFFFF + } + return nil + case *int16: + if int64Val < math.MinInt16 || int64Val > math.MaxInt16 { + return unmarshalErrorf("unmarshal int: value %d out of range for %T", int64Val, *v) + } + *v = int16(int64Val) + return nil + case *uint16: + switch info.Type() { + case TypeSmallInt: + *v = uint16(int64Val) & 0xFFFF + case TypeTinyInt: + *v = uint16(int64Val) & 0xFF + default: + if int64Val < 0 || int64Val > math.MaxUint16 { + return unmarshalErrorf("unmarshal int: value %d out of range for %T", int64Val, *v) + } + *v = uint16(int64Val) & 0xFFFF + } + return nil + case *int8: + if int64Val < math.MinInt8 || int64Val > math.MaxInt8 { + return unmarshalErrorf("unmarshal int: value %d out of range for %T", int64Val, *v) + } + *v = int8(int64Val) + return nil + case *uint8: + if info.Type() != TypeTinyInt && (int64Val < 0 || int64Val > math.MaxUint8) { + return unmarshalErrorf("unmarshal int: value %d out of range for %T", int64Val, *v) + } + *v = uint8(int64Val) & 0xFF + return nil + case *big.Int: + decBigInt2C(data, v) + return nil + case *string: + *v = strconv.FormatInt(int64Val, 10) + return nil + } + + rv := reflect.ValueOf(value) + if rv.Kind() != reflect.Ptr { + return unmarshalErrorf("can not unmarshal into non-pointer %T", value) + } + rv = rv.Elem() + + switch rv.Type().Kind() { + case reflect.Int: + if ^uint(0) == math.MaxUint32 && (int64Val < math.MinInt32 || int64Val > math.MaxInt32) { + return unmarshalErrorf("unmarshal int: value %d out of range", int64Val) + } + rv.SetInt(int64Val) + return nil + case reflect.Int64: + rv.SetInt(int64Val) + return nil + case reflect.Int32: + if int64Val < math.MinInt32 || int64Val > math.MaxInt32 { + return unmarshalErrorf("unmarshal int: value %d out of range", int64Val) + } + rv.SetInt(int64Val) + return nil + case reflect.Int16: + if int64Val < math.MinInt16 || int64Val > math.MaxInt16 { + return unmarshalErrorf("unmarshal int: value %d out of range", int64Val) + } + rv.SetInt(int64Val) + return nil + case reflect.Int8: + if int64Val < math.MinInt8 || int64Val > math.MaxInt8 { + return unmarshalErrorf("unmarshal int: value %d out of range", int64Val) + } + rv.SetInt(int64Val) + return nil + case reflect.Uint: + unitVal := uint64(int64Val) + switch info.Type() { + case TypeInt: + rv.SetUint(unitVal & 0xFFFFFFFF) + case TypeSmallInt: + rv.SetUint(unitVal & 0xFFFF) + case TypeTinyInt: + rv.SetUint(unitVal & 0xFF) + default: + if ^uint(0) == math.MaxUint32 && (int64Val < 0 || int64Val > math.MaxUint32) { + return unmarshalErrorf("unmarshal int: value %d out of range for %s", unitVal, rv.Type()) + } + rv.SetUint(unitVal) + } + return nil + case reflect.Uint64: + unitVal := uint64(int64Val) + switch info.Type() { + case TypeInt: + rv.SetUint(unitVal & 0xFFFFFFFF) + case TypeSmallInt: + rv.SetUint(unitVal & 0xFFFF) + case TypeTinyInt: + rv.SetUint(unitVal & 0xFF) + default: + rv.SetUint(unitVal) + } + return nil + case reflect.Uint32: + unitVal := uint64(int64Val) + switch info.Type() { + case TypeInt: + rv.SetUint(unitVal & 0xFFFFFFFF) + case TypeSmallInt: + rv.SetUint(unitVal & 0xFFFF) + case TypeTinyInt: + rv.SetUint(unitVal & 0xFF) + default: + if int64Val < 0 || int64Val > math.MaxUint32 { + return unmarshalErrorf("unmarshal int: value %d out of range for %s", int64Val, rv.Type()) + } + rv.SetUint(unitVal & 0xFFFFFFFF) + } + return nil + case reflect.Uint16: + unitVal := uint64(int64Val) + switch info.Type() { + case TypeSmallInt: + rv.SetUint(unitVal & 0xFFFF) + case TypeTinyInt: + rv.SetUint(unitVal & 0xFF) + default: + if int64Val < 0 || int64Val > math.MaxUint16 { + return unmarshalErrorf("unmarshal int: value %d out of range for %s", int64Val, rv.Type()) + } + rv.SetUint(unitVal & 0xFFFF) + } + return nil + case reflect.Uint8: + if info.Type() != TypeTinyInt && (int64Val < 0 || int64Val > math.MaxUint8) { + return unmarshalErrorf("unmarshal int: value %d out of range for %s", int64Val, rv.Type()) + } + rv.SetUint(uint64(int64Val) & 0xff) + return nil + } + return unmarshalErrorf("can not unmarshal %s into %T", info, value) +} + +func decBigInt(data []byte) int64 { + if len(data) != 8 { + return 0 + } + return int64(data[0])<<56 | int64(data[1])<<48 | + int64(data[2])<<40 | int64(data[3])<<32 | + int64(data[4])<<24 | int64(data[5])<<16 | + int64(data[6])<<8 | int64(data[7]) +} + +func marshalBool(info TypeInfo, value interface{}) ([]byte, error) { + switch v := value.(type) { + case Marshaler: + return v.MarshalCQL(info) + case unsetColumn: + return nil, nil + case bool: + return encBool(v), nil + } + + if value == nil { + return nil, nil + } + + rv := reflect.ValueOf(value) + switch rv.Type().Kind() { + case reflect.Bool: + return encBool(rv.Bool()), nil + } + return nil, marshalErrorf("can not marshal %T into %s", value, info) +} + +func encBool(v bool) []byte { + if v { + return []byte{1} + } + return []byte{0} +} + +func unmarshalBool(info TypeInfo, data []byte, value interface{}) error { + switch v := value.(type) { + case Unmarshaler: + return v.UnmarshalCQL(info, data) + case *bool: + *v = decBool(data) + return nil + } + rv := reflect.ValueOf(value) + if rv.Kind() != reflect.Ptr { + return unmarshalErrorf("can not unmarshal into non-pointer %T", value) + } + rv = rv.Elem() + switch rv.Type().Kind() { + case reflect.Bool: + rv.SetBool(decBool(data)) + return nil + } + return unmarshalErrorf("can not unmarshal %s into %T", info, value) +} + +func decBool(v []byte) bool { + if len(v) == 0 { + return false + } + return v[0] != 0 +} + +func marshalFloat(info TypeInfo, value interface{}) ([]byte, error) { + switch v := value.(type) { + case Marshaler: + return v.MarshalCQL(info) + case unsetColumn: + return nil, nil + case float32: + return encInt(int32(math.Float32bits(v))), nil + } + + if value == nil { + return nil, nil + } + + rv := reflect.ValueOf(value) + switch rv.Type().Kind() { + case reflect.Float32: + return encInt(int32(math.Float32bits(float32(rv.Float())))), nil + } + return nil, marshalErrorf("can not marshal %T into %s", value, info) +} + +func unmarshalFloat(info TypeInfo, data []byte, value interface{}) error { + switch v := value.(type) { + case Unmarshaler: + return v.UnmarshalCQL(info, data) + case *float32: + *v = math.Float32frombits(uint32(decInt(data))) + return nil + } + rv := reflect.ValueOf(value) + if rv.Kind() != reflect.Ptr { + return unmarshalErrorf("can not unmarshal into non-pointer %T", value) + } + rv = rv.Elem() + switch rv.Type().Kind() { + case reflect.Float32: + rv.SetFloat(float64(math.Float32frombits(uint32(decInt(data))))) + return nil + } + return unmarshalErrorf("can not unmarshal %s into %T", info, value) +} + +func marshalDouble(info TypeInfo, value interface{}) ([]byte, error) { + switch v := value.(type) { + case Marshaler: + return v.MarshalCQL(info) + case unsetColumn: + return nil, nil + case float64: + return encBigInt(int64(math.Float64bits(v))), nil + } + if value == nil { + return nil, nil + } + rv := reflect.ValueOf(value) + switch rv.Type().Kind() { + case reflect.Float64: + return encBigInt(int64(math.Float64bits(rv.Float()))), nil + } + return nil, marshalErrorf("can not marshal %T into %s", value, info) +} + +func unmarshalDouble(info TypeInfo, data []byte, value interface{}) error { + switch v := value.(type) { + case Unmarshaler: + return v.UnmarshalCQL(info, data) + case *float64: + *v = math.Float64frombits(uint64(decBigInt(data))) + return nil + } + rv := reflect.ValueOf(value) + if rv.Kind() != reflect.Ptr { + return unmarshalErrorf("can not unmarshal into non-pointer %T", value) + } + rv = rv.Elem() + switch rv.Type().Kind() { + case reflect.Float64: + rv.SetFloat(math.Float64frombits(uint64(decBigInt(data)))) + return nil + } + return unmarshalErrorf("can not unmarshal %s into %T", info, value) +} + +func marshalDecimal(info TypeInfo, value interface{}) ([]byte, error) { + if value == nil { + return nil, nil + } + + switch v := value.(type) { + case Marshaler: + return v.MarshalCQL(info) + case unsetColumn: + return nil, nil + case inf.Dec: + unscaled := encBigInt2C(v.UnscaledBig()) + if unscaled == nil { + return nil, marshalErrorf("can not marshal %T into %s", value, info) + } + + buf := make([]byte, 4+len(unscaled)) + copy(buf[0:4], encInt(int32(v.Scale()))) + copy(buf[4:], unscaled) + return buf, nil + } + return nil, marshalErrorf("can not marshal %T into %s", value, info) +} + +func unmarshalDecimal(info TypeInfo, data []byte, value interface{}) error { + switch v := value.(type) { + case Unmarshaler: + return v.UnmarshalCQL(info, data) + case *inf.Dec: + scale := decInt(data[0:4]) + unscaled := decBigInt2C(data[4:], nil) + *v = *inf.NewDecBig(unscaled, inf.Scale(scale)) + return nil + } + return unmarshalErrorf("can not unmarshal %s into %T", info, value) +} + +// decBigInt2C sets the value of n to the big-endian two's complement +// value stored in the given data. If data[0]&80 != 0, the number +// is negative. If data is empty, the result will be 0. +func decBigInt2C(data []byte, n *big.Int) *big.Int { + if n == nil { + n = new(big.Int) + } + n.SetBytes(data) + if len(data) > 0 && data[0]&0x80 > 0 { + n.Sub(n, new(big.Int).Lsh(bigOne, uint(len(data))*8)) + } + return n +} + +// encBigInt2C returns the big-endian two's complement +// form of n. +func encBigInt2C(n *big.Int) []byte { + switch n.Sign() { + case 0: + return []byte{0} + case 1: + b := n.Bytes() + if b[0]&0x80 > 0 { + b = append([]byte{0}, b...) + } + return b + case -1: + length := uint(n.BitLen()/8+1) * 8 + b := new(big.Int).Add(n, new(big.Int).Lsh(bigOne, length)).Bytes() + // When the most significant bit is on a byte + // boundary, we can get some extra significant + // bits, so strip them off when that happens. + if len(b) >= 2 && b[0] == 0xff && b[1]&0x80 != 0 { + b = b[1:] + } + return b + } + return nil +} + +func marshalTime(info TypeInfo, value interface{}) ([]byte, error) { + switch v := value.(type) { + case Marshaler: + return v.MarshalCQL(info) + case unsetColumn: + return nil, nil + case int64: + return encBigInt(v), nil + case time.Duration: + return encBigInt(v.Nanoseconds()), nil + } + + if value == nil { + return nil, nil + } + + rv := reflect.ValueOf(value) + switch rv.Type().Kind() { + case reflect.Int64: + return encBigInt(rv.Int()), nil + } + return nil, marshalErrorf("can not marshal %T into %s", value, info) +} + +func marshalTimestamp(info TypeInfo, value interface{}) ([]byte, error) { + switch v := value.(type) { + case Marshaler: + return v.MarshalCQL(info) + case unsetColumn: + return nil, nil + case int64: + return encBigInt(v), nil + case time.Time: + if v.IsZero() { + return []byte{}, nil + } + x := int64(v.UTC().Unix()*1e3) + int64(v.UTC().Nanosecond()/1e6) + return encBigInt(x), nil + } + + if value == nil { + return nil, nil + } + + rv := reflect.ValueOf(value) + switch rv.Type().Kind() { + case reflect.Int64: + return encBigInt(rv.Int()), nil + } + return nil, marshalErrorf("can not marshal %T into %s", value, info) +} + +func unmarshalTime(info TypeInfo, data []byte, value interface{}) error { + switch v := value.(type) { + case Unmarshaler: + return v.UnmarshalCQL(info, data) + case *int64: + *v = decBigInt(data) + return nil + case *time.Duration: + *v = time.Duration(decBigInt(data)) + return nil + } + + rv := reflect.ValueOf(value) + if rv.Kind() != reflect.Ptr { + return unmarshalErrorf("can not unmarshal into non-pointer %T", value) + } + rv = rv.Elem() + switch rv.Type().Kind() { + case reflect.Int64: + rv.SetInt(decBigInt(data)) + return nil + } + return unmarshalErrorf("can not unmarshal %s into %T", info, value) +} + +func unmarshalTimestamp(info TypeInfo, data []byte, value interface{}) error { + switch v := value.(type) { + case Unmarshaler: + return v.UnmarshalCQL(info, data) + case *int64: + *v = decBigInt(data) + return nil + case *time.Time: + if len(data) == 0 { + *v = time.Time{} + return nil + } + x := decBigInt(data) + sec := x / 1000 + nsec := (x - sec*1000) * 1000000 + *v = time.Unix(sec, nsec).In(time.UTC) + return nil + } + + rv := reflect.ValueOf(value) + if rv.Kind() != reflect.Ptr { + return unmarshalErrorf("can not unmarshal into non-pointer %T", value) + } + rv = rv.Elem() + switch rv.Type().Kind() { + case reflect.Int64: + rv.SetInt(decBigInt(data)) + return nil + } + return unmarshalErrorf("can not unmarshal %s into %T", info, value) +} + +func marshalDate(info TypeInfo, value interface{}) ([]byte, error) { + var timestamp int64 + switch v := value.(type) { + case Marshaler: + return v.MarshalCQL(info) + case unsetColumn: + return nil, nil + case int64: + timestamp = v + x := timestamp/86400000 + int64(1<<31) + return encInt(int32(x)), nil + case time.Time: + if v.IsZero() { + return []byte{}, nil + } + timestamp = int64(v.UTC().Unix()*1e3) + int64(v.UTC().Nanosecond()/1e6) + x := timestamp/86400000 + int64(1<<31) + return encInt(int32(x)), nil + case *time.Time: + if v.IsZero() { + return []byte{}, nil + } + timestamp = int64(v.UTC().Unix()*1e3) + int64(v.UTC().Nanosecond()/1e6) + x := timestamp/86400000 + int64(1<<31) + return encInt(int32(x)), nil + case string: + if v == "" { + return []byte{}, nil + } + t, err := time.Parse("2006-01-02", v) + if err != nil { + return nil, marshalErrorf("can not marshal %T into %s, date layout must be '2006-01-02'", value, info) + } + timestamp = int64(t.UTC().Unix()*1e3) + int64(t.UTC().Nanosecond()/1e6) + x := timestamp/86400000 + int64(1<<31) + return encInt(int32(x)), nil + } + + if value == nil { + return nil, nil + } + return nil, marshalErrorf("can not marshal %T into %s", value, info) +} + +func unmarshalDate(info TypeInfo, data []byte, value interface{}) error { + switch v := value.(type) { + case Unmarshaler: + return v.UnmarshalCQL(info, data) + case *time.Time: + if len(data) == 0 { + *v = time.Time{} + return nil + } + var origin uint32 = 1 << 31 + var current uint32 = binary.BigEndian.Uint32(data) + timestamp := (int64(current) - int64(origin)) * 86400000 + *v = time.Unix(0, timestamp*int64(time.Millisecond)).In(time.UTC) + return nil + case *string: + if len(data) == 0 { + *v = "" + return nil + } + var origin uint32 = 1 << 31 + var current uint32 = binary.BigEndian.Uint32(data) + timestamp := (int64(current) - int64(origin)) * 86400000 + *v = time.Unix(0, timestamp*int64(time.Millisecond)).In(time.UTC).Format("2006-01-02") + return nil + } + return unmarshalErrorf("can not unmarshal %s into %T", info, value) +} + +func marshalDuration(info TypeInfo, value interface{}) ([]byte, error) { + switch v := value.(type) { + case Marshaler: + return v.MarshalCQL(info) + case unsetColumn: + return nil, nil + case int64: + return encVints(0, 0, v), nil + case time.Duration: + return encVints(0, 0, v.Nanoseconds()), nil + case string: + d, err := time.ParseDuration(v) + if err != nil { + return nil, err + } + return encVints(0, 0, d.Nanoseconds()), nil + case Duration: + return encVints(v.Months, v.Days, v.Nanoseconds), nil + } + + if value == nil { + return nil, nil + } + + rv := reflect.ValueOf(value) + switch rv.Type().Kind() { + case reflect.Int64: + return encBigInt(rv.Int()), nil + } + return nil, marshalErrorf("can not marshal %T into %s", value, info) +} + +func unmarshalDuration(info TypeInfo, data []byte, value interface{}) error { + switch v := value.(type) { + case Unmarshaler: + return v.UnmarshalCQL(info, data) + case *Duration: + if len(data) == 0 { + *v = Duration{ + Months: 0, + Days: 0, + Nanoseconds: 0, + } + return nil + } + months, days, nanos := decVints(data) + *v = Duration{ + Months: months, + Days: days, + Nanoseconds: nanos, + } + return nil + } + return unmarshalErrorf("can not unmarshal %s into %T", info, value) +} + +func decVints(data []byte) (int32, int32, int64) { + month, i := decVint(data) + days, j := decVint(data[i:]) + nanos, _ := decVint(data[i+j:]) + return int32(month), int32(days), nanos +} + +func decVint(data []byte) (int64, int) { + firstByte := data[0] + if firstByte&0x80 == 0 { + return decIntZigZag(uint64(firstByte)), 1 + } + numBytes := bits.LeadingZeros32(uint32(^firstByte)) - 24 + ret := uint64(firstByte & (0xff >> uint(numBytes))) + for i := 0; i < numBytes; i++ { + ret <<= 8 + ret |= uint64(data[i+1] & 0xff) + } + return decIntZigZag(ret), numBytes + 1 +} + +func decIntZigZag(n uint64) int64 { + return int64((n >> 1) ^ -(n & 1)) +} + +func encIntZigZag(n int64) uint64 { + return uint64((n >> 63) ^ (n << 1)) +} + +func encVints(months int32, seconds int32, nanos int64) []byte { + buf := append(encVint(int64(months)), encVint(int64(seconds))...) + return append(buf, encVint(nanos)...) +} + +func encVint(v int64) []byte { + vEnc := encIntZigZag(v) + lead0 := bits.LeadingZeros64(vEnc) + numBytes := (639 - lead0*9) >> 6 + + // It can be 1 or 0 is v ==0 + if numBytes <= 1 { + return []byte{byte(vEnc)} + } + extraBytes := numBytes - 1 + var buf = make([]byte, numBytes) + for i := extraBytes; i >= 0; i-- { + buf[i] = byte(vEnc) + vEnc >>= 8 + } + buf[0] |= byte(^(0xff >> uint(extraBytes))) + return buf +} + +func writeCollectionSize(info CollectionType, n int, buf *bytes.Buffer) error { + if info.proto > protoVersion2 { + if n > math.MaxInt32 { + return marshalErrorf("marshal: collection too large") + } + + buf.WriteByte(byte(n >> 24)) + buf.WriteByte(byte(n >> 16)) + buf.WriteByte(byte(n >> 8)) + buf.WriteByte(byte(n)) + } else { + if n > math.MaxUint16 { + return marshalErrorf("marshal: collection too large") + } + + buf.WriteByte(byte(n >> 8)) + buf.WriteByte(byte(n)) + } + + return nil +} + +func marshalList(info TypeInfo, value interface{}) ([]byte, error) { + listInfo, ok := info.(CollectionType) + if !ok { + return nil, marshalErrorf("marshal: can not marshal non collection type into list") + } + + if value == nil { + return nil, nil + } else if _, ok := value.(unsetColumn); ok { + return nil, nil + } + + rv := reflect.ValueOf(value) + t := rv.Type() + k := t.Kind() + if k == reflect.Slice && rv.IsNil() { + return nil, nil + } + + switch k { + case reflect.Slice, reflect.Array: + buf := &bytes.Buffer{} + n := rv.Len() + + if err := writeCollectionSize(listInfo, n, buf); err != nil { + return nil, err + } + + for i := 0; i < n; i++ { + item, err := Marshal(listInfo.Elem, rv.Index(i).Interface()) + if err != nil { + return nil, err + } + if err := writeCollectionSize(listInfo, len(item), buf); err != nil { + return nil, err + } + buf.Write(item) + } + return buf.Bytes(), nil + case reflect.Map: + elem := t.Elem() + if elem.Kind() == reflect.Struct && elem.NumField() == 0 { + rkeys := rv.MapKeys() + keys := make([]interface{}, len(rkeys)) + for i := 0; i < len(keys); i++ { + keys[i] = rkeys[i].Interface() + } + return marshalList(listInfo, keys) + } + } + return nil, marshalErrorf("can not marshal %T into %s", value, info) +} + +func readCollectionSize(info CollectionType, data []byte) (size, read int, err error) { + if info.proto > protoVersion2 { + if len(data) < 4 { + return 0, 0, unmarshalErrorf("unmarshal list: unexpected eof") + } + size = int(data[0])<<24 | int(data[1])<<16 | int(data[2])<<8 | int(data[3]) + read = 4 + } else { + if len(data) < 2 { + return 0, 0, unmarshalErrorf("unmarshal list: unexpected eof") + } + size = int(data[0])<<8 | int(data[1]) + read = 2 + } + return +} + +func unmarshalList(info TypeInfo, data []byte, value interface{}) error { + listInfo, ok := info.(CollectionType) + if !ok { + return unmarshalErrorf("unmarshal: can not unmarshal none collection type into list") + } + + rv := reflect.ValueOf(value) + if rv.Kind() != reflect.Ptr { + return unmarshalErrorf("can not unmarshal into non-pointer %T", value) + } + rv = rv.Elem() + t := rv.Type() + k := t.Kind() + + switch k { + case reflect.Slice, reflect.Array: + if data == nil { + if k == reflect.Array { + return unmarshalErrorf("unmarshal list: can not store nil in array value") + } + if rv.IsNil() { + return nil + } + rv.Set(reflect.Zero(t)) + return nil + } + n, p, err := readCollectionSize(listInfo, data) + if err != nil { + return err + } + data = data[p:] + if k == reflect.Array { + if rv.Len() != n { + return unmarshalErrorf("unmarshal list: array with wrong size") + } + } else { + rv.Set(reflect.MakeSlice(t, n, n)) + } + for i := 0; i < n; i++ { + m, p, err := readCollectionSize(listInfo, data) + if err != nil { + return err + } + data = data[p:] + if len(data) < m { + return unmarshalErrorf("unmarshal list: unexpected eof") + } + if err := Unmarshal(listInfo.Elem, data[:m], rv.Index(i).Addr().Interface()); err != nil { + return err + } + data = data[m:] + } + return nil + } + return unmarshalErrorf("can not unmarshal %s into %T", info, value) +} + +func marshalMap(info TypeInfo, value interface{}) ([]byte, error) { + mapInfo, ok := info.(CollectionType) + if !ok { + return nil, marshalErrorf("marshal: can not marshal none collection type into map") + } + + if value == nil { + return nil, nil + } else if _, ok := value.(unsetColumn); ok { + return nil, nil + } + + rv := reflect.ValueOf(value) + + t := rv.Type() + if t.Kind() != reflect.Map { + return nil, marshalErrorf("can not marshal %T into %s", value, info) + } + + if rv.IsNil() { + return nil, nil + } + + buf := &bytes.Buffer{} + n := rv.Len() + + if err := writeCollectionSize(mapInfo, n, buf); err != nil { + return nil, err + } + + keys := rv.MapKeys() + for _, key := range keys { + item, err := Marshal(mapInfo.Key, key.Interface()) + if err != nil { + return nil, err + } + if err := writeCollectionSize(mapInfo, len(item), buf); err != nil { + return nil, err + } + buf.Write(item) + + item, err = Marshal(mapInfo.Elem, rv.MapIndex(key).Interface()) + if err != nil { + return nil, err + } + if err := writeCollectionSize(mapInfo, len(item), buf); err != nil { + return nil, err + } + buf.Write(item) + } + return buf.Bytes(), nil +} + +func unmarshalMap(info TypeInfo, data []byte, value interface{}) error { + mapInfo, ok := info.(CollectionType) + if !ok { + return unmarshalErrorf("unmarshal: can not unmarshal none collection type into map") + } + + rv := reflect.ValueOf(value) + if rv.Kind() != reflect.Ptr { + return unmarshalErrorf("can not unmarshal into non-pointer %T", value) + } + rv = rv.Elem() + t := rv.Type() + if t.Kind() != reflect.Map { + return unmarshalErrorf("can not unmarshal %s into %T", info, value) + } + if data == nil { + rv.Set(reflect.Zero(t)) + return nil + } + rv.Set(reflect.MakeMap(t)) + n, p, err := readCollectionSize(mapInfo, data) + if err != nil { + return err + } + data = data[p:] + for i := 0; i < n; i++ { + m, p, err := readCollectionSize(mapInfo, data) + if err != nil { + return err + } + data = data[p:] + if len(data) < m { + return unmarshalErrorf("unmarshal map: unexpected eof") + } + key := reflect.New(t.Key()) + if err := Unmarshal(mapInfo.Key, data[:m], key.Interface()); err != nil { + return err + } + data = data[m:] + + m, p, err = readCollectionSize(mapInfo, data) + if err != nil { + return err + } + data = data[p:] + if len(data) < m { + return unmarshalErrorf("unmarshal map: unexpected eof") + } + val := reflect.New(t.Elem()) + if err := Unmarshal(mapInfo.Elem, data[:m], val.Interface()); err != nil { + return err + } + data = data[m:] + + rv.SetMapIndex(key.Elem(), val.Elem()) + } + return nil +} + +func marshalUUID(info TypeInfo, value interface{}) ([]byte, error) { + switch val := value.(type) { + case unsetColumn: + return nil, nil + case UUID: + return val.Bytes(), nil + case [16]byte: + return val[:], nil + case []byte: + if len(val) != 16 { + return nil, marshalErrorf("can not marshal []byte %d bytes long into %s, must be exactly 16 bytes long", len(val), info) + } + return val, nil + case string: + b, err := ParseUUID(val) + if err != nil { + return nil, err + } + return b[:], nil + } + + if value == nil { + return nil, nil + } + + return nil, marshalErrorf("can not marshal %T into %s", value, info) +} + +func unmarshalUUID(info TypeInfo, data []byte, value interface{}) error { + if len(data) == 0 { + switch v := value.(type) { + case *string: + *v = "" + case *[]byte: + *v = nil + case *UUID: + *v = UUID{} + default: + return unmarshalErrorf("can not unmarshal X %s into %T", info, value) + } + + return nil + } + + if len(data) != 16 { + return unmarshalErrorf("unable to parse UUID: UUIDs must be exactly 16 bytes long") + } + + switch v := value.(type) { + case *[16]byte: + copy((*v)[:], data) + return nil + case *UUID: + copy((*v)[:], data) + return nil + } + + u, err := UUIDFromBytes(data) + if err != nil { + return unmarshalErrorf("unable to parse UUID: %s", err) + } + + switch v := value.(type) { + case *string: + *v = u.String() + return nil + case *[]byte: + *v = u[:] + return nil + } + return unmarshalErrorf("can not unmarshal X %s into %T", info, value) +} + +func unmarshalTimeUUID(info TypeInfo, data []byte, value interface{}) error { + switch v := value.(type) { + case Unmarshaler: + return v.UnmarshalCQL(info, data) + case *time.Time: + id, err := UUIDFromBytes(data) + if err != nil { + return err + } else if id.Version() != 1 { + return unmarshalErrorf("invalid timeuuid") + } + *v = id.Time() + return nil + default: + return unmarshalUUID(info, data, value) + } +} + +func marshalInet(info TypeInfo, value interface{}) ([]byte, error) { + // we return either the 4 or 16 byte representation of an + // ip address here otherwise the db value will be prefixed + // with the remaining byte values e.g. ::ffff:127.0.0.1 and not 127.0.0.1 + switch val := value.(type) { + case unsetColumn: + return nil, nil + case net.IP: + t := val.To4() + if t == nil { + return val.To16(), nil + } + return t, nil + case string: + b := net.ParseIP(val) + if b != nil { + t := b.To4() + if t == nil { + return b.To16(), nil + } + return t, nil + } + return nil, marshalErrorf("cannot marshal. invalid ip string %s", val) + } + + if value == nil { + return nil, nil + } + + return nil, marshalErrorf("cannot marshal %T into %s", value, info) +} + +func unmarshalInet(info TypeInfo, data []byte, value interface{}) error { + switch v := value.(type) { + case Unmarshaler: + return v.UnmarshalCQL(info, data) + case *net.IP: + if x := len(data); !(x == 4 || x == 16) { + return unmarshalErrorf("cannot unmarshal %s into %T: invalid sized IP: got %d bytes not 4 or 16", info, value, x) + } + buf := copyBytes(data) + ip := net.IP(buf) + if v4 := ip.To4(); v4 != nil { + *v = v4 + return nil + } + *v = ip + return nil + case *string: + if len(data) == 0 { + *v = "" + return nil + } + ip := net.IP(data) + if v4 := ip.To4(); v4 != nil { + *v = v4.String() + return nil + } + *v = ip.String() + return nil + } + return unmarshalErrorf("cannot unmarshal %s into %T", info, value) +} + +func marshalTuple(info TypeInfo, value interface{}) ([]byte, error) { + tuple := info.(TupleTypeInfo) + switch v := value.(type) { + case unsetColumn: + return nil, unmarshalErrorf("Invalid request: UnsetValue is unsupported for tuples") + case []interface{}: + if len(v) != len(tuple.Elems) { + return nil, unmarshalErrorf("cannont marshal tuple: wrong number of elements") + } + + var buf []byte + for i, elem := range v { + if elem == nil { + buf = appendInt(buf, int32(-1)) + continue + } + + data, err := Marshal(tuple.Elems[i], elem) + if err != nil { + return nil, err + } + + n := len(data) + buf = appendInt(buf, int32(n)) + buf = append(buf, data...) + } + + return buf, nil + } + + rv := reflect.ValueOf(value) + t := rv.Type() + k := t.Kind() + + switch k { + case reflect.Struct: + if v := t.NumField(); v != len(tuple.Elems) { + return nil, marshalErrorf("can not marshal tuple into struct %v, not enough fields have %d need %d", t, v, len(tuple.Elems)) + } + + var buf []byte + for i, elem := range tuple.Elems { + field := rv.Field(i) + + if field.Kind() == reflect.Ptr && field.IsNil() { + buf = appendInt(buf, int32(-1)) + continue + } + + data, err := Marshal(elem, field.Interface()) + if err != nil { + return nil, err + } + + n := len(data) + buf = appendInt(buf, int32(n)) + buf = append(buf, data...) + } + + return buf, nil + case reflect.Slice, reflect.Array: + size := rv.Len() + if size != len(tuple.Elems) { + return nil, marshalErrorf("can not marshal tuple into %v of length %d need %d elements", k, size, len(tuple.Elems)) + } + + var buf []byte + for i, elem := range tuple.Elems { + item := rv.Index(i) + + if item.Kind() == reflect.Ptr && item.IsNil() { + buf = appendInt(buf, int32(-1)) + continue + } + + data, err := Marshal(elem, item.Interface()) + if err != nil { + return nil, err + } + + n := len(data) + buf = appendInt(buf, int32(n)) + buf = append(buf, data...) + } + + return buf, nil + } + + return nil, marshalErrorf("cannot marshal %T into %s", value, tuple) +} + +func readBytes(p []byte) ([]byte, []byte) { + // TODO: really should use a framer + size := readInt(p) + p = p[4:] + if size < 0 { + return nil, p + } + return p[:size], p[size:] +} + +// currently only support unmarshal into a list of values, this makes it possible +// to support tuples without changing the query API. In the future this can be extend +// to allow unmarshalling into custom tuple types. +func unmarshalTuple(info TypeInfo, data []byte, value interface{}) error { + if v, ok := value.(Unmarshaler); ok { + return v.UnmarshalCQL(info, data) + } + + tuple := info.(TupleTypeInfo) + switch v := value.(type) { + case []interface{}: + for i, elem := range tuple.Elems { + // each element inside data is a [bytes] + var p []byte + if len(data) >= 4 { + p, data = readBytes(data) + } + err := Unmarshal(elem, p, v[i]) + if err != nil { + return err + } + } + + return nil + } + + rv := reflect.ValueOf(value) + if rv.Kind() != reflect.Ptr { + return unmarshalErrorf("can not unmarshal into non-pointer %T", value) + } + + rv = rv.Elem() + t := rv.Type() + k := t.Kind() + + switch k { + case reflect.Struct: + if v := t.NumField(); v != len(tuple.Elems) { + return unmarshalErrorf("can not unmarshal tuple into struct %v, not enough fields have %d need %d", t, v, len(tuple.Elems)) + } + + for i, elem := range tuple.Elems { + var p []byte + if len(data) >= 4 { + p, data = readBytes(data) + } + + v := elem.New() + if err := Unmarshal(elem, p, v); err != nil { + return err + } + + switch rv.Field(i).Kind() { + case reflect.Ptr: + if p != nil { + rv.Field(i).Set(reflect.ValueOf(v)) + } else { + rv.Field(i).Set(reflect.Zero(reflect.TypeOf(v))) + } + default: + rv.Field(i).Set(reflect.ValueOf(v).Elem()) + } + } + + return nil + case reflect.Slice, reflect.Array: + if k == reflect.Array { + size := rv.Len() + if size != len(tuple.Elems) { + return unmarshalErrorf("can not unmarshal tuple into array of length %d need %d elements", size, len(tuple.Elems)) + } + } else { + rv.Set(reflect.MakeSlice(t, len(tuple.Elems), len(tuple.Elems))) + } + + for i, elem := range tuple.Elems { + var p []byte + if len(data) >= 4 { + p, data = readBytes(data) + } + + v := elem.New() + if err := Unmarshal(elem, p, v); err != nil { + return err + } + + switch rv.Index(i).Kind() { + case reflect.Ptr: + if p != nil { + rv.Index(i).Set(reflect.ValueOf(v)) + } else { + rv.Index(i).Set(reflect.Zero(reflect.TypeOf(v))) + } + default: + rv.Index(i).Set(reflect.ValueOf(v).Elem()) + } + } + + return nil + } + + return unmarshalErrorf("cannot unmarshal %s into %T", info, value) +} + +// UDTMarshaler is an interface which should be implemented by users wishing to +// handle encoding UDT types to sent to Cassandra. Note: due to current implentations +// methods defined for this interface must be value receivers not pointer receivers. +type UDTMarshaler interface { + // MarshalUDT will be called for each field in the the UDT returned by Cassandra, + // the implementor should marshal the type to return by for example calling + // Marshal. + MarshalUDT(name string, info TypeInfo) ([]byte, error) +} + +// UDTUnmarshaler should be implemented by users wanting to implement custom +// UDT unmarshaling. +type UDTUnmarshaler interface { + // UnmarshalUDT will be called for each field in the UDT return by Cassandra, + // the implementor should unmarshal the data into the value of their chosing, + // for example by calling Unmarshal. + UnmarshalUDT(name string, info TypeInfo, data []byte) error +} + +func marshalUDT(info TypeInfo, value interface{}) ([]byte, error) { + udt := info.(UDTTypeInfo) + + switch v := value.(type) { + case Marshaler: + return v.MarshalCQL(info) + case unsetColumn: + return nil, unmarshalErrorf("invalid request: UnsetValue is unsupported for user defined types") + case UDTMarshaler: + var buf []byte + for _, e := range udt.Elements { + data, err := v.MarshalUDT(e.Name, e.Type) + if err != nil { + return nil, err + } + + buf = appendBytes(buf, data) + } + + return buf, nil + case map[string]interface{}: + var buf []byte + for _, e := range udt.Elements { + val, ok := v[e.Name] + if !ok { + return nil, marshalErrorf("marshal missing map key %q", e.Name) + } + + data, err := Marshal(e.Type, val) + if err != nil { + return nil, err + } + + buf = appendBytes(buf, data) + } + + return buf, nil + } + + k := reflect.ValueOf(value) + if k.Kind() == reflect.Ptr { + if k.IsNil() { + return nil, marshalErrorf("cannot marshal %T into %s", value, info) + } + k = k.Elem() + } + + if k.Kind() != reflect.Struct || !k.IsValid() { + return nil, marshalErrorf("cannot marshal %T into %s", value, info) + } + + fields := make(map[string]reflect.Value) + t := reflect.TypeOf(value) + for i := 0; i < t.NumField(); i++ { + sf := t.Field(i) + + if tag := sf.Tag.Get("cql"); tag != "" { + fields[tag] = k.Field(i) + } + } + + var buf []byte + for _, e := range udt.Elements { + f, ok := fields[e.Name] + if !ok { + f = k.FieldByName(e.Name) + } + + var data []byte + if f.IsValid() && f.CanInterface() { + var err error + data, err = Marshal(e.Type, f.Interface()) + if err != nil { + return nil, err + } + } + + buf = appendBytes(buf, data) + } + + return buf, nil +} + +func unmarshalUDT(info TypeInfo, data []byte, value interface{}) error { + switch v := value.(type) { + case Unmarshaler: + return v.UnmarshalCQL(info, data) + case UDTUnmarshaler: + udt := info.(UDTTypeInfo) + + for _, e := range udt.Elements { + if len(data) == 0 { + return nil + } + + var p []byte + p, data = readBytes(data) + + if err := v.UnmarshalUDT(e.Name, e.Type, p); err != nil { + return err + } + } + + return nil + case *map[string]interface{}: + udt := info.(UDTTypeInfo) + + rv := reflect.ValueOf(value) + if rv.Kind() != reflect.Ptr { + return unmarshalErrorf("can not unmarshal into non-pointer %T", value) + } + + rv = rv.Elem() + t := rv.Type() + if t.Kind() != reflect.Map { + return unmarshalErrorf("can not unmarshal %s into %T", info, value) + } else if data == nil { + rv.Set(reflect.Zero(t)) + return nil + } + + rv.Set(reflect.MakeMap(t)) + m := *v + + for _, e := range udt.Elements { + if len(data) == 0 { + return nil + } + + val := reflect.New(goType(e.Type)) + + var p []byte + p, data = readBytes(data) + + if err := Unmarshal(e.Type, p, val.Interface()); err != nil { + return err + } + + m[e.Name] = val.Elem().Interface() + } + + return nil + } + + k := reflect.ValueOf(value).Elem() + if k.Kind() != reflect.Struct || !k.IsValid() { + return unmarshalErrorf("cannot unmarshal %s into %T", info, value) + } + + if len(data) == 0 { + if k.CanSet() { + k.Set(reflect.Zero(k.Type())) + } + + return nil + } + + t := k.Type() + fields := make(map[string]reflect.Value, t.NumField()) + for i := 0; i < t.NumField(); i++ { + sf := t.Field(i) + + if tag := sf.Tag.Get("cql"); tag != "" { + fields[tag] = k.Field(i) + } + } + + udt := info.(UDTTypeInfo) + for _, e := range udt.Elements { + if len(data) < 4 { + // UDT def does not match the column value + return nil + } + + var p []byte + p, data = readBytes(data) + + f, ok := fields[e.Name] + if !ok { + f = k.FieldByName(e.Name) + if f == emptyValue { + // skip fields which exist in the UDT but not in + // the struct passed in + continue + } + } + + if !f.IsValid() || !f.CanAddr() { + return unmarshalErrorf("cannot unmarshal %s into %T: field %v is not valid", info, value, e.Name) + } + + fk := f.Addr().Interface() + if err := Unmarshal(e.Type, p, fk); err != nil { + return err + } + } + + return nil +} + +// TypeInfo describes a Cassandra specific data type. +type TypeInfo interface { + Type() Type + Version() byte + Custom() string + + // New creates a pointer to an empty version of whatever type + // is referenced by the TypeInfo receiver + New() interface{} +} + +type NativeType struct { + proto byte + typ Type + custom string // only used for TypeCustom +} + +func NewNativeType(proto byte, typ Type, custom string) NativeType { + return NativeType{proto, typ, custom} +} + +func (t NativeType) New() interface{} { + return reflect.New(goType(t)).Interface() +} + +func (s NativeType) Type() Type { + return s.typ +} + +func (s NativeType) Version() byte { + return s.proto +} + +func (s NativeType) Custom() string { + return s.custom +} + +func (s NativeType) String() string { + switch s.typ { + case TypeCustom: + return fmt.Sprintf("%s(%s)", s.typ, s.custom) + default: + return s.typ.String() + } +} + +type CollectionType struct { + NativeType + Key TypeInfo // only used for TypeMap + Elem TypeInfo // only used for TypeMap, TypeList and TypeSet +} + +func (t CollectionType) New() interface{} { + return reflect.New(goType(t)).Interface() +} + +func (c CollectionType) String() string { + switch c.typ { + case TypeMap: + return fmt.Sprintf("%s(%s, %s)", c.typ, c.Key, c.Elem) + case TypeList, TypeSet: + return fmt.Sprintf("%s(%s)", c.typ, c.Elem) + case TypeCustom: + return fmt.Sprintf("%s(%s)", c.typ, c.custom) + default: + return c.typ.String() + } +} + +type TupleTypeInfo struct { + NativeType + Elems []TypeInfo +} + +func (t TupleTypeInfo) String() string { + var buf bytes.Buffer + buf.WriteString(fmt.Sprintf("%s(", t.typ)) + for _, elem := range t.Elems { + buf.WriteString(fmt.Sprintf("%s, ", elem)) + } + buf.Truncate(buf.Len() - 2) + buf.WriteByte(')') + return buf.String() +} + +func (t TupleTypeInfo) New() interface{} { + return reflect.New(goType(t)).Interface() +} + +type UDTField struct { + Name string + Type TypeInfo +} + +type UDTTypeInfo struct { + NativeType + KeySpace string + Name string + Elements []UDTField +} + +func (u UDTTypeInfo) New() interface{} { + return reflect.New(goType(u)).Interface() +} + +func (u UDTTypeInfo) String() string { + buf := &bytes.Buffer{} + + fmt.Fprintf(buf, "%s.%s{", u.KeySpace, u.Name) + first := true + for _, e := range u.Elements { + if !first { + fmt.Fprint(buf, ",") + } else { + first = false + } + + fmt.Fprintf(buf, "%s=%v", e.Name, e.Type) + } + fmt.Fprint(buf, "}") + + return buf.String() +} + +// String returns a human readable name for the Cassandra datatype +// described by t. +// Type is the identifier of a Cassandra internal datatype. +type Type int + +const ( + TypeCustom Type = 0x0000 + TypeAscii Type = 0x0001 + TypeBigInt Type = 0x0002 + TypeBlob Type = 0x0003 + TypeBoolean Type = 0x0004 + TypeCounter Type = 0x0005 + TypeDecimal Type = 0x0006 + TypeDouble Type = 0x0007 + TypeFloat Type = 0x0008 + TypeInt Type = 0x0009 + TypeText Type = 0x000A + TypeTimestamp Type = 0x000B + TypeUUID Type = 0x000C + TypeVarchar Type = 0x000D + TypeVarint Type = 0x000E + TypeTimeUUID Type = 0x000F + TypeInet Type = 0x0010 + TypeDate Type = 0x0011 + TypeTime Type = 0x0012 + TypeSmallInt Type = 0x0013 + TypeTinyInt Type = 0x0014 + TypeDuration Type = 0x0015 + TypeList Type = 0x0020 + TypeMap Type = 0x0021 + TypeSet Type = 0x0022 + TypeUDT Type = 0x0030 + TypeTuple Type = 0x0031 +) + +// String returns the name of the identifier. +func (t Type) String() string { + switch t { + case TypeCustom: + return "custom" + case TypeAscii: + return "ascii" + case TypeBigInt: + return "bigint" + case TypeBlob: + return "blob" + case TypeBoolean: + return "boolean" + case TypeCounter: + return "counter" + case TypeDecimal: + return "decimal" + case TypeDouble: + return "double" + case TypeFloat: + return "float" + case TypeInt: + return "int" + case TypeText: + return "text" + case TypeTimestamp: + return "timestamp" + case TypeUUID: + return "uuid" + case TypeVarchar: + return "varchar" + case TypeTimeUUID: + return "timeuuid" + case TypeInet: + return "inet" + case TypeDate: + return "date" + case TypeDuration: + return "duration" + case TypeTime: + return "time" + case TypeSmallInt: + return "smallint" + case TypeTinyInt: + return "tinyint" + case TypeList: + return "list" + case TypeMap: + return "map" + case TypeSet: + return "set" + case TypeVarint: + return "varint" + case TypeTuple: + return "tuple" + default: + return fmt.Sprintf("unknown_type_%d", t) + } +} + +type MarshalError string + +func (m MarshalError) Error() string { + return string(m) +} + +func marshalErrorf(format string, args ...interface{}) MarshalError { + return MarshalError(fmt.Sprintf(format, args...)) +} + +type UnmarshalError string + +func (m UnmarshalError) Error() string { + return string(m) +} + +func unmarshalErrorf(format string, args ...interface{}) UnmarshalError { + return UnmarshalError(fmt.Sprintf(format, args...)) +} diff --git a/gocql/metadata_scylla.go b/gocql/metadata_scylla.go new file mode 100644 index 00000000..e1e3e24d --- /dev/null +++ b/gocql/metadata_scylla.go @@ -0,0 +1,709 @@ +// Copyright (c) 2015 The gocql Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package gocql + +import ( + "fmt" + "strings" + "sync" +) + +// schema metadata for a keyspace +type KeyspaceMetadata struct { + Name string + DurableWrites bool + StrategyClass string + StrategyOptions map[string]interface{} + Tables map[string]*TableMetadata + Functions map[string]*FunctionMetadata + Aggregates map[string]*AggregateMetadata + Types map[string]*TypeMetadata + Indexes map[string]*IndexMetadata + Views map[string]*ViewMetadata +} + +// schema metadata for a table (a.k.a. column family) +type TableMetadata struct { + Keyspace string + Name string + PartitionKey []*ColumnMetadata + ClusteringColumns []*ColumnMetadata + Columns map[string]*ColumnMetadata + OrderedColumns []string + Options TableMetadataOptions + Flags []string + Extensions map[string]interface{} +} + +type TableMetadataOptions struct { + BloomFilterFpChance float64 + Caching map[string]string + Comment string + Compaction map[string]string + Compression map[string]string + CrcCheckChance float64 + DcLocalReadRepairChance float64 + DefaultTimeToLive int + GcGraceSeconds int + MaxIndexInterval int + MemtableFlushPeriodInMs int + MinIndexInterval int + ReadRepairChance float64 + SpeculativeRetry string + CDC map[string]string + InMemory bool + Partitioner string + Version string +} + +type ViewMetadata struct { + KeyspaceName string + ViewName string + BaseTableID string + BaseTableName string + ID string + IncludeAllColumns bool + Columns map[string]*ColumnMetadata + OrderedColumns []string + PartitionKey []*ColumnMetadata + ClusteringColumns []*ColumnMetadata + WhereClause string + Options TableMetadataOptions + Extensions map[string]interface{} +} + +// schema metadata for a column +type ColumnMetadata struct { + Keyspace string + Table string + Name string + ComponentIndex int + Kind ColumnKind + Type string + ClusteringOrder string + Order ColumnOrder + Index ColumnIndexMetadata +} + +// FunctionMetadata holds metadata for function constructs +type FunctionMetadata struct { + Keyspace string + Name string + ArgumentTypes []string + ArgumentNames []string + Body string + CalledOnNullInput bool + Language string + ReturnType string +} + +// AggregateMetadata holds metadata for aggregate constructs +type AggregateMetadata struct { + Keyspace string + Name string + ArgumentTypes []string + FinalFunc FunctionMetadata + InitCond string + ReturnType string + StateFunc FunctionMetadata + StateType string + + stateFunc string + finalFunc string +} + +// TypeMetadata holds the metadata for views. +type TypeMetadata struct { + Keyspace string + Name string + FieldNames []string + FieldTypes []string +} + +type IndexMetadata struct { + Name string + KeyspaceName string + TableName string + Kind string + Options map[string]string +} + +const ( + IndexKindCustom = "CUSTOM" +) + +const ( + TableFlagDense = "dense" + TableFlagSuper = "super" + TableFlagCompound = "compound" +) + +// the ordering of the column with regard to its comparator +type ColumnOrder bool + +const ( + ASC ColumnOrder = false + DESC = true +) + +type ColumnIndexMetadata struct { + Name string + Type string + Options map[string]interface{} +} + +type ColumnKind int + +const ( + ColumnUnkownKind ColumnKind = iota + ColumnPartitionKey + ColumnClusteringKey + ColumnRegular + ColumnCompact + ColumnStatic +) + +func (c ColumnKind) String() string { + switch c { + case ColumnPartitionKey: + return "partition_key" + case ColumnClusteringKey: + return "clustering_key" + case ColumnRegular: + return "regular" + case ColumnCompact: + return "compact" + case ColumnStatic: + return "static" + default: + return fmt.Sprintf("unknown_column_%d", c) + } +} + +func (c *ColumnKind) UnmarshalCQL(typ TypeInfo, p []byte) error { + if typ.Type() != TypeVarchar { + return unmarshalErrorf("unable to marshall %s into ColumnKind, expected Varchar", typ) + } + + kind, err := columnKindFromSchema(string(p)) + if err != nil { + return err + } + *c = kind + + return nil +} + +func columnKindFromSchema(kind string) (ColumnKind, error) { + switch kind { + case "partition_key": + return ColumnPartitionKey, nil + case "clustering_key", "clustering": + return ColumnClusteringKey, nil + case "regular": + return ColumnRegular, nil + case "compact_value": + return ColumnCompact, nil + case "static": + return ColumnStatic, nil + default: + return -1, fmt.Errorf("unknown column kind: %q", kind) + } +} + +// queries the cluster for schema information for a specific keyspace +type schemaDescriber struct { + session *Session + mu sync.Mutex + + cache map[string]*KeyspaceMetadata +} + +// creates a session bound schema describer which will query and cache +// keyspace metadata +func newSchemaDescriber(session *Session) *schemaDescriber { + return &schemaDescriber{ + session: session, + cache: map[string]*KeyspaceMetadata{}, + } +} + +// returns the cached KeyspaceMetadata held by the describer for the named +// keyspace. +func (s *schemaDescriber) getSchema(keyspaceName string) (*KeyspaceMetadata, error) { + s.mu.Lock() + defer s.mu.Unlock() + + metadata, found := s.cache[keyspaceName] + if !found { + // refresh the cache for this keyspace + err := s.refreshSchema(keyspaceName) + if err != nil { + return nil, err + } + + metadata = s.cache[keyspaceName] + } + + return metadata, nil +} + +// clears the already cached keyspace metadata +func (s *schemaDescriber) clearSchema(keyspaceName string) { + s.mu.Lock() + defer s.mu.Unlock() + + delete(s.cache, keyspaceName) +} + +// forcibly updates the current KeyspaceMetadata held by the schema describer +// for a given named keyspace. +func (s *schemaDescriber) refreshSchema(keyspaceName string) error { + var err error + + // query the system keyspace for schema data + // TODO retrieve concurrently + keyspace, err := getKeyspaceMetadata(s.session, keyspaceName) + if err != nil { + return err + } + tables, err := getTableMetadata(s.session, keyspaceName) + if err != nil { + return err + } + columns, err := getColumnMetadata(s.session, keyspaceName) + if err != nil { + return err + } + functions, err := getFunctionsMetadata(s.session, keyspaceName) + if err != nil { + return err + } + aggregates, err := getAggregatesMetadata(s.session, keyspaceName) + if err != nil { + return err + } + types, err := getTypeMetadata(s.session, keyspaceName) + if err != nil { + return err + } + indexes, err := getIndexMetadata(s.session, keyspaceName) + if err != nil { + return err + } + views, err := getViewMetadata(s.session, keyspaceName) + if err != nil { + return err + } + + // organize the schema data + compileMetadata(keyspace, tables, columns, functions, aggregates, types, indexes, views) + + // update the cache + s.cache[keyspaceName] = keyspace + + return nil +} + +// "compiles" derived information about keyspace, table, and column metadata +// for a keyspace from the basic queried metadata objects returned by +// getKeyspaceMetadata, getTableMetadata, and getColumnMetadata respectively; +// Links the metadata objects together and derives the column composition of +// the partition key and clustering key for a table. +func compileMetadata( + keyspace *KeyspaceMetadata, + tables []TableMetadata, + columns []ColumnMetadata, + functions []FunctionMetadata, + aggregates []AggregateMetadata, + types []TypeMetadata, + indexes []IndexMetadata, + views []ViewMetadata, +) { + keyspace.Tables = make(map[string]*TableMetadata) + for i := range tables { + tables[i].Columns = make(map[string]*ColumnMetadata) + keyspace.Tables[tables[i].Name] = &tables[i] + } + keyspace.Functions = make(map[string]*FunctionMetadata, len(functions)) + for i := range functions { + keyspace.Functions[functions[i].Name] = &functions[i] + } + keyspace.Aggregates = make(map[string]*AggregateMetadata, len(aggregates)) + for _, aggregate := range aggregates { + aggregate.FinalFunc = *keyspace.Functions[aggregate.finalFunc] + aggregate.StateFunc = *keyspace.Functions[aggregate.stateFunc] + keyspace.Aggregates[aggregate.Name] = &aggregate + } + keyspace.Types = make(map[string]*TypeMetadata, len(types)) + for i := range types { + keyspace.Types[types[i].Name] = &types[i] + } + keyspace.Indexes = make(map[string]*IndexMetadata, len(indexes)) + for i := range indexes { + keyspace.Indexes[indexes[i].Name] = &indexes[i] + } + keyspace.Views = make(map[string]*ViewMetadata, len(views)) + for i := range views { + v := &views[i] + if _, ok := keyspace.Indexes[strings.TrimSuffix(v.ViewName, "_index")]; ok { + continue + } + + v.Columns = make(map[string]*ColumnMetadata) + keyspace.Views[v.ViewName] = v + } + + // add columns from the schema data + for i := range columns { + col := &columns[i] + col.Order = ASC + if col.ClusteringOrder == "desc" { + col.Order = DESC + } + + table, ok := keyspace.Tables[col.Table] + if !ok { + view, ok := keyspace.Views[col.Table] + if !ok { + // if the schema is being updated we will race between seeing + // the metadata be complete. Potentially we should check for + // schema versions before and after reading the metadata and + // if they dont match try again. + continue + } + + view.Columns[col.Name] = col + view.OrderedColumns = append(view.OrderedColumns, col.Name) + continue + } + + table.Columns[col.Name] = col + table.OrderedColumns = append(table.OrderedColumns, col.Name) + } + + for i := range tables { + t := &tables[i] + t.PartitionKey, t.ClusteringColumns, t.OrderedColumns = compileColumns(t.Columns, t.OrderedColumns) + } + for i := range views { + v := &views[i] + v.PartitionKey, v.ClusteringColumns, v.OrderedColumns = compileColumns(v.Columns, v.OrderedColumns) + } +} + +func compileColumns(columns map[string]*ColumnMetadata, orderedColumns []string) ( + partitionKey, clusteringColumns []*ColumnMetadata, sortedColumns []string) { + clusteringColumnCount := componentColumnCountOfType(columns, ColumnClusteringKey) + clusteringColumns = make([]*ColumnMetadata, clusteringColumnCount) + + partitionKeyCount := componentColumnCountOfType(columns, ColumnPartitionKey) + partitionKey = make([]*ColumnMetadata, partitionKeyCount) + + var otherColumns []string + for _, columnName := range orderedColumns { + column := columns[columnName] + if column.Kind == ColumnPartitionKey { + partitionKey[column.ComponentIndex] = column + } else if column.Kind == ColumnClusteringKey { + clusteringColumns[column.ComponentIndex] = column + } else { + otherColumns = append(otherColumns, columnName) + } + } + + sortedColumns = orderedColumns[:0] + for _, pk := range partitionKey { + sortedColumns = append(sortedColumns, pk.Name) + } + for _, ck := range clusteringColumns { + sortedColumns = append(sortedColumns, ck.Name) + } + for _, oc := range otherColumns { + sortedColumns = append(sortedColumns, oc) + } + + return +} + +// returns the count of coluns with the given "kind" value. +func componentColumnCountOfType(columns map[string]*ColumnMetadata, kind ColumnKind) int { + maxComponentIndex := -1 + for _, column := range columns { + if column.Kind == kind && column.ComponentIndex > maxComponentIndex { + maxComponentIndex = column.ComponentIndex + } + } + return maxComponentIndex + 1 +} + +// query for keyspace metadata in the system_schema.keyspaces +func getKeyspaceMetadata(session *Session, keyspaceName string) (*KeyspaceMetadata, error) { + keyspace := &KeyspaceMetadata{Name: keyspaceName} + + const stmt = ` + SELECT durable_writes, replication + FROM system_schema.keyspaces + WHERE keyspace_name = ?` + + var replication map[string]string + + iter := session.control.Iter(stmt, keyspaceName) + if iter.NumRows() == 0 { + return nil, ErrKeyspaceDoesNotExist + } + iter.Scan(&keyspace.DurableWrites, &replication) + err := iter.Close() + if err != nil { + return nil, fmt.Errorf("error querying keyspace schema: %v", err) + } + + keyspace.StrategyClass = replication["class"] + delete(replication, "class") + + keyspace.StrategyOptions = make(map[string]interface{}, len(replication)) + for k, v := range replication { + keyspace.StrategyOptions[k] = v + } + + return keyspace, nil +} + +// query for table metadata in the system_schema.tables and system_schema.scylla_tables +func getTableMetadata(session *Session, keyspaceName string) ([]TableMetadata, error) { + stmt := `SELECT * FROM system_schema.tables WHERE keyspace_name = ?` + iter := session.control.Iter(stmt, keyspaceName) + + var tables []TableMetadata + table := TableMetadata{Keyspace: keyspaceName} + for iter.MapScan(map[string]interface{}{ + "table_name": &table.Name, + "bloom_filter_fp_chance": &table.Options.BloomFilterFpChance, + "caching": &table.Options.Caching, + "comment": &table.Options.Comment, + "compaction": &table.Options.Compaction, + "compression": &table.Options.Compression, + "crc_check_chance": &table.Options.CrcCheckChance, + "dclocal_read_repair_chance": &table.Options.DcLocalReadRepairChance, + "default_time_to_live": &table.Options.DefaultTimeToLive, + "gc_grace_seconds": &table.Options.GcGraceSeconds, + "max_index_interval": &table.Options.MaxIndexInterval, + "memtable_flush_period_in_ms": &table.Options.MemtableFlushPeriodInMs, + "min_index_interval": &table.Options.MinIndexInterval, + "read_repair_chance": &table.Options.ReadRepairChance, + "speculative_retry": &table.Options.SpeculativeRetry, + "flags": &table.Flags, + "extensions": &table.Extensions, + }) { + tables = append(tables, table) + table = TableMetadata{Keyspace: keyspaceName} + } + + err := iter.Close() + if err != nil && err != ErrNotFound { + return nil, fmt.Errorf("error querying table schema: %v", err) + } + + stmt = `SELECT * FROM system_schema.scylla_tables WHERE keyspace_name = ? AND table_name = ?` + for i, t := range tables { + iter := session.control.Iter(stmt, keyspaceName, t.Name) + + table := TableMetadata{} + if iter.MapScan(map[string]interface{}{ + "cdc": &table.Options.CDC, + "in_memory": &table.Options.InMemory, + "partitioner": &table.Options.Partitioner, + "version": &table.Options.Version, + }) { + tables[i].Options.CDC = table.Options.CDC + tables[i].Options.Version = table.Options.Version + tables[i].Options.Partitioner = table.Options.Partitioner + tables[i].Options.InMemory = table.Options.InMemory + } + if err := iter.Close(); err != nil && err != ErrNotFound { + return nil, fmt.Errorf("error querying scylla table schema: %v", err) + } + } + + return tables, nil +} + +// query for column metadata in the system_schema.columns +func getColumnMetadata(session *Session, keyspaceName string) ([]ColumnMetadata, error) { + const stmt = `SELECT * FROM system_schema.columns WHERE keyspace_name = ?` + + var columns []ColumnMetadata + + iter := session.control.Iter(stmt, keyspaceName) + column := ColumnMetadata{Keyspace: keyspaceName} + + for iter.MapScan(map[string]interface{}{ + "table_name": &column.Table, + "column_name": &column.Name, + "clustering_order": &column.ClusteringOrder, + "type": &column.Type, + "kind": &column.Kind, + "position": &column.ComponentIndex, + }) { + columns = append(columns, column) + column = ColumnMetadata{Keyspace: keyspaceName} + } + + if err := iter.Close(); err != nil && err != ErrNotFound { + return nil, fmt.Errorf("error querying column schema: %v", err) + } + + return columns, nil +} + +// query for type metadata in the system_schema.types +func getTypeMetadata(session *Session, keyspaceName string) ([]TypeMetadata, error) { + stmt := `SELECT * FROM system_schema.types WHERE keyspace_name = ?` + iter := session.control.Iter(stmt, keyspaceName) + + var types []TypeMetadata + tm := TypeMetadata{Keyspace: keyspaceName} + + for iter.MapScan(map[string]interface{}{ + "type_name": &tm.Name, + "field_names": &tm.FieldNames, + "field_types": &tm.FieldTypes, + }) { + types = append(types, tm) + tm = TypeMetadata{Keyspace: keyspaceName} + } + + if err := iter.Close(); err != nil { + return nil, err + } + + return types, nil +} + +// query for function metadata in the system_schema.functions +func getFunctionsMetadata(session *Session, keyspaceName string) ([]FunctionMetadata, error) { + stmt := `SELECT * FROM system_schema.functions WHERE keyspace_name = ?` + + var functions []FunctionMetadata + function := FunctionMetadata{Keyspace: keyspaceName} + + iter := session.control.Iter(stmt, keyspaceName) + for iter.MapScan(map[string]interface{}{ + "function_name": &function.Name, + "argument_types": &function.ArgumentTypes, + "argument_names": &function.ArgumentNames, + "body": &function.Body, + "called_on_null_input": &function.CalledOnNullInput, + "language": &function.Language, + "return_type": &function.ReturnType, + }) { + functions = append(functions, function) + function = FunctionMetadata{Keyspace: keyspaceName} + } + + if err := iter.Close(); err != nil { + return nil, err + } + + return functions, nil +} + +// query for aggregate metadata in the system_schema.aggregates +func getAggregatesMetadata(session *Session, keyspaceName string) ([]AggregateMetadata, error) { + const stmt = `SELECT * FROM system_schema.aggregates WHERE keyspace_name = ?` + + var aggregates []AggregateMetadata + aggregate := AggregateMetadata{Keyspace: keyspaceName} + + iter := session.control.Iter(stmt, keyspaceName) + for iter.MapScan(map[string]interface{}{ + "aggregate_name": &aggregate.Name, + "argument_types": &aggregate.ArgumentTypes, + "final_func": &aggregate.finalFunc, + "initcond": &aggregate.InitCond, + "return_type": &aggregate.ReturnType, + "state_func": &aggregate.stateFunc, + "state_type": &aggregate.StateType, + }) { + aggregates = append(aggregates, aggregate) + aggregate = AggregateMetadata{Keyspace: keyspaceName} + } + + if err := iter.Close(); err != nil { + return nil, err + } + + return aggregates, nil +} + +// query for index metadata in the system_schema.indexes +func getIndexMetadata(session *Session, keyspaceName string) ([]IndexMetadata, error) { + const stmt = `SELECT * FROM system_schema.indexes WHERE keyspace_name = ?` + + var indexes []IndexMetadata + index := IndexMetadata{} + + iter := session.control.Iter(stmt, keyspaceName) + for iter.MapScan(map[string]interface{}{ + "index_name": &index.Name, + "keyspace_name": &index.KeyspaceName, + "table_name": &index.TableName, + "kind": &index.Kind, + "options": &index.Options, + }) { + indexes = append(indexes, index) + index = IndexMetadata{} + } + + if err := iter.Close(); err != nil { + return nil, err + } + + return indexes, nil +} + +// query for view metadata in the system_schema.views +func getViewMetadata(session *Session, keyspaceName string) ([]ViewMetadata, error) { + stmt := `SELECT * FROM system_schema.views WHERE keyspace_name = ?` + + iter := session.control.Iter(stmt, keyspaceName) + + var views []ViewMetadata + view := ViewMetadata{KeyspaceName: keyspaceName} + + for iter.MapScan(map[string]interface{}{ + "id": &view.ID, + "view_name": &view.ViewName, + "base_table_id": &view.BaseTableID, + "base_table_name": &view.BaseTableName, + "include_all_columns": &view.IncludeAllColumns, + "where_clause": &view.WhereClause, + "bloom_filter_fp_chance": &view.Options.BloomFilterFpChance, + "caching": &view.Options.Caching, + "comment": &view.Options.Comment, + "compaction": &view.Options.Compaction, + "compression": &view.Options.Compression, + "crc_check_chance": &view.Options.CrcCheckChance, + "dclocal_read_repair_chance": &view.Options.DcLocalReadRepairChance, + "default_time_to_live": &view.Options.DefaultTimeToLive, + "gc_grace_seconds": &view.Options.GcGraceSeconds, + "max_index_interval": &view.Options.MaxIndexInterval, + "memtable_flush_period_in_ms": &view.Options.MemtableFlushPeriodInMs, + "min_index_interval": &view.Options.MinIndexInterval, + "read_repair_chance": &view.Options.ReadRepairChance, + "speculative_retry": &view.Options.SpeculativeRetry, + "extensions": &view.Extensions, + }) { + views = append(views, view) + view = ViewMetadata{KeyspaceName: keyspaceName} + } + + err := iter.Close() + if err != nil && err != ErrNotFound { + return nil, fmt.Errorf("error querying view schema: %v", err) + } + + return views, nil +} diff --git a/gocql/metadata_scylla_test.go b/gocql/metadata_scylla_test.go new file mode 100644 index 00000000..8978cbc5 --- /dev/null +++ b/gocql/metadata_scylla_test.go @@ -0,0 +1,394 @@ +// Copyright (c) 2015 The gocql Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package gocql + +import ( + "testing" +) + +// Tests metadata "compilation" from example data which might be returned +// from metadata schema queries (see getKeyspaceMetadata, getTableMetadata, and getColumnMetadata) +func TestCompileMetadata(t *testing.T) { + keyspace := &KeyspaceMetadata{ + Name: "V2Keyspace", + } + tables := []TableMetadata{ + { + Keyspace: "V2Keyspace", + Name: "Table1", + }, + { + Keyspace: "V2Keyspace", + Name: "Table2", + }, + } + columns := []ColumnMetadata{ + { + Keyspace: "V2Keyspace", + Table: "Table1", + Name: "KEY1", + Kind: ColumnPartitionKey, + ComponentIndex: 0, + Type: "text", + }, + { + Keyspace: "V2Keyspace", + Table: "Table1", + Name: "Key1", + Kind: ColumnPartitionKey, + ComponentIndex: 0, + Type: "text", + }, + { + Keyspace: "V2Keyspace", + Table: "Table2", + Name: "Column1", + Kind: ColumnPartitionKey, + ComponentIndex: 0, + Type: "text", + }, + { + Keyspace: "V2Keyspace", + Table: "Table2", + Name: "Column2", + Kind: ColumnClusteringKey, + ComponentIndex: 0, + Type: "text", + }, + { + Keyspace: "V2Keyspace", + Table: "Table2", + Name: "Column3", + Kind: ColumnClusteringKey, + ComponentIndex: 1, + Type: "text", + ClusteringOrder: "desc", + }, + { + Keyspace: "V2Keyspace", + Table: "Table2", + Name: "Column4", + Kind: ColumnRegular, + Type: "text", + }, + { + Keyspace: "V2Keyspace", + Table: "view", + Name: "ColReg", + Kind: ColumnRegular, + Type: "text", + }, + { + Keyspace: "V2Keyspace", + Table: "view", + Name: "ColCK", + Kind: ColumnClusteringKey, + Type: "text", + }, + { + Keyspace: "V2Keyspace", + Table: "view", + Name: "ColPK", + Kind: ColumnPartitionKey, + Type: "text", + }, + } + indexes := []IndexMetadata{ + { + Name: "sec_idx", + }, + } + views := []ViewMetadata{ + { + KeyspaceName: "V2Keyspace", + ViewName: "view", + }, + { + KeyspaceName: "V2Keyspace", + ViewName: "sec_idx_index", + }, + } + compileMetadata(keyspace, tables, columns, nil, nil, nil, indexes, views) + assertKeyspaceMetadata( + t, + keyspace, + &KeyspaceMetadata{ + Name: "V2Keyspace", + Views: map[string]*ViewMetadata{ + "view": { + PartitionKey: []*ColumnMetadata{ + { + Name: "ColPK", + Type: "text", + }, + }, + ClusteringColumns: []*ColumnMetadata{ + { + Name: "ColCK", + Type: "text", + }, + }, + OrderedColumns: []string{ + "ColPK", "ColCK", "ColReg", + }, + Columns: map[string]*ColumnMetadata{ + "ColPK": { + Name: "ColPK", + Kind: ColumnPartitionKey, + Type: "text", + }, + "ColCK": { + Name: "ColCK", + Kind: ColumnClusteringKey, + Type: "text", + }, + "ColReg": { + Name: "ColReg", + Kind: ColumnRegular, + Type: "text", + }, + }, + }, + }, + Tables: map[string]*TableMetadata{ + "Table1": { + PartitionKey: []*ColumnMetadata{ + { + Name: "Key1", + Type: "text", + }, + }, + ClusteringColumns: []*ColumnMetadata{}, + Columns: map[string]*ColumnMetadata{ + "KEY1": { + Name: "KEY1", + Type: "text", + Kind: ColumnPartitionKey, + }, + "Key1": { + Name: "Key1", + Type: "text", + Kind: ColumnPartitionKey, + }, + }, + OrderedColumns: []string{ + "Key1", + }, + }, + "Table2": { + PartitionKey: []*ColumnMetadata{ + { + Name: "Column1", + Type: "text", + }, + }, + ClusteringColumns: []*ColumnMetadata{ + { + Name: "Column2", + Type: "text", + Order: ASC, + }, + { + Name: "Column3", + Type: "text", + Order: DESC, + }, + }, + Columns: map[string]*ColumnMetadata{ + "Column1": { + Name: "Column1", + Type: "text", + Kind: ColumnPartitionKey, + }, + "Column2": { + Name: "Column2", + Type: "text", + Order: ASC, + Kind: ColumnClusteringKey, + }, + "Column3": { + Name: "Column3", + Type: "text", + Order: DESC, + Kind: ColumnClusteringKey, + }, + "Column4": { + Name: "Column4", + Type: "text", + Kind: ColumnRegular, + }, + }, + OrderedColumns: []string{ + "Column1", "Column2", "Column3", "Column4", + }, + }, + }, + }, + ) +} + +func assertPartitionKey(t *testing.T, keyspaceName, tableName string, actual, expected []*ColumnMetadata) { + if len(expected) != len(actual) { + t.Errorf("Expected len(%s.Tables[%s].PartitionKey) to be %v but was %v", keyspaceName, tableName, len(expected), len(actual)) + } else { + for i := range expected { + if expected[i].Name != actual[i].Name { + t.Errorf("Expected %s.Tables[%s].PartitionKey[%d].Name to be '%v' but was '%v'", keyspaceName, tableName, i, expected[i].Name, actual[i].Name) + } + if keyspaceName != actual[i].Keyspace { + t.Errorf("Expected %s.Tables[%s].PartitionKey[%d].Keyspace to be '%v' but was '%v'", keyspaceName, tableName, i, keyspaceName, actual[i].Keyspace) + } + if tableName != actual[i].Table { + t.Errorf("Expected %s.Tables[%s].PartitionKey[%d].Table to be '%v' but was '%v'", keyspaceName, tableName, i, tableName, actual[i].Table) + } + if expected[i].Type != actual[i].Type { + t.Errorf("Expected %s.Tables[%s].PartitionKey[%d].Type.Type to be %v but was %v", keyspaceName, tableName, i, expected[i].Type, actual[i].Type) + } + if i != actual[i].ComponentIndex { + t.Errorf("Expected %s.Tables[%s].PartitionKey[%d].ComponentIndex to be %v but was %v", keyspaceName, tableName, i, i, actual[i].ComponentIndex) + } + if ColumnPartitionKey != actual[i].Kind { + t.Errorf("Expected %s.Tables[%s].PartitionKey[%d].Kind to be '%v' but was '%v'", keyspaceName, tableName, i, ColumnPartitionKey, actual[i].Kind) + } + } + } +} +func assertClusteringColumns(t *testing.T, keyspaceName, tableName string, actual, expected []*ColumnMetadata) { + if len(expected) != len(actual) { + t.Errorf("Expected len(%s.Tables[%s].ClusteringColumns) to be %v but was %v", keyspaceName, tableName, len(expected), len(actual)) + } else { + for i := range expected { + if actual[i] == nil { + t.Fatalf("Unexpected nil value: %s.Tables[%s].ClusteringColumns[%d]", keyspaceName, tableName, i) + } + if expected[i].Name != actual[i].Name { + t.Errorf("Expected %s.Tables[%s].ClusteringColumns[%d].Name to be '%v' but was '%v'", keyspaceName, tableName, i, expected[i].Name, actual[i].Name) + } + if keyspaceName != actual[i].Keyspace { + t.Errorf("Expected %s.Tables[%s].ClusteringColumns[%d].Keyspace to be '%v' but was '%v'", keyspaceName, tableName, i, keyspaceName, actual[i].Keyspace) + } + if tableName != actual[i].Table { + t.Errorf("Expected %s.Tables[%s].ClusteringColumns[%d].Table to be '%v' but was '%v'", keyspaceName, tableName, i, tableName, actual[i].Table) + } + if expected[i].Type != actual[i].Type { + t.Errorf("Expected %s.Tables[%s].ClusteringColumns[%d].Type.Type to be %v but was %v", keyspaceName, tableName, i, expected[i].Type, actual[i].Type) + } + if i != actual[i].ComponentIndex { + t.Errorf("Expected %s.Tables[%s].ClusteringColumns[%d].ComponentIndex to be %v but was %v", keyspaceName, tableName, i, i, actual[i].ComponentIndex) + } + if expected[i].Order != actual[i].Order { + t.Errorf("Expected %s.Tables[%s].ClusteringColumns[%d].Order to be %v but was %v", keyspaceName, tableName, i, expected[i].Order, actual[i].Order) + } + if ColumnClusteringKey != actual[i].Kind { + t.Errorf("Expected %s.Tables[%s].ClusteringColumns[%d].Kind to be '%v' but was '%v'", keyspaceName, tableName, i, ColumnClusteringKey, actual[i].Kind) + } + } + } +} + +func assertColumns(t *testing.T, keyspaceName, tableName string, actual, expected map[string]*ColumnMetadata) { + if len(expected) != len(actual) { + eKeys := make([]string, 0, len(expected)) + for key := range expected { + eKeys = append(eKeys, key) + } + aKeys := make([]string, 0, len(actual)) + for key := range actual { + aKeys = append(aKeys, key) + } + t.Errorf("Expected len(%s.Tables[%s].Columns) to be %v (keys:%v) but was %v (keys:%v)", keyspaceName, tableName, len(expected), eKeys, len(actual), aKeys) + } else { + for keyC := range expected { + ec := expected[keyC] + ac, found := actual[keyC] + + if !found { + t.Errorf("Expected %s.Tables[%s].Columns[%s] but was not found", keyspaceName, tableName, keyC) + } else { + if keyC != ac.Name { + t.Errorf("Expected %s.Tables[%s].Columns[%s].Name to be '%v' but was '%v'", keyspaceName, tableName, keyC, keyC, tableName) + } + if keyspaceName != ac.Keyspace { + t.Errorf("Expected %s.Tables[%s].Columns[%s].Keyspace to be '%v' but was '%v'", keyspaceName, tableName, keyC, keyspaceName, ac.Keyspace) + } + if tableName != ac.Table { + t.Errorf("Expected %s.Tables[%s].Columns[%s].Table to be '%v' but was '%v'", keyspaceName, tableName, keyC, tableName, ac.Table) + } + if ec.Type != ac.Type { + t.Errorf("Expected %s.Tables[%s].Columns[%s].Type.Type to be %v but was %v", keyspaceName, tableName, keyC, ec.Type, ac.Type) + } + if ec.Order != ac.Order { + t.Errorf("Expected %s.Tables[%s].Columns[%s].Order to be %v but was %v", keyspaceName, tableName, keyC, ec.Order, ac.Order) + } + if ec.Kind != ac.Kind { + t.Errorf("Expected %s.Tables[%s].Columns[%s].Kind to be '%v' but was '%v'", keyspaceName, tableName, keyC, ec.Kind, ac.Kind) + } + } + } + } +} + +func assertOrderedColumns(t *testing.T, keyspaceName, tableName string, actual, expected []string) { + if len(expected) != len(actual) { + t.Errorf("Expected len(%s.Tables[%s].OrderedColumns to be %v but was %v", keyspaceName, tableName, len(expected), len(actual)) + } else { + for i, eoc := range expected { + aoc := actual[i] + if eoc != aoc { + t.Errorf("Expected %s.Tables[%s].OrderedColumns[%d] to be %s, but was %s", keyspaceName, tableName, i, eoc, aoc) + } + } + } +} + +func assertTableMetadata(t *testing.T, keyspaceName string, actual, expected map[string]*TableMetadata) { + if len(expected) != len(actual) { + t.Errorf("Expected len(%s.Tables) to be %v but was %v", keyspaceName, len(expected), len(actual)) + } + for keyT := range expected { + et := expected[keyT] + at, found := actual[keyT] + + if !found { + t.Errorf("Expected %s.Tables[%s] but was not found", keyspaceName, keyT) + } else { + if keyT != at.Name { + t.Errorf("Expected %s.Tables[%s].Name to be %v but was %v", keyspaceName, keyT, keyT, at.Name) + } + assertPartitionKey(t, keyspaceName, keyT, at.PartitionKey, et.PartitionKey) + assertClusteringColumns(t, keyspaceName, keyT, at.ClusteringColumns, et.ClusteringColumns) + assertColumns(t, keyspaceName, keyT, at.Columns, et.Columns) + assertOrderedColumns(t, keyspaceName, keyT, at.OrderedColumns, et.OrderedColumns) + } + } +} + +func assertViewsMetadata(t *testing.T, keyspaceName string, actual, expected map[string]*ViewMetadata) { + if len(expected) != len(actual) { + t.Errorf("Expected len(%s.Views) to be %v but was %v", keyspaceName, len(expected), len(actual)) + } + for keyT := range expected { + et := expected[keyT] + at, found := actual[keyT] + + if !found { + t.Errorf("Expected %s.Views[%s] but was not found", keyspaceName, keyT) + } else { + if keyT != at.ViewName { + t.Errorf("Expected %s.Views[%s].Name to be %v but was %v", keyspaceName, keyT, keyT, at.ViewName) + } + assertPartitionKey(t, keyspaceName, keyT, at.PartitionKey, et.PartitionKey) + assertClusteringColumns(t, keyspaceName, keyT, at.ClusteringColumns, et.ClusteringColumns) + assertColumns(t, keyspaceName, keyT, at.Columns, et.Columns) + assertOrderedColumns(t, keyspaceName, keyT, at.OrderedColumns, et.OrderedColumns) + } + } +} + +// Helper function for asserting that actual metadata returned was as expected +func assertKeyspaceMetadata(t *testing.T, actual, expected *KeyspaceMetadata) { + assertTableMetadata(t, expected.Name, actual.Tables, expected.Tables) + assertViewsMetadata(t, expected.Name, actual.Views, expected.Views) +} diff --git a/gocql/query.go b/gocql/query.go new file mode 100644 index 00000000..08dba75e --- /dev/null +++ b/gocql/query.go @@ -0,0 +1,176 @@ +package gocql + +import ( + "context" + "fmt" + + "github.com/scylladb/scylla-go-driver" + "github.com/scylladb/scylla-go-driver/frame" +) + +type Query struct { + ctx context.Context + query scylla.Query + err error + values []interface{} + prepared bool +} + +type anyWrapper struct { + val any +} + +func (w anyWrapper) Serialize(o *frame.Option) (n int32, bytes []byte, err error) { + ti := WrapOption(o) + bytes, err = Marshal(ti, w.val) + if bytes == nil { + n = -1 + } else { + n = int32(len(bytes)) + } + return +} + +func (q *Query) Bind(values ...interface{}) *Query { + if !q.prepared { + q.values = values + return q + } + + for i, v := range values { + q.query.Bind(i, anyWrapper{v}) + } + return q +} + +func (q *Query) Exec() error { + q.prepare() + _, err := q.query.Exec(q.ctx) + return err +} + +func unmarshalCqlValue(c frame.CqlValue, dst interface{}) error { + return Unmarshal(WrapOption(c.Type), c.Value, dst) +} + +func (q *Query) Scan(values ...interface{}) error { + q.prepare() + res, err := q.query.Exec(q.ctx) + if err != nil { + return err + } + + if len(res.Rows[0]) != len(values) { + return fmt.Errorf("column count mismatch expected %d, got %d", len(values), len(res.Rows)) + } + + for i, v := range res.Rows[0] { + if err := unmarshalCqlValue(v, values[i]); err != nil { + return err + } + } + + return nil +} + +func (q *Query) prepare() { + if q.prepared || q.err != nil { + return + } + + q.err = q.query.Prepare(q.ctx) + q.prepared = true + for i, v := range q.values { + q.query.Bind(i, anyWrapper{v}) + } + q.values = nil +} + +func (q *Query) Iter() *Iter { + q.prepare() + if q.err != nil { + return &Iter{it: &scylla.Iter{}, err: q.err} + } + it := q.query.Iter(q.ctx) + return newIter(&it) +} + +func (q *Query) Release() { + // TODO: does this need to do anything, new driver doesn't have a pool of queries. +} + +func (q *Query) WithContext(ctx context.Context) *Query { + q.ctx = ctx + return q +} + +func (q *Query) Consistency(c Consistency) *Query { + panic("unimplemented") +} + +// CustomPayload sets the custom payload level for this query. +func (q *Query) CustomPayload(customPayload map[string][]byte) *Query { + panic("unimplemented") +} + +// Trace enables tracing of this query. Look at the documentation of the +// Tracer interface to learn more about tracing. +func (q *Query) Trace(trace Tracer) *Query { + panic("unimplemented") +} + +// Observer enables query-level observer on this query. +// The provided observer will be called every time this query is executed. +func (q *Query) Observer(observer QueryObserver) *Query { + panic("unimplemented") +} + +func (q *Query) PageSize(n int) *Query { + q.query.SetPageSize(int32(n)) + return q +} + +func (q *Query) DefaultTimestamp(enable bool) *Query { + panic("unimplemented") +} + +func (q *Query) WithTimestamp(timestamp int64) *Query { + panic("unimplemented") +} + +func (q *Query) RoutingKey(routingKey []byte) *Query { + panic("unimplemented") +} + +func (q *Query) Prefetch(p float64) *Query { + panic("unimplemented") +} + +func (q *Query) RetryPolicy(r RetryPolicy) *Query { + q.query.SetRetryPolicy(transformRetryPolicy(r)) + return q +} + +func (q *Query) SetSpeculativeExecutionPolicy(sp SpeculativeExecutionPolicy) *Query { + panic("unimplemented") +} + +func (q *Query) Idempotent(value bool) *Query { + panic("unimplemented") + // q.query.SetIdempotent(value) +} + +func (q *Query) SerialConsistency(cons Consistency) *Query { + q.query.SetSerialConsistency(frame.Consistency(cons)) + return q +} + +func (q *Query) PageState(state []byte) *Query { + q.query.SetPageState(state) + return q +} + +func (q *Query) NoSkipMetadata() *Query { + q.query.NoSkipMetadata() + return q +} diff --git a/gocql/recreate.go b/gocql/recreate.go new file mode 100644 index 00000000..2eeb6b4c --- /dev/null +++ b/gocql/recreate.go @@ -0,0 +1,539 @@ +//go:build !cassandra +// +build !cassandra + +// Copyright (C) 2017 ScyllaDB + +package gocql + +import ( + "encoding/binary" + "encoding/json" + "fmt" + "io" + "sort" + "strconv" + "strings" + "text/template" +) + +// ToCQL returns a CQL query that ca be used to recreate keyspace with all +// user defined types, tables, indexes, functions, aggregates and views associated +// with this keyspace. +func (km *KeyspaceMetadata) ToCQL() (string, error) { + var sb strings.Builder + + if err := km.keyspaceToCQL(&sb); err != nil { + return "", err + } + + sortedTypes := km.typesSortedTopologically() + for _, tm := range sortedTypes { + if err := km.userTypeToCQL(&sb, tm); err != nil { + return "", err + } + } + + for _, tm := range km.Tables { + if err := km.tableToCQL(&sb, km.Name, tm); err != nil { + return "", err + } + } + + for _, im := range km.Indexes { + if err := km.indexToCQL(&sb, im); err != nil { + return "", err + } + } + + for _, fm := range km.Functions { + if err := km.functionToCQL(&sb, km.Name, fm); err != nil { + return "", err + } + } + + for _, am := range km.Aggregates { + if err := km.aggregateToCQL(&sb, am); err != nil { + return "", err + } + } + + for _, vm := range km.Views { + if err := km.viewToCQL(&sb, vm); err != nil { + return "", err + } + } + + return sb.String(), nil +} + +func (km *KeyspaceMetadata) typesSortedTopologically() []*TypeMetadata { + sortedTypes := make([]*TypeMetadata, 0, len(km.Types)) + for _, tm := range km.Types { + sortedTypes = append(sortedTypes, tm) + } + sort.Slice(sortedTypes, func(i, j int) bool { + for _, ft := range sortedTypes[j].FieldTypes { + if strings.Contains(ft, sortedTypes[i].Name) { + return true + } + } + return false + }) + return sortedTypes +} + +var tableCQLTemplate = template.Must(template.New("table"). + Funcs(map[string]interface{}{ + "escape": cqlHelpers.escape, + "tableColumnToCQL": cqlHelpers.tableColumnToCQL, + "tablePropertiesToCQL": cqlHelpers.tablePropertiesToCQL, + }). + Parse(` +CREATE TABLE {{ .KeyspaceName }}.{{ .Tm.Name }} ( + {{ tableColumnToCQL .Tm }} +) WITH {{ tablePropertiesToCQL .Tm.ClusteringColumns .Tm.Options .Tm.Flags .Tm.Extensions }}; +`)) + +func (km *KeyspaceMetadata) tableToCQL(w io.Writer, kn string, tm *TableMetadata) error { + if err := tableCQLTemplate.Execute(w, map[string]interface{}{ + "Tm": tm, + "KeyspaceName": kn, + }); err != nil { + return err + } + return nil +} + +var functionTemplate = template.Must(template.New("functions"). + Funcs(map[string]interface{}{ + "escape": cqlHelpers.escape, + "zip": cqlHelpers.zip, + "stripFrozen": cqlHelpers.stripFrozen, + }). + Parse(` +CREATE FUNCTION {{ escape .keyspaceName }}.{{ escape .fm.Name }} ( + {{- range $i, $args := zip .fm.ArgumentNames .fm.ArgumentTypes }} + {{- if ne $i 0 }}, {{ end }} + {{- escape (index $args 0) }} + {{ stripFrozen (index $args 1) }} + {{- end -}}) + {{ if .fm.CalledOnNullInput }}CALLED{{ else }}RETURNS NULL{{ end }} ON NULL INPUT + RETURNS {{ .fm.ReturnType }} + LANGUAGE {{ .fm.Language }} + AS $${{ .fm.Body }}$$; +`)) + +func (km *KeyspaceMetadata) functionToCQL(w io.Writer, keyspaceName string, fm *FunctionMetadata) error { + if err := functionTemplate.Execute(w, map[string]interface{}{ + "fm": fm, + "keyspaceName": keyspaceName, + }); err != nil { + return err + } + return nil +} + +var viewTemplate = template.Must(template.New("views"). + Funcs(map[string]interface{}{ + "zip": cqlHelpers.zip, + "partitionKeyString": cqlHelpers.partitionKeyString, + "tablePropertiesToCQL": cqlHelpers.tablePropertiesToCQL, + }). + Parse(` +CREATE MATERIALIZED VIEW {{ .vm.KeyspaceName }}.{{ .vm.ViewName }} AS + SELECT {{ if .vm.IncludeAllColumns }}*{{ else }} + {{- range $i, $col := .vm.OrderedColumns }} + {{- if ne $i 0 }}, {{ end }} + {{ $col }} + {{- end }} + {{- end }} + FROM {{ .vm.KeyspaceName }}.{{ .vm.BaseTableName }} + WHERE {{ .vm.WhereClause }} + PRIMARY KEY ({{ partitionKeyString .vm.PartitionKey .vm.ClusteringColumns }}) + WITH {{ tablePropertiesToCQL .vm.ClusteringColumns .vm.Options .flags .vm.Extensions }}; +`)) + +func (km *KeyspaceMetadata) viewToCQL(w io.Writer, vm *ViewMetadata) error { + if err := viewTemplate.Execute(w, map[string]interface{}{ + "vm": vm, + "flags": []string{}, + }); err != nil { + return err + } + return nil +} + +var aggregatesTemplate = template.Must(template.New("aggregate"). + Funcs(map[string]interface{}{ + "stripFrozen": cqlHelpers.stripFrozen, + }). + Parse(` +CREATE AGGREGATE {{ .Keyspace }}.{{ .Name }}( + {{- range $arg, $i := .ArgumentTypes }} + {{- if ne $i 0 }}, {{ end }} + {{ stripFrozen $arg }} + {{- end -}}) + SFUNC {{ .StateFunc.Name }} + STYPE {{ stripFrozen .State }} + {{- if ne .FinalFunc.Name "" }} + FINALFUNC {{ .FinalFunc.Name }} + {{- end -}} + {{- if ne .InitCond "" }} + INITCOND {{ .InitCond }} + {{- end -}} +); +`)) + +func (km *KeyspaceMetadata) aggregateToCQL(w io.Writer, am *AggregateMetadata) error { + if err := aggregatesTemplate.Execute(w, am); err != nil { + return err + } + return nil +} + +var typeCQLTemplate = template.Must(template.New("types"). + Funcs(map[string]interface{}{ + "zip": cqlHelpers.zip, + }). + Parse(` +CREATE TYPE {{ .Keyspace }}.{{ .Name }} ( + {{- range $i, $fields := zip .FieldNames .FieldTypes }} {{- if ne $i 0 }},{{ end }} + {{ index $fields 0 }} {{ index $fields 1 }} + {{- end }} +); +`)) + +func (km *KeyspaceMetadata) userTypeToCQL(w io.Writer, tm *TypeMetadata) error { + if err := typeCQLTemplate.Execute(w, tm); err != nil { + return err + } + return nil +} + +func (km *KeyspaceMetadata) indexToCQL(w io.Writer, im *IndexMetadata) error { + // Scylla doesn't support any custom indexes + if im.Kind == IndexKindCustom { + return nil + } + + options := im.Options + indexTarget := options["target"] + + // secondary index + si := struct { + ClusteringKeys []string `json:"ck"` + PartitionKeys []string `json:"pk"` + }{} + + if err := json.Unmarshal([]byte(indexTarget), &si); err == nil { + indexTarget = fmt.Sprintf("(%s), %s", + strings.Join(si.PartitionKeys, ","), + strings.Join(si.ClusteringKeys, ","), + ) + } + + _, err := fmt.Fprintf(w, "\nCREATE INDEX %s ON %s.%s (%s);\n", + im.Name, + im.KeyspaceName, + im.TableName, + indexTarget, + ) + if err != nil { + return err + } + + return nil +} + +var keyspaceCQLTemplate = template.Must(template.New("keyspace"). + Funcs(map[string]interface{}{ + "escape": cqlHelpers.escape, + "fixStrategy": cqlHelpers.fixStrategy, + }). + Parse(`CREATE KEYSPACE {{ .Name }} WITH replication = { + 'class': {{ escape ( fixStrategy .StrategyClass) }} + {{- range $key, $value := .StrategyOptions }}, + {{ escape $key }}: {{ escape $value }} + {{- end }} +}{{ if not .DurableWrites }} AND durable_writes = 'false'{{ end }}; +`)) + +func (km *KeyspaceMetadata) keyspaceToCQL(w io.Writer) error { + if err := keyspaceCQLTemplate.Execute(w, km); err != nil { + return err + } + return nil +} + +func contains(in []string, v string) bool { + for _, e := range in { + if e == v { + return true + } + } + return false +} + +type toCQLHelpers struct{} + +var cqlHelpers = toCQLHelpers{} + +func (h toCQLHelpers) zip(a []string, b []string) [][]string { + m := make([][]string, len(a)) + for i := range a { + m[i] = []string{a[i], b[i]} + } + return m +} + +func (h toCQLHelpers) escape(e interface{}) string { + switch v := e.(type) { + case int, float64: + return fmt.Sprint(v) + case bool: + if v { + return "true" + } + return "false" + case string: + return "'" + strings.ReplaceAll(v, "'", "''") + "'" + case []byte: + return string(v) + } + return "" +} + +func (h toCQLHelpers) stripFrozen(v string) string { + return strings.TrimSuffix(strings.TrimPrefix(v, "frozen<"), ">") +} +func (h toCQLHelpers) fixStrategy(v string) string { + return strings.TrimPrefix(v, "org.apache.cassandra.locator.") +} + +func (h toCQLHelpers) fixQuote(v string) string { + return strings.ReplaceAll(v, `"`, `'`) +} + +func (h toCQLHelpers) tableOptionsToCQL(ops TableMetadataOptions) ([]string, error) { + opts := map[string]interface{}{ + "bloom_filter_fp_chance": ops.BloomFilterFpChance, + "comment": ops.Comment, + "crc_check_chance": ops.CrcCheckChance, + "dclocal_read_repair_chance": ops.DcLocalReadRepairChance, + "default_time_to_live": ops.DefaultTimeToLive, + "gc_grace_seconds": ops.GcGraceSeconds, + "max_index_interval": ops.MaxIndexInterval, + "memtable_flush_period_in_ms": ops.MemtableFlushPeriodInMs, + "min_index_interval": ops.MinIndexInterval, + "read_repair_chance": ops.ReadRepairChance, + "speculative_retry": ops.SpeculativeRetry, + } + + var err error + opts["caching"], err = json.Marshal(ops.Caching) + if err != nil { + return nil, err + } + + opts["compaction"], err = json.Marshal(ops.Compaction) + if err != nil { + return nil, err + } + + opts["compression"], err = json.Marshal(ops.Compression) + if err != nil { + return nil, err + } + + cdc, err := json.Marshal(ops.CDC) + if err != nil { + return nil, err + } + + if string(cdc) != "null" { + opts["cdc"] = cdc + } + + if ops.InMemory { + opts["in_memory"] = ops.InMemory + } + + out := make([]string, 0, len(opts)) + for key, opt := range opts { + out = append(out, fmt.Sprintf("%s = %s", key, h.fixQuote(h.escape(opt)))) + } + + sort.Strings(out) + return out, nil +} + +func (h toCQLHelpers) tableExtensionsToCQL(extensions map[string]interface{}) ([]string, error) { + exts := map[string]interface{}{} + + if blob, ok := extensions["scylla_encryption_options"]; ok { + encOpts := &scyllaEncryptionOptions{} + if err := encOpts.UnmarshalBinary(blob.([]byte)); err != nil { + return nil, err + } + + var err error + exts["scylla_encryption_options"], err = json.Marshal(encOpts) + if err != nil { + return nil, err + } + + } + + out := make([]string, 0, len(exts)) + for key, ext := range exts { + out = append(out, fmt.Sprintf("%s = %s", key, h.fixQuote(h.escape(ext)))) + } + + sort.Strings(out) + return out, nil +} + +func (h toCQLHelpers) tablePropertiesToCQL(cks []*ColumnMetadata, opts TableMetadataOptions, + flags []string, extensions map[string]interface{}) (string, error) { + var sb strings.Builder + + var properties []string + + compactStorage := len(flags) > 0 && (contains(flags, TableFlagDense) || + contains(flags, TableFlagSuper) || + !contains(flags, TableFlagCompound)) + + if compactStorage { + properties = append(properties, "COMPACT STORAGE") + } + + if len(cks) > 0 { + var inner []string + for _, col := range cks { + inner = append(inner, fmt.Sprintf("%s %s", col.Name, col.ClusteringOrder)) + } + properties = append(properties, fmt.Sprintf("CLUSTERING ORDER BY (%s)", strings.Join(inner, ", "))) + } + + options, err := h.tableOptionsToCQL(opts) + if err != nil { + return "", err + } + properties = append(properties, options...) + + exts, err := h.tableExtensionsToCQL(extensions) + if err != nil { + return "", err + } + properties = append(properties, exts...) + + sb.WriteString(strings.Join(properties, "\n AND ")) + return sb.String(), nil +} + +func (h toCQLHelpers) tableColumnToCQL(tm *TableMetadata) string { + var sb strings.Builder + + var columns []string + for _, cn := range tm.OrderedColumns { + cm := tm.Columns[cn] + column := fmt.Sprintf("%s %s", cn, cm.Type) + if cm.Kind == ColumnStatic { + column += " static" + } + columns = append(columns, column) + } + if len(tm.PartitionKey) == 1 && len(tm.ClusteringColumns) == 0 && len(columns) > 0 { + columns[0] += " PRIMARY KEY" + } + + sb.WriteString(strings.Join(columns, ",\n ")) + + if len(tm.PartitionKey) > 1 || len(tm.ClusteringColumns) > 0 { + sb.WriteString(",\n PRIMARY KEY (") + sb.WriteString(h.partitionKeyString(tm.PartitionKey, tm.ClusteringColumns)) + sb.WriteRune(')') + } + + return sb.String() +} + +func (h toCQLHelpers) partitionKeyString(pks, cks []*ColumnMetadata) string { + var sb strings.Builder + + if len(pks) > 1 { + sb.WriteRune('(') + for i, pk := range pks { + if i != 0 { + sb.WriteString(", ") + } + sb.WriteString(pk.Name) + } + sb.WriteRune(')') + } else { + sb.WriteString(pks[0].Name) + } + + if len(cks) > 0 { + sb.WriteString(", ") + for i, ck := range cks { + if i != 0 { + sb.WriteString(", ") + } + sb.WriteString(ck.Name) + } + } + + return sb.String() +} + +type scyllaEncryptionOptions struct { + CipherAlgorithm string `json:"cipher_algorithm"` + SecretKeyStrength int `json:"secret_key_strength"` + KeyProvider string `json:"key_provider"` + SecretKeyFile string `json:"secret_key_file"` +} + +// UnmarshalBinary deserializes blob into scyllaEncryptionOptions. +// Format: +// - 4 bytes - size of KV map +// Size times: +// - 4 bytes - length of key +// - len_of_key bytes - key +// - 4 bytes - length of value +// - len_of_value bytes - value +func (enc *scyllaEncryptionOptions) UnmarshalBinary(data []byte) error { + size := binary.LittleEndian.Uint32(data[0:4]) + + m := make(map[string]string, size) + + off := uint32(4) + for i := uint32(0); i < size; i++ { + keyLen := binary.LittleEndian.Uint32(data[off : off+4]) + off += 4 + + key := string(data[off : off+keyLen]) + off += keyLen + + valueLen := binary.LittleEndian.Uint32(data[off : off+4]) + off += 4 + + value := string(data[off : off+valueLen]) + off += valueLen + + m[key] = value + } + + enc.CipherAlgorithm = m["cipher_algorithm"] + enc.KeyProvider = m["key_provider"] + enc.SecretKeyFile = m["secret_key_file"] + if secretKeyStrength, ok := m["secret_key_strength"]; ok { + sks, err := strconv.Atoi(secretKeyStrength) + if err != nil { + return err + } + enc.SecretKeyStrength = sks + } + + return nil +} diff --git a/gocql/retry.go b/gocql/retry.go new file mode 100644 index 00000000..c23093fb --- /dev/null +++ b/gocql/retry.go @@ -0,0 +1,77 @@ +package gocql + +import ( + "math" + "math/rand" + "time" + + "github.com/scylladb/scylla-go-driver/transport" +) + +// ExponentialBackoffRetryPolicy sleeps between attempts +type ExponentialBackoffRetryPolicy struct { + NumRetries int + attempts int + Min, Max time.Duration +} + +func (e *ExponentialBackoffRetryPolicy) NewRetryDecider() transport.RetryDecider { + return e +} + +func (e *ExponentialBackoffRetryPolicy) Decide(transport.RetryInfo) transport.RetryDecision { + if e.attempt() { + return transport.RetryNextNode + } + return transport.DontRetry +} + +func (e *ExponentialBackoffRetryPolicy) Reset() { + e.attempts = 0 +} + +func (e *ExponentialBackoffRetryPolicy) attempt() bool { + if e.attempts > e.NumRetries { + return false + } + time.Sleep(e.napTime(e.attempts)) + e.attempts++ + return true +} + +func (e *ExponentialBackoffRetryPolicy) napTime(attempts int) time.Duration { + return getExponentialTime(e.Min, e.Max, attempts) +} + +// used to calculate exponentially growing time +func getExponentialTime(min time.Duration, max time.Duration, attempts int) time.Duration { + if min <= 0 { + min = 100 * time.Millisecond + } + if max <= 0 { + max = 10 * time.Second + } + minFloat := float64(min) + napDuration := minFloat * math.Pow(2, float64(attempts-1)) + // add some jitter + napDuration += rand.Float64()*minFloat - (minFloat / 2) + if napDuration > float64(max) { + return time.Duration(max) + } + return time.Duration(napDuration) +} + +func transformRetryPolicy(rp RetryPolicy) transport.RetryPolicy { + if ret, ok := rp.(transport.RetryPolicy); ok { + return ret + } + if rp == nil { + return transport.NewFallthroughRetryPolicy() + } + + return transport.NewDefaultRetryPolicy() +} + +type SimpleRetryPolicy struct { + NumRetries int +} diff --git a/gocql/scanner.go b/gocql/scanner.go new file mode 100644 index 00000000..2f4a6f76 --- /dev/null +++ b/gocql/scanner.go @@ -0,0 +1,116 @@ +package gocql + +import ( + "errors" + "fmt" + + "github.com/scylladb/scylla-go-driver/frame" +) + +type Scanner interface { + // Next advances the row pointer to point at the next row, the row is valid until + // the next call of Next. It returns true if there is a row which is available to be + // scanned into with Scan. + // Next must be called before every call to Scan. + Next() bool + + // Scan copies the current row's columns into dest. If the length of dest does not equal + // the number of columns returned in the row an error is returned. If an error is encountered + // when unmarshalling a column into the value in dest an error is returned and the row is invalidated + // until the next call to Next. + // Next must be called before calling Scan, if it is not an error is returned. + Scan(...interface{}) error + + // Err returns the if there was one during iteration that resulted in iteration being unable to complete. + // Err will also release resources held by the iterator, the Scanner should not used after being called. + Err() error +} + +type iterScanner struct { + iter *Iter + result frame.Row + valid bool +} + +func (is *iterScanner) Next() bool { + iter := is.iter + if iter.err != nil { + return false + } + + var err error + is.result, err = iter.it.Next() + is.valid = err != nil + return is.valid +} + +func scanColumn(p []byte, col ColumnInfo, dest []interface{}) (int, error) { + if dest[0] == nil { + return 1, nil + } + + if col.TypeInfo.Type() == TypeTuple { + // this will panic, actually a bug, please report + tuple := col.TypeInfo.(TupleTypeInfo) + + count := len(tuple.Elems) + // here we pass in a slice of the struct which has the number number of + // values as elements in the tuple + if err := Unmarshal(col.TypeInfo, p, dest[:count]); err != nil { + return 0, err + } + return count, nil + } else { + if err := Unmarshal(col.TypeInfo, p, dest[0]); err != nil { + return 0, err + } + return 1, nil + } +} + +func (is *iterScanner) Scan(dest ...interface{}) error { + if !is.valid { + return errors.New("gocql: Scan called without calling Next") + } + + iter := is.iter + // currently only support scanning into an expand tuple, such that its the same + // as scanning in more values from a single column + if len(dest) != len(iter.Columns()) { + return fmt.Errorf("gocql: not enough columns to scan into: have %d want %d", len(dest), len(iter.Columns())) + } + + // i is the current position in dest, could posible replace it and just use + // slices of dest + i := 0 + var err error + for _, col := range iter.Columns() { + var n int + n, err = scanColumn(is.result[i].Value, col, dest[i:]) + if err != nil { + break + } + i += n + } + + is.valid = false + return err +} + +func (is *iterScanner) Err() error { + iter := is.iter + is.iter = nil + is.result = nil + is.valid = false + return iter.Close() +} + +// Scanner returns a row Scanner which provides an interface to scan rows in a manner which is +// similar to database/sql. The iter should NOT be used again after calling this method. +func (iter *Iter) Scanner() Scanner { + if iter == nil { + return nil + } + + return &iterScanner{iter: iter, result: make(frame.Row, len(iter.Columns()))} +} diff --git a/gocql/session.go b/gocql/session.go new file mode 100644 index 00000000..f669e5d2 --- /dev/null +++ b/gocql/session.go @@ -0,0 +1,80 @@ +package gocql + +import ( + "context" + "errors" + "time" + + "github.com/scylladb/scylla-go-driver" +) + +type Session struct { + session *scylla.Session + cfg scylla.SessionConfig + control SingleHostQueryExecutor + schemaDescriber *schemaDescriber +} + +func NewSession(cfg ClusterConfig) (*Session, error) { + scfg, err := sessionConfigFromGocql(&cfg) + if err != nil { + return nil, err + } + + session, err := scylla.NewSession(context.Background(), scfg) + if err != nil { + return nil, err + } + s := &Session{ + session: session, + cfg: scfg, + } + + s.control, err = NewSingleHostQueryExecutor(&cfg) + if err != nil { + return nil, err + } + + s.cfg.RetryPolicy = transformRetryPolicy(cfg.RetryPolicy) + s.schemaDescriber = newSchemaDescriber(s) + return s, nil +} + +func (s *Session) Query(stmt string, values ...interface{}) *Query { + return &Query{ + ctx: context.Background(), + query: s.session.Query(stmt), + values: values, + } +} + +func (s *Session) Close() { + s.session.Close() + s.control.Close() +} + +func (s *Session) Closed() bool { + return s.session.Closed() +} + +func (s *Session) AwaitSchemaAgreement(ctx context.Context) error { + s.session.AwaitSchemaAgreement(context.Background(), time.Minute) + return nil +} + +var ( + ErrSessionClosed = errors.New("session closed") + ErrNoKeyspace = errors.New("no keyspace") +) + +// KeyspaceMetadata returns the schema metadata for the keyspace specified. Returns an error if the keyspace does not exist. +func (s *Session) KeyspaceMetadata(keyspace string) (*KeyspaceMetadata, error) { + // fail fast + if s.Closed() { + return nil, ErrSessionClosed + } else if keyspace == "" { + return nil, ErrNoKeyspace + } + + return s.schemaDescriber.getSchema(keyspace) +} diff --git a/gocql/types.go b/gocql/types.go new file mode 100644 index 00000000..07a3810b --- /dev/null +++ b/gocql/types.go @@ -0,0 +1,237 @@ +package gocql + +import ( + "crypto/tls" + "crypto/x509" + "errors" + "fmt" + "io/ioutil" + + "github.com/scylladb/scylla-go-driver" + "github.com/scylladb/scylla-go-driver/frame" + "github.com/scylladb/scylla-go-driver/transport" +) + +type unsetColumn struct{} + +// UnsetValue represents a value used in a query binding that will be ignored by Cassandra. +// +// By setting a field to the unset value Cassandra will ignore the write completely. +// The main advantage is the ability to keep the same prepared statement even when you don't +// want to update some fields, where before you needed to make another prepared statement. +// +// UnsetValue is only available when using the version 4 of the protocol. +var UnsetValue = unsetColumn{} + +const ( + protoDirectionMask = 0x80 + protoVersionMask = 0x7F + protoVersion1 = 0x01 + protoVersion2 = 0x02 + protoVersion3 = 0x03 + protoVersion4 = 0x04 + protoVersion5 = 0x05 +) + +type Duration struct { + Months int32 + Days int32 + Nanoseconds int64 +} + +type PoolConfig struct { + // HostSelectionPolicy sets the policy for selecting which host to use for a + // given query (default: RoundRobinHostPolicy()) + HostSelectionPolicy HostSelectionPolicy +} + +type HostSelectionPolicy interface{} + +func TokenAwareHostPolicy(hsp HostSelectionPolicy) HostSelectionPolicy { + return hsp +} + +func RoundRobinHostPolicy() HostSelectionPolicy { + return transport.NewTokenAwarePolicy("") +} + +func DCAwareRoundRobinPolicy(localDC string) HostSelectionPolicy { + return transport.NewTokenAwarePolicy(localDC) +} + +type RetryPolicy interface{} // TODO: use retry policy +type SpeculativeExecutionPolicy interface{} +type ConvictionPolicy interface { + // Implementations should return `true` if the host should be convicted, `false` otherwise. + AddFailure(error error, host *HostInfo) bool + //Implementations should clear out any convictions or state regarding the host. + Reset(host *HostInfo) +} +type HostInfo interface{} + +// SimpleConvictionPolicy implements a ConvictionPolicy which convicts all hosts +// regardless of error +type SimpleConvictionPolicy struct { +} + +func (e *SimpleConvictionPolicy) AddFailure(error error, host *HostInfo) bool { + return true +} + +func (e *SimpleConvictionPolicy) Reset(host *HostInfo) {} + +type SerialConsistency = Consistency +type QueryObserver interface{} +type Tracer interface{} +type Compressor interface{} + +type ColumnInfo struct { + Keyspace string + Table string + Name string + TypeInfo TypeInfo +} + +type optionWrapper frame.Option + +func WrapOption(o *frame.Option) TypeInfo { + nt := NewNativeType(0x04, Type(o.ID), "") + switch o.ID { + case frame.ListID: + return CollectionType{ + NativeType: nt, + Elem: WrapOption(&o.List.Element), + } + case frame.SetID: + return CollectionType{ + NativeType: nt, + Elem: WrapOption(&o.Set.Element), + } + case frame.MapID: + return CollectionType{ + NativeType: nt, + Key: WrapOption(&o.Map.Key), + Elem: WrapOption(&o.Map.Value), + } + case frame.UDTID: + return UDTTypeInfo{ + NativeType: nt, + KeySpace: o.UDT.Keyspace, + Name: o.UDT.Name, + Elements: getUDTFields(o.UDT), + } + case frame.CustomID: + panic("unimplemented") + default: + return NewNativeType(0x04, Type(o.ID), "") + } +} + +func getUDTFields(udt *frame.UDTOption) []UDTField { + res := make([]UDTField, len(udt.FieldNames)) + for i := range res { + res[i] = UDTField{ + Name: udt.FieldNames[i], + Type: WrapOption(&udt.FieldTypes[i]), + } + } + + return res +} + +var ErrNotFound = errors.New("not found") + +type Consistency scylla.Consistency + +const ( + Any Consistency = 0x00 + One Consistency = 0x01 + Two Consistency = 0x02 + Three Consistency = 0x03 + Quorum Consistency = 0x04 + All Consistency = 0x05 + LocalQuorum Consistency = 0x06 + EachQuorum Consistency = 0x07 + Serial Consistency = 0x08 + LocalSerial Consistency = 0x09 + LocalOne Consistency = 0x0A +) + +type SnappyCompressor struct{} + +type Authenticator interface{} + +var ErrKeyspaceDoesNotExist = errors.New("keyspace doesn't exist") + +type PasswordAuthenticator struct { + Username, Password string +} + +type SslOptions struct { + *tls.Config + + // CertPath and KeyPath are optional depending on server + // config, but both fields must be omitted to avoid using a + // client certificate + CertPath string + KeyPath string + CaPath string //optional depending on server config + // If you want to verify the hostname and server cert (like a wildcard for cass cluster) then you should turn this + // on. + // This option is basically the inverse of tls.Config.InsecureSkipVerify. + // See InsecureSkipVerify in http://golang.org/pkg/crypto/tls/ for more info. + // + // See SslOptions documentation to see how EnableHostVerification interacts with the provided tls.Config. + EnableHostVerification bool +} + +func setupTLSConfig(sslOpts *SslOptions) (*tls.Config, error) { + // Config.InsecureSkipVerify | EnableHostVerification | Result + // Config is nil | true | verify host + // Config is nil | false | do not verify host + // false | false | verify host + // true | false | do not verify host + // false | true | verify host + // true | true | verify host + var tlsConfig *tls.Config + if sslOpts.Config == nil { + tlsConfig = &tls.Config{ + InsecureSkipVerify: !sslOpts.EnableHostVerification, + } + } else { + // use clone to avoid race. + tlsConfig = sslOpts.Config.Clone() + } + + if tlsConfig.InsecureSkipVerify && sslOpts.EnableHostVerification { + tlsConfig.InsecureSkipVerify = false + } + + // ca cert is optional + if sslOpts.CaPath != "" { + if tlsConfig.RootCAs == nil { + tlsConfig.RootCAs = x509.NewCertPool() + } + + pem, err := ioutil.ReadFile(sslOpts.CaPath) + if err != nil { + return nil, fmt.Errorf("connectionpool: unable to open CA certs: %v", err) + } + + if !tlsConfig.RootCAs.AppendCertsFromPEM(pem) { + return nil, errors.New("connectionpool: failed parsing or CA certs") + } + } + + if sslOpts.CertPath != "" || sslOpts.KeyPath != "" { + mycert, err := tls.LoadX509KeyPair(sslOpts.CertPath, sslOpts.KeyPath) + if err != nil { + return nil, fmt.Errorf("connectionpool: unable to load X509 key pair: %v", err) + } + tlsConfig.Certificates = append(tlsConfig.Certificates, mycert) + } + + return tlsConfig, nil +} + +var ErrNoHosts = errors.New("no hosts provided") diff --git a/gocql/uuid.go b/gocql/uuid.go new file mode 100644 index 00000000..acdd81f9 --- /dev/null +++ b/gocql/uuid.go @@ -0,0 +1,324 @@ +// Copyright (c) 2012 The gocql Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package gocql + +// The uuid package can be used to generate and parse universally unique +// identifiers, a standardized format in the form of a 128 bit number. +// +// http://tools.ietf.org/html/rfc4122 + +import ( + "crypto/rand" + "errors" + "fmt" + "io" + "net" + "strings" + "sync/atomic" + "time" +) + +type UUID [16]byte + +var hardwareAddr []byte +var clockSeq uint32 + +const ( + VariantNCSCompat = 0 + VariantIETF = 2 + VariantMicrosoft = 6 + VariantFuture = 7 +) + +func init() { + if interfaces, err := net.Interfaces(); err == nil { + for _, i := range interfaces { + if i.Flags&net.FlagLoopback == 0 && len(i.HardwareAddr) > 0 { + hardwareAddr = i.HardwareAddr + break + } + } + } + if hardwareAddr == nil { + // If we failed to obtain the MAC address of the current computer, + // we will use a randomly generated 6 byte sequence instead and set + // the multicast bit as recommended in RFC 4122. + hardwareAddr = make([]byte, 6) + _, err := io.ReadFull(rand.Reader, hardwareAddr) + if err != nil { + panic(err) + } + hardwareAddr[0] = hardwareAddr[0] | 0x01 + } + + // initialize the clock sequence with a random number + var clockSeqRand [2]byte + io.ReadFull(rand.Reader, clockSeqRand[:]) + clockSeq = uint32(clockSeqRand[1])<<8 | uint32(clockSeqRand[0]) +} + +// ParseUUID parses a 32 digit hexadecimal number (that might contain hypens) +// representing an UUID. +func ParseUUID(input string) (UUID, error) { + var u UUID + j := 0 + for _, r := range input { + switch { + case r == '-' && j&1 == 0: + continue + case r >= '0' && r <= '9' && j < 32: + u[j/2] |= byte(r-'0') << uint(4-j&1*4) + case r >= 'a' && r <= 'f' && j < 32: + u[j/2] |= byte(r-'a'+10) << uint(4-j&1*4) + case r >= 'A' && r <= 'F' && j < 32: + u[j/2] |= byte(r-'A'+10) << uint(4-j&1*4) + default: + return UUID{}, fmt.Errorf("invalid UUID %q", input) + } + j += 1 + } + if j != 32 { + return UUID{}, fmt.Errorf("invalid UUID %q", input) + } + return u, nil +} + +// UUIDFromBytes converts a raw byte slice to an UUID. +func UUIDFromBytes(input []byte) (UUID, error) { + var u UUID + if len(input) != 16 { + return u, errors.New("UUIDs must be exactly 16 bytes long") + } + + copy(u[:], input) + return u, nil +} + +func MustRandomUUID() UUID { + uuid, err := RandomUUID() + if err != nil { + panic(err) + } + return uuid +} + +// RandomUUID generates a totally random UUID (version 4) as described in +// RFC 4122. +func RandomUUID() (UUID, error) { + var u UUID + _, err := io.ReadFull(rand.Reader, u[:]) + if err != nil { + return u, err + } + u[6] &= 0x0F // clear version + u[6] |= 0x40 // set version to 4 (random uuid) + u[8] &= 0x3F // clear variant + u[8] |= 0x80 // set to IETF variant + return u, nil +} + +var timeBase = time.Date(1582, time.October, 15, 0, 0, 0, 0, time.UTC).Unix() + +// getTimestamp converts time to UUID (version 1) timestamp. +// It must be an interval of 100-nanoseconds since timeBase. +func getTimestamp(t time.Time) int64 { + utcTime := t.In(time.UTC) + ts := int64(utcTime.Unix()-timeBase)*10000000 + int64(utcTime.Nanosecond()/100) + + return ts +} + +// TimeUUID generates a new time based UUID (version 1) using the current +// time as the timestamp. +func TimeUUID() UUID { + return UUIDFromTime(time.Now()) +} + +// The min and max clock values for a UUID. +// +// Cassandra's TimeUUIDType compares the lsb parts as signed byte arrays. +// Thus, the min value for each byte is -128 and the max is +127. +const ( + minClock = 0x8080 + maxClock = 0x7f7f +) + +// The min and max node values for a UUID. +// +// See explanation about Cassandra's TimeUUIDType comparison logic above. +var ( + minNode = []byte{0x80, 0x80, 0x80, 0x80, 0x80, 0x80} + maxNode = []byte{0x7f, 0x7f, 0x7f, 0x7f, 0x7f, 0x7f} +) + +// MinTimeUUID generates a "fake" time based UUID (version 1) which will be +// the smallest possible UUID generated for the provided timestamp. +// +// UUIDs generated by this function are not unique and are mostly suitable only +// in queries to select a time range of a Cassandra's TimeUUID column. +func MinTimeUUID(t time.Time) UUID { + return TimeUUIDWith(getTimestamp(t), minClock, minNode) +} + +// MaxTimeUUID generates a "fake" time based UUID (version 1) which will be +// the biggest possible UUID generated for the provided timestamp. +// +// UUIDs generated by this function are not unique and are mostly suitable only +// in queries to select a time range of a Cassandra's TimeUUID column. +func MaxTimeUUID(t time.Time) UUID { + return TimeUUIDWith(getTimestamp(t), maxClock, maxNode) +} + +// UUIDFromTime generates a new time based UUID (version 1) as described in +// RFC 4122. This UUID contains the MAC address of the node that generated +// the UUID, the given timestamp and a sequence number. +func UUIDFromTime(t time.Time) UUID { + ts := getTimestamp(t) + clock := atomic.AddUint32(&clockSeq, 1) + + return TimeUUIDWith(ts, clock, hardwareAddr) +} + +// TimeUUIDWith generates a new time based UUID (version 1) as described in +// RFC4122 with given parameters. t is the number of 100's of nanoseconds +// since 15 Oct 1582 (60bits). clock is the number of clock sequence (14bits). +// node is a slice to gurarantee the uniqueness of the UUID (up to 6bytes). +// Note: calling this function does not increment the static clock sequence. +func TimeUUIDWith(t int64, clock uint32, node []byte) UUID { + var u UUID + + u[0], u[1], u[2], u[3] = byte(t>>24), byte(t>>16), byte(t>>8), byte(t) + u[4], u[5] = byte(t>>40), byte(t>>32) + u[6], u[7] = byte(t>>56)&0x0F, byte(t>>48) + + u[8] = byte(clock >> 8) + u[9] = byte(clock) + + copy(u[10:], node) + + u[6] |= 0x10 // set version to 1 (time based uuid) + u[8] &= 0x3F // clear variant + u[8] |= 0x80 // set to IETF variant + + return u +} + +// String returns the UUID in it's canonical form, a 32 digit hexadecimal +// number in the form of xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx. +func (u UUID) String() string { + var offsets = [...]int{0, 2, 4, 6, 9, 11, 14, 16, 19, 21, 24, 26, 28, 30, 32, 34} + const hexString = "0123456789abcdef" + r := make([]byte, 36) + for i, b := range u { + r[offsets[i]] = hexString[b>>4] + r[offsets[i]+1] = hexString[b&0xF] + } + r[8] = '-' + r[13] = '-' + r[18] = '-' + r[23] = '-' + return string(r) + +} + +// Bytes returns the raw byte slice for this UUID. A UUID is always 128 bits +// (16 bytes) long. +func (u UUID) Bytes() []byte { + return u[:] +} + +// Variant returns the variant of this UUID. This package will only generate +// UUIDs in the IETF variant. +func (u UUID) Variant() int { + x := u[8] + if x&0x80 == 0 { + return VariantNCSCompat + } + if x&0x40 == 0 { + return VariantIETF + } + if x&0x20 == 0 { + return VariantMicrosoft + } + return VariantFuture +} + +// Version extracts the version of this UUID variant. The RFC 4122 describes +// five kinds of UUIDs. +func (u UUID) Version() int { + return int(u[6] & 0xF0 >> 4) +} + +// Node extracts the MAC address of the node who generated this UUID. It will +// return nil if the UUID is not a time based UUID (version 1). +func (u UUID) Node() []byte { + if u.Version() != 1 { + return nil + } + return u[10:] +} + +// Clock extracts the clock sequence of this UUID. It will return zero if the +// UUID is not a time based UUID (version 1). +func (u UUID) Clock() uint32 { + if u.Version() != 1 { + return 0 + } + + // Clock sequence is the lower 14bits of u[8:10] + return uint32(u[8]&0x3F)<<8 | uint32(u[9]) +} + +// Timestamp extracts the timestamp information from a time based UUID +// (version 1). +func (u UUID) Timestamp() int64 { + if u.Version() != 1 { + return 0 + } + return int64(uint64(u[0])<<24|uint64(u[1])<<16| + uint64(u[2])<<8|uint64(u[3])) + + int64(uint64(u[4])<<40|uint64(u[5])<<32) + + int64(uint64(u[6]&0x0F)<<56|uint64(u[7])<<48) +} + +// Time is like Timestamp, except that it returns a time.Time. +func (u UUID) Time() time.Time { + if u.Version() != 1 { + return time.Time{} + } + t := u.Timestamp() + sec := t / 1e7 + nsec := (t % 1e7) * 100 + return time.Unix(sec+timeBase, nsec).UTC() +} + +// Marshaling for JSON +func (u UUID) MarshalJSON() ([]byte, error) { + return []byte(`"` + u.String() + `"`), nil +} + +// Unmarshaling for JSON +func (u *UUID) UnmarshalJSON(data []byte) error { + str := strings.Trim(string(data), `"`) + if len(str) > 36 { + return fmt.Errorf("invalid JSON UUID %s", str) + } + + parsed, err := ParseUUID(str) + if err == nil { + copy(u[:], parsed[:]) + } + + return err +} + +func (u UUID) MarshalText() ([]byte, error) { + return []byte(u.String()), nil +} + +func (u *UUID) UnmarshalText(text []byte) (err error) { + *u, err = ParseUUID(string(text)) + return +}