diff options
Diffstat (limited to 'vendor/github.com/denisenkom/go-mssqldb/mssql.go')
-rw-r--r-- | vendor/github.com/denisenkom/go-mssqldb/mssql.go | 385 |
1 files changed, 294 insertions, 91 deletions
diff --git a/vendor/github.com/denisenkom/go-mssqldb/mssql.go b/vendor/github.com/denisenkom/go-mssqldb/mssql.go index 8f5ff2d0ce..cf84f3a12b 100644 --- a/vendor/github.com/denisenkom/go-mssqldb/mssql.go +++ b/vendor/github.com/denisenkom/go-mssqldb/mssql.go @@ -13,32 +13,37 @@ import ( "reflect" "strings" "time" + "unicode" + + "github.com/denisenkom/go-mssqldb/internal/querytext" ) +// ReturnStatus may be used to return the return value from a proc. +// +// var rs mssql.ReturnStatus +// _, err := db.Exec("theproc", &rs) +// log.Printf("return status = %d", rs) +type ReturnStatus int32 + var driverInstance = &Driver{processQueryText: true} var driverInstanceNoProcess = &Driver{processQueryText: false} func init() { sql.Register("mssql", driverInstance) sql.Register("sqlserver", driverInstanceNoProcess) - createDialer = func(p *connectParams) dialer { - return tcpDialer{&net.Dialer{Timeout: p.dial_timeout, KeepAlive: p.keepAlive}} + createDialer = func(p *connectParams) Dialer { + return netDialer{&net.Dialer{KeepAlive: p.keepAlive}} } } -// Abstract the dialer for testing and for non-TCP based connections. -type dialer interface { - Dial(ctx context.Context, addr string) (net.Conn, error) -} +var createDialer func(p *connectParams) Dialer -var createDialer func(p *connectParams) dialer - -type tcpDialer struct { +type netDialer struct { nd *net.Dialer } -func (d tcpDialer) Dial(ctx context.Context, addr string) (net.Conn, error) { - return d.nd.DialContext(ctx, "tcp", addr) +func (d netDialer) DialContext(ctx context.Context, network string, addr string) (net.Conn, error) { + return d.nd.DialContext(ctx, network, addr) } type Driver struct { @@ -63,6 +68,29 @@ func (d *Driver) Open(dsn string) (driver.Conn, error) { return d.open(context.Background(), dsn) } +func SetLogger(logger Logger) { + driverInstance.SetLogger(logger) + driverInstanceNoProcess.SetLogger(logger) +} + +func (d *Driver) SetLogger(logger Logger) { + d.log = optionalLogger{logger} +} + +// NewConnector creates a new connector from a DSN. +// The returned connector may be used with sql.OpenDB. +func NewConnector(dsn string) (*Connector, error) { + params, err := parseConnectParams(dsn) + if err != nil { + return nil, err + } + c := &Connector{ + params: params, + driver: driverInstanceNoProcess, + } + return c, nil +} + // Connector holds the parsed DSN and is ready to make a new connection // at any time. // @@ -71,35 +99,64 @@ func (d *Driver) Open(dsn string) (driver.Conn, error) { type Connector struct { params connectParams driver *Driver -} - -// Connect to the server and return a TDS connection. -func (c *Connector) Connect(ctx context.Context) (driver.Conn, error) { - return c.driver.connect(ctx, c.params) -} -// Driver underlying the Connector. -func (c *Connector) Driver() driver.Driver { - return c.driver -} - -func SetLogger(logger Logger) { - driverInstance.SetLogger(logger) - driverInstanceNoProcess.SetLogger(logger) -} - -func (d *Driver) SetLogger(logger Logger) { - d.log = optionalLogger{logger} + // SessionInitSQL is executed after marking a given session to be reset. + // When not present, the next query will still reset the session to the + // database defaults. + // + // When present the connection will immediately mark the session to + // be reset, then execute the SessionInitSQL text to setup the session + // that may be different from the base database defaults. + // + // For Example, the application relies on the following defaults + // but is not allowed to set them at the database system level. + // + // SET XACT_ABORT ON; + // SET TEXTSIZE -1; + // SET ANSI_NULLS ON; + // SET LOCK_TIMEOUT 10000; + // + // SessionInitSQL should not attempt to manually call sp_reset_connection. + // This will happen at the TDS layer. + // + // SessionInitSQL is optional. The session will be reset even if + // SessionInitSQL is empty. + SessionInitSQL string + + // Dialer sets a custom dialer for all network operations. + // If Dialer is not set, normal net dialers are used. + Dialer Dialer +} + +type Dialer interface { + DialContext(ctx context.Context, network string, addr string) (net.Conn, error) +} + +func (c *Connector) getDialer(p *connectParams) Dialer { + if c != nil && c.Dialer != nil { + return c.Dialer + } + return createDialer(p) } type Conn struct { + connector *Connector sess *tdsSession transactionCtx context.Context + resetSession bool processQueryText bool connectionGood bool - outs map[string]interface{} + outs map[string]interface{} + returnStatus *ReturnStatus +} + +func (c *Conn) setReturnStatus(s ReturnStatus) { + if c.returnStatus == nil { + return + } + *c.returnStatus = s } func (c *Conn) checkBadConn(err error) error { @@ -117,6 +174,7 @@ func (c *Conn) checkBadConn(err error) error { case nil: return nil case io.EOF: + c.connectionGood = false return driver.ErrBadConn case driver.ErrBadConn: // It is an internal programming error if driver.ErrBadConn @@ -174,7 +232,9 @@ func (c *Conn) sendCommitRequest() error { {hdrtype: dataStmHdrTransDescr, data: transDescrHdr{c.sess.tranid, 1}.pack()}, } - if err := sendCommitXact(c.sess.buf, headers, "", 0, 0, ""); err != nil { + reset := c.resetSession + c.resetSession = false + if err := sendCommitXact(c.sess.buf, headers, "", 0, 0, "", reset); err != nil { if c.sess.logFlags&logErrors != 0 { c.sess.log.Printf("Failed to send CommitXact with %v", err) } @@ -199,7 +259,9 @@ func (c *Conn) sendRollbackRequest() error { {hdrtype: dataStmHdrTransDescr, data: transDescrHdr{c.sess.tranid, 1}.pack()}, } - if err := sendRollbackXact(c.sess.buf, headers, "", 0, 0, ""); err != nil { + reset := c.resetSession + c.resetSession = false + if err := sendRollbackXact(c.sess.buf, headers, "", 0, 0, "", reset); err != nil { if c.sess.logFlags&logErrors != 0 { c.sess.log.Printf("Failed to send RollbackXact with %v", err) } @@ -234,12 +296,14 @@ func (c *Conn) sendBeginRequest(ctx context.Context, tdsIsolation isoLevel) erro {hdrtype: dataStmHdrTransDescr, data: transDescrHdr{0, 1}.pack()}, } - if err := sendBeginXact(c.sess.buf, headers, tdsIsolation, ""); err != nil { + reset := c.resetSession + c.resetSession = false + if err := sendBeginXact(c.sess.buf, headers, tdsIsolation, "", reset); err != nil { if c.sess.logFlags&logErrors != 0 { c.sess.log.Printf("Failed to send BeginXact with %v", err) } c.connectionGood = false - return fmt.Errorf("Failed to send BiginXant: %v", err) + return fmt.Errorf("Failed to send BeginXact: %v", err) } return nil } @@ -258,12 +322,12 @@ func (d *Driver) open(ctx context.Context, dsn string) (*Conn, error) { if err != nil { return nil, err } - return d.connect(ctx, params) + return d.connect(ctx, nil, params) } // connect to the server, using the provided context for dialing only. -func (d *Driver) connect(ctx context.Context, params connectParams) (*Conn, error) { - sess, err := connect(ctx, d.log, params) +func (d *Driver) connect(ctx context.Context, c *Connector, params connectParams) (*Conn, error) { + sess, err := connect(ctx, c, d.log, params) if err != nil { // main server failed, try fail-over partner if params.failOverPartner == "" { @@ -275,7 +339,7 @@ func (d *Driver) connect(ctx context.Context, params connectParams) (*Conn, erro params.port = params.failOverPort } - sess, err = connect(ctx, d.log, params) + sess, err = connect(ctx, c, d.log, params) if err != nil { // fail-over partner also failed, now fail return nil, err @@ -283,12 +347,13 @@ func (d *Driver) connect(ctx context.Context, params connectParams) (*Conn, erro } conn := &Conn{ + connector: c, sess: sess, transactionCtx: context.Background(), processQueryText: d.processQueryText, connectionGood: true, } - conn.sess.log = d.log + return conn, nil } @@ -314,16 +379,15 @@ func (c *Conn) Prepare(query string) (driver.Stmt, error) { return nil, driver.ErrBadConn } if len(query) > 10 && strings.EqualFold(query[:10], "INSERTBULK") { - return c.prepareCopyIn(query) + return c.prepareCopyIn(context.Background(), query) } - return c.prepareContext(context.Background(), query) } func (c *Conn) prepareContext(ctx context.Context, query string) (*Stmt, error) { paramCount := -1 if c.processQueryText { - query, paramCount = parseParams(query) + query, paramCount = querytext.ParseParams(query) } return &Stmt{c, query, paramCount, nil}, nil } @@ -362,11 +426,13 @@ func (s *Stmt) sendQuery(args []namedValue) (err error) { }) } + conn := s.c + // no need to check number of parameters here, it is checked by database/sql - if s.c.sess.logFlags&logSQL != 0 { - s.c.sess.log.Println(s.query) + if conn.sess.logFlags&logSQL != 0 { + conn.sess.log.Println(s.query) } - if s.c.sess.logFlags&logParams != 0 && len(args) > 0 { + if conn.sess.logFlags&logParams != 0 && len(args) > 0 { for i := 0; i < len(args); i++ { if len(args[i].Name) > 0 { s.c.sess.log.Printf("\t@%s\t%v\n", args[i].Name, args[i].Value) @@ -374,36 +440,41 @@ func (s *Stmt) sendQuery(args []namedValue) (err error) { s.c.sess.log.Printf("\t@p%d\t%v\n", i+1, args[i].Value) } } - } + + reset := conn.resetSession + conn.resetSession = false if len(args) == 0 { - if err = sendSqlBatch72(s.c.sess.buf, s.query, headers); err != nil { - if s.c.sess.logFlags&logErrors != 0 { - s.c.sess.log.Printf("Failed to send SqlBatch with %v", err) + if err = sendSqlBatch72(conn.sess.buf, s.query, headers, reset); err != nil { + if conn.sess.logFlags&logErrors != 0 { + conn.sess.log.Printf("Failed to send SqlBatch with %v", err) } - s.c.connectionGood = false + conn.connectionGood = false return fmt.Errorf("failed to send SQL Batch: %v", err) } } else { - proc := Sp_ExecuteSql - var params []Param + proc := sp_ExecuteSql + var params []param if isProc(s.query) { proc.name = s.query - params, _, err = s.makeRPCParams(args, 0) + params, _, err = s.makeRPCParams(args, true) + if err != nil { + return + } } else { var decls []string - params, decls, err = s.makeRPCParams(args, 2) + params, decls, err = s.makeRPCParams(args, false) if err != nil { return } params[0] = makeStrParam(s.query) params[1] = makeStrParam(strings.Join(decls, ",")) } - if err = sendRpc(s.c.sess.buf, headers, proc, 0, params); err != nil { - if s.c.sess.logFlags&logErrors != 0 { - s.c.sess.log.Printf("Failed to send Rpc with %v", err) + if err = sendRpc(conn.sess.buf, headers, proc, 0, params, reset); err != nil { + if conn.sess.logFlags&logErrors != 0 { + conn.sess.log.Printf("Failed to send Rpc with %v", err) } - s.c.connectionGood = false + conn.connectionGood = false return fmt.Errorf("Failed to send RPC: %v", err) } } @@ -416,15 +487,61 @@ func isProc(s string) bool { if len(s) == 0 { return false } - if s[0] == '[' && s[len(s)-1] == ']' && strings.ContainsAny(s, "\n\r") == false { - return true + const ( + outside = iota + text + escaped + ) + st := outside + var rn1, rPrev rune + for _, r := range s { + rPrev = rn1 + rn1 = r + switch r { + // No newlines or string sequences. + case '\n', '\r', '\'', ';': + return false + } + switch st { + case outside: + switch { + case unicode.IsSpace(r): + return false + case r == '[': + st = escaped + continue + case r == ']' && rPrev == ']': + st = escaped + continue + case unicode.IsLetter(r): + st = text + } + case text: + switch { + case r == '.': + st = outside + continue + case unicode.IsSpace(r): + return false + } + case escaped: + switch { + case r == ']': + st = outside + continue + } + } } - return !strings.ContainsAny(s, " \t\n\r;") + return true } -func (s *Stmt) makeRPCParams(args []namedValue, offset int) ([]Param, []string, error) { +func (s *Stmt) makeRPCParams(args []namedValue, isProc bool) ([]param, []string, error) { var err error - params := make([]Param, len(args)+offset) + var offset int + if !isProc { + offset = 2 + } + params := make([]param, len(args)+offset) decls := make([]string, len(args)) for i, val := range args { params[i+offset], err = s.makeParam(val.Value) @@ -434,7 +551,7 @@ func (s *Stmt) makeRPCParams(args []namedValue, offset int) ([]Param, []string, var name string if len(val.Name) > 0 { name = "@" + val.Name - } else { + } else if !isProc { name = fmt.Sprintf("@p%d", val.Ordinal) } params[i+offset].Name = name @@ -498,6 +615,8 @@ loop: if token.isError() { return nil, s.c.checkBadConn(token.getError()) } + case ReturnStatus: + s.c.setReturnStatus(token) case error: return nil, s.c.checkBadConn(token) } @@ -541,6 +660,8 @@ func (s *Stmt) processExec(ctx context.Context) (res driver.Result, err error) { if token.isError() { return nil, token.getError() } + case ReturnStatus: + s.c.setReturnStatus(token) case error: return nil, token } @@ -666,14 +787,14 @@ func (r *Rows) ColumnTypeNullable(index int) (nullable, ok bool) { return } -func makeStrParam(val string) (res Param) { +func makeStrParam(val string) (res param) { res.ti.TypeId = typeNVarChar res.buffer = str2ucs2(val) res.ti.Size = len(res.buffer) return } -func (s *Stmt) makeParam(val driver.Value) (res Param, err error) { +func (s *Stmt) makeParam(val driver.Value) (res param, err error) { if val == nil { res.ti.TypeId = typeNull res.buffer = nil @@ -686,17 +807,34 @@ func (s *Stmt) makeParam(val driver.Value) (res Param, err error) { res.buffer = make([]byte, 8) res.ti.Size = 8 binary.LittleEndian.PutUint64(res.buffer, uint64(val)) + case sql.NullInt64: + // only null values should be getting here + res.ti.TypeId = typeIntN + res.ti.Size = 8 + res.buffer = []byte{} + case float64: res.ti.TypeId = typeFltN res.ti.Size = 8 res.buffer = make([]byte, 8) binary.LittleEndian.PutUint64(res.buffer, math.Float64bits(val)) + case sql.NullFloat64: + // only null values should be getting here + res.ti.TypeId = typeFltN + res.ti.Size = 8 + res.buffer = []byte{} + case []byte: res.ti.TypeId = typeBigVarBin res.ti.Size = len(val) res.buffer = val case string: res = makeStrParam(val) + case sql.NullString: + // only null values should be getting here + res.ti.TypeId = typeNVarChar + res.buffer = nil + res.ti.Size = 8000 case bool: res.ti.TypeId = typeBitN res.ti.Size = 1 @@ -704,37 +842,22 @@ func (s *Stmt) makeParam(val driver.Value) (res Param, err error) { if val { res.buffer[0] = 1 } + case sql.NullBool: + // only null values should be getting here + res.ti.TypeId = typeBitN + res.ti.Size = 1 + res.buffer = []byte{} + case time.Time: if s.c.sess.loginAck.TDSVersion >= verTDS73 { res.ti.TypeId = typeDateTimeOffsetN res.ti.Scale = 7 - res.ti.Size = 10 - buf := make([]byte, 10) - res.buffer = buf - days, ns := dateTime2(val) - ns /= 100 - buf[0] = byte(ns) - buf[1] = byte(ns >> 8) - buf[2] = byte(ns >> 16) - buf[3] = byte(ns >> 24) - buf[4] = byte(ns >> 32) - buf[5] = byte(days) - buf[6] = byte(days >> 8) - buf[7] = byte(days >> 16) - _, offset := val.Zone() - offset /= 60 - buf[8] = byte(offset) - buf[9] = byte(offset >> 8) + res.buffer = encodeDateTimeOffset(val, int(res.ti.Scale)) + res.ti.Size = len(res.buffer) } else { res.ti.TypeId = typeDateTimeN - res.ti.Size = 8 - res.buffer = make([]byte, 8) - ref := time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC) - dur := val.Sub(ref) - days := dur / (24 * time.Hour) - tm := (300 * (dur % (24 * time.Hour))) / time.Second - binary.LittleEndian.PutUint32(res.buffer[0:4], uint32(days)) - binary.LittleEndian.PutUint32(res.buffer[4:8], uint32(tm)) + res.buffer = encodeDateTime(val) + res.ti.Size = len(res.buffer) } default: return s.makeParamExtra(val) @@ -773,3 +896,83 @@ func (r *Result) LastInsertId() (int64, error) { lastInsertId := dest[0].(int64) return lastInsertId, nil } + +var _ driver.Pinger = &Conn{} + +// Ping is used to check if the remote server is available and satisfies the Pinger interface. +func (c *Conn) Ping(ctx context.Context) error { + if !c.connectionGood { + return driver.ErrBadConn + } + stmt := &Stmt{c, `select 1;`, 0, nil} + _, err := stmt.ExecContext(ctx, nil) + return err +} + +var _ driver.ConnBeginTx = &Conn{} + +// BeginTx satisfies ConnBeginTx. +func (c *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { + if !c.connectionGood { + return nil, driver.ErrBadConn + } + if opts.ReadOnly { + return nil, errors.New("Read-only transactions are not supported") + } + + var tdsIsolation isoLevel + switch sql.IsolationLevel(opts.Isolation) { + case sql.LevelDefault: + tdsIsolation = isolationUseCurrent + case sql.LevelReadUncommitted: + tdsIsolation = isolationReadUncommited + case sql.LevelReadCommitted: + tdsIsolation = isolationReadCommited + case sql.LevelWriteCommitted: + return nil, errors.New("LevelWriteCommitted isolation level is not supported") + case sql.LevelRepeatableRead: + tdsIsolation = isolationRepeatableRead + case sql.LevelSnapshot: + tdsIsolation = isolationSnapshot + case sql.LevelSerializable: + tdsIsolation = isolationSerializable + case sql.LevelLinearizable: + return nil, errors.New("LevelLinearizable isolation level is not supported") + default: + return nil, errors.New("Isolation level is not supported or unknown") + } + return c.begin(ctx, tdsIsolation) +} + +func (c *Conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { + if !c.connectionGood { + return nil, driver.ErrBadConn + } + if len(query) > 10 && strings.EqualFold(query[:10], "INSERTBULK") { + return c.prepareCopyIn(ctx, query) + } + + return c.prepareContext(ctx, query) +} + +func (s *Stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { + if !s.c.connectionGood { + return nil, driver.ErrBadConn + } + list := make([]namedValue, len(args)) + for i, nv := range args { + list[i] = namedValue(nv) + } + return s.queryContext(ctx, list) +} + +func (s *Stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { + if !s.c.connectionGood { + return nil, driver.ErrBadConn + } + list := make([]namedValue, len(args)) + for i, nv := range args { + list[i] = namedValue(nv) + } + return s.exec(ctx, list) +} |