diff options
Diffstat (limited to 'vendor/github.com/denisenkom/go-mssqldb/buf.go')
-rw-r--r-- | vendor/github.com/denisenkom/go-mssqldb/buf.go | 222 |
1 files changed, 222 insertions, 0 deletions
diff --git a/vendor/github.com/denisenkom/go-mssqldb/buf.go b/vendor/github.com/denisenkom/go-mssqldb/buf.go new file mode 100644 index 0000000000..42e8ae345c --- /dev/null +++ b/vendor/github.com/denisenkom/go-mssqldb/buf.go @@ -0,0 +1,222 @@ +package mssql + +import ( + "encoding/binary" + "io" + "errors" +) + +type header struct { + PacketType uint8 + Status uint8 + Size uint16 + Spid uint16 + PacketNo uint8 + Pad uint8 +} + +type tdsBuffer struct { + buf []byte + pos uint16 + transport io.ReadWriteCloser + size uint16 + final bool + packet_type uint8 + afterFirst func() +} + +func newTdsBuffer(bufsize int, transport io.ReadWriteCloser) *tdsBuffer { + buf := make([]byte, bufsize) + w := new(tdsBuffer) + w.buf = buf + w.pos = 8 + w.transport = transport + w.size = 0 + return w +} + +func (w *tdsBuffer) flush() (err error) { + // writing packet size + binary.BigEndian.PutUint16(w.buf[2:], w.pos) + + // writing packet into underlying transport + if _, err = w.transport.Write(w.buf[:w.pos]); err != nil { + return err + } + + // execute afterFirst hook if it is set + if w.afterFirst != nil { + w.afterFirst() + w.afterFirst = nil + } + + w.pos = 8 + // packet number + w.buf[6] += 1 + return nil +} + +func (w *tdsBuffer) Write(p []byte) (total int, err error) { + total = 0 + for { + copied := copy(w.buf[w.pos:], p) + w.pos += uint16(copied) + total += copied + if copied == len(p) { + break + } + if err = w.flush(); err != nil { + return + } + p = p[copied:] + } + return +} + +func (w *tdsBuffer) WriteByte(b byte) error { + if int(w.pos) == len(w.buf) { + if err := w.flush(); err != nil { + return err + } + } + w.buf[w.pos] = b + w.pos += 1 + return nil +} + +func (w *tdsBuffer) BeginPacket(packet_type byte) { + w.buf[0] = packet_type + w.buf[1] = 0 // packet is incomplete + w.buf[4] = 0 // spid + w.buf[5] = 0 + w.buf[6] = 1 // packet id + w.buf[7] = 0 // window + w.pos = 8 +} + +func (w *tdsBuffer) FinishPacket() error { + w.buf[1] = 1 // this is last packet + return w.flush() +} + +func (r *tdsBuffer) readNextPacket() error { + header := header{} + var err error + err = binary.Read(r.transport, binary.BigEndian, &header) + if err != nil { + return err + } + offset := uint16(binary.Size(header)) + if int(header.Size) > len(r.buf) { + return errors.New("Invalid packet size, it is longer than buffer size") + } + if int(offset) > int(header.Size) { + return errors.New("Invalid packet size, it is shorter than header size") + } + _, err = io.ReadFull(r.transport, r.buf[offset:header.Size]) + if err != nil { + return err + } + r.pos = offset + r.size = header.Size + r.final = header.Status != 0 + r.packet_type = header.PacketType + return nil +} + +func (r *tdsBuffer) BeginRead() (uint8, error) { + err := r.readNextPacket() + if err != nil { + return 0, err + } + return r.packet_type, nil +} + +func (r *tdsBuffer) ReadByte() (res byte, err error) { + if r.pos == r.size { + if r.final { + return 0, io.EOF + } + err = r.readNextPacket() + if err != nil { + return 0, err + } + } + res = r.buf[r.pos] + r.pos++ + return res, nil +} + +func (r *tdsBuffer) byte() byte { + b, err := r.ReadByte() + if err != nil { + badStreamPanic(err) + } + return b +} + +func (r *tdsBuffer) ReadFull(buf []byte) { + _, err := io.ReadFull(r, buf[:]) + if err != nil { + badStreamPanic(err) + } +} + +func (r *tdsBuffer) uint64() uint64 { + var buf [8]byte + r.ReadFull(buf[:]) + return binary.LittleEndian.Uint64(buf[:]) +} + +func (r *tdsBuffer) int32() int32 { + return int32(r.uint32()) +} + +func (r *tdsBuffer) uint32() uint32 { + var buf [4]byte + r.ReadFull(buf[:]) + return binary.LittleEndian.Uint32(buf[:]) +} + +func (r *tdsBuffer) uint16() uint16 { + var buf [2]byte + r.ReadFull(buf[:]) + return binary.LittleEndian.Uint16(buf[:]) +} + +func (r *tdsBuffer) BVarChar() string { + l := int(r.byte()) + return r.readUcs2(l) +} + +func (r *tdsBuffer) UsVarChar() string { + l := int(r.uint16()) + return r.readUcs2(l) +} + +func (r *tdsBuffer) readUcs2(numchars int) string { + b := make([]byte, numchars*2) + r.ReadFull(b) + res, err := ucs22str(b) + if err != nil { + badStreamPanic(err) + } + return res +} + +func (r *tdsBuffer) Read(buf []byte) (copied int, err error) { + copied = 0 + err = nil + if r.pos == r.size { + if r.final { + return 0, io.EOF + } + err = r.readNextPacket() + if err != nil { + return + } + } + copied = copy(buf, r.buf[r.pos:r.size]) + r.pos += uint16(copied) + return +} |