summaryrefslogtreecommitdiffstats
path: root/vendor/github.com/denisenkom/go-mssqldb/token.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/denisenkom/go-mssqldb/token.go')
-rw-r--r--vendor/github.com/denisenkom/go-mssqldb/token.go387
1 files changed, 327 insertions, 60 deletions
diff --git a/vendor/github.com/denisenkom/go-mssqldb/token.go b/vendor/github.com/denisenkom/go-mssqldb/token.go
index f20bd14cc9..5f2167eb86 100644
--- a/vendor/github.com/denisenkom/go-mssqldb/token.go
+++ b/vendor/github.com/denisenkom/go-mssqldb/token.go
@@ -1,30 +1,40 @@
package mssql
import (
+ "context"
"encoding/binary"
+ "errors"
+ "fmt"
"io"
+ "net"
"strconv"
"strings"
)
+//go:generate stringer -type token
+
+type token byte
+
// token ids
const (
- tokenReturnStatus = 121 // 0x79
- tokenColMetadata = 129 // 0x81
- tokenOrder = 169 // 0xA9
- tokenError = 170 // 0xAA
- tokenInfo = 171 // 0xAB
- tokenLoginAck = 173 // 0xad
- tokenRow = 209 // 0xd1
- tokenNbcRow = 210 // 0xd2
- tokenEnvChange = 227 // 0xE3
- tokenSSPI = 237 // 0xED
- tokenDone = 253 // 0xFD
- tokenDoneProc = 254
- tokenDoneInProc = 255
+ tokenReturnStatus token = 121 // 0x79
+ tokenColMetadata token = 129 // 0x81
+ tokenOrder token = 169 // 0xA9
+ tokenError token = 170 // 0xAA
+ tokenInfo token = 171 // 0xAB
+ tokenReturnValue token = 0xAC
+ tokenLoginAck token = 173 // 0xad
+ tokenRow token = 209 // 0xd1
+ tokenNbcRow token = 210 // 0xd2
+ tokenEnvChange token = 227 // 0xE3
+ tokenSSPI token = 237 // 0xED
+ tokenDone token = 253 // 0xFD
+ tokenDoneProc token = 254
+ tokenDoneInProc token = 255
)
// done flags
+// https://msdn.microsoft.com/en-us/library/dd340421.aspx
const (
doneFinal = 0
doneMore = 1
@@ -59,6 +69,13 @@ const (
envRouting = 20
)
+// COLMETADATA flags
+// https://msdn.microsoft.com/en-us/library/dd357363.aspx
+const (
+ colFlagNullable = 1
+ // TODO implement more flags
+)
+
// interface for all tokens
type tokenStruct interface{}
@@ -70,6 +87,19 @@ type doneStruct struct {
Status uint16
CurCmd uint16
RowCount uint64
+ errors []Error
+}
+
+func (d doneStruct) isError() bool {
+ return d.Status&doneError != 0 || len(d.errors) > 0
+}
+
+func (d doneStruct) getError() Error {
+ if len(d.errors) > 0 {
+ return d.errors[len(d.errors)-1]
+ } else {
+ return Error{Message: "Request failed but didn't provide reason"}
+ }
}
type doneInProcStruct doneStruct
@@ -120,27 +150,23 @@ func processEnvChg(sess *tdsSession) {
badStreamPanic(err)
}
case envTypLanguage:
- //currently ignored
- // old value
- _, err = readBVarChar(r)
- if err != nil {
- badStreamPanic(err)
- }
+ // currently ignored
// new value
- _, err = readBVarChar(r)
- if err != nil {
+ if _, err = readBVarChar(r); err != nil {
badStreamPanic(err)
}
- case envTypCharset:
- //currently ignored
// old value
- _, err = readBVarChar(r)
- if err != nil {
+ if _, err = readBVarChar(r); err != nil {
badStreamPanic(err)
}
+ case envTypCharset:
+ // currently ignored
// new value
- _, err = readBVarChar(r)
- if err != nil {
+ if _, err = readBVarChar(r); err != nil {
+ badStreamPanic(err)
+ }
+ // old value
+ if _, err = readBVarChar(r); err != nil {
badStreamPanic(err)
}
case envTypPacketSize:
@@ -156,38 +182,55 @@ func processEnvChg(sess *tdsSession) {
if err != nil {
badStreamPanicf("Invalid Packet size value returned from server (%s): %s", packetsize, err.Error())
}
- if len(sess.buf.buf) != packetsizei {
- newbuf := make([]byte, packetsizei)
- copy(newbuf, sess.buf.buf)
- sess.buf.buf = newbuf
- }
+ sess.buf.ResizeBuffer(packetsizei)
case envSortId:
// currently ignored
- // old value, should be 0
+ // new value
if _, err = readBVarChar(r); err != nil {
badStreamPanic(err)
}
- // new value
+ // old value, should be 0
if _, err = readBVarChar(r); err != nil {
badStreamPanic(err)
}
case envSortFlags:
// currently ignored
- // old value, should be 0
+ // new value
if _, err = readBVarChar(r); err != nil {
badStreamPanic(err)
}
- // new value
+ // old value, should be 0
if _, err = readBVarChar(r); err != nil {
badStreamPanic(err)
}
case envSqlCollation:
// currently ignored
- // old value
- if _, err = readBVarChar(r); err != nil {
+ var collationSize uint8
+ err = binary.Read(r, binary.LittleEndian, &collationSize)
+ if err != nil {
badStreamPanic(err)
}
- // new value
+
+ // SQL Collation data should contain 5 bytes in length
+ if collationSize != 5 {
+ badStreamPanicf("Invalid SQL Collation size value returned from server: %s", collationSize)
+ }
+
+ // 4 bytes, contains: LCID ColFlags Version
+ var info uint32
+ err = binary.Read(r, binary.LittleEndian, &info)
+ if err != nil {
+ badStreamPanic(err)
+ }
+
+ // 1 byte, contains: sortID
+ var sortID uint8
+ err = binary.Read(r, binary.LittleEndian, &sortID)
+ if err != nil {
+ badStreamPanic(err)
+ }
+
+ // old value, should be 0
if _, err = readBVarChar(r); err != nil {
badStreamPanic(err)
}
@@ -226,21 +269,21 @@ func processEnvChg(sess *tdsSession) {
sess.tranid = 0
case envEnlistDTC:
// currently ignored
- // old value
+ // new value, should be 0
if _, err = readBVarChar(r); err != nil {
badStreamPanic(err)
}
- // new value, should be 0
+ // old value
if _, err = readBVarChar(r); err != nil {
badStreamPanic(err)
}
case envDefectTran:
// currently ignored
- // old value, should be 0
+ // new value
if _, err = readBVarChar(r); err != nil {
badStreamPanic(err)
}
- // new value
+ // old value, should be 0
if _, err = readBVarChar(r); err != nil {
badStreamPanic(err)
}
@@ -358,6 +401,7 @@ func parseOrder(r *tdsBuffer) (res orderStruct) {
return res
}
+// https://msdn.microsoft.com/en-us/library/dd340421.aspx
func parseDone(r *tdsBuffer) (res doneStruct) {
res.Status = r.uint16()
res.CurCmd = r.uint16()
@@ -365,6 +409,7 @@ func parseDone(r *tdsBuffer) (res doneStruct) {
return res
}
+// https://msdn.microsoft.com/en-us/library/dd340553.aspx
func parseDoneInProc(r *tdsBuffer) (res doneInProcStruct) {
res.Status = r.uint16()
res.CurCmd = r.uint16()
@@ -473,26 +518,57 @@ func parseInfo(r *tdsBuffer) (res Error) {
return
}
-func processResponse(sess *tdsSession, ch chan tokenStruct) {
+// https://msdn.microsoft.com/en-us/library/dd303881.aspx
+func parseReturnValue(r *tdsBuffer) (nv namedValue) {
+ /*
+ ParamOrdinal
+ ParamName
+ Status
+ UserType
+ Flags
+ TypeInfo
+ CryptoMetadata
+ Value
+ */
+ r.uint16()
+ nv.Name = r.BVarChar()
+ r.byte()
+ r.uint32() // UserType (uint16 prior to 7.2)
+ r.uint16()
+ ti := readTypeInfo(r)
+ nv.Value = ti.Reader(&ti, r)
+ return
+}
+
+func processSingleResponse(sess *tdsSession, ch chan tokenStruct, outs map[string]interface{}) {
defer func() {
if err := recover(); err != nil {
+ if sess.logFlags&logErrors != 0 {
+ sess.log.Printf("ERROR: Intercepted panic %v", err)
+ }
ch <- err
}
close(ch)
}()
+
packet_type, err := sess.buf.BeginRead()
if err != nil {
+ if sess.logFlags&logErrors != 0 {
+ sess.log.Printf("ERROR: BeginRead failed %v", err)
+ }
ch <- err
return
}
if packet_type != packReply {
- badStreamPanicf("invalid response packet type, expected REPLY, actual: %d", packet_type)
+ badStreamPanic(fmt.Errorf("unexpected packet type in reply: got %v, expected %v", packet_type, packReply))
}
var columns []columnStruct
- var lastError Error
- var failed bool
+ errs := make([]Error, 0, 5)
for {
- token := sess.buf.byte()
+ token := token(sess.buf.byte())
+ if sess.logFlags&logDebug != 0 {
+ sess.log.Printf("got token %v", token)
+ }
switch token {
case tokenSSPI:
ch <- parseSSPIMsg(sess.buf)
@@ -514,18 +590,17 @@ func processResponse(sess *tdsSession, ch chan tokenStruct) {
ch <- done
case tokenDone, tokenDoneProc:
done := parseDone(sess.buf)
- if sess.logFlags&logRows != 0 && done.Status&doneCount != 0 {
- sess.log.Printf("(%d row(s) affected)\n", done.RowCount)
- }
- if done.Status&doneError != 0 || failed {
- ch <- lastError
- return
+ done.errors = errs
+ if sess.logFlags&logDebug != 0 {
+ sess.log.Printf("got DONE or DONEPROC status=%d", done.Status)
}
if done.Status&doneSrvError != 0 {
- lastError.Message = "Server Error"
- ch <- lastError
+ ch <- errors.New("SQL Server had internal error")
return
}
+ if sess.logFlags&logRows != 0 && done.Status&doneCount != 0 {
+ sess.log.Printf("(%d row(s) affected)\n", done.RowCount)
+ }
ch <- done
if done.Status&doneMore == 0 {
return
@@ -544,18 +619,210 @@ func processResponse(sess *tdsSession, ch chan tokenStruct) {
case tokenEnvChange:
processEnvChg(sess)
case tokenError:
- lastError = parseError72(sess.buf)
- failed = true
+ err := parseError72(sess.buf)
+ if sess.logFlags&logDebug != 0 {
+ sess.log.Printf("got ERROR %d %s", err.Number, err.Message)
+ }
+ errs = append(errs, err)
if sess.logFlags&logErrors != 0 {
- sess.log.Println(lastError.Message)
+ sess.log.Println(err.Message)
}
case tokenInfo:
info := parseInfo(sess.buf)
+ if sess.logFlags&logDebug != 0 {
+ sess.log.Printf("got INFO %d %s", info.Number, info.Message)
+ }
if sess.logFlags&logMessages != 0 {
sess.log.Println(info.Message)
}
+ case tokenReturnValue:
+ nv := parseReturnValue(sess.buf)
+ if len(nv.Name) > 0 {
+ name := nv.Name[1:] // Remove the leading "@".
+ if ov, has := outs[name]; has {
+ err = scanIntoOut(nv.Value, ov)
+ if err != nil {
+ fmt.Println("scan error", err)
+ ch <- err
+ }
+ }
+ }
+ default:
+ badStreamPanic(fmt.Errorf("unknown token type returned: %v", token))
+ }
+ }
+}
+
+func scanIntoOut(fromServer, scanInto interface{}) error {
+ switch fs := fromServer.(type) {
+ case int64:
+ switch si := scanInto.(type) {
+ case *int64:
+ *si = fs
default:
- badStreamPanicf("Unknown token type: %d", token)
+ return fmt.Errorf("unsupported scan into type %[1]T for server type %[2]T", scanInto, fromServer)
+ }
+ return nil
+ case string:
+ switch si := scanInto.(type) {
+ case *string:
+ *si = fs
+ default:
+ return fmt.Errorf("unsupported scan into type %[1]T for server type %[2]T", scanInto, fromServer)
+ }
+ return nil
+ }
+ return fmt.Errorf("unsupported type from server %[1]T=%[1]v", fromServer)
+}
+
+type parseRespIter byte
+
+const (
+ parseRespIterContinue parseRespIter = iota // Continue parsing current token.
+ parseRespIterNext // Fetch the next token.
+ parseRespIterDone // Done with parsing the response.
+)
+
+type parseRespState byte
+
+const (
+ parseRespStateNormal parseRespState = iota // Normal response state.
+ parseRespStateCancel // Query is canceled, wait for server to confirm.
+ parseRespStateClosing // Waiting for tokens to come through.
+)
+
+type parseResp struct {
+ sess *tdsSession
+ ctxDone <-chan struct{}
+ state parseRespState
+ cancelError error
+}
+
+func (ts *parseResp) sendAttention(ch chan tokenStruct) parseRespIter {
+ if err := sendAttention(ts.sess.buf); err != nil {
+ ts.dlogf("failed to send attention signal %v", err)
+ ch <- err
+ return parseRespIterDone
+ }
+ ts.state = parseRespStateCancel
+ return parseRespIterContinue
+}
+
+func (ts *parseResp) dlog(msg string) {
+ if ts.sess.logFlags&logDebug != 0 {
+ ts.sess.log.Println(msg)
+ }
+}
+func (ts *parseResp) dlogf(f string, v ...interface{}) {
+ if ts.sess.logFlags&logDebug != 0 {
+ ts.sess.log.Printf(f, v...)
+ }
+}
+
+func (ts *parseResp) iter(ctx context.Context, ch chan tokenStruct, tokChan chan tokenStruct) parseRespIter {
+ switch ts.state {
+ default:
+ panic("unknown state")
+ case parseRespStateNormal:
+ select {
+ case tok, ok := <-tokChan:
+ if !ok {
+ ts.dlog("response finished")
+ return parseRespIterDone
+ }
+ if err, ok := tok.(net.Error); ok && err.Timeout() {
+ ts.cancelError = err
+ ts.dlog("got timeout error, sending attention signal to server")
+ return ts.sendAttention(ch)
+ }
+ // Pass the token along.
+ ch <- tok
+ return parseRespIterContinue
+
+ case <-ts.ctxDone:
+ ts.ctxDone = nil
+ ts.dlog("got cancel message, sending attention signal to server")
+ return ts.sendAttention(ch)
+ }
+ case parseRespStateCancel: // Read all responses until a DONE or error is received.Auth
+ select {
+ case tok, ok := <-tokChan:
+ if !ok {
+ ts.dlog("response finished but waiting for attention ack")
+ return parseRespIterNext
+ }
+ switch tok := tok.(type) {
+ default:
+ // Ignore all other tokens while waiting.
+ // The TDS spec says other tokens may arrive after an attention
+ // signal is sent. Ignore these tokens and continue looking for
+ // a DONE with attention confirm mark.
+ case doneStruct:
+ if tok.Status&doneAttn != 0 {
+ ts.dlog("got cancellation confirmation from server")
+ if ts.cancelError != nil {
+ ch <- ts.cancelError
+ ts.cancelError = nil
+ } else {
+ ch <- ctx.Err()
+ }
+ return parseRespIterDone
+ }
+
+ // If an error happens during cancel, pass it along and just stop.
+ // We are uncertain to receive more tokens.
+ case error:
+ ch <- tok
+ ts.state = parseRespStateClosing
+ }
+ return parseRespIterContinue
+ case <-ts.ctxDone:
+ ts.ctxDone = nil
+ ts.state = parseRespStateClosing
+ return parseRespIterContinue
+ }
+ case parseRespStateClosing: // Wait for current token chan to close.
+ if _, ok := <-tokChan; !ok {
+ ts.dlog("response finished")
+ return parseRespIterDone
+ }
+ return parseRespIterContinue
+ }
+}
+
+func processResponse(ctx context.Context, sess *tdsSession, ch chan tokenStruct, outs map[string]interface{}) {
+ ts := &parseResp{
+ sess: sess,
+ ctxDone: ctx.Done(),
+ }
+ defer func() {
+ // Ensure any remaining error is piped through
+ // or the query may look like it executed when it actually failed.
+ if ts.cancelError != nil {
+ ch <- ts.cancelError
+ ts.cancelError = nil
+ }
+ close(ch)
+ }()
+
+ // Loop over multiple responses.
+ for {
+ ts.dlog("initiating response reading")
+
+ tokChan := make(chan tokenStruct)
+ go processSingleResponse(sess, tokChan, outs)
+
+ // Loop over multiple tokens in response.
+ tokensLoop:
+ for {
+ switch ts.iter(ctx, ch, tokChan) {
+ case parseRespIterContinue:
+ // Nothing, continue to next token.
+ case parseRespIterNext:
+ break tokensLoop
+ case parseRespIterDone:
+ return
+ }
}
}
}