summaryrefslogtreecommitdiffstats
path: root/vendor/github.com/denisenkom/go-mssqldb/buf.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/denisenkom/go-mssqldb/buf.go')
-rw-r--r--vendor/github.com/denisenkom/go-mssqldb/buf.go145
1 files changed, 87 insertions, 58 deletions
diff --git a/vendor/github.com/denisenkom/go-mssqldb/buf.go b/vendor/github.com/denisenkom/go-mssqldb/buf.go
index 42e8ae345c..365acd4833 100644
--- a/vendor/github.com/denisenkom/go-mssqldb/buf.go
+++ b/vendor/github.com/denisenkom/go-mssqldb/buf.go
@@ -2,12 +2,14 @@ package mssql
import (
"encoding/binary"
- "io"
"errors"
+ "io"
)
+type packetType uint8
+
type header struct {
- PacketType uint8
+ PacketType packetType
Status uint8
Size uint16
Spid uint16
@@ -15,55 +17,84 @@ type header struct {
Pad uint8
}
+// tdsBuffer reads and writes TDS packets of data to the transport.
+// The write and read buffers are separate to make sending attn signals
+// possible without locks. Currently attn signals are only sent during
+// reads, not writes.
type tdsBuffer struct {
- buf []byte
- pos uint16
- transport io.ReadWriteCloser
- size uint16
+ transport io.ReadWriteCloser
+
+ packetSize int
+
+ // Write fields.
+ wbuf []byte
+ wpos int
+ wPacketSeq byte
+ wPacketType packetType
+
+ // Read fields.
+ rbuf []byte
+ rpos int
+ rsize int
final bool
- packet_type uint8
- afterFirst func()
+ rPacketType packetType
+
+ // afterFirst is assigned to right after tdsBuffer is created and
+ // before the first use. It is executed after the first packet is
+ // written and then removed.
+ afterFirst func()
}
-func newTdsBuffer(bufsize int, transport io.ReadWriteCloser) *tdsBuffer {
- buf := make([]byte, bufsize)
- w := new(tdsBuffer)
- w.buf = buf
- w.pos = 8
- w.transport = transport
- w.size = 0
- return w
+func newTdsBuffer(bufsize uint16, transport io.ReadWriteCloser) *tdsBuffer {
+ return &tdsBuffer{
+ packetSize: int(bufsize),
+ wbuf: make([]byte, 1<<16),
+ rbuf: make([]byte, 1<<16),
+ rpos: 8,
+ transport: transport,
+ }
+}
+
+func (rw *tdsBuffer) ResizeBuffer(packetSize int) {
+ rw.packetSize = packetSize
+}
+
+func (w *tdsBuffer) PackageSize() int {
+ return w.packetSize
}
func (w *tdsBuffer) flush() (err error) {
- // writing packet size
- binary.BigEndian.PutUint16(w.buf[2:], w.pos)
+ // Write packet size.
+ w.wbuf[0] = byte(w.wPacketType)
+ binary.BigEndian.PutUint16(w.wbuf[2:], uint16(w.wpos))
+ w.wbuf[6] = w.wPacketSeq
- // writing packet into underlying transport
- if _, err = w.transport.Write(w.buf[:w.pos]); err != nil {
+ // Write packet into underlying transport.
+ if _, err = w.transport.Write(w.wbuf[:w.wpos]); err != nil {
return err
}
+ // It is possible to create a whole new buffer after a flush.
+ // Useful for debugging. Normally reuse the buffer.
+ // w.wbuf = make([]byte, 1<<16)
- // execute afterFirst hook if it is set
+ // Execute afterFirst hook if it is set.
if w.afterFirst != nil {
w.afterFirst()
w.afterFirst = nil
}
- w.pos = 8
- // packet number
- w.buf[6] += 1
+ w.wpos = 8
+ w.wPacketSeq++
return nil
}
func (w *tdsBuffer) Write(p []byte) (total int, err error) {
- total = 0
for {
- copied := copy(w.buf[w.pos:], p)
- w.pos += uint16(copied)
+ copied := copy(w.wbuf[w.wpos:w.packetSize], p)
+ w.wpos += copied
total += copied
if copied == len(p) {
- break
+ return
}
if err = w.flush(); err != nil {
return
@@ -74,66 +105,64 @@ func (w *tdsBuffer) Write(p []byte) (total int, err error) {
}
func (w *tdsBuffer) WriteByte(b byte) error {
- if int(w.pos) == len(w.buf) {
+ if int(w.wpos) == len(w.wbuf) {
if err := w.flush(); err != nil {
return err
}
}
- w.buf[w.pos] = b
- w.pos += 1
+ w.wbuf[w.wpos] = b
+ w.wpos += 1
return nil
}
-func (w *tdsBuffer) BeginPacket(packet_type byte) {
- w.buf[0] = packet_type
- w.buf[1] = 0 // packet is incomplete
- w.buf[4] = 0 // spid
- w.buf[5] = 0
- w.buf[6] = 1 // packet id
- w.buf[7] = 0 // window
- w.pos = 8
+func (w *tdsBuffer) BeginPacket(packetType packetType) {
+ w.wbuf[1] = 0 // Packet is incomplete. This byte is set again in FinishPacket.
+ w.wpos = 8
+ w.wPacketSeq = 1
+ w.wPacketType = packetType
}
func (w *tdsBuffer) FinishPacket() error {
- w.buf[1] = 1 // this is last packet
+ w.wbuf[1] = 1 // Mark this as the last packet in the message.
return w.flush()
}
+var headerSize = binary.Size(header{})
+
func (r *tdsBuffer) readNextPacket() error {
- header := header{}
+ h := header{}
var err error
- err = binary.Read(r.transport, binary.BigEndian, &header)
+ err = binary.Read(r.transport, binary.BigEndian, &h)
if err != nil {
return err
}
- offset := uint16(binary.Size(header))
- if int(header.Size) > len(r.buf) {
+ if int(h.Size) > len(r.rbuf) {
return errors.New("Invalid packet size, it is longer than buffer size")
}
- if int(offset) > int(header.Size) {
+ if headerSize > int(h.Size) {
return errors.New("Invalid packet size, it is shorter than header size")
}
- _, err = io.ReadFull(r.transport, r.buf[offset:header.Size])
+ _, err = io.ReadFull(r.transport, r.rbuf[headerSize:h.Size])
if err != nil {
return err
}
- r.pos = offset
- r.size = header.Size
- r.final = header.Status != 0
- r.packet_type = header.PacketType
+ r.rpos = headerSize
+ r.rsize = int(h.Size)
+ r.final = h.Status != 0
+ r.rPacketType = h.PacketType
return nil
}
-func (r *tdsBuffer) BeginRead() (uint8, error) {
+func (r *tdsBuffer) BeginRead() (packetType, error) {
err := r.readNextPacket()
if err != nil {
return 0, err
}
- return r.packet_type, nil
+ return r.rPacketType, nil
}
func (r *tdsBuffer) ReadByte() (res byte, err error) {
- if r.pos == r.size {
+ if r.rpos == r.rsize {
if r.final {
return 0, io.EOF
}
@@ -142,8 +171,8 @@ func (r *tdsBuffer) ReadByte() (res byte, err error) {
return 0, err
}
}
- res = r.buf[r.pos]
- r.pos++
+ res = r.rbuf[r.rpos]
+ r.rpos++
return res, nil
}
@@ -207,7 +236,7 @@ func (r *tdsBuffer) readUcs2(numchars int) string {
func (r *tdsBuffer) Read(buf []byte) (copied int, err error) {
copied = 0
err = nil
- if r.pos == r.size {
+ if r.rpos == r.rsize {
if r.final {
return 0, io.EOF
}
@@ -216,7 +245,7 @@ func (r *tdsBuffer) Read(buf []byte) (copied int, err error) {
return
}
}
- copied = copy(buf, r.buf[r.pos:r.size])
- r.pos += uint16(copied)
+ copied = copy(buf, r.rbuf[r.rpos:r.rsize])
+ r.rpos += copied
return
}