diff options
Diffstat (limited to 'vendor/github.com/denisenkom/go-mssqldb/token.go')
-rw-r--r-- | vendor/github.com/denisenkom/go-mssqldb/token.go | 387 |
1 files changed, 327 insertions, 60 deletions
diff --git a/vendor/github.com/denisenkom/go-mssqldb/token.go b/vendor/github.com/denisenkom/go-mssqldb/token.go index f20bd14cc9..5f2167eb86 100644 --- a/vendor/github.com/denisenkom/go-mssqldb/token.go +++ b/vendor/github.com/denisenkom/go-mssqldb/token.go @@ -1,30 +1,40 @@ package mssql import ( + "context" "encoding/binary" + "errors" + "fmt" "io" + "net" "strconv" "strings" ) +//go:generate stringer -type token + +type token byte + // token ids const ( - tokenReturnStatus = 121 // 0x79 - tokenColMetadata = 129 // 0x81 - tokenOrder = 169 // 0xA9 - tokenError = 170 // 0xAA - tokenInfo = 171 // 0xAB - tokenLoginAck = 173 // 0xad - tokenRow = 209 // 0xd1 - tokenNbcRow = 210 // 0xd2 - tokenEnvChange = 227 // 0xE3 - tokenSSPI = 237 // 0xED - tokenDone = 253 // 0xFD - tokenDoneProc = 254 - tokenDoneInProc = 255 + tokenReturnStatus token = 121 // 0x79 + tokenColMetadata token = 129 // 0x81 + tokenOrder token = 169 // 0xA9 + tokenError token = 170 // 0xAA + tokenInfo token = 171 // 0xAB + tokenReturnValue token = 0xAC + tokenLoginAck token = 173 // 0xad + tokenRow token = 209 // 0xd1 + tokenNbcRow token = 210 // 0xd2 + tokenEnvChange token = 227 // 0xE3 + tokenSSPI token = 237 // 0xED + tokenDone token = 253 // 0xFD + tokenDoneProc token = 254 + tokenDoneInProc token = 255 ) // done flags +// https://msdn.microsoft.com/en-us/library/dd340421.aspx const ( doneFinal = 0 doneMore = 1 @@ -59,6 +69,13 @@ const ( envRouting = 20 ) +// COLMETADATA flags +// https://msdn.microsoft.com/en-us/library/dd357363.aspx +const ( + colFlagNullable = 1 + // TODO implement more flags +) + // interface for all tokens type tokenStruct interface{} @@ -70,6 +87,19 @@ type doneStruct struct { Status uint16 CurCmd uint16 RowCount uint64 + errors []Error +} + +func (d doneStruct) isError() bool { + return d.Status&doneError != 0 || len(d.errors) > 0 +} + +func (d doneStruct) getError() Error { + if len(d.errors) > 0 { + return d.errors[len(d.errors)-1] + } else { + return Error{Message: "Request failed but didn't provide reason"} + } } type doneInProcStruct doneStruct @@ -120,27 +150,23 @@ func processEnvChg(sess *tdsSession) { badStreamPanic(err) } case envTypLanguage: - //currently ignored - // old value - _, err = readBVarChar(r) - if err != nil { - badStreamPanic(err) - } + // currently ignored // new value - _, err = readBVarChar(r) - if err != nil { + if _, err = readBVarChar(r); err != nil { badStreamPanic(err) } - case envTypCharset: - //currently ignored // old value - _, err = readBVarChar(r) - if err != nil { + if _, err = readBVarChar(r); err != nil { badStreamPanic(err) } + case envTypCharset: + // currently ignored // new value - _, err = readBVarChar(r) - if err != nil { + if _, err = readBVarChar(r); err != nil { + badStreamPanic(err) + } + // old value + if _, err = readBVarChar(r); err != nil { badStreamPanic(err) } case envTypPacketSize: @@ -156,38 +182,55 @@ func processEnvChg(sess *tdsSession) { if err != nil { badStreamPanicf("Invalid Packet size value returned from server (%s): %s", packetsize, err.Error()) } - if len(sess.buf.buf) != packetsizei { - newbuf := make([]byte, packetsizei) - copy(newbuf, sess.buf.buf) - sess.buf.buf = newbuf - } + sess.buf.ResizeBuffer(packetsizei) case envSortId: // currently ignored - // old value, should be 0 + // new value if _, err = readBVarChar(r); err != nil { badStreamPanic(err) } - // new value + // old value, should be 0 if _, err = readBVarChar(r); err != nil { badStreamPanic(err) } case envSortFlags: // currently ignored - // old value, should be 0 + // new value if _, err = readBVarChar(r); err != nil { badStreamPanic(err) } - // new value + // old value, should be 0 if _, err = readBVarChar(r); err != nil { badStreamPanic(err) } case envSqlCollation: // currently ignored - // old value - if _, err = readBVarChar(r); err != nil { + var collationSize uint8 + err = binary.Read(r, binary.LittleEndian, &collationSize) + if err != nil { badStreamPanic(err) } - // new value + + // SQL Collation data should contain 5 bytes in length + if collationSize != 5 { + badStreamPanicf("Invalid SQL Collation size value returned from server: %s", collationSize) + } + + // 4 bytes, contains: LCID ColFlags Version + var info uint32 + err = binary.Read(r, binary.LittleEndian, &info) + if err != nil { + badStreamPanic(err) + } + + // 1 byte, contains: sortID + var sortID uint8 + err = binary.Read(r, binary.LittleEndian, &sortID) + if err != nil { + badStreamPanic(err) + } + + // old value, should be 0 if _, err = readBVarChar(r); err != nil { badStreamPanic(err) } @@ -226,21 +269,21 @@ func processEnvChg(sess *tdsSession) { sess.tranid = 0 case envEnlistDTC: // currently ignored - // old value + // new value, should be 0 if _, err = readBVarChar(r); err != nil { badStreamPanic(err) } - // new value, should be 0 + // old value if _, err = readBVarChar(r); err != nil { badStreamPanic(err) } case envDefectTran: // currently ignored - // old value, should be 0 + // new value if _, err = readBVarChar(r); err != nil { badStreamPanic(err) } - // new value + // old value, should be 0 if _, err = readBVarChar(r); err != nil { badStreamPanic(err) } @@ -358,6 +401,7 @@ func parseOrder(r *tdsBuffer) (res orderStruct) { return res } +// https://msdn.microsoft.com/en-us/library/dd340421.aspx func parseDone(r *tdsBuffer) (res doneStruct) { res.Status = r.uint16() res.CurCmd = r.uint16() @@ -365,6 +409,7 @@ func parseDone(r *tdsBuffer) (res doneStruct) { return res } +// https://msdn.microsoft.com/en-us/library/dd340553.aspx func parseDoneInProc(r *tdsBuffer) (res doneInProcStruct) { res.Status = r.uint16() res.CurCmd = r.uint16() @@ -473,26 +518,57 @@ func parseInfo(r *tdsBuffer) (res Error) { return } -func processResponse(sess *tdsSession, ch chan tokenStruct) { +// https://msdn.microsoft.com/en-us/library/dd303881.aspx +func parseReturnValue(r *tdsBuffer) (nv namedValue) { + /* + ParamOrdinal + ParamName + Status + UserType + Flags + TypeInfo + CryptoMetadata + Value + */ + r.uint16() + nv.Name = r.BVarChar() + r.byte() + r.uint32() // UserType (uint16 prior to 7.2) + r.uint16() + ti := readTypeInfo(r) + nv.Value = ti.Reader(&ti, r) + return +} + +func processSingleResponse(sess *tdsSession, ch chan tokenStruct, outs map[string]interface{}) { defer func() { if err := recover(); err != nil { + if sess.logFlags&logErrors != 0 { + sess.log.Printf("ERROR: Intercepted panic %v", err) + } ch <- err } close(ch) }() + packet_type, err := sess.buf.BeginRead() if err != nil { + if sess.logFlags&logErrors != 0 { + sess.log.Printf("ERROR: BeginRead failed %v", err) + } ch <- err return } if packet_type != packReply { - badStreamPanicf("invalid response packet type, expected REPLY, actual: %d", packet_type) + badStreamPanic(fmt.Errorf("unexpected packet type in reply: got %v, expected %v", packet_type, packReply)) } var columns []columnStruct - var lastError Error - var failed bool + errs := make([]Error, 0, 5) for { - token := sess.buf.byte() + token := token(sess.buf.byte()) + if sess.logFlags&logDebug != 0 { + sess.log.Printf("got token %v", token) + } switch token { case tokenSSPI: ch <- parseSSPIMsg(sess.buf) @@ -514,18 +590,17 @@ func processResponse(sess *tdsSession, ch chan tokenStruct) { ch <- done case tokenDone, tokenDoneProc: done := parseDone(sess.buf) - if sess.logFlags&logRows != 0 && done.Status&doneCount != 0 { - sess.log.Printf("(%d row(s) affected)\n", done.RowCount) - } - if done.Status&doneError != 0 || failed { - ch <- lastError - return + done.errors = errs + if sess.logFlags&logDebug != 0 { + sess.log.Printf("got DONE or DONEPROC status=%d", done.Status) } if done.Status&doneSrvError != 0 { - lastError.Message = "Server Error" - ch <- lastError + ch <- errors.New("SQL Server had internal error") return } + if sess.logFlags&logRows != 0 && done.Status&doneCount != 0 { + sess.log.Printf("(%d row(s) affected)\n", done.RowCount) + } ch <- done if done.Status&doneMore == 0 { return @@ -544,18 +619,210 @@ func processResponse(sess *tdsSession, ch chan tokenStruct) { case tokenEnvChange: processEnvChg(sess) case tokenError: - lastError = parseError72(sess.buf) - failed = true + err := parseError72(sess.buf) + if sess.logFlags&logDebug != 0 { + sess.log.Printf("got ERROR %d %s", err.Number, err.Message) + } + errs = append(errs, err) if sess.logFlags&logErrors != 0 { - sess.log.Println(lastError.Message) + sess.log.Println(err.Message) } case tokenInfo: info := parseInfo(sess.buf) + if sess.logFlags&logDebug != 0 { + sess.log.Printf("got INFO %d %s", info.Number, info.Message) + } if sess.logFlags&logMessages != 0 { sess.log.Println(info.Message) } + case tokenReturnValue: + nv := parseReturnValue(sess.buf) + if len(nv.Name) > 0 { + name := nv.Name[1:] // Remove the leading "@". + if ov, has := outs[name]; has { + err = scanIntoOut(nv.Value, ov) + if err != nil { + fmt.Println("scan error", err) + ch <- err + } + } + } + default: + badStreamPanic(fmt.Errorf("unknown token type returned: %v", token)) + } + } +} + +func scanIntoOut(fromServer, scanInto interface{}) error { + switch fs := fromServer.(type) { + case int64: + switch si := scanInto.(type) { + case *int64: + *si = fs default: - badStreamPanicf("Unknown token type: %d", token) + return fmt.Errorf("unsupported scan into type %[1]T for server type %[2]T", scanInto, fromServer) + } + return nil + case string: + switch si := scanInto.(type) { + case *string: + *si = fs + default: + return fmt.Errorf("unsupported scan into type %[1]T for server type %[2]T", scanInto, fromServer) + } + return nil + } + return fmt.Errorf("unsupported type from server %[1]T=%[1]v", fromServer) +} + +type parseRespIter byte + +const ( + parseRespIterContinue parseRespIter = iota // Continue parsing current token. + parseRespIterNext // Fetch the next token. + parseRespIterDone // Done with parsing the response. +) + +type parseRespState byte + +const ( + parseRespStateNormal parseRespState = iota // Normal response state. + parseRespStateCancel // Query is canceled, wait for server to confirm. + parseRespStateClosing // Waiting for tokens to come through. +) + +type parseResp struct { + sess *tdsSession + ctxDone <-chan struct{} + state parseRespState + cancelError error +} + +func (ts *parseResp) sendAttention(ch chan tokenStruct) parseRespIter { + if err := sendAttention(ts.sess.buf); err != nil { + ts.dlogf("failed to send attention signal %v", err) + ch <- err + return parseRespIterDone + } + ts.state = parseRespStateCancel + return parseRespIterContinue +} + +func (ts *parseResp) dlog(msg string) { + if ts.sess.logFlags&logDebug != 0 { + ts.sess.log.Println(msg) + } +} +func (ts *parseResp) dlogf(f string, v ...interface{}) { + if ts.sess.logFlags&logDebug != 0 { + ts.sess.log.Printf(f, v...) + } +} + +func (ts *parseResp) iter(ctx context.Context, ch chan tokenStruct, tokChan chan tokenStruct) parseRespIter { + switch ts.state { + default: + panic("unknown state") + case parseRespStateNormal: + select { + case tok, ok := <-tokChan: + if !ok { + ts.dlog("response finished") + return parseRespIterDone + } + if err, ok := tok.(net.Error); ok && err.Timeout() { + ts.cancelError = err + ts.dlog("got timeout error, sending attention signal to server") + return ts.sendAttention(ch) + } + // Pass the token along. + ch <- tok + return parseRespIterContinue + + case <-ts.ctxDone: + ts.ctxDone = nil + ts.dlog("got cancel message, sending attention signal to server") + return ts.sendAttention(ch) + } + case parseRespStateCancel: // Read all responses until a DONE or error is received.Auth + select { + case tok, ok := <-tokChan: + if !ok { + ts.dlog("response finished but waiting for attention ack") + return parseRespIterNext + } + switch tok := tok.(type) { + default: + // Ignore all other tokens while waiting. + // The TDS spec says other tokens may arrive after an attention + // signal is sent. Ignore these tokens and continue looking for + // a DONE with attention confirm mark. + case doneStruct: + if tok.Status&doneAttn != 0 { + ts.dlog("got cancellation confirmation from server") + if ts.cancelError != nil { + ch <- ts.cancelError + ts.cancelError = nil + } else { + ch <- ctx.Err() + } + return parseRespIterDone + } + + // If an error happens during cancel, pass it along and just stop. + // We are uncertain to receive more tokens. + case error: + ch <- tok + ts.state = parseRespStateClosing + } + return parseRespIterContinue + case <-ts.ctxDone: + ts.ctxDone = nil + ts.state = parseRespStateClosing + return parseRespIterContinue + } + case parseRespStateClosing: // Wait for current token chan to close. + if _, ok := <-tokChan; !ok { + ts.dlog("response finished") + return parseRespIterDone + } + return parseRespIterContinue + } +} + +func processResponse(ctx context.Context, sess *tdsSession, ch chan tokenStruct, outs map[string]interface{}) { + ts := &parseResp{ + sess: sess, + ctxDone: ctx.Done(), + } + defer func() { + // Ensure any remaining error is piped through + // or the query may look like it executed when it actually failed. + if ts.cancelError != nil { + ch <- ts.cancelError + ts.cancelError = nil + } + close(ch) + }() + + // Loop over multiple responses. + for { + ts.dlog("initiating response reading") + + tokChan := make(chan tokenStruct) + go processSingleResponse(sess, tokChan, outs) + + // Loop over multiple tokens in response. + tokensLoop: + for { + switch ts.iter(ctx, ch, tokChan) { + case parseRespIterContinue: + // Nothing, continue to next token. + case parseRespIterNext: + break tokensLoop + case parseRespIterDone: + return + } } } } |