summaryrefslogtreecommitdiffstats
path: root/vendor/github.com/denisenkom/go-mssqldb/buf.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/denisenkom/go-mssqldb/buf.go')
-rw-r--r--vendor/github.com/denisenkom/go-mssqldb/buf.go222
1 files changed, 222 insertions, 0 deletions
diff --git a/vendor/github.com/denisenkom/go-mssqldb/buf.go b/vendor/github.com/denisenkom/go-mssqldb/buf.go
new file mode 100644
index 0000000000..42e8ae345c
--- /dev/null
+++ b/vendor/github.com/denisenkom/go-mssqldb/buf.go
@@ -0,0 +1,222 @@
+package mssql
+
+import (
+ "encoding/binary"
+ "io"
+ "errors"
+)
+
+type header struct {
+ PacketType uint8
+ Status uint8
+ Size uint16
+ Spid uint16
+ PacketNo uint8
+ Pad uint8
+}
+
+type tdsBuffer struct {
+ buf []byte
+ pos uint16
+ transport io.ReadWriteCloser
+ size uint16
+ final bool
+ packet_type uint8
+ afterFirst func()
+}
+
+func newTdsBuffer(bufsize int, transport io.ReadWriteCloser) *tdsBuffer {
+ buf := make([]byte, bufsize)
+ w := new(tdsBuffer)
+ w.buf = buf
+ w.pos = 8
+ w.transport = transport
+ w.size = 0
+ return w
+}
+
+func (w *tdsBuffer) flush() (err error) {
+ // writing packet size
+ binary.BigEndian.PutUint16(w.buf[2:], w.pos)
+
+ // writing packet into underlying transport
+ if _, err = w.transport.Write(w.buf[:w.pos]); err != nil {
+ return err
+ }
+
+ // execute afterFirst hook if it is set
+ if w.afterFirst != nil {
+ w.afterFirst()
+ w.afterFirst = nil
+ }
+
+ w.pos = 8
+ // packet number
+ w.buf[6] += 1
+ return nil
+}
+
+func (w *tdsBuffer) Write(p []byte) (total int, err error) {
+ total = 0
+ for {
+ copied := copy(w.buf[w.pos:], p)
+ w.pos += uint16(copied)
+ total += copied
+ if copied == len(p) {
+ break
+ }
+ if err = w.flush(); err != nil {
+ return
+ }
+ p = p[copied:]
+ }
+ return
+}
+
+func (w *tdsBuffer) WriteByte(b byte) error {
+ if int(w.pos) == len(w.buf) {
+ if err := w.flush(); err != nil {
+ return err
+ }
+ }
+ w.buf[w.pos] = b
+ w.pos += 1
+ return nil
+}
+
+func (w *tdsBuffer) BeginPacket(packet_type byte) {
+ w.buf[0] = packet_type
+ w.buf[1] = 0 // packet is incomplete
+ w.buf[4] = 0 // spid
+ w.buf[5] = 0
+ w.buf[6] = 1 // packet id
+ w.buf[7] = 0 // window
+ w.pos = 8
+}
+
+func (w *tdsBuffer) FinishPacket() error {
+ w.buf[1] = 1 // this is last packet
+ return w.flush()
+}
+
+func (r *tdsBuffer) readNextPacket() error {
+ header := header{}
+ var err error
+ err = binary.Read(r.transport, binary.BigEndian, &header)
+ if err != nil {
+ return err
+ }
+ offset := uint16(binary.Size(header))
+ if int(header.Size) > len(r.buf) {
+ return errors.New("Invalid packet size, it is longer than buffer size")
+ }
+ if int(offset) > int(header.Size) {
+ return errors.New("Invalid packet size, it is shorter than header size")
+ }
+ _, err = io.ReadFull(r.transport, r.buf[offset:header.Size])
+ if err != nil {
+ return err
+ }
+ r.pos = offset
+ r.size = header.Size
+ r.final = header.Status != 0
+ r.packet_type = header.PacketType
+ return nil
+}
+
+func (r *tdsBuffer) BeginRead() (uint8, error) {
+ err := r.readNextPacket()
+ if err != nil {
+ return 0, err
+ }
+ return r.packet_type, nil
+}
+
+func (r *tdsBuffer) ReadByte() (res byte, err error) {
+ if r.pos == r.size {
+ if r.final {
+ return 0, io.EOF
+ }
+ err = r.readNextPacket()
+ if err != nil {
+ return 0, err
+ }
+ }
+ res = r.buf[r.pos]
+ r.pos++
+ return res, nil
+}
+
+func (r *tdsBuffer) byte() byte {
+ b, err := r.ReadByte()
+ if err != nil {
+ badStreamPanic(err)
+ }
+ return b
+}
+
+func (r *tdsBuffer) ReadFull(buf []byte) {
+ _, err := io.ReadFull(r, buf[:])
+ if err != nil {
+ badStreamPanic(err)
+ }
+}
+
+func (r *tdsBuffer) uint64() uint64 {
+ var buf [8]byte
+ r.ReadFull(buf[:])
+ return binary.LittleEndian.Uint64(buf[:])
+}
+
+func (r *tdsBuffer) int32() int32 {
+ return int32(r.uint32())
+}
+
+func (r *tdsBuffer) uint32() uint32 {
+ var buf [4]byte
+ r.ReadFull(buf[:])
+ return binary.LittleEndian.Uint32(buf[:])
+}
+
+func (r *tdsBuffer) uint16() uint16 {
+ var buf [2]byte
+ r.ReadFull(buf[:])
+ return binary.LittleEndian.Uint16(buf[:])
+}
+
+func (r *tdsBuffer) BVarChar() string {
+ l := int(r.byte())
+ return r.readUcs2(l)
+}
+
+func (r *tdsBuffer) UsVarChar() string {
+ l := int(r.uint16())
+ return r.readUcs2(l)
+}
+
+func (r *tdsBuffer) readUcs2(numchars int) string {
+ b := make([]byte, numchars*2)
+ r.ReadFull(b)
+ res, err := ucs22str(b)
+ if err != nil {
+ badStreamPanic(err)
+ }
+ return res
+}
+
+func (r *tdsBuffer) Read(buf []byte) (copied int, err error) {
+ copied = 0
+ err = nil
+ if r.pos == r.size {
+ if r.final {
+ return 0, io.EOF
+ }
+ err = r.readNextPacket()
+ if err != nil {
+ return
+ }
+ }
+ copied = copy(buf, r.buf[r.pos:r.size])
+ r.pos += uint16(copied)
+ return
+}