diff options
Diffstat (limited to 'vendor/github.com/denisenkom/go-mssqldb/buf.go')
-rw-r--r-- | vendor/github.com/denisenkom/go-mssqldb/buf.go | 145 |
1 files changed, 87 insertions, 58 deletions
diff --git a/vendor/github.com/denisenkom/go-mssqldb/buf.go b/vendor/github.com/denisenkom/go-mssqldb/buf.go index 42e8ae345c..365acd4833 100644 --- a/vendor/github.com/denisenkom/go-mssqldb/buf.go +++ b/vendor/github.com/denisenkom/go-mssqldb/buf.go @@ -2,12 +2,14 @@ package mssql import ( "encoding/binary" - "io" "errors" + "io" ) +type packetType uint8 + type header struct { - PacketType uint8 + PacketType packetType Status uint8 Size uint16 Spid uint16 @@ -15,55 +17,84 @@ type header struct { Pad uint8 } +// tdsBuffer reads and writes TDS packets of data to the transport. +// The write and read buffers are separate to make sending attn signals +// possible without locks. Currently attn signals are only sent during +// reads, not writes. type tdsBuffer struct { - buf []byte - pos uint16 - transport io.ReadWriteCloser - size uint16 + transport io.ReadWriteCloser + + packetSize int + + // Write fields. + wbuf []byte + wpos int + wPacketSeq byte + wPacketType packetType + + // Read fields. + rbuf []byte + rpos int + rsize int final bool - packet_type uint8 - afterFirst func() + rPacketType packetType + + // afterFirst is assigned to right after tdsBuffer is created and + // before the first use. It is executed after the first packet is + // written and then removed. + afterFirst func() } -func newTdsBuffer(bufsize int, transport io.ReadWriteCloser) *tdsBuffer { - buf := make([]byte, bufsize) - w := new(tdsBuffer) - w.buf = buf - w.pos = 8 - w.transport = transport - w.size = 0 - return w +func newTdsBuffer(bufsize uint16, transport io.ReadWriteCloser) *tdsBuffer { + return &tdsBuffer{ + packetSize: int(bufsize), + wbuf: make([]byte, 1<<16), + rbuf: make([]byte, 1<<16), + rpos: 8, + transport: transport, + } +} + +func (rw *tdsBuffer) ResizeBuffer(packetSize int) { + rw.packetSize = packetSize +} + +func (w *tdsBuffer) PackageSize() int { + return w.packetSize } func (w *tdsBuffer) flush() (err error) { - // writing packet size - binary.BigEndian.PutUint16(w.buf[2:], w.pos) + // Write packet size. + w.wbuf[0] = byte(w.wPacketType) + binary.BigEndian.PutUint16(w.wbuf[2:], uint16(w.wpos)) + w.wbuf[6] = w.wPacketSeq - // writing packet into underlying transport - if _, err = w.transport.Write(w.buf[:w.pos]); err != nil { + // Write packet into underlying transport. + if _, err = w.transport.Write(w.wbuf[:w.wpos]); err != nil { return err } + // It is possible to create a whole new buffer after a flush. + // Useful for debugging. Normally reuse the buffer. + // w.wbuf = make([]byte, 1<<16) - // execute afterFirst hook if it is set + // Execute afterFirst hook if it is set. if w.afterFirst != nil { w.afterFirst() w.afterFirst = nil } - w.pos = 8 - // packet number - w.buf[6] += 1 + w.wpos = 8 + w.wPacketSeq++ return nil } func (w *tdsBuffer) Write(p []byte) (total int, err error) { - total = 0 for { - copied := copy(w.buf[w.pos:], p) - w.pos += uint16(copied) + copied := copy(w.wbuf[w.wpos:w.packetSize], p) + w.wpos += copied total += copied if copied == len(p) { - break + return } if err = w.flush(); err != nil { return @@ -74,66 +105,64 @@ func (w *tdsBuffer) Write(p []byte) (total int, err error) { } func (w *tdsBuffer) WriteByte(b byte) error { - if int(w.pos) == len(w.buf) { + if int(w.wpos) == len(w.wbuf) { if err := w.flush(); err != nil { return err } } - w.buf[w.pos] = b - w.pos += 1 + w.wbuf[w.wpos] = b + w.wpos += 1 return nil } -func (w *tdsBuffer) BeginPacket(packet_type byte) { - w.buf[0] = packet_type - w.buf[1] = 0 // packet is incomplete - w.buf[4] = 0 // spid - w.buf[5] = 0 - w.buf[6] = 1 // packet id - w.buf[7] = 0 // window - w.pos = 8 +func (w *tdsBuffer) BeginPacket(packetType packetType) { + w.wbuf[1] = 0 // Packet is incomplete. This byte is set again in FinishPacket. + w.wpos = 8 + w.wPacketSeq = 1 + w.wPacketType = packetType } func (w *tdsBuffer) FinishPacket() error { - w.buf[1] = 1 // this is last packet + w.wbuf[1] = 1 // Mark this as the last packet in the message. return w.flush() } +var headerSize = binary.Size(header{}) + func (r *tdsBuffer) readNextPacket() error { - header := header{} + h := header{} var err error - err = binary.Read(r.transport, binary.BigEndian, &header) + err = binary.Read(r.transport, binary.BigEndian, &h) if err != nil { return err } - offset := uint16(binary.Size(header)) - if int(header.Size) > len(r.buf) { + if int(h.Size) > len(r.rbuf) { return errors.New("Invalid packet size, it is longer than buffer size") } - if int(offset) > int(header.Size) { + if headerSize > int(h.Size) { return errors.New("Invalid packet size, it is shorter than header size") } - _, err = io.ReadFull(r.transport, r.buf[offset:header.Size]) + _, err = io.ReadFull(r.transport, r.rbuf[headerSize:h.Size]) if err != nil { return err } - r.pos = offset - r.size = header.Size - r.final = header.Status != 0 - r.packet_type = header.PacketType + r.rpos = headerSize + r.rsize = int(h.Size) + r.final = h.Status != 0 + r.rPacketType = h.PacketType return nil } -func (r *tdsBuffer) BeginRead() (uint8, error) { +func (r *tdsBuffer) BeginRead() (packetType, error) { err := r.readNextPacket() if err != nil { return 0, err } - return r.packet_type, nil + return r.rPacketType, nil } func (r *tdsBuffer) ReadByte() (res byte, err error) { - if r.pos == r.size { + if r.rpos == r.rsize { if r.final { return 0, io.EOF } @@ -142,8 +171,8 @@ func (r *tdsBuffer) ReadByte() (res byte, err error) { return 0, err } } - res = r.buf[r.pos] - r.pos++ + res = r.rbuf[r.rpos] + r.rpos++ return res, nil } @@ -207,7 +236,7 @@ func (r *tdsBuffer) readUcs2(numchars int) string { func (r *tdsBuffer) Read(buf []byte) (copied int, err error) { copied = 0 err = nil - if r.pos == r.size { + if r.rpos == r.rsize { if r.final { return 0, io.EOF } @@ -216,7 +245,7 @@ func (r *tdsBuffer) Read(buf []byte) (copied int, err error) { return } } - copied = copy(buf, r.buf[r.pos:r.size]) - r.pos += uint16(copied) + copied = copy(buf, r.rbuf[r.rpos:r.rsize]) + r.rpos += copied return } |