You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

buf.go 4.0KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222
  1. package mssql
  2. import (
  3. "encoding/binary"
  4. "io"
  5. "errors"
  6. )
  7. type header struct {
  8. PacketType uint8
  9. Status uint8
  10. Size uint16
  11. Spid uint16
  12. PacketNo uint8
  13. Pad uint8
  14. }
  15. type tdsBuffer struct {
  16. buf []byte
  17. pos uint16
  18. transport io.ReadWriteCloser
  19. size uint16
  20. final bool
  21. packet_type uint8
  22. afterFirst func()
  23. }
  24. func newTdsBuffer(bufsize int, transport io.ReadWriteCloser) *tdsBuffer {
  25. buf := make([]byte, bufsize)
  26. w := new(tdsBuffer)
  27. w.buf = buf
  28. w.pos = 8
  29. w.transport = transport
  30. w.size = 0
  31. return w
  32. }
  33. func (w *tdsBuffer) flush() (err error) {
  34. // writing packet size
  35. binary.BigEndian.PutUint16(w.buf[2:], w.pos)
  36. // writing packet into underlying transport
  37. if _, err = w.transport.Write(w.buf[:w.pos]); err != nil {
  38. return err
  39. }
  40. // execute afterFirst hook if it is set
  41. if w.afterFirst != nil {
  42. w.afterFirst()
  43. w.afterFirst = nil
  44. }
  45. w.pos = 8
  46. // packet number
  47. w.buf[6] += 1
  48. return nil
  49. }
  50. func (w *tdsBuffer) Write(p []byte) (total int, err error) {
  51. total = 0
  52. for {
  53. copied := copy(w.buf[w.pos:], p)
  54. w.pos += uint16(copied)
  55. total += copied
  56. if copied == len(p) {
  57. break
  58. }
  59. if err = w.flush(); err != nil {
  60. return
  61. }
  62. p = p[copied:]
  63. }
  64. return
  65. }
  66. func (w *tdsBuffer) WriteByte(b byte) error {
  67. if int(w.pos) == len(w.buf) {
  68. if err := w.flush(); err != nil {
  69. return err
  70. }
  71. }
  72. w.buf[w.pos] = b
  73. w.pos += 1
  74. return nil
  75. }
  76. func (w *tdsBuffer) BeginPacket(packet_type byte) {
  77. w.buf[0] = packet_type
  78. w.buf[1] = 0 // packet is incomplete
  79. w.buf[4] = 0 // spid
  80. w.buf[5] = 0
  81. w.buf[6] = 1 // packet id
  82. w.buf[7] = 0 // window
  83. w.pos = 8
  84. }
  85. func (w *tdsBuffer) FinishPacket() error {
  86. w.buf[1] = 1 // this is last packet
  87. return w.flush()
  88. }
  89. func (r *tdsBuffer) readNextPacket() error {
  90. header := header{}
  91. var err error
  92. err = binary.Read(r.transport, binary.BigEndian, &header)
  93. if err != nil {
  94. return err
  95. }
  96. offset := uint16(binary.Size(header))
  97. if int(header.Size) > len(r.buf) {
  98. return errors.New("Invalid packet size, it is longer than buffer size")
  99. }
  100. if int(offset) > int(header.Size) {
  101. return errors.New("Invalid packet size, it is shorter than header size")
  102. }
  103. _, err = io.ReadFull(r.transport, r.buf[offset:header.Size])
  104. if err != nil {
  105. return err
  106. }
  107. r.pos = offset
  108. r.size = header.Size
  109. r.final = header.Status != 0
  110. r.packet_type = header.PacketType
  111. return nil
  112. }
  113. func (r *tdsBuffer) BeginRead() (uint8, error) {
  114. err := r.readNextPacket()
  115. if err != nil {
  116. return 0, err
  117. }
  118. return r.packet_type, nil
  119. }
  120. func (r *tdsBuffer) ReadByte() (res byte, err error) {
  121. if r.pos == r.size {
  122. if r.final {
  123. return 0, io.EOF
  124. }
  125. err = r.readNextPacket()
  126. if err != nil {
  127. return 0, err
  128. }
  129. }
  130. res = r.buf[r.pos]
  131. r.pos++
  132. return res, nil
  133. }
  134. func (r *tdsBuffer) byte() byte {
  135. b, err := r.ReadByte()
  136. if err != nil {
  137. badStreamPanic(err)
  138. }
  139. return b
  140. }
  141. func (r *tdsBuffer) ReadFull(buf []byte) {
  142. _, err := io.ReadFull(r, buf[:])
  143. if err != nil {
  144. badStreamPanic(err)
  145. }
  146. }
  147. func (r *tdsBuffer) uint64() uint64 {
  148. var buf [8]byte
  149. r.ReadFull(buf[:])
  150. return binary.LittleEndian.Uint64(buf[:])
  151. }
  152. func (r *tdsBuffer) int32() int32 {
  153. return int32(r.uint32())
  154. }
  155. func (r *tdsBuffer) uint32() uint32 {
  156. var buf [4]byte
  157. r.ReadFull(buf[:])
  158. return binary.LittleEndian.Uint32(buf[:])
  159. }
  160. func (r *tdsBuffer) uint16() uint16 {
  161. var buf [2]byte
  162. r.ReadFull(buf[:])
  163. return binary.LittleEndian.Uint16(buf[:])
  164. }
  165. func (r *tdsBuffer) BVarChar() string {
  166. l := int(r.byte())
  167. return r.readUcs2(l)
  168. }
  169. func (r *tdsBuffer) UsVarChar() string {
  170. l := int(r.uint16())
  171. return r.readUcs2(l)
  172. }
  173. func (r *tdsBuffer) readUcs2(numchars int) string {
  174. b := make([]byte, numchars*2)
  175. r.ReadFull(b)
  176. res, err := ucs22str(b)
  177. if err != nil {
  178. badStreamPanic(err)
  179. }
  180. return res
  181. }
  182. func (r *tdsBuffer) Read(buf []byte) (copied int, err error) {
  183. copied = 0
  184. err = nil
  185. if r.pos == r.size {
  186. if r.final {
  187. return 0, io.EOF
  188. }
  189. err = r.readNextPacket()
  190. if err != nil {
  191. return
  192. }
  193. }
  194. copied = copy(buf, r.buf[r.pos:r.size])
  195. r.pos += uint16(copied)
  196. return
  197. }