summaryrefslogtreecommitdiffstats
path: root/vendor/github.com/denisenkom/go-mssqldb/bulkcopy.go
diff options
context:
space:
mode:
authorAntoine GIRARD <sapk@users.noreply.github.com>2019-06-30 05:28:17 +0200
committerLunny Xiao <xiaolunwen@gmail.com>2019-06-30 11:28:17 +0800
commit1e46eedce721e22e3e1e94eec6e0afccd36147c1 (patch)
tree171585bf8d7345134a900868903d83de0a22f22a /vendor/github.com/denisenkom/go-mssqldb/bulkcopy.go
parentaeb8f7aad8b71909fdc81e34c29a70a0bb700f33 (diff)
downloadgitea-1e46eedce721e22e3e1e94eec6e0afccd36147c1.tar.gz
gitea-1e46eedce721e22e3e1e94eec6e0afccd36147c1.zip
update mssql drive to last working version 20180314172330-6a30f4e59a44 (#7306)
Diffstat (limited to 'vendor/github.com/denisenkom/go-mssqldb/bulkcopy.go')
-rw-r--r--vendor/github.com/denisenkom/go-mssqldb/bulkcopy.go616
1 files changed, 616 insertions, 0 deletions
diff --git a/vendor/github.com/denisenkom/go-mssqldb/bulkcopy.go b/vendor/github.com/denisenkom/go-mssqldb/bulkcopy.go
new file mode 100644
index 0000000000..8c0a4e0a2a
--- /dev/null
+++ b/vendor/github.com/denisenkom/go-mssqldb/bulkcopy.go
@@ -0,0 +1,616 @@
+package mssql
+
+import (
+ "bytes"
+ "context"
+ "encoding/binary"
+ "fmt"
+ "math"
+ "reflect"
+ "strconv"
+ "strings"
+ "time"
+)
+
+type Bulk struct {
+ cn *Conn
+ metadata []columnStruct
+ bulkColumns []columnStruct
+ columnsName []string
+ tablename string
+ numRows int
+
+ headerSent bool
+ Options BulkOptions
+ Debug bool
+}
+type BulkOptions struct {
+ CheckConstraints bool
+ FireTriggers bool
+ KeepNulls bool
+ KilobytesPerBatch int
+ RowsPerBatch int
+ Order []string
+ Tablock bool
+}
+
+type DataValue interface{}
+
+func (cn *Conn) CreateBulk(table string, columns []string) (_ *Bulk) {
+ b := Bulk{cn: cn, tablename: table, headerSent: false, columnsName: columns}
+ b.Debug = false
+ return &b
+}
+
+func (b *Bulk) sendBulkCommand() (err error) {
+ //get table columns info
+ err = b.getMetadata()
+ if err != nil {
+ return err
+ }
+
+ //match the columns
+ for _, colname := range b.columnsName {
+ var bulkCol *columnStruct
+
+ for _, m := range b.metadata {
+ if m.ColName == colname {
+ bulkCol = &m
+ break
+ }
+ }
+ if bulkCol != nil {
+
+ if bulkCol.ti.TypeId == typeUdt {
+ //send udt as binary
+ bulkCol.ti.TypeId = typeBigVarBin
+ }
+ b.bulkColumns = append(b.bulkColumns, *bulkCol)
+ b.dlogf("Adding column %s %s %#x", colname, bulkCol.ColName, bulkCol.ti.TypeId)
+ } else {
+ return fmt.Errorf("Column %s does not exist in destination table %s", colname, b.tablename)
+ }
+ }
+
+ //create the bulk command
+
+ //columns definitions
+ var col_defs bytes.Buffer
+ for i, col := range b.bulkColumns {
+ if i != 0 {
+ col_defs.WriteString(", ")
+ }
+ col_defs.WriteString("[" + col.ColName + "] " + makeDecl(col.ti))
+ }
+
+ //options
+ var with_opts []string
+
+ if b.Options.CheckConstraints {
+ with_opts = append(with_opts, "CHECK_CONSTRAINTS")
+ }
+ if b.Options.FireTriggers {
+ with_opts = append(with_opts, "FIRE_TRIGGERS")
+ }
+ if b.Options.KeepNulls {
+ with_opts = append(with_opts, "KEEP_NULLS")
+ }
+ if b.Options.KilobytesPerBatch > 0 {
+ with_opts = append(with_opts, fmt.Sprintf("KILOBYTES_PER_BATCH = %d", b.Options.KilobytesPerBatch))
+ }
+ if b.Options.RowsPerBatch > 0 {
+ with_opts = append(with_opts, fmt.Sprintf("ROWS_PER_BATCH = %d", b.Options.RowsPerBatch))
+ }
+ if len(b.Options.Order) > 0 {
+ with_opts = append(with_opts, fmt.Sprintf("ORDER(%s)", strings.Join(b.Options.Order, ",")))
+ }
+ if b.Options.Tablock {
+ with_opts = append(with_opts, "TABLOCK")
+ }
+ var with_part string
+ if len(with_opts) > 0 {
+ with_part = fmt.Sprintf("WITH (%s)", strings.Join(with_opts, ","))
+ }
+
+ query := fmt.Sprintf("INSERT BULK %s (%s) %s", b.tablename, col_defs.String(), with_part)
+
+ stmt, err := b.cn.Prepare(query)
+ if err != nil {
+ return fmt.Errorf("Prepare failed: %s", err.Error())
+ }
+ b.dlogf(query)
+
+ _, err = stmt.Exec(nil)
+ if err != nil {
+ return err
+ }
+
+ b.headerSent = true
+
+ var buf = b.cn.sess.buf
+ buf.BeginPacket(packBulkLoadBCP)
+
+ // send the columns metadata
+ columnMetadata := b.createColMetadata()
+ _, err = buf.Write(columnMetadata)
+
+ return
+}
+
+// AddRow immediately writes the row to the destination table.
+// The arguments are the row values in the order they were specified.
+func (b *Bulk) AddRow(row []interface{}) (err error) {
+ if !b.headerSent {
+ err = b.sendBulkCommand()
+ if err != nil {
+ return
+ }
+ }
+
+ if len(row) != len(b.bulkColumns) {
+ return fmt.Errorf("Row does not have the same number of columns than the destination table %d %d",
+ len(row), len(b.bulkColumns))
+ }
+
+ bytes, err := b.makeRowData(row)
+ if err != nil {
+ return
+ }
+
+ _, err = b.cn.sess.buf.Write(bytes)
+ if err != nil {
+ return
+ }
+
+ b.numRows = b.numRows + 1
+ return
+}
+
+func (b *Bulk) makeRowData(row []interface{}) ([]byte, error) {
+ buf := new(bytes.Buffer)
+ buf.WriteByte(byte(tokenRow))
+
+ var logcol bytes.Buffer
+ for i, col := range b.bulkColumns {
+
+ if b.Debug {
+ logcol.WriteString(fmt.Sprintf(" col[%d]='%v' ", i, row[i]))
+ }
+ param, err := b.makeParam(row[i], col)
+ if err != nil {
+ return nil, fmt.Errorf("bulkcopy: %s", err.Error())
+ }
+
+ if col.ti.Writer == nil {
+ return nil, fmt.Errorf("no writer for column: %s, TypeId: %#x",
+ col.ColName, col.ti.TypeId)
+ }
+ err = col.ti.Writer(buf, param.ti, param.buffer)
+ if err != nil {
+ return nil, fmt.Errorf("bulkcopy: %s", err.Error())
+ }
+ }
+
+ b.dlogf("row[%d] %s\n", b.numRows, logcol.String())
+
+ return buf.Bytes(), nil
+}
+
+func (b *Bulk) Done() (rowcount int64, err error) {
+ if b.headerSent == false {
+ //no rows had been sent
+ return 0, nil
+ }
+ var buf = b.cn.sess.buf
+ buf.WriteByte(byte(tokenDone))
+
+ binary.Write(buf, binary.LittleEndian, uint16(doneFinal))
+ binary.Write(buf, binary.LittleEndian, uint16(0)) // curcmd
+
+ if b.cn.sess.loginAck.TDSVersion >= verTDS72 {
+ binary.Write(buf, binary.LittleEndian, uint64(0)) //rowcount 0
+ } else {
+ binary.Write(buf, binary.LittleEndian, uint32(0)) //rowcount 0
+ }
+
+ buf.FinishPacket()
+
+ tokchan := make(chan tokenStruct, 5)
+ go processResponse(context.Background(), b.cn.sess, tokchan, nil)
+
+ var rowCount int64
+ for token := range tokchan {
+ switch token := token.(type) {
+ case doneStruct:
+ if token.Status&doneCount != 0 {
+ rowCount = int64(token.RowCount)
+ }
+ if token.isError() {
+ return 0, token.getError()
+ }
+ case error:
+ return 0, b.cn.checkBadConn(token)
+ }
+ }
+ return rowCount, nil
+}
+
+func (b *Bulk) createColMetadata() []byte {
+ buf := new(bytes.Buffer)
+ buf.WriteByte(byte(tokenColMetadata)) // token
+ binary.Write(buf, binary.LittleEndian, uint16(len(b.bulkColumns))) // column count
+
+ for i, col := range b.bulkColumns {
+
+ if b.cn.sess.loginAck.TDSVersion >= verTDS72 {
+ binary.Write(buf, binary.LittleEndian, uint32(col.UserType)) // usertype, always 0?
+ } else {
+ binary.Write(buf, binary.LittleEndian, uint16(col.UserType))
+ }
+ binary.Write(buf, binary.LittleEndian, uint16(col.Flags))
+
+ writeTypeInfo(buf, &b.bulkColumns[i].ti)
+
+ if col.ti.TypeId == typeNText ||
+ col.ti.TypeId == typeText ||
+ col.ti.TypeId == typeImage {
+
+ tablename_ucs2 := str2ucs2(b.tablename)
+ binary.Write(buf, binary.LittleEndian, uint16(len(tablename_ucs2)/2))
+ buf.Write(tablename_ucs2)
+ }
+ colname_ucs2 := str2ucs2(col.ColName)
+ buf.WriteByte(uint8(len(colname_ucs2) / 2))
+ buf.Write(colname_ucs2)
+ }
+
+ return buf.Bytes()
+}
+
+func (b *Bulk) getMetadata() (err error) {
+ stmt, err := b.cn.Prepare("SET FMTONLY ON")
+ if err != nil {
+ return
+ }
+
+ _, err = stmt.Exec(nil)
+ if err != nil {
+ return
+ }
+
+ //get columns info
+ stmt, err = b.cn.Prepare(fmt.Sprintf("select * from %s SET FMTONLY OFF", b.tablename))
+ if err != nil {
+ return
+ }
+ stmt2 := stmt.(*Stmt)
+ cols, err := stmt2.QueryMeta()
+ if err != nil {
+ return fmt.Errorf("get columns info failed: %v", err.Error())
+ }
+ b.metadata = cols
+
+ if b.Debug {
+ for _, col := range b.metadata {
+ b.dlogf("col: %s typeId: %#x size: %d scale: %d prec: %d flags: %d lcid: %#x\n",
+ col.ColName, col.ti.TypeId, col.ti.Size, col.ti.Scale, col.ti.Prec,
+ col.Flags, col.ti.Collation.LcidAndFlags)
+ }
+ }
+
+ return nil
+}
+
+// QueryMeta is almost the same as mssql.Stmt.Query, but returns all the columns info.
+func (s *Stmt) QueryMeta() (cols []columnStruct, err error) {
+ if err = s.sendQuery(nil); err != nil {
+ return
+ }
+ tokchan := make(chan tokenStruct, 5)
+ go processResponse(context.Background(), s.c.sess, tokchan, s.c.outs)
+ s.c.clearOuts()
+loop:
+ for tok := range tokchan {
+ switch token := tok.(type) {
+ case doneStruct:
+ break loop
+ case []columnStruct:
+ cols = token
+ break loop
+ case error:
+ return nil, s.c.checkBadConn(token)
+ }
+ }
+ return cols, nil
+}
+
+func (b *Bulk) makeParam(val DataValue, col columnStruct) (res Param, err error) {
+ res.ti.Size = col.ti.Size
+ res.ti.TypeId = col.ti.TypeId
+
+ if val == nil {
+ res.ti.Size = 0
+ return
+ }
+
+ switch col.ti.TypeId {
+
+ case typeInt1, typeInt2, typeInt4, typeInt8, typeIntN:
+ var intvalue int64
+
+ switch val := val.(type) {
+ case int:
+ intvalue = int64(val)
+ case int32:
+ intvalue = int64(val)
+ case int64:
+ intvalue = val
+ default:
+ err = fmt.Errorf("mssql: invalid type for int column")
+ return
+ }
+
+ res.buffer = make([]byte, res.ti.Size)
+ if col.ti.Size == 1 {
+ res.buffer[0] = byte(intvalue)
+ } else if col.ti.Size == 2 {
+ binary.LittleEndian.PutUint16(res.buffer, uint16(intvalue))
+ } else if col.ti.Size == 4 {
+ binary.LittleEndian.PutUint32(res.buffer, uint32(intvalue))
+ } else if col.ti.Size == 8 {
+ binary.LittleEndian.PutUint64(res.buffer, uint64(intvalue))
+ }
+ case typeFlt4, typeFlt8, typeFltN:
+ var floatvalue float64
+
+ switch val := val.(type) {
+ case float32:
+ floatvalue = float64(val)
+ case float64:
+ floatvalue = val
+ case int:
+ floatvalue = float64(val)
+ case int64:
+ floatvalue = float64(val)
+ default:
+ err = fmt.Errorf("mssql: invalid type for float column: %s", val)
+ return
+ }
+
+ if col.ti.Size == 4 {
+ res.buffer = make([]byte, 4)
+ binary.LittleEndian.PutUint32(res.buffer, math.Float32bits(float32(floatvalue)))
+ } else if col.ti.Size == 8 {
+ res.buffer = make([]byte, 8)
+ binary.LittleEndian.PutUint64(res.buffer, math.Float64bits(floatvalue))
+ }
+ case typeNVarChar, typeNText, typeNChar:
+
+ switch val := val.(type) {
+ case string:
+ res.buffer = str2ucs2(val)
+ case []byte:
+ res.buffer = val
+ default:
+ err = fmt.Errorf("mssql: invalid type for nvarchar column: %s", val)
+ return
+ }
+ res.ti.Size = len(res.buffer)
+
+ case typeVarChar, typeBigVarChar, typeText, typeChar, typeBigChar:
+ switch val := val.(type) {
+ case string:
+ res.buffer = []byte(val)
+ case []byte:
+ res.buffer = val
+ default:
+ err = fmt.Errorf("mssql: invalid type for varchar column: %s", val)
+ return
+ }
+ res.ti.Size = len(res.buffer)
+
+ case typeBit, typeBitN:
+ if reflect.TypeOf(val).Kind() != reflect.Bool {
+ err = fmt.Errorf("mssql: invalid type for bit column: %s", val)
+ return
+ }
+ res.ti.TypeId = typeBitN
+ res.ti.Size = 1
+ res.buffer = make([]byte, 1)
+ if val.(bool) {
+ res.buffer[0] = 1
+ }
+
+ case typeDateTime2N, typeDateTimeOffsetN:
+ switch val := val.(type) {
+ case time.Time:
+ days, ns := dateTime2(val)
+ ns /= int64(math.Pow10(int(col.ti.Scale)*-1) * 1000000000)
+
+ var data = make([]byte, 5)
+
+ data[0] = byte(ns)
+ data[1] = byte(ns >> 8)
+ data[2] = byte(ns >> 16)
+ data[3] = byte(ns >> 24)
+ data[4] = byte(ns >> 32)
+
+ if col.ti.Scale <= 2 {
+ res.ti.Size = 6
+ } else if col.ti.Scale <= 4 {
+ res.ti.Size = 7
+ } else {
+ res.ti.Size = 8
+ }
+ var buf []byte
+ buf = make([]byte, res.ti.Size)
+ copy(buf, data[0:res.ti.Size-3])
+
+ buf[res.ti.Size-3] = byte(days)
+ buf[res.ti.Size-2] = byte(days >> 8)
+ buf[res.ti.Size-1] = byte(days >> 16)
+
+ if col.ti.TypeId == typeDateTimeOffsetN {
+ _, offset := val.Zone()
+ var offsetMinute = uint16(offset / 60)
+ buf = append(buf, byte(offsetMinute))
+ buf = append(buf, byte(offsetMinute>>8))
+ res.ti.Size = res.ti.Size + 2
+ }
+
+ res.buffer = buf
+
+ default:
+ err = fmt.Errorf("mssql: invalid type for datetime2 column: %s", val)
+ return
+ }
+ case typeDateN:
+ switch val := val.(type) {
+ case time.Time:
+ days, _ := dateTime2(val)
+
+ res.ti.Size = 3
+ res.buffer = make([]byte, 3)
+ res.buffer[0] = byte(days)
+ res.buffer[1] = byte(days >> 8)
+ res.buffer[2] = byte(days >> 16)
+ default:
+ err = fmt.Errorf("mssql: invalid type for date column: %s", val)
+ return
+ }
+ case typeDateTime, typeDateTimeN, typeDateTim4:
+ switch val := val.(type) {
+ case time.Time:
+ if col.ti.Size == 4 {
+ res.ti.Size = 4
+ res.buffer = make([]byte, 4)
+
+ ref := time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC)
+ dur := val.Sub(ref)
+ days := dur / (24 * time.Hour)
+ if days < 0 {
+ err = fmt.Errorf("mssql: Date %s is out of range", val)
+ return
+ }
+ mins := val.Hour()*60 + val.Minute()
+
+ binary.LittleEndian.PutUint16(res.buffer[0:2], uint16(days))
+ binary.LittleEndian.PutUint16(res.buffer[2:4], uint16(mins))
+ } else if col.ti.Size == 8 {
+ res.ti.Size = 8
+ res.buffer = make([]byte, 8)
+
+ days := divFloor(val.Unix(), 24*60*60)
+ //25567 - number of days since Jan 1 1900 UTC to Jan 1 1970
+ days = days + 25567
+ tm := (val.Hour()*60*60+val.Minute()*60+val.Second())*300 + int(val.Nanosecond()/10000000*3)
+
+ binary.LittleEndian.PutUint32(res.buffer[0:4], uint32(days))
+ binary.LittleEndian.PutUint32(res.buffer[4:8], uint32(tm))
+ } else {
+ err = fmt.Errorf("mssql: invalid size of column")
+ }
+
+ default:
+ err = fmt.Errorf("mssql: invalid type for datetime column: %s", val)
+ }
+
+ // case typeMoney, typeMoney4, typeMoneyN:
+ case typeDecimal, typeDecimalN, typeNumeric, typeNumericN:
+ var value float64
+ switch v := val.(type) {
+ case int:
+ value = float64(v)
+ case int8:
+ value = float64(v)
+ case int16:
+ value = float64(v)
+ case int32:
+ value = float64(v)
+ case int64:
+ value = float64(v)
+ case float32:
+ value = float64(v)
+ case float64:
+ value = v
+ case string:
+ if value, err = strconv.ParseFloat(v, 64); err != nil {
+ return res, fmt.Errorf("bulk: unable to convert string to float: %v", err)
+ }
+ default:
+ return res, fmt.Errorf("unknown value for decimal: %#v", v)
+ }
+
+ perc := col.ti.Prec
+ scale := col.ti.Scale
+ var dec Decimal
+ dec, err = Float64ToDecimalScale(value, scale)
+ if err != nil {
+ return res, err
+ }
+ dec.prec = perc
+
+ var length byte
+ switch {
+ case perc <= 9:
+ length = 4
+ case perc <= 19:
+ length = 8
+ case perc <= 28:
+ length = 12
+ default:
+ length = 16
+ }
+
+ buf := make([]byte, length+1)
+ // first byte length written by typeInfo.writer
+ res.ti.Size = int(length) + 1
+ // second byte sign
+ if value < 0 {
+ buf[0] = 0
+ } else {
+ buf[0] = 1
+ }
+
+ ub := dec.UnscaledBytes()
+ l := len(ub)
+ if l > int(length) {
+ err = fmt.Errorf("decimal out of range: %s", dec)
+ return res, err
+ }
+ // reverse the bytes
+ for i, j := 1, l-1; j >= 0; i, j = i+1, j-1 {
+ buf[i] = ub[j]
+ }
+ res.buffer = buf
+ case typeBigVarBin:
+ switch val := val.(type) {
+ case []byte:
+ res.ti.Size = len(val)
+ res.buffer = val
+ default:
+ err = fmt.Errorf("mssql: invalid type for Binary column: %s", val)
+ return
+ }
+ case typeGuid:
+ switch val := val.(type) {
+ case []byte:
+ res.ti.Size = len(val)
+ res.buffer = val
+ default:
+ err = fmt.Errorf("mssql: invalid type for Guid column: %s", val)
+ return
+ }
+
+ default:
+ err = fmt.Errorf("mssql: type %x not implemented", col.ti.TypeId)
+ }
+ return
+
+}
+
+func (b *Bulk) dlogf(format string, v ...interface{}) {
+ if b.Debug {
+ b.cn.sess.log.Printf(format, v...)
+ }
+}