You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

mssql.go 9.8KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464
  1. package mssql
  2. import (
  3. "database/sql"
  4. "database/sql/driver"
  5. "encoding/binary"
  6. "errors"
  7. "fmt"
  8. "io"
  9. "log"
  10. "math"
  11. "net"
  12. "strings"
  13. "time"
  14. )
  15. func init() {
  16. sql.Register("mssql", &MssqlDriver{})
  17. }
  18. type MssqlDriver struct {
  19. log *log.Logger
  20. }
  21. func (d *MssqlDriver) SetLogger(logger *log.Logger) {
  22. d.log = logger
  23. }
  24. func CheckBadConn(err error) error {
  25. if err == io.EOF {
  26. return driver.ErrBadConn
  27. }
  28. switch e := err.(type) {
  29. case net.Error:
  30. if e.Timeout() {
  31. return e
  32. }
  33. return driver.ErrBadConn
  34. default:
  35. return err
  36. }
  37. }
  38. type MssqlConn struct {
  39. sess *tdsSession
  40. }
  41. func (c *MssqlConn) Commit() error {
  42. headers := []headerStruct{
  43. {hdrtype: dataStmHdrTransDescr,
  44. data: transDescrHdr{c.sess.tranid, 1}.pack()},
  45. }
  46. if err := sendCommitXact(c.sess.buf, headers, "", 0, 0, ""); err != nil {
  47. return err
  48. }
  49. tokchan := make(chan tokenStruct, 5)
  50. go processResponse(c.sess, tokchan)
  51. for tok := range tokchan {
  52. switch token := tok.(type) {
  53. case error:
  54. return token
  55. }
  56. }
  57. return nil
  58. }
  59. func (c *MssqlConn) Rollback() error {
  60. headers := []headerStruct{
  61. {hdrtype: dataStmHdrTransDescr,
  62. data: transDescrHdr{c.sess.tranid, 1}.pack()},
  63. }
  64. if err := sendRollbackXact(c.sess.buf, headers, "", 0, 0, ""); err != nil {
  65. return err
  66. }
  67. tokchan := make(chan tokenStruct, 5)
  68. go processResponse(c.sess, tokchan)
  69. for tok := range tokchan {
  70. switch token := tok.(type) {
  71. case error:
  72. return token
  73. }
  74. }
  75. return nil
  76. }
  77. func (c *MssqlConn) Begin() (driver.Tx, error) {
  78. headers := []headerStruct{
  79. {hdrtype: dataStmHdrTransDescr,
  80. data: transDescrHdr{0, 1}.pack()},
  81. }
  82. if err := sendBeginXact(c.sess.buf, headers, 0, ""); err != nil {
  83. return nil, CheckBadConn(err)
  84. }
  85. tokchan := make(chan tokenStruct, 5)
  86. go processResponse(c.sess, tokchan)
  87. for tok := range tokchan {
  88. switch token := tok.(type) {
  89. case error:
  90. if c.sess.tranid != 0 {
  91. return nil, token
  92. }
  93. return nil, CheckBadConn(token)
  94. }
  95. }
  96. // successful BEGINXACT request will return sess.tranid
  97. // for started transaction
  98. return c, nil
  99. }
  100. func (d *MssqlDriver) Open(dsn string) (driver.Conn, error) {
  101. params, err := parseConnectParams(dsn)
  102. if err != nil {
  103. return nil, err
  104. }
  105. sess, err := connect(params)
  106. if err != nil {
  107. // main server failed, try fail-over partner
  108. if params.failOverPartner == "" {
  109. return nil, err
  110. }
  111. params.host = params.failOverPartner
  112. if params.failOverPort != 0 {
  113. params.port = params.failOverPort
  114. }
  115. sess, err = connect(params)
  116. if err != nil {
  117. // fail-over partner also failed, now fail
  118. return nil, err
  119. }
  120. }
  121. conn := &MssqlConn{sess}
  122. conn.sess.log = (*Logger)(d.log)
  123. return conn, nil
  124. }
  125. func (c *MssqlConn) Close() error {
  126. return c.sess.buf.transport.Close()
  127. }
  128. type MssqlStmt struct {
  129. c *MssqlConn
  130. query string
  131. paramCount int
  132. notifSub *queryNotifSub
  133. }
  134. type queryNotifSub struct {
  135. msgText string
  136. options string
  137. timeout uint32
  138. }
  139. func (c *MssqlConn) Prepare(query string) (driver.Stmt, error) {
  140. q, paramCount := parseParams(query)
  141. return &MssqlStmt{c, q, paramCount, nil}, nil
  142. }
  143. func (s *MssqlStmt) Close() error {
  144. return nil
  145. }
  146. func (s *MssqlStmt) SetQueryNotification(id, options string, timeout time.Duration) {
  147. to := uint32(timeout / time.Second)
  148. if to < 1 {
  149. to = 1
  150. }
  151. s.notifSub = &queryNotifSub{id, options, to}
  152. }
  153. func (s *MssqlStmt) NumInput() int {
  154. return s.paramCount
  155. }
  156. func (s *MssqlStmt) sendQuery(args []driver.Value) (err error) {
  157. headers := []headerStruct{
  158. {hdrtype: dataStmHdrTransDescr,
  159. data: transDescrHdr{s.c.sess.tranid, 1}.pack()},
  160. }
  161. if s.notifSub != nil {
  162. headers = append(headers, headerStruct{hdrtype: dataStmHdrQueryNotif,
  163. data: queryNotifHdr{s.notifSub.msgText, s.notifSub.options, s.notifSub.timeout}.pack()})
  164. }
  165. if len(args) != s.paramCount {
  166. return errors.New(fmt.Sprintf("sql: expected %d parameters, got %d", s.paramCount, len(args)))
  167. }
  168. if s.c.sess.logFlags&logSQL != 0 {
  169. s.c.sess.log.Println(s.query)
  170. }
  171. if s.c.sess.logFlags&logParams != 0 && len(args) > 0 {
  172. for i := 0; i < len(args); i++ {
  173. s.c.sess.log.Printf("\t@p%d\t%v\n", i+1, args[i])
  174. }
  175. }
  176. if len(args) == 0 {
  177. if err = sendSqlBatch72(s.c.sess.buf, s.query, headers); err != nil {
  178. if s.c.sess.tranid != 0 {
  179. return err
  180. }
  181. return CheckBadConn(err)
  182. }
  183. } else {
  184. params := make([]Param, len(args)+2)
  185. decls := make([]string, len(args))
  186. params[0], err = s.makeParam(s.query)
  187. if err != nil {
  188. return
  189. }
  190. for i, val := range args {
  191. params[i+2], err = s.makeParam(val)
  192. if err != nil {
  193. return
  194. }
  195. name := fmt.Sprintf("@p%d", i+1)
  196. params[i+2].Name = name
  197. decls[i] = fmt.Sprintf("%s %s", name, makeDecl(params[i+2].ti))
  198. }
  199. params[1], err = s.makeParam(strings.Join(decls, ","))
  200. if err != nil {
  201. return
  202. }
  203. if err = sendRpc(s.c.sess.buf, headers, Sp_ExecuteSql, 0, params); err != nil {
  204. if s.c.sess.tranid != 0 {
  205. return err
  206. }
  207. return CheckBadConn(err)
  208. }
  209. }
  210. return
  211. }
  212. func (s *MssqlStmt) Query(args []driver.Value) (res driver.Rows, err error) {
  213. if err = s.sendQuery(args); err != nil {
  214. return
  215. }
  216. tokchan := make(chan tokenStruct, 5)
  217. go processResponse(s.c.sess, tokchan)
  218. // process metadata
  219. var cols []string
  220. loop:
  221. for tok := range tokchan {
  222. switch token := tok.(type) {
  223. // by ignoring DONE token we effectively
  224. // skip empty result-sets
  225. // this improves results in queryes like that:
  226. // set nocount on; select 1
  227. // see TestIgnoreEmptyResults test
  228. //case doneStruct:
  229. //break loop
  230. case []columnStruct:
  231. cols = make([]string, len(token))
  232. for i, col := range token {
  233. cols[i] = col.ColName
  234. }
  235. break loop
  236. case error:
  237. if s.c.sess.tranid != 0 {
  238. return nil, token
  239. }
  240. return nil, CheckBadConn(token)
  241. }
  242. }
  243. return &MssqlRows{sess: s.c.sess, tokchan: tokchan, cols: cols}, nil
  244. }
  245. func (s *MssqlStmt) Exec(args []driver.Value) (res driver.Result, err error) {
  246. if err = s.sendQuery(args); err != nil {
  247. return
  248. }
  249. tokchan := make(chan tokenStruct, 5)
  250. go processResponse(s.c.sess, tokchan)
  251. var rowCount int64
  252. for token := range tokchan {
  253. switch token := token.(type) {
  254. case doneInProcStruct:
  255. if token.Status&doneCount != 0 {
  256. rowCount = int64(token.RowCount)
  257. }
  258. case doneStruct:
  259. if token.Status&doneCount != 0 {
  260. rowCount = int64(token.RowCount)
  261. }
  262. case error:
  263. if s.c.sess.logFlags&logErrors != 0 {
  264. s.c.sess.log.Println("got error:", token)
  265. }
  266. if s.c.sess.tranid != 0 {
  267. return nil, token
  268. }
  269. return nil, CheckBadConn(token)
  270. }
  271. }
  272. return &MssqlResult{s.c, rowCount}, nil
  273. }
  274. type MssqlRows struct {
  275. sess *tdsSession
  276. cols []string
  277. tokchan chan tokenStruct
  278. nextCols []string
  279. }
  280. func (rc *MssqlRows) Close() error {
  281. for _ = range rc.tokchan {
  282. }
  283. rc.tokchan = nil
  284. return nil
  285. }
  286. func (rc *MssqlRows) Columns() (res []string) {
  287. return rc.cols
  288. }
  289. func (rc *MssqlRows) Next(dest []driver.Value) (err error) {
  290. if rc.nextCols != nil {
  291. return io.EOF
  292. }
  293. for tok := range rc.tokchan {
  294. switch tokdata := tok.(type) {
  295. case []columnStruct:
  296. cols := make([]string, len(tokdata))
  297. for i, col := range tokdata {
  298. cols[i] = col.ColName
  299. }
  300. rc.nextCols = cols
  301. return io.EOF
  302. case []interface{}:
  303. for i := range dest {
  304. dest[i] = tokdata[i]
  305. }
  306. return nil
  307. case error:
  308. return tokdata
  309. }
  310. }
  311. return io.EOF
  312. }
  313. func (rc *MssqlRows) HasNextResultSet() bool {
  314. return rc.nextCols != nil
  315. }
  316. func (rc *MssqlRows) NextResultSet() error {
  317. rc.cols = rc.nextCols
  318. rc.nextCols = nil
  319. if rc.cols == nil {
  320. return io.EOF
  321. }
  322. return nil
  323. }
  324. func (s *MssqlStmt) makeParam(val driver.Value) (res Param, err error) {
  325. if val == nil {
  326. res.ti.TypeId = typeNVarChar
  327. res.buffer = nil
  328. res.ti.Size = 2
  329. return
  330. }
  331. switch val := val.(type) {
  332. case int64:
  333. res.ti.TypeId = typeIntN
  334. res.buffer = make([]byte, 8)
  335. res.ti.Size = 8
  336. binary.LittleEndian.PutUint64(res.buffer, uint64(val))
  337. case float64:
  338. res.ti.TypeId = typeFltN
  339. res.ti.Size = 8
  340. res.buffer = make([]byte, 8)
  341. binary.LittleEndian.PutUint64(res.buffer, math.Float64bits(val))
  342. case []byte:
  343. res.ti.TypeId = typeBigVarBin
  344. res.ti.Size = len(val)
  345. res.buffer = val
  346. case string:
  347. res.ti.TypeId = typeNVarChar
  348. res.buffer = str2ucs2(val)
  349. res.ti.Size = len(res.buffer)
  350. case bool:
  351. res.ti.TypeId = typeBitN
  352. res.ti.Size = 1
  353. res.buffer = make([]byte, 1)
  354. if val {
  355. res.buffer[0] = 1
  356. }
  357. case time.Time:
  358. if s.c.sess.loginAck.TDSVersion >= verTDS73 {
  359. res.ti.TypeId = typeDateTimeOffsetN
  360. res.ti.Scale = 7
  361. res.ti.Size = 10
  362. buf := make([]byte, 10)
  363. res.buffer = buf
  364. days, ns := dateTime2(val)
  365. ns /= 100
  366. buf[0] = byte(ns)
  367. buf[1] = byte(ns >> 8)
  368. buf[2] = byte(ns >> 16)
  369. buf[3] = byte(ns >> 24)
  370. buf[4] = byte(ns >> 32)
  371. buf[5] = byte(days)
  372. buf[6] = byte(days >> 8)
  373. buf[7] = byte(days >> 16)
  374. _, offset := val.Zone()
  375. offset /= 60
  376. buf[8] = byte(offset)
  377. buf[9] = byte(offset >> 8)
  378. } else {
  379. res.ti.TypeId = typeDateTimeN
  380. res.ti.Size = 8
  381. res.buffer = make([]byte, 8)
  382. ref := time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC)
  383. dur := val.Sub(ref)
  384. days := dur / (24 * time.Hour)
  385. tm := (300 * (dur % (24 * time.Hour))) / time.Second
  386. binary.LittleEndian.PutUint32(res.buffer[0:4], uint32(days))
  387. binary.LittleEndian.PutUint32(res.buffer[4:8], uint32(tm))
  388. }
  389. default:
  390. err = fmt.Errorf("mssql: unknown type for %T", val)
  391. return
  392. }
  393. return
  394. }
  395. type MssqlResult struct {
  396. c *MssqlConn
  397. rowsAffected int64
  398. }
  399. func (r *MssqlResult) RowsAffected() (int64, error) {
  400. return r.rowsAffected, nil
  401. }
  402. func (r *MssqlResult) LastInsertId() (int64, error) {
  403. s, err := r.c.Prepare("select cast(@@identity as bigint)")
  404. if err != nil {
  405. return 0, err
  406. }
  407. defer s.Close()
  408. rows, err := s.Query(nil)
  409. if err != nil {
  410. return 0, err
  411. }
  412. defer rows.Close()
  413. dest := make([]driver.Value, 1)
  414. err = rows.Next(dest)
  415. if err != nil {
  416. return 0, err
  417. }
  418. if dest[0] == nil {
  419. return -1, errors.New("There is no generated identity value")
  420. }
  421. lastInsertId := dest[0].(int64)
  422. return lastInsertId, nil
  423. }