You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
996 lines
22 KiB
996 lines
22 KiB
// Copyright (c) 2012 The gocql Authors. All rights reserved.
|
|
// Use of this source code is governed by a BSD-style
|
|
// license that can be found in the LICENSE file.
|
|
|
|
package gocql
|
|
|
|
import (
|
|
"bufio"
|
|
"crypto/tls"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"io/ioutil"
|
|
"log"
|
|
"net"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/gocql/gocql/internal/streams"
|
|
)
|
|
|
|
var (
|
|
approvedAuthenticators = [...]string{
|
|
"org.apache.cassandra.auth.PasswordAuthenticator",
|
|
"com.instaclustr.cassandra.auth.SharedSecretAuthenticator",
|
|
}
|
|
)
|
|
|
|
func approve(authenticator string) bool {
|
|
for _, s := range approvedAuthenticators {
|
|
if authenticator == s {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
//JoinHostPort is a utility to return a address string that can be used
|
|
//gocql.Conn to form a connection with a host.
|
|
func JoinHostPort(addr string, port int) string {
|
|
addr = strings.TrimSpace(addr)
|
|
if _, _, err := net.SplitHostPort(addr); err != nil {
|
|
addr = net.JoinHostPort(addr, strconv.Itoa(port))
|
|
}
|
|
return addr
|
|
}
|
|
|
|
type Authenticator interface {
|
|
Challenge(req []byte) (resp []byte, auth Authenticator, err error)
|
|
Success(data []byte) error
|
|
}
|
|
|
|
type PasswordAuthenticator struct {
|
|
Username string
|
|
Password string
|
|
}
|
|
|
|
func (p PasswordAuthenticator) Challenge(req []byte) ([]byte, Authenticator, error) {
|
|
if !approve(string(req)) {
|
|
return nil, nil, fmt.Errorf("unexpected authenticator %q", req)
|
|
}
|
|
resp := make([]byte, 2+len(p.Username)+len(p.Password))
|
|
resp[0] = 0
|
|
copy(resp[1:], p.Username)
|
|
resp[len(p.Username)+1] = 0
|
|
copy(resp[2+len(p.Username):], p.Password)
|
|
return resp, nil, nil
|
|
}
|
|
|
|
func (p PasswordAuthenticator) Success(data []byte) error {
|
|
return nil
|
|
}
|
|
|
|
type SslOptions struct {
|
|
tls.Config
|
|
|
|
// CertPath and KeyPath are optional depending on server
|
|
// config, but both fields must be omitted to avoid using a
|
|
// client certificate
|
|
CertPath string
|
|
KeyPath string
|
|
CaPath string //optional depending on server config
|
|
// If you want to verify the hostname and server cert (like a wildcard for cass cluster) then you should turn this on
|
|
// This option is basically the inverse of InSecureSkipVerify
|
|
// See InSecureSkipVerify in http://golang.org/pkg/crypto/tls/ for more info
|
|
EnableHostVerification bool
|
|
}
|
|
|
|
type ConnConfig struct {
|
|
ProtoVersion int
|
|
CQLVersion string
|
|
Timeout time.Duration
|
|
Compressor Compressor
|
|
Authenticator Authenticator
|
|
Keepalive time.Duration
|
|
tlsConfig *tls.Config
|
|
}
|
|
|
|
type ConnErrorHandler interface {
|
|
HandleError(conn *Conn, err error, closed bool)
|
|
}
|
|
|
|
// How many timeouts we will allow to occur before the connection is closed
|
|
// and restarted. This is to prevent a single query timeout from killing a connection
|
|
// which may be serving more queries just fine.
|
|
// Default is 10, should not be changed concurrently with queries.
|
|
var TimeoutLimit int64 = 10
|
|
|
|
// Conn is a single connection to a Cassandra node. It can be used to execute
|
|
// queries, but users are usually advised to use a more reliable, higher
|
|
// level API.
|
|
type Conn struct {
|
|
conn net.Conn
|
|
r *bufio.Reader
|
|
timeout time.Duration
|
|
cfg *ConnConfig
|
|
|
|
headerBuf []byte
|
|
|
|
streams *streams.IDGenerator
|
|
mu sync.RWMutex
|
|
calls map[int]*callReq
|
|
|
|
errorHandler ConnErrorHandler
|
|
compressor Compressor
|
|
auth Authenticator
|
|
addr string
|
|
version uint8
|
|
currentKeyspace string
|
|
started bool
|
|
|
|
session *Session
|
|
|
|
closed int32
|
|
quit chan struct{}
|
|
|
|
timeouts int64
|
|
}
|
|
|
|
// Connect establishes a connection to a Cassandra node.
|
|
// You must also call the Serve method before you can execute any queries.
|
|
func Connect(addr string, cfg *ConnConfig, errorHandler ConnErrorHandler, session *Session) (*Conn, error) {
|
|
var (
|
|
err error
|
|
conn net.Conn
|
|
)
|
|
|
|
dialer := &net.Dialer{
|
|
Timeout: cfg.Timeout,
|
|
}
|
|
|
|
if cfg.tlsConfig != nil {
|
|
// the TLS config is safe to be reused by connections but it must not
|
|
// be modified after being used.
|
|
conn, err = tls.DialWithDialer(dialer, "tcp", addr, cfg.tlsConfig)
|
|
} else {
|
|
conn, err = dialer.Dial("tcp", addr)
|
|
}
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// going to default to proto 2
|
|
if cfg.ProtoVersion < protoVersion1 || cfg.ProtoVersion > protoVersion4 {
|
|
log.Printf("unsupported protocol version: %d using 2\n", cfg.ProtoVersion)
|
|
cfg.ProtoVersion = 2
|
|
}
|
|
|
|
headerSize := 8
|
|
if cfg.ProtoVersion > protoVersion2 {
|
|
headerSize = 9
|
|
}
|
|
|
|
c := &Conn{
|
|
conn: conn,
|
|
r: bufio.NewReader(conn),
|
|
cfg: cfg,
|
|
calls: make(map[int]*callReq),
|
|
timeout: cfg.Timeout,
|
|
version: uint8(cfg.ProtoVersion),
|
|
addr: conn.RemoteAddr().String(),
|
|
errorHandler: errorHandler,
|
|
compressor: cfg.Compressor,
|
|
auth: cfg.Authenticator,
|
|
headerBuf: make([]byte, headerSize),
|
|
quit: make(chan struct{}),
|
|
session: session,
|
|
streams: streams.New(cfg.ProtoVersion),
|
|
}
|
|
|
|
if cfg.Keepalive > 0 {
|
|
c.setKeepalive(cfg.Keepalive)
|
|
}
|
|
|
|
go c.serve()
|
|
|
|
if err := c.startup(); err != nil {
|
|
conn.Close()
|
|
return nil, err
|
|
}
|
|
c.started = true
|
|
|
|
return c, nil
|
|
}
|
|
|
|
func (c *Conn) Write(p []byte) (int, error) {
|
|
if c.timeout > 0 {
|
|
c.conn.SetWriteDeadline(time.Now().Add(c.timeout))
|
|
}
|
|
|
|
return c.conn.Write(p)
|
|
}
|
|
|
|
func (c *Conn) Read(p []byte) (n int, err error) {
|
|
const maxAttempts = 5
|
|
|
|
for i := 0; i < maxAttempts; i++ {
|
|
var nn int
|
|
if c.timeout > 0 {
|
|
c.conn.SetReadDeadline(time.Now().Add(c.timeout))
|
|
}
|
|
|
|
nn, err = io.ReadFull(c.r, p[n:])
|
|
n += nn
|
|
if err == nil {
|
|
break
|
|
}
|
|
|
|
if verr, ok := err.(net.Error); !ok || !verr.Temporary() {
|
|
break
|
|
}
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
func (c *Conn) startup() error {
|
|
m := map[string]string{
|
|
"CQL_VERSION": c.cfg.CQLVersion,
|
|
}
|
|
|
|
if c.compressor != nil {
|
|
m["COMPRESSION"] = c.compressor.Name()
|
|
}
|
|
|
|
framer, err := c.exec(&writeStartupFrame{opts: m}, nil)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
frame, err := framer.parseFrame()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
switch v := frame.(type) {
|
|
case error:
|
|
return v
|
|
case *readyFrame:
|
|
return nil
|
|
case *authenticateFrame:
|
|
return c.authenticateHandshake(v)
|
|
default:
|
|
return NewErrProtocol("Unknown type of response to startup frame: %s", v)
|
|
}
|
|
}
|
|
|
|
func (c *Conn) authenticateHandshake(authFrame *authenticateFrame) error {
|
|
if c.auth == nil {
|
|
return fmt.Errorf("authentication required (using %q)", authFrame.class)
|
|
}
|
|
|
|
resp, challenger, err := c.auth.Challenge([]byte(authFrame.class))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
req := &writeAuthResponseFrame{data: resp}
|
|
|
|
for {
|
|
framer, err := c.exec(req, nil)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
frame, err := framer.parseFrame()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
switch v := frame.(type) {
|
|
case error:
|
|
return v
|
|
case *authSuccessFrame:
|
|
if challenger != nil {
|
|
return challenger.Success(v.data)
|
|
}
|
|
return nil
|
|
case *authChallengeFrame:
|
|
resp, challenger, err = challenger.Challenge(v.data)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
req = &writeAuthResponseFrame{
|
|
data: resp,
|
|
}
|
|
default:
|
|
return fmt.Errorf("unknown frame response during authentication: %v", v)
|
|
}
|
|
|
|
framerPool.Put(framer)
|
|
}
|
|
}
|
|
|
|
func (c *Conn) closeWithError(err error) {
|
|
if !atomic.CompareAndSwapInt32(&c.closed, 0, 1) {
|
|
return
|
|
}
|
|
|
|
if err != nil {
|
|
// we should attempt to deliver the error back to the caller if it
|
|
// exists
|
|
c.mu.RLock()
|
|
for _, req := range c.calls {
|
|
// we need to send the error to all waiting queries, put the state
|
|
// of this conn into not active so that it can not execute any queries.
|
|
if err != nil {
|
|
select {
|
|
case req.resp <- err:
|
|
default:
|
|
}
|
|
}
|
|
}
|
|
c.mu.RUnlock()
|
|
}
|
|
|
|
// if error was nil then unblock the quit channel
|
|
close(c.quit)
|
|
c.conn.Close()
|
|
|
|
if c.started && err != nil {
|
|
c.errorHandler.HandleError(c, err, true)
|
|
}
|
|
}
|
|
|
|
func (c *Conn) Close() {
|
|
c.closeWithError(nil)
|
|
}
|
|
|
|
// Serve starts the stream multiplexer for this connection, which is required
|
|
// to execute any queries. This method runs as long as the connection is
|
|
// open and is therefore usually called in a separate goroutine.
|
|
func (c *Conn) serve() {
|
|
var (
|
|
err error
|
|
)
|
|
|
|
for {
|
|
err = c.recv()
|
|
if err != nil {
|
|
break
|
|
}
|
|
}
|
|
|
|
c.closeWithError(err)
|
|
}
|
|
|
|
func (c *Conn) discardFrame(head frameHeader) error {
|
|
_, err := io.CopyN(ioutil.Discard, c, int64(head.length))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (c *Conn) recv() error {
|
|
// not safe for concurrent reads
|
|
|
|
// read a full header, ignore timeouts, as this is being ran in a loop
|
|
// TODO: TCP level deadlines? or just query level deadlines?
|
|
if c.timeout > 0 {
|
|
c.conn.SetReadDeadline(time.Time{})
|
|
}
|
|
|
|
// were just reading headers over and over and copy bodies
|
|
head, err := readHeader(c.r, c.headerBuf)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if head.stream > c.streams.NumStreams {
|
|
return fmt.Errorf("gocql: frame header stream is beyond call exepected bounds: %d", head.stream)
|
|
} else if head.stream == -1 {
|
|
// TODO: handle cassandra event frames, we shouldnt get any currently
|
|
return c.discardFrame(head)
|
|
} else if head.stream <= 0 {
|
|
// reserved stream that we dont use, probably due to a protocol error
|
|
// or a bug in Cassandra, this should be an error, parse it and return.
|
|
framer := newFramer(c, c, c.compressor, c.version)
|
|
if err := framer.readFrame(&head); err != nil {
|
|
return err
|
|
}
|
|
defer framerPool.Put(framer)
|
|
|
|
frame, err := framer.parseFrame()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
switch v := frame.(type) {
|
|
case error:
|
|
return fmt.Errorf("gocql: error on stream %d: %v", head.stream, v)
|
|
default:
|
|
return fmt.Errorf("gocql: received frame on stream %d: %v", head.stream, frame)
|
|
}
|
|
}
|
|
|
|
c.mu.RLock()
|
|
call, ok := c.calls[head.stream]
|
|
c.mu.RUnlock()
|
|
if call == nil || call.framer == nil || !ok {
|
|
log.Printf("gocql: received response for stream which has no handler: header=%v\n", head)
|
|
return c.discardFrame(head)
|
|
}
|
|
|
|
err = call.framer.readFrame(&head)
|
|
if err != nil {
|
|
// only net errors should cause the connection to be closed. Though
|
|
// cassandra returning corrupt frames will be returned here as well.
|
|
if _, ok := err.(net.Error); ok {
|
|
return err
|
|
}
|
|
}
|
|
|
|
// we either, return a response to the caller, the caller timedout, or the
|
|
// connection has closed. Either way we should never block indefinatly here
|
|
select {
|
|
case call.resp <- err:
|
|
case <-call.timeout:
|
|
c.releaseStream(head.stream)
|
|
case <-c.quit:
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
type callReq struct {
|
|
// could use a waitgroup but this allows us to do timeouts on the read/send
|
|
resp chan error
|
|
framer *framer
|
|
timeout chan struct{} // indicates to recv() that a call has timedout
|
|
streamID int // current stream in use
|
|
}
|
|
|
|
func (c *Conn) releaseStream(stream int) {
|
|
c.mu.Lock()
|
|
call := c.calls[stream]
|
|
if call != nil && stream != call.streamID {
|
|
panic(fmt.Sprintf("attempt to release streamID with ivalid stream: %d -> %+v\n", stream, call))
|
|
} else if call == nil {
|
|
panic(fmt.Sprintf("releasing a stream not in use: %d", stream))
|
|
}
|
|
delete(c.calls, stream)
|
|
c.mu.Unlock()
|
|
|
|
streamPool.Put(call)
|
|
c.streams.Clear(stream)
|
|
}
|
|
|
|
func (c *Conn) handleTimeout() {
|
|
if atomic.AddInt64(&c.timeouts, 1) > TimeoutLimit {
|
|
c.closeWithError(ErrTooManyTimeouts)
|
|
}
|
|
}
|
|
|
|
var (
|
|
streamPool = sync.Pool{
|
|
New: func() interface{} {
|
|
return &callReq{
|
|
resp: make(chan error),
|
|
}
|
|
},
|
|
}
|
|
)
|
|
|
|
func (c *Conn) exec(req frameWriter, tracer Tracer) (*framer, error) {
|
|
// TODO: move tracer onto conn
|
|
stream, ok := c.streams.GetStream()
|
|
if !ok {
|
|
fmt.Println(c.streams)
|
|
return nil, ErrNoStreams
|
|
}
|
|
|
|
// resp is basically a waiting semaphore protecting the framer
|
|
framer := newFramer(c, c, c.compressor, c.version)
|
|
|
|
c.mu.Lock()
|
|
call := c.calls[stream]
|
|
if call != nil {
|
|
c.mu.Unlock()
|
|
return nil, fmt.Errorf("attempting to use stream already in use: %d -> %d", stream, call.streamID)
|
|
} else {
|
|
call = streamPool.Get().(*callReq)
|
|
}
|
|
c.calls[stream] = call
|
|
c.mu.Unlock()
|
|
|
|
call.framer = framer
|
|
call.timeout = make(chan struct{})
|
|
call.streamID = stream
|
|
|
|
if tracer != nil {
|
|
framer.trace()
|
|
}
|
|
|
|
err := req.writeFrame(framer, stream)
|
|
if err != nil {
|
|
// I think this is the correct thing to do, im not entirely sure. It is not
|
|
// ideal as readers might still get some data, but they probably wont.
|
|
// Here we need to be careful as the stream is not available and if all
|
|
// writes just timeout or fail then the pool might use this connection to
|
|
// send a frame on, with all the streams used up and not returned.
|
|
c.closeWithError(err)
|
|
return nil, err
|
|
}
|
|
|
|
select {
|
|
case err := <-call.resp:
|
|
if err != nil {
|
|
if !c.Closed() {
|
|
// if the connection is closed then we cant release the stream,
|
|
// this is because the request is still outstanding and we have
|
|
// been handed another error from another stream which caused the
|
|
// connection to close.
|
|
c.releaseStream(stream)
|
|
}
|
|
return nil, err
|
|
}
|
|
case <-time.After(c.timeout):
|
|
close(call.timeout)
|
|
c.handleTimeout()
|
|
return nil, ErrTimeoutNoResponse
|
|
case <-c.quit:
|
|
return nil, ErrConnectionClosed
|
|
}
|
|
|
|
// dont release the stream if detect a timeout as another request can reuse
|
|
// that stream and get a response for the old request, which we have no
|
|
// easy way of detecting.
|
|
//
|
|
// Ensure that the stream is not released if there are potentially outstanding
|
|
// requests on the stream to prevent nil pointer dereferences in recv().
|
|
defer c.releaseStream(stream)
|
|
|
|
if v := framer.header.version.version(); v != c.version {
|
|
return nil, NewErrProtocol("unexpected protocol version in response: got %d expected %d", v, c.version)
|
|
}
|
|
|
|
return framer, nil
|
|
}
|
|
|
|
func (c *Conn) prepareStatement(stmt string, tracer Tracer) (*QueryInfo, error) {
|
|
stmtsLRU.Lock()
|
|
if stmtsLRU.lru == nil {
|
|
initStmtsLRU(defaultMaxPreparedStmts)
|
|
}
|
|
|
|
stmtCacheKey := c.addr + c.currentKeyspace + stmt
|
|
|
|
if val, ok := stmtsLRU.lru.Get(stmtCacheKey); ok {
|
|
stmtsLRU.Unlock()
|
|
flight := val.(*inflightPrepare)
|
|
flight.wg.Wait()
|
|
return &flight.info, flight.err
|
|
}
|
|
|
|
flight := new(inflightPrepare)
|
|
flight.wg.Add(1)
|
|
stmtsLRU.lru.Add(stmtCacheKey, flight)
|
|
stmtsLRU.Unlock()
|
|
|
|
prep := &writePrepareFrame{
|
|
statement: stmt,
|
|
}
|
|
|
|
framer, err := c.exec(prep, tracer)
|
|
if err != nil {
|
|
flight.err = err
|
|
flight.wg.Done()
|
|
return nil, err
|
|
}
|
|
|
|
frame, err := framer.parseFrame()
|
|
if err != nil {
|
|
flight.err = err
|
|
flight.wg.Done()
|
|
return nil, err
|
|
}
|
|
|
|
// TODO(zariel): tidy this up, simplify handling of frame parsing so its not duplicated
|
|
// everytime we need to parse a frame.
|
|
if len(framer.traceID) > 0 {
|
|
tracer.Trace(framer.traceID)
|
|
}
|
|
|
|
switch x := frame.(type) {
|
|
case *resultPreparedFrame:
|
|
// defensivly copy as we will recycle the underlying buffer after we
|
|
// return.
|
|
flight.info.Id = copyBytes(x.preparedID)
|
|
// the type info's should _not_ have a reference to the framers read buffer,
|
|
// therefore we can just copy them directly.
|
|
flight.info.Args = x.reqMeta.columns
|
|
flight.info.PKeyColumns = x.reqMeta.pkeyColumns
|
|
flight.info.Rval = x.respMeta.columns
|
|
case error:
|
|
flight.err = x
|
|
default:
|
|
flight.err = NewErrProtocol("Unknown type in response to prepare frame: %s", x)
|
|
}
|
|
flight.wg.Done()
|
|
|
|
if flight.err != nil {
|
|
stmtsLRU.Lock()
|
|
stmtsLRU.lru.Remove(stmtCacheKey)
|
|
stmtsLRU.Unlock()
|
|
}
|
|
|
|
framerPool.Put(framer)
|
|
|
|
return &flight.info, flight.err
|
|
}
|
|
|
|
func (c *Conn) executeQuery(qry *Query) *Iter {
|
|
params := queryParams{
|
|
consistency: qry.cons,
|
|
}
|
|
|
|
// frame checks that it is not 0
|
|
params.serialConsistency = qry.serialCons
|
|
params.defaultTimestamp = qry.defaultTimestamp
|
|
|
|
if len(qry.pageState) > 0 {
|
|
params.pagingState = qry.pageState
|
|
}
|
|
if qry.pageSize > 0 {
|
|
params.pageSize = qry.pageSize
|
|
}
|
|
|
|
var frame frameWriter
|
|
if qry.shouldPrepare() {
|
|
// Prepare all DML queries. Other queries can not be prepared.
|
|
info, err := c.prepareStatement(qry.stmt, qry.trace)
|
|
if err != nil {
|
|
return &Iter{err: err}
|
|
}
|
|
|
|
var values []interface{}
|
|
|
|
if qry.binding == nil {
|
|
values = qry.values
|
|
} else {
|
|
values, err = qry.binding(info)
|
|
if err != nil {
|
|
return &Iter{err: err}
|
|
}
|
|
}
|
|
|
|
if len(values) != len(info.Args) {
|
|
return &Iter{err: ErrQueryArgLength}
|
|
}
|
|
|
|
params.values = make([]queryValues, len(values))
|
|
for i := 0; i < len(values); i++ {
|
|
val, err := Marshal(info.Args[i].TypeInfo, values[i])
|
|
if err != nil {
|
|
return &Iter{err: err}
|
|
}
|
|
|
|
v := ¶ms.values[i]
|
|
v.value = val
|
|
// TODO: handle query binding names
|
|
}
|
|
|
|
frame = &writeExecuteFrame{
|
|
preparedID: info.Id,
|
|
params: params,
|
|
}
|
|
} else {
|
|
frame = &writeQueryFrame{
|
|
statement: qry.stmt,
|
|
params: params,
|
|
}
|
|
}
|
|
|
|
framer, err := c.exec(frame, qry.trace)
|
|
if err != nil {
|
|
return &Iter{err: err}
|
|
}
|
|
|
|
resp, err := framer.parseFrame()
|
|
if err != nil {
|
|
return &Iter{err: err}
|
|
}
|
|
|
|
if len(framer.traceID) > 0 {
|
|
qry.trace.Trace(framer.traceID)
|
|
}
|
|
|
|
switch x := resp.(type) {
|
|
case *resultVoidFrame:
|
|
return &Iter{framer: framer}
|
|
case *resultRowsFrame:
|
|
iter := &Iter{
|
|
meta: x.meta,
|
|
rows: x.rows,
|
|
framer: framer,
|
|
}
|
|
|
|
if len(x.meta.pagingState) > 0 && !qry.disableAutoPage {
|
|
iter.next = &nextIter{
|
|
qry: *qry,
|
|
pos: int((1 - qry.prefetch) * float64(len(iter.rows))),
|
|
}
|
|
|
|
iter.next.qry.pageState = x.meta.pagingState
|
|
if iter.next.pos < 1 {
|
|
iter.next.pos = 1
|
|
}
|
|
}
|
|
|
|
return iter
|
|
case *resultKeyspaceFrame:
|
|
return &Iter{framer: framer}
|
|
case *resultSchemaChangeFrame, *schemaChangeKeyspace, *schemaChangeTable, *schemaChangeFunction:
|
|
iter := &Iter{framer: framer}
|
|
c.awaitSchemaAgreement()
|
|
// dont return an error from this, might be a good idea to give a warning
|
|
// though. The impact of this returning an error would be that the cluster
|
|
// is not consistent with regards to its schema.
|
|
return iter
|
|
case *RequestErrUnprepared:
|
|
stmtsLRU.Lock()
|
|
stmtCacheKey := c.addr + c.currentKeyspace + qry.stmt
|
|
if _, ok := stmtsLRU.lru.Get(stmtCacheKey); ok {
|
|
stmtsLRU.lru.Remove(stmtCacheKey)
|
|
stmtsLRU.Unlock()
|
|
return c.executeQuery(qry)
|
|
}
|
|
stmtsLRU.Unlock()
|
|
return &Iter{err: x, framer: framer}
|
|
case error:
|
|
return &Iter{err: x, framer: framer}
|
|
default:
|
|
return &Iter{
|
|
err: NewErrProtocol("Unknown type in response to execute query (%T): %s", x, x),
|
|
framer: framer,
|
|
}
|
|
}
|
|
}
|
|
|
|
func (c *Conn) Pick(qry *Query) *Conn {
|
|
if c.Closed() {
|
|
return nil
|
|
}
|
|
return c
|
|
}
|
|
|
|
func (c *Conn) Closed() bool {
|
|
return atomic.LoadInt32(&c.closed) == 1
|
|
}
|
|
|
|
func (c *Conn) Address() string {
|
|
return c.addr
|
|
}
|
|
|
|
func (c *Conn) AvailableStreams() int {
|
|
return c.streams.Available()
|
|
}
|
|
|
|
func (c *Conn) UseKeyspace(keyspace string) error {
|
|
q := &writeQueryFrame{statement: `USE "` + keyspace + `"`}
|
|
q.params.consistency = Any
|
|
|
|
framer, err := c.exec(q, nil)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
resp, err := framer.parseFrame()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
switch x := resp.(type) {
|
|
case *resultKeyspaceFrame:
|
|
case error:
|
|
return x
|
|
default:
|
|
return NewErrProtocol("unknown frame in response to USE: %v", x)
|
|
}
|
|
|
|
c.currentKeyspace = keyspace
|
|
|
|
return nil
|
|
}
|
|
|
|
func (c *Conn) executeBatch(batch *Batch) (*Iter, error) {
|
|
if c.version == protoVersion1 {
|
|
return nil, ErrUnsupported
|
|
}
|
|
|
|
n := len(batch.Entries)
|
|
req := &writeBatchFrame{
|
|
typ: batch.Type,
|
|
statements: make([]batchStatment, n),
|
|
consistency: batch.Cons,
|
|
serialConsistency: batch.serialCons,
|
|
defaultTimestamp: batch.defaultTimestamp,
|
|
}
|
|
|
|
stmts := make(map[string]string)
|
|
|
|
for i := 0; i < n; i++ {
|
|
entry := &batch.Entries[i]
|
|
b := &req.statements[i]
|
|
if len(entry.Args) > 0 || entry.binding != nil {
|
|
info, err := c.prepareStatement(entry.Stmt, nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var args []interface{}
|
|
if entry.binding == nil {
|
|
args = entry.Args
|
|
} else {
|
|
args, err = entry.binding(info)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
if len(args) != len(info.Args) {
|
|
return nil, ErrQueryArgLength
|
|
}
|
|
|
|
b.preparedID = info.Id
|
|
stmts[string(info.Id)] = entry.Stmt
|
|
|
|
b.values = make([]queryValues, len(info.Args))
|
|
|
|
for j := 0; j < len(info.Args); j++ {
|
|
val, err := Marshal(info.Args[j].TypeInfo, args[j])
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
b.values[j].value = val
|
|
// TODO: add names
|
|
}
|
|
} else {
|
|
b.statement = entry.Stmt
|
|
}
|
|
}
|
|
|
|
// TODO: should batch support tracing?
|
|
framer, err := c.exec(req, nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
resp, err := framer.parseFrame()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
switch x := resp.(type) {
|
|
case *resultVoidFrame:
|
|
framerPool.Put(framer)
|
|
return nil, nil
|
|
case *RequestErrUnprepared:
|
|
stmt, found := stmts[string(x.StatementId)]
|
|
if found {
|
|
stmtsLRU.Lock()
|
|
stmtsLRU.lru.Remove(c.addr + c.currentKeyspace + stmt)
|
|
stmtsLRU.Unlock()
|
|
}
|
|
|
|
framerPool.Put(framer)
|
|
|
|
if found {
|
|
return c.executeBatch(batch)
|
|
} else {
|
|
return nil, x
|
|
}
|
|
case *resultRowsFrame:
|
|
iter := &Iter{
|
|
meta: x.meta,
|
|
rows: x.rows,
|
|
framer: framer,
|
|
}
|
|
|
|
return iter, nil
|
|
case error:
|
|
framerPool.Put(framer)
|
|
return nil, x
|
|
default:
|
|
framerPool.Put(framer)
|
|
return nil, NewErrProtocol("Unknown type in response to batch statement: %s", x)
|
|
}
|
|
}
|
|
|
|
func (c *Conn) setKeepalive(d time.Duration) error {
|
|
if tc, ok := c.conn.(*net.TCPConn); ok {
|
|
err := tc.SetKeepAlivePeriod(d)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return tc.SetKeepAlive(true)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (c *Conn) query(statement string, values ...interface{}) (iter *Iter) {
|
|
q := c.session.Query(statement, values...).Consistency(One)
|
|
return c.executeQuery(q)
|
|
}
|
|
|
|
func (c *Conn) awaitSchemaAgreement() (err error) {
|
|
const (
|
|
peerSchemas = "SELECT schema_version FROM system.peers"
|
|
localSchemas = "SELECT schema_version FROM system.local WHERE key='local'"
|
|
)
|
|
|
|
endDeadline := time.Now().Add(c.session.cfg.MaxWaitSchemaAgreement)
|
|
for time.Now().Before(endDeadline) {
|
|
iter := c.query(peerSchemas)
|
|
|
|
versions := make(map[string]struct{})
|
|
|
|
var schemaVersion string
|
|
for iter.Scan(&schemaVersion) {
|
|
versions[schemaVersion] = struct{}{}
|
|
schemaVersion = ""
|
|
}
|
|
|
|
if err = iter.Close(); err != nil {
|
|
goto cont
|
|
}
|
|
|
|
iter = c.query(localSchemas)
|
|
for iter.Scan(&schemaVersion) {
|
|
versions[schemaVersion] = struct{}{}
|
|
schemaVersion = ""
|
|
}
|
|
|
|
if err = iter.Close(); err != nil {
|
|
goto cont
|
|
}
|
|
|
|
if len(versions) <= 1 {
|
|
return nil
|
|
}
|
|
|
|
cont:
|
|
time.Sleep(200 * time.Millisecond)
|
|
}
|
|
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
// not exported
|
|
return errors.New("gocql: cluster schema versions not consistent")
|
|
}
|
|
|
|
type inflightPrepare struct {
|
|
info QueryInfo
|
|
err error
|
|
wg sync.WaitGroup
|
|
}
|
|
|
|
var (
|
|
ErrQueryArgLength = errors.New("gocql: query argument length mismatch")
|
|
ErrTimeoutNoResponse = errors.New("gocql: no response received from cassandra within timeout period")
|
|
ErrTooManyTimeouts = errors.New("gocql: too many query timeouts on the connection")
|
|
ErrConnectionClosed = errors.New("gocql: connection closed waiting for response")
|
|
ErrNoStreams = errors.New("gocql: no streams available on connection")
|
|
)
|