summaryrefslogtreecommitdiffstats
path: root/vendor/github.com/denisenkom/go-mssqldb/mssql.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/denisenkom/go-mssqldb/mssql.go')
-rw-r--r--vendor/github.com/denisenkom/go-mssqldb/mssql.go385
1 files changed, 294 insertions, 91 deletions
diff --git a/vendor/github.com/denisenkom/go-mssqldb/mssql.go b/vendor/github.com/denisenkom/go-mssqldb/mssql.go
index 8f5ff2d0ce..cf84f3a12b 100644
--- a/vendor/github.com/denisenkom/go-mssqldb/mssql.go
+++ b/vendor/github.com/denisenkom/go-mssqldb/mssql.go
@@ -13,32 +13,37 @@ import (
"reflect"
"strings"
"time"
+ "unicode"
+
+ "github.com/denisenkom/go-mssqldb/internal/querytext"
)
+// ReturnStatus may be used to return the return value from a proc.
+//
+// var rs mssql.ReturnStatus
+// _, err := db.Exec("theproc", &rs)
+// log.Printf("return status = %d", rs)
+type ReturnStatus int32
+
var driverInstance = &Driver{processQueryText: true}
var driverInstanceNoProcess = &Driver{processQueryText: false}
func init() {
sql.Register("mssql", driverInstance)
sql.Register("sqlserver", driverInstanceNoProcess)
- createDialer = func(p *connectParams) dialer {
- return tcpDialer{&net.Dialer{Timeout: p.dial_timeout, KeepAlive: p.keepAlive}}
+ createDialer = func(p *connectParams) Dialer {
+ return netDialer{&net.Dialer{KeepAlive: p.keepAlive}}
}
}
-// Abstract the dialer for testing and for non-TCP based connections.
-type dialer interface {
- Dial(ctx context.Context, addr string) (net.Conn, error)
-}
+var createDialer func(p *connectParams) Dialer
-var createDialer func(p *connectParams) dialer
-
-type tcpDialer struct {
+type netDialer struct {
nd *net.Dialer
}
-func (d tcpDialer) Dial(ctx context.Context, addr string) (net.Conn, error) {
- return d.nd.DialContext(ctx, "tcp", addr)
+func (d netDialer) DialContext(ctx context.Context, network string, addr string) (net.Conn, error) {
+ return d.nd.DialContext(ctx, network, addr)
}
type Driver struct {
@@ -63,6 +68,29 @@ func (d *Driver) Open(dsn string) (driver.Conn, error) {
return d.open(context.Background(), dsn)
}
+func SetLogger(logger Logger) {
+ driverInstance.SetLogger(logger)
+ driverInstanceNoProcess.SetLogger(logger)
+}
+
+func (d *Driver) SetLogger(logger Logger) {
+ d.log = optionalLogger{logger}
+}
+
+// NewConnector creates a new connector from a DSN.
+// The returned connector may be used with sql.OpenDB.
+func NewConnector(dsn string) (*Connector, error) {
+ params, err := parseConnectParams(dsn)
+ if err != nil {
+ return nil, err
+ }
+ c := &Connector{
+ params: params,
+ driver: driverInstanceNoProcess,
+ }
+ return c, nil
+}
+
// Connector holds the parsed DSN and is ready to make a new connection
// at any time.
//
@@ -71,35 +99,64 @@ func (d *Driver) Open(dsn string) (driver.Conn, error) {
type Connector struct {
params connectParams
driver *Driver
-}
-
-// Connect to the server and return a TDS connection.
-func (c *Connector) Connect(ctx context.Context) (driver.Conn, error) {
- return c.driver.connect(ctx, c.params)
-}
-// Driver underlying the Connector.
-func (c *Connector) Driver() driver.Driver {
- return c.driver
-}
-
-func SetLogger(logger Logger) {
- driverInstance.SetLogger(logger)
- driverInstanceNoProcess.SetLogger(logger)
-}
-
-func (d *Driver) SetLogger(logger Logger) {
- d.log = optionalLogger{logger}
+ // SessionInitSQL is executed after marking a given session to be reset.
+ // When not present, the next query will still reset the session to the
+ // database defaults.
+ //
+ // When present the connection will immediately mark the session to
+ // be reset, then execute the SessionInitSQL text to setup the session
+ // that may be different from the base database defaults.
+ //
+ // For Example, the application relies on the following defaults
+ // but is not allowed to set them at the database system level.
+ //
+ // SET XACT_ABORT ON;
+ // SET TEXTSIZE -1;
+ // SET ANSI_NULLS ON;
+ // SET LOCK_TIMEOUT 10000;
+ //
+ // SessionInitSQL should not attempt to manually call sp_reset_connection.
+ // This will happen at the TDS layer.
+ //
+ // SessionInitSQL is optional. The session will be reset even if
+ // SessionInitSQL is empty.
+ SessionInitSQL string
+
+ // Dialer sets a custom dialer for all network operations.
+ // If Dialer is not set, normal net dialers are used.
+ Dialer Dialer
+}
+
+type Dialer interface {
+ DialContext(ctx context.Context, network string, addr string) (net.Conn, error)
+}
+
+func (c *Connector) getDialer(p *connectParams) Dialer {
+ if c != nil && c.Dialer != nil {
+ return c.Dialer
+ }
+ return createDialer(p)
}
type Conn struct {
+ connector *Connector
sess *tdsSession
transactionCtx context.Context
+ resetSession bool
processQueryText bool
connectionGood bool
- outs map[string]interface{}
+ outs map[string]interface{}
+ returnStatus *ReturnStatus
+}
+
+func (c *Conn) setReturnStatus(s ReturnStatus) {
+ if c.returnStatus == nil {
+ return
+ }
+ *c.returnStatus = s
}
func (c *Conn) checkBadConn(err error) error {
@@ -117,6 +174,7 @@ func (c *Conn) checkBadConn(err error) error {
case nil:
return nil
case io.EOF:
+ c.connectionGood = false
return driver.ErrBadConn
case driver.ErrBadConn:
// It is an internal programming error if driver.ErrBadConn
@@ -174,7 +232,9 @@ func (c *Conn) sendCommitRequest() error {
{hdrtype: dataStmHdrTransDescr,
data: transDescrHdr{c.sess.tranid, 1}.pack()},
}
- if err := sendCommitXact(c.sess.buf, headers, "", 0, 0, ""); err != nil {
+ reset := c.resetSession
+ c.resetSession = false
+ if err := sendCommitXact(c.sess.buf, headers, "", 0, 0, "", reset); err != nil {
if c.sess.logFlags&logErrors != 0 {
c.sess.log.Printf("Failed to send CommitXact with %v", err)
}
@@ -199,7 +259,9 @@ func (c *Conn) sendRollbackRequest() error {
{hdrtype: dataStmHdrTransDescr,
data: transDescrHdr{c.sess.tranid, 1}.pack()},
}
- if err := sendRollbackXact(c.sess.buf, headers, "", 0, 0, ""); err != nil {
+ reset := c.resetSession
+ c.resetSession = false
+ if err := sendRollbackXact(c.sess.buf, headers, "", 0, 0, "", reset); err != nil {
if c.sess.logFlags&logErrors != 0 {
c.sess.log.Printf("Failed to send RollbackXact with %v", err)
}
@@ -234,12 +296,14 @@ func (c *Conn) sendBeginRequest(ctx context.Context, tdsIsolation isoLevel) erro
{hdrtype: dataStmHdrTransDescr,
data: transDescrHdr{0, 1}.pack()},
}
- if err := sendBeginXact(c.sess.buf, headers, tdsIsolation, ""); err != nil {
+ reset := c.resetSession
+ c.resetSession = false
+ if err := sendBeginXact(c.sess.buf, headers, tdsIsolation, "", reset); err != nil {
if c.sess.logFlags&logErrors != 0 {
c.sess.log.Printf("Failed to send BeginXact with %v", err)
}
c.connectionGood = false
- return fmt.Errorf("Failed to send BiginXant: %v", err)
+ return fmt.Errorf("Failed to send BeginXact: %v", err)
}
return nil
}
@@ -258,12 +322,12 @@ func (d *Driver) open(ctx context.Context, dsn string) (*Conn, error) {
if err != nil {
return nil, err
}
- return d.connect(ctx, params)
+ return d.connect(ctx, nil, params)
}
// connect to the server, using the provided context for dialing only.
-func (d *Driver) connect(ctx context.Context, params connectParams) (*Conn, error) {
- sess, err := connect(ctx, d.log, params)
+func (d *Driver) connect(ctx context.Context, c *Connector, params connectParams) (*Conn, error) {
+ sess, err := connect(ctx, c, d.log, params)
if err != nil {
// main server failed, try fail-over partner
if params.failOverPartner == "" {
@@ -275,7 +339,7 @@ func (d *Driver) connect(ctx context.Context, params connectParams) (*Conn, erro
params.port = params.failOverPort
}
- sess, err = connect(ctx, d.log, params)
+ sess, err = connect(ctx, c, d.log, params)
if err != nil {
// fail-over partner also failed, now fail
return nil, err
@@ -283,12 +347,13 @@ func (d *Driver) connect(ctx context.Context, params connectParams) (*Conn, erro
}
conn := &Conn{
+ connector: c,
sess: sess,
transactionCtx: context.Background(),
processQueryText: d.processQueryText,
connectionGood: true,
}
- conn.sess.log = d.log
+
return conn, nil
}
@@ -314,16 +379,15 @@ func (c *Conn) Prepare(query string) (driver.Stmt, error) {
return nil, driver.ErrBadConn
}
if len(query) > 10 && strings.EqualFold(query[:10], "INSERTBULK") {
- return c.prepareCopyIn(query)
+ return c.prepareCopyIn(context.Background(), query)
}
-
return c.prepareContext(context.Background(), query)
}
func (c *Conn) prepareContext(ctx context.Context, query string) (*Stmt, error) {
paramCount := -1
if c.processQueryText {
- query, paramCount = parseParams(query)
+ query, paramCount = querytext.ParseParams(query)
}
return &Stmt{c, query, paramCount, nil}, nil
}
@@ -362,11 +426,13 @@ func (s *Stmt) sendQuery(args []namedValue) (err error) {
})
}
+ conn := s.c
+
// no need to check number of parameters here, it is checked by database/sql
- if s.c.sess.logFlags&logSQL != 0 {
- s.c.sess.log.Println(s.query)
+ if conn.sess.logFlags&logSQL != 0 {
+ conn.sess.log.Println(s.query)
}
- if s.c.sess.logFlags&logParams != 0 && len(args) > 0 {
+ if conn.sess.logFlags&logParams != 0 && len(args) > 0 {
for i := 0; i < len(args); i++ {
if len(args[i].Name) > 0 {
s.c.sess.log.Printf("\t@%s\t%v\n", args[i].Name, args[i].Value)
@@ -374,36 +440,41 @@ func (s *Stmt) sendQuery(args []namedValue) (err error) {
s.c.sess.log.Printf("\t@p%d\t%v\n", i+1, args[i].Value)
}
}
-
}
+
+ reset := conn.resetSession
+ conn.resetSession = false
if len(args) == 0 {
- if err = sendSqlBatch72(s.c.sess.buf, s.query, headers); err != nil {
- if s.c.sess.logFlags&logErrors != 0 {
- s.c.sess.log.Printf("Failed to send SqlBatch with %v", err)
+ if err = sendSqlBatch72(conn.sess.buf, s.query, headers, reset); err != nil {
+ if conn.sess.logFlags&logErrors != 0 {
+ conn.sess.log.Printf("Failed to send SqlBatch with %v", err)
}
- s.c.connectionGood = false
+ conn.connectionGood = false
return fmt.Errorf("failed to send SQL Batch: %v", err)
}
} else {
- proc := Sp_ExecuteSql
- var params []Param
+ proc := sp_ExecuteSql
+ var params []param
if isProc(s.query) {
proc.name = s.query
- params, _, err = s.makeRPCParams(args, 0)
+ params, _, err = s.makeRPCParams(args, true)
+ if err != nil {
+ return
+ }
} else {
var decls []string
- params, decls, err = s.makeRPCParams(args, 2)
+ params, decls, err = s.makeRPCParams(args, false)
if err != nil {
return
}
params[0] = makeStrParam(s.query)
params[1] = makeStrParam(strings.Join(decls, ","))
}
- if err = sendRpc(s.c.sess.buf, headers, proc, 0, params); err != nil {
- if s.c.sess.logFlags&logErrors != 0 {
- s.c.sess.log.Printf("Failed to send Rpc with %v", err)
+ if err = sendRpc(conn.sess.buf, headers, proc, 0, params, reset); err != nil {
+ if conn.sess.logFlags&logErrors != 0 {
+ conn.sess.log.Printf("Failed to send Rpc with %v", err)
}
- s.c.connectionGood = false
+ conn.connectionGood = false
return fmt.Errorf("Failed to send RPC: %v", err)
}
}
@@ -416,15 +487,61 @@ func isProc(s string) bool {
if len(s) == 0 {
return false
}
- if s[0] == '[' && s[len(s)-1] == ']' && strings.ContainsAny(s, "\n\r") == false {
- return true
+ const (
+ outside = iota
+ text
+ escaped
+ )
+ st := outside
+ var rn1, rPrev rune
+ for _, r := range s {
+ rPrev = rn1
+ rn1 = r
+ switch r {
+ // No newlines or string sequences.
+ case '\n', '\r', '\'', ';':
+ return false
+ }
+ switch st {
+ case outside:
+ switch {
+ case unicode.IsSpace(r):
+ return false
+ case r == '[':
+ st = escaped
+ continue
+ case r == ']' && rPrev == ']':
+ st = escaped
+ continue
+ case unicode.IsLetter(r):
+ st = text
+ }
+ case text:
+ switch {
+ case r == '.':
+ st = outside
+ continue
+ case unicode.IsSpace(r):
+ return false
+ }
+ case escaped:
+ switch {
+ case r == ']':
+ st = outside
+ continue
+ }
+ }
}
- return !strings.ContainsAny(s, " \t\n\r;")
+ return true
}
-func (s *Stmt) makeRPCParams(args []namedValue, offset int) ([]Param, []string, error) {
+func (s *Stmt) makeRPCParams(args []namedValue, isProc bool) ([]param, []string, error) {
var err error
- params := make([]Param, len(args)+offset)
+ var offset int
+ if !isProc {
+ offset = 2
+ }
+ params := make([]param, len(args)+offset)
decls := make([]string, len(args))
for i, val := range args {
params[i+offset], err = s.makeParam(val.Value)
@@ -434,7 +551,7 @@ func (s *Stmt) makeRPCParams(args []namedValue, offset int) ([]Param, []string,
var name string
if len(val.Name) > 0 {
name = "@" + val.Name
- } else {
+ } else if !isProc {
name = fmt.Sprintf("@p%d", val.Ordinal)
}
params[i+offset].Name = name
@@ -498,6 +615,8 @@ loop:
if token.isError() {
return nil, s.c.checkBadConn(token.getError())
}
+ case ReturnStatus:
+ s.c.setReturnStatus(token)
case error:
return nil, s.c.checkBadConn(token)
}
@@ -541,6 +660,8 @@ func (s *Stmt) processExec(ctx context.Context) (res driver.Result, err error) {
if token.isError() {
return nil, token.getError()
}
+ case ReturnStatus:
+ s.c.setReturnStatus(token)
case error:
return nil, token
}
@@ -666,14 +787,14 @@ func (r *Rows) ColumnTypeNullable(index int) (nullable, ok bool) {
return
}
-func makeStrParam(val string) (res Param) {
+func makeStrParam(val string) (res param) {
res.ti.TypeId = typeNVarChar
res.buffer = str2ucs2(val)
res.ti.Size = len(res.buffer)
return
}
-func (s *Stmt) makeParam(val driver.Value) (res Param, err error) {
+func (s *Stmt) makeParam(val driver.Value) (res param, err error) {
if val == nil {
res.ti.TypeId = typeNull
res.buffer = nil
@@ -686,17 +807,34 @@ func (s *Stmt) makeParam(val driver.Value) (res Param, err error) {
res.buffer = make([]byte, 8)
res.ti.Size = 8
binary.LittleEndian.PutUint64(res.buffer, uint64(val))
+ case sql.NullInt64:
+ // only null values should be getting here
+ res.ti.TypeId = typeIntN
+ res.ti.Size = 8
+ res.buffer = []byte{}
+
case float64:
res.ti.TypeId = typeFltN
res.ti.Size = 8
res.buffer = make([]byte, 8)
binary.LittleEndian.PutUint64(res.buffer, math.Float64bits(val))
+ case sql.NullFloat64:
+ // only null values should be getting here
+ res.ti.TypeId = typeFltN
+ res.ti.Size = 8
+ res.buffer = []byte{}
+
case []byte:
res.ti.TypeId = typeBigVarBin
res.ti.Size = len(val)
res.buffer = val
case string:
res = makeStrParam(val)
+ case sql.NullString:
+ // only null values should be getting here
+ res.ti.TypeId = typeNVarChar
+ res.buffer = nil
+ res.ti.Size = 8000
case bool:
res.ti.TypeId = typeBitN
res.ti.Size = 1
@@ -704,37 +842,22 @@ func (s *Stmt) makeParam(val driver.Value) (res Param, err error) {
if val {
res.buffer[0] = 1
}
+ case sql.NullBool:
+ // only null values should be getting here
+ res.ti.TypeId = typeBitN
+ res.ti.Size = 1
+ res.buffer = []byte{}
+
case time.Time:
if s.c.sess.loginAck.TDSVersion >= verTDS73 {
res.ti.TypeId = typeDateTimeOffsetN
res.ti.Scale = 7
- res.ti.Size = 10
- buf := make([]byte, 10)
- res.buffer = buf
- days, ns := dateTime2(val)
- ns /= 100
- buf[0] = byte(ns)
- buf[1] = byte(ns >> 8)
- buf[2] = byte(ns >> 16)
- buf[3] = byte(ns >> 24)
- buf[4] = byte(ns >> 32)
- buf[5] = byte(days)
- buf[6] = byte(days >> 8)
- buf[7] = byte(days >> 16)
- _, offset := val.Zone()
- offset /= 60
- buf[8] = byte(offset)
- buf[9] = byte(offset >> 8)
+ res.buffer = encodeDateTimeOffset(val, int(res.ti.Scale))
+ res.ti.Size = len(res.buffer)
} else {
res.ti.TypeId = typeDateTimeN
- res.ti.Size = 8
- res.buffer = make([]byte, 8)
- ref := time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC)
- dur := val.Sub(ref)
- days := dur / (24 * time.Hour)
- tm := (300 * (dur % (24 * time.Hour))) / time.Second
- binary.LittleEndian.PutUint32(res.buffer[0:4], uint32(days))
- binary.LittleEndian.PutUint32(res.buffer[4:8], uint32(tm))
+ res.buffer = encodeDateTime(val)
+ res.ti.Size = len(res.buffer)
}
default:
return s.makeParamExtra(val)
@@ -773,3 +896,83 @@ func (r *Result) LastInsertId() (int64, error) {
lastInsertId := dest[0].(int64)
return lastInsertId, nil
}
+
+var _ driver.Pinger = &Conn{}
+
+// Ping is used to check if the remote server is available and satisfies the Pinger interface.
+func (c *Conn) Ping(ctx context.Context) error {
+ if !c.connectionGood {
+ return driver.ErrBadConn
+ }
+ stmt := &Stmt{c, `select 1;`, 0, nil}
+ _, err := stmt.ExecContext(ctx, nil)
+ return err
+}
+
+var _ driver.ConnBeginTx = &Conn{}
+
+// BeginTx satisfies ConnBeginTx.
+func (c *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
+ if !c.connectionGood {
+ return nil, driver.ErrBadConn
+ }
+ if opts.ReadOnly {
+ return nil, errors.New("Read-only transactions are not supported")
+ }
+
+ var tdsIsolation isoLevel
+ switch sql.IsolationLevel(opts.Isolation) {
+ case sql.LevelDefault:
+ tdsIsolation = isolationUseCurrent
+ case sql.LevelReadUncommitted:
+ tdsIsolation = isolationReadUncommited
+ case sql.LevelReadCommitted:
+ tdsIsolation = isolationReadCommited
+ case sql.LevelWriteCommitted:
+ return nil, errors.New("LevelWriteCommitted isolation level is not supported")
+ case sql.LevelRepeatableRead:
+ tdsIsolation = isolationRepeatableRead
+ case sql.LevelSnapshot:
+ tdsIsolation = isolationSnapshot
+ case sql.LevelSerializable:
+ tdsIsolation = isolationSerializable
+ case sql.LevelLinearizable:
+ return nil, errors.New("LevelLinearizable isolation level is not supported")
+ default:
+ return nil, errors.New("Isolation level is not supported or unknown")
+ }
+ return c.begin(ctx, tdsIsolation)
+}
+
+func (c *Conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
+ if !c.connectionGood {
+ return nil, driver.ErrBadConn
+ }
+ if len(query) > 10 && strings.EqualFold(query[:10], "INSERTBULK") {
+ return c.prepareCopyIn(ctx, query)
+ }
+
+ return c.prepareContext(ctx, query)
+}
+
+func (s *Stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
+ if !s.c.connectionGood {
+ return nil, driver.ErrBadConn
+ }
+ list := make([]namedValue, len(args))
+ for i, nv := range args {
+ list[i] = namedValue(nv)
+ }
+ return s.queryContext(ctx, list)
+}
+
+func (s *Stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
+ if !s.c.connectionGood {
+ return nil, driver.ErrBadConn
+ }
+ list := make([]namedValue, len(args))
+ for i, nv := range args {
+ list[i] = namedValue(nv)
+ }
+ return s.exec(ctx, list)
+}