diff options
Diffstat (limited to 'vendor/github.com/denisenkom/go-mssqldb/bulkcopy.go')
-rw-r--r-- | vendor/github.com/denisenkom/go-mssqldb/bulkcopy.go | 164 |
1 files changed, 51 insertions, 113 deletions
diff --git a/vendor/github.com/denisenkom/go-mssqldb/bulkcopy.go b/vendor/github.com/denisenkom/go-mssqldb/bulkcopy.go index 8c0a4e0a2a..3b319af893 100644 --- a/vendor/github.com/denisenkom/go-mssqldb/bulkcopy.go +++ b/vendor/github.com/denisenkom/go-mssqldb/bulkcopy.go @@ -13,6 +13,12 @@ import ( ) type Bulk struct { + // ctx is used only for AddRow and Done methods. + // This could be removed if AddRow and Done accepted + // a ctx field as well, which is available with the + // database/sql call. + ctx context.Context + cn *Conn metadata []columnStruct bulkColumns []columnStruct @@ -37,14 +43,20 @@ type BulkOptions struct { type DataValue interface{} func (cn *Conn) CreateBulk(table string, columns []string) (_ *Bulk) { - b := Bulk{cn: cn, tablename: table, headerSent: false, columnsName: columns} + b := Bulk{ctx: context.Background(), cn: cn, tablename: table, headerSent: false, columnsName: columns} + b.Debug = false + return &b +} + +func (cn *Conn) CreateBulkContext(ctx context.Context, table string, columns []string) (_ *Bulk) { + b := Bulk{ctx: ctx, cn: cn, tablename: table, headerSent: false, columnsName: columns} b.Debug = false return &b } -func (b *Bulk) sendBulkCommand() (err error) { +func (b *Bulk) sendBulkCommand(ctx context.Context) (err error) { //get table columns info - err = b.getMetadata() + err = b.getMetadata(ctx) if err != nil { return err } @@ -114,13 +126,13 @@ func (b *Bulk) sendBulkCommand() (err error) { query := fmt.Sprintf("INSERT BULK %s (%s) %s", b.tablename, col_defs.String(), with_part) - stmt, err := b.cn.Prepare(query) + stmt, err := b.cn.PrepareContext(ctx, query) if err != nil { return fmt.Errorf("Prepare failed: %s", err.Error()) } b.dlogf(query) - _, err = stmt.Exec(nil) + _, err = stmt.(*Stmt).ExecContext(ctx, nil) if err != nil { return err } @@ -128,9 +140,9 @@ func (b *Bulk) sendBulkCommand() (err error) { b.headerSent = true var buf = b.cn.sess.buf - buf.BeginPacket(packBulkLoadBCP) + buf.BeginPacket(packBulkLoadBCP, false) - // send the columns metadata + // Send the columns metadata. columnMetadata := b.createColMetadata() _, err = buf.Write(columnMetadata) @@ -141,7 +153,7 @@ func (b *Bulk) sendBulkCommand() (err error) { // The arguments are the row values in the order they were specified. func (b *Bulk) AddRow(row []interface{}) (err error) { if !b.headerSent { - err = b.sendBulkCommand() + err = b.sendBulkCommand(b.ctx) if err != nil { return } @@ -216,7 +228,7 @@ func (b *Bulk) Done() (rowcount int64, err error) { buf.FinishPacket() tokchan := make(chan tokenStruct, 5) - go processResponse(context.Background(), b.cn.sess, tokchan, nil) + go processResponse(b.ctx, b.cn.sess, tokchan, nil) var rowCount int64 for token := range tokchan { @@ -267,28 +279,27 @@ func (b *Bulk) createColMetadata() []byte { return buf.Bytes() } -func (b *Bulk) getMetadata() (err error) { - stmt, err := b.cn.Prepare("SET FMTONLY ON") +func (b *Bulk) getMetadata(ctx context.Context) (err error) { + stmt, err := b.cn.prepareContext(ctx, "SET FMTONLY ON") if err != nil { return } - _, err = stmt.Exec(nil) + _, err = stmt.ExecContext(ctx, nil) if err != nil { return } - //get columns info - stmt, err = b.cn.Prepare(fmt.Sprintf("select * from %s SET FMTONLY OFF", b.tablename)) + // Get columns info. + stmt, err = b.cn.prepareContext(ctx, fmt.Sprintf("select * from %s SET FMTONLY OFF", b.tablename)) if err != nil { return } - stmt2 := stmt.(*Stmt) - cols, err := stmt2.QueryMeta() + rows, err := stmt.QueryContext(ctx, nil) if err != nil { - return fmt.Errorf("get columns info failed: %v", err.Error()) + return fmt.Errorf("get columns info failed: %v", err) } - b.metadata = cols + b.metadata = rows.(*Rows).cols if b.Debug { for _, col := range b.metadata { @@ -298,33 +309,10 @@ func (b *Bulk) getMetadata() (err error) { } } - return nil -} - -// QueryMeta is almost the same as mssql.Stmt.Query, but returns all the columns info. -func (s *Stmt) QueryMeta() (cols []columnStruct, err error) { - if err = s.sendQuery(nil); err != nil { - return - } - tokchan := make(chan tokenStruct, 5) - go processResponse(context.Background(), s.c.sess, tokchan, s.c.outs) - s.c.clearOuts() -loop: - for tok := range tokchan { - switch token := tok.(type) { - case doneStruct: - break loop - case []columnStruct: - cols = token - break loop - case error: - return nil, s.c.checkBadConn(token) - } - } - return cols, nil + return rows.Close() } -func (b *Bulk) makeParam(val DataValue, col columnStruct) (res Param, err error) { +func (b *Bulk) makeParam(val DataValue, col columnStruct) (res param, err error) { res.ti.Size = col.ti.Size res.ti.TypeId = col.ti.TypeId @@ -420,60 +408,30 @@ func (b *Bulk) makeParam(val DataValue, col columnStruct) (res Param, err error) if val.(bool) { res.buffer[0] = 1 } - - case typeDateTime2N, typeDateTimeOffsetN: + case typeDateTime2N: switch val := val.(type) { case time.Time: - days, ns := dateTime2(val) - ns /= int64(math.Pow10(int(col.ti.Scale)*-1) * 1000000000) - - var data = make([]byte, 5) - - data[0] = byte(ns) - data[1] = byte(ns >> 8) - data[2] = byte(ns >> 16) - data[3] = byte(ns >> 24) - data[4] = byte(ns >> 32) - - if col.ti.Scale <= 2 { - res.ti.Size = 6 - } else if col.ti.Scale <= 4 { - res.ti.Size = 7 - } else { - res.ti.Size = 8 - } - var buf []byte - buf = make([]byte, res.ti.Size) - copy(buf, data[0:res.ti.Size-3]) - - buf[res.ti.Size-3] = byte(days) - buf[res.ti.Size-2] = byte(days >> 8) - buf[res.ti.Size-1] = byte(days >> 16) - - if col.ti.TypeId == typeDateTimeOffsetN { - _, offset := val.Zone() - var offsetMinute = uint16(offset / 60) - buf = append(buf, byte(offsetMinute)) - buf = append(buf, byte(offsetMinute>>8)) - res.ti.Size = res.ti.Size + 2 - } - - res.buffer = buf - + res.buffer = encodeDateTime2(val, int(col.ti.Scale)) + res.ti.Size = len(res.buffer) default: err = fmt.Errorf("mssql: invalid type for datetime2 column: %s", val) return } - case typeDateN: + case typeDateTimeOffsetN: switch val := val.(type) { case time.Time: - days, _ := dateTime2(val) + res.buffer = encodeDateTimeOffset(val, int(res.ti.Scale)) + res.ti.Size = len(res.buffer) - res.ti.Size = 3 - res.buffer = make([]byte, 3) - res.buffer[0] = byte(days) - res.buffer[1] = byte(days >> 8) - res.buffer[2] = byte(days >> 16) + default: + err = fmt.Errorf("mssql: invalid type for datetimeoffset column: %s", val) + return + } + case typeDateN: + switch val := val.(type) { + case time.Time: + res.buffer = encodeDate(val) + res.ti.Size = len(res.buffer) default: err = fmt.Errorf("mssql: invalid type for date column: %s", val) return @@ -482,31 +440,11 @@ func (b *Bulk) makeParam(val DataValue, col columnStruct) (res Param, err error) switch val := val.(type) { case time.Time: if col.ti.Size == 4 { - res.ti.Size = 4 - res.buffer = make([]byte, 4) - - ref := time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC) - dur := val.Sub(ref) - days := dur / (24 * time.Hour) - if days < 0 { - err = fmt.Errorf("mssql: Date %s is out of range", val) - return - } - mins := val.Hour()*60 + val.Minute() - - binary.LittleEndian.PutUint16(res.buffer[0:2], uint16(days)) - binary.LittleEndian.PutUint16(res.buffer[2:4], uint16(mins)) + res.buffer = encodeDateTim4(val) + res.ti.Size = len(res.buffer) } else if col.ti.Size == 8 { - res.ti.Size = 8 - res.buffer = make([]byte, 8) - - days := divFloor(val.Unix(), 24*60*60) - //25567 - number of days since Jan 1 1900 UTC to Jan 1 1970 - days = days + 25567 - tm := (val.Hour()*60*60+val.Minute()*60+val.Second())*300 + int(val.Nanosecond()/10000000*3) - - binary.LittleEndian.PutUint32(res.buffer[0:4], uint32(days)) - binary.LittleEndian.PutUint32(res.buffer[4:8], uint32(tm)) + res.buffer = encodeDateTime(val) + res.ti.Size = len(res.buffer) } else { err = fmt.Errorf("mssql: invalid size of column") } @@ -583,7 +521,7 @@ func (b *Bulk) makeParam(val DataValue, col columnStruct) (res Param, err error) buf[i] = ub[j] } res.buffer = buf - case typeBigVarBin: + case typeBigVarBin, typeBigBinary: switch val := val.(type) { case []byte: res.ti.Size = len(val) |