You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

db.go 9.0KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401
  1. package core
  2. import (
  3. "database/sql"
  4. "database/sql/driver"
  5. "errors"
  6. "fmt"
  7. "reflect"
  8. "regexp"
  9. "sync"
  10. )
  11. var (
  12. DefaultCacheSize = 200
  13. )
  14. func MapToSlice(query string, mp interface{}) (string, []interface{}, error) {
  15. vv := reflect.ValueOf(mp)
  16. if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map {
  17. return "", []interface{}{}, ErrNoMapPointer
  18. }
  19. args := make([]interface{}, 0, len(vv.Elem().MapKeys()))
  20. var err error
  21. query = re.ReplaceAllStringFunc(query, func(src string) string {
  22. v := vv.Elem().MapIndex(reflect.ValueOf(src[1:]))
  23. if !v.IsValid() {
  24. err = fmt.Errorf("map key %s is missing", src[1:])
  25. } else {
  26. args = append(args, v.Interface())
  27. }
  28. return "?"
  29. })
  30. return query, args, err
  31. }
  32. func StructToSlice(query string, st interface{}) (string, []interface{}, error) {
  33. vv := reflect.ValueOf(st)
  34. if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct {
  35. return "", []interface{}{}, ErrNoStructPointer
  36. }
  37. args := make([]interface{}, 0)
  38. var err error
  39. query = re.ReplaceAllStringFunc(query, func(src string) string {
  40. fv := vv.Elem().FieldByName(src[1:]).Interface()
  41. if v, ok := fv.(driver.Valuer); ok {
  42. var value driver.Value
  43. value, err = v.Value()
  44. if err != nil {
  45. return "?"
  46. }
  47. args = append(args, value)
  48. } else {
  49. args = append(args, fv)
  50. }
  51. return "?"
  52. })
  53. if err != nil {
  54. return "", []interface{}{}, err
  55. }
  56. return query, args, nil
  57. }
  58. type cacheStruct struct {
  59. value reflect.Value
  60. idx int
  61. }
  62. type DB struct {
  63. *sql.DB
  64. Mapper IMapper
  65. reflectCache map[reflect.Type]*cacheStruct
  66. reflectCacheMutex sync.RWMutex
  67. }
  68. func Open(driverName, dataSourceName string) (*DB, error) {
  69. db, err := sql.Open(driverName, dataSourceName)
  70. if err != nil {
  71. return nil, err
  72. }
  73. return &DB{
  74. DB: db,
  75. Mapper: NewCacheMapper(&SnakeMapper{}),
  76. reflectCache: make(map[reflect.Type]*cacheStruct),
  77. }, nil
  78. }
  79. func FromDB(db *sql.DB) *DB {
  80. return &DB{
  81. DB: db,
  82. Mapper: NewCacheMapper(&SnakeMapper{}),
  83. reflectCache: make(map[reflect.Type]*cacheStruct),
  84. }
  85. }
  86. func (db *DB) reflectNew(typ reflect.Type) reflect.Value {
  87. db.reflectCacheMutex.Lock()
  88. defer db.reflectCacheMutex.Unlock()
  89. cs, ok := db.reflectCache[typ]
  90. if !ok || cs.idx+1 > DefaultCacheSize-1 {
  91. cs = &cacheStruct{reflect.MakeSlice(reflect.SliceOf(typ), DefaultCacheSize, DefaultCacheSize), 0}
  92. db.reflectCache[typ] = cs
  93. } else {
  94. cs.idx = cs.idx + 1
  95. }
  96. return cs.value.Index(cs.idx).Addr()
  97. }
  98. func (db *DB) Query(query string, args ...interface{}) (*Rows, error) {
  99. rows, err := db.DB.Query(query, args...)
  100. if err != nil {
  101. if rows != nil {
  102. rows.Close()
  103. }
  104. return nil, err
  105. }
  106. return &Rows{rows, db}, nil
  107. }
  108. func (db *DB) QueryMap(query string, mp interface{}) (*Rows, error) {
  109. query, args, err := MapToSlice(query, mp)
  110. if err != nil {
  111. return nil, err
  112. }
  113. return db.Query(query, args...)
  114. }
  115. func (db *DB) QueryStruct(query string, st interface{}) (*Rows, error) {
  116. query, args, err := StructToSlice(query, st)
  117. if err != nil {
  118. return nil, err
  119. }
  120. return db.Query(query, args...)
  121. }
  122. func (db *DB) QueryRow(query string, args ...interface{}) *Row {
  123. rows, err := db.Query(query, args...)
  124. if err != nil {
  125. return &Row{nil, err}
  126. }
  127. return &Row{rows, nil}
  128. }
  129. func (db *DB) QueryRowMap(query string, mp interface{}) *Row {
  130. query, args, err := MapToSlice(query, mp)
  131. if err != nil {
  132. return &Row{nil, err}
  133. }
  134. return db.QueryRow(query, args...)
  135. }
  136. func (db *DB) QueryRowStruct(query string, st interface{}) *Row {
  137. query, args, err := StructToSlice(query, st)
  138. if err != nil {
  139. return &Row{nil, err}
  140. }
  141. return db.QueryRow(query, args...)
  142. }
  143. type Stmt struct {
  144. *sql.Stmt
  145. db *DB
  146. names map[string]int
  147. }
  148. func (db *DB) Prepare(query string) (*Stmt, error) {
  149. names := make(map[string]int)
  150. var i int
  151. query = re.ReplaceAllStringFunc(query, func(src string) string {
  152. names[src[1:]] = i
  153. i += 1
  154. return "?"
  155. })
  156. stmt, err := db.DB.Prepare(query)
  157. if err != nil {
  158. return nil, err
  159. }
  160. return &Stmt{stmt, db, names}, nil
  161. }
  162. func (s *Stmt) ExecMap(mp interface{}) (sql.Result, error) {
  163. vv := reflect.ValueOf(mp)
  164. if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map {
  165. return nil, errors.New("mp should be a map's pointer")
  166. }
  167. args := make([]interface{}, len(s.names))
  168. for k, i := range s.names {
  169. args[i] = vv.Elem().MapIndex(reflect.ValueOf(k)).Interface()
  170. }
  171. return s.Stmt.Exec(args...)
  172. }
  173. func (s *Stmt) ExecStruct(st interface{}) (sql.Result, error) {
  174. vv := reflect.ValueOf(st)
  175. if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct {
  176. return nil, errors.New("mp should be a map's pointer")
  177. }
  178. args := make([]interface{}, len(s.names))
  179. for k, i := range s.names {
  180. args[i] = vv.Elem().FieldByName(k).Interface()
  181. }
  182. return s.Stmt.Exec(args...)
  183. }
  184. func (s *Stmt) Query(args ...interface{}) (*Rows, error) {
  185. rows, err := s.Stmt.Query(args...)
  186. if err != nil {
  187. return nil, err
  188. }
  189. return &Rows{rows, s.db}, nil
  190. }
  191. func (s *Stmt) QueryMap(mp interface{}) (*Rows, error) {
  192. vv := reflect.ValueOf(mp)
  193. if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map {
  194. return nil, errors.New("mp should be a map's pointer")
  195. }
  196. args := make([]interface{}, len(s.names))
  197. for k, i := range s.names {
  198. args[i] = vv.Elem().MapIndex(reflect.ValueOf(k)).Interface()
  199. }
  200. return s.Query(args...)
  201. }
  202. func (s *Stmt) QueryStruct(st interface{}) (*Rows, error) {
  203. vv := reflect.ValueOf(st)
  204. if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct {
  205. return nil, errors.New("mp should be a map's pointer")
  206. }
  207. args := make([]interface{}, len(s.names))
  208. for k, i := range s.names {
  209. args[i] = vv.Elem().FieldByName(k).Interface()
  210. }
  211. return s.Query(args...)
  212. }
  213. func (s *Stmt) QueryRow(args ...interface{}) *Row {
  214. rows, err := s.Query(args...)
  215. return &Row{rows, err}
  216. }
  217. func (s *Stmt) QueryRowMap(mp interface{}) *Row {
  218. vv := reflect.ValueOf(mp)
  219. if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map {
  220. return &Row{nil, errors.New("mp should be a map's pointer")}
  221. }
  222. args := make([]interface{}, len(s.names))
  223. for k, i := range s.names {
  224. args[i] = vv.Elem().MapIndex(reflect.ValueOf(k)).Interface()
  225. }
  226. return s.QueryRow(args...)
  227. }
  228. func (s *Stmt) QueryRowStruct(st interface{}) *Row {
  229. vv := reflect.ValueOf(st)
  230. if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct {
  231. return &Row{nil, errors.New("st should be a struct's pointer")}
  232. }
  233. args := make([]interface{}, len(s.names))
  234. for k, i := range s.names {
  235. args[i] = vv.Elem().FieldByName(k).Interface()
  236. }
  237. return s.QueryRow(args...)
  238. }
  239. var (
  240. re = regexp.MustCompile(`[?](\w+)`)
  241. )
  242. // insert into (name) values (?)
  243. // insert into (name) values (?name)
  244. func (db *DB) ExecMap(query string, mp interface{}) (sql.Result, error) {
  245. query, args, err := MapToSlice(query, mp)
  246. if err != nil {
  247. return nil, err
  248. }
  249. return db.DB.Exec(query, args...)
  250. }
  251. func (db *DB) ExecStruct(query string, st interface{}) (sql.Result, error) {
  252. query, args, err := StructToSlice(query, st)
  253. if err != nil {
  254. return nil, err
  255. }
  256. return db.DB.Exec(query, args...)
  257. }
  258. type EmptyScanner struct {
  259. }
  260. func (EmptyScanner) Scan(src interface{}) error {
  261. return nil
  262. }
  263. type Tx struct {
  264. *sql.Tx
  265. db *DB
  266. }
  267. func (db *DB) Begin() (*Tx, error) {
  268. tx, err := db.DB.Begin()
  269. if err != nil {
  270. return nil, err
  271. }
  272. return &Tx{tx, db}, nil
  273. }
  274. func (tx *Tx) Prepare(query string) (*Stmt, error) {
  275. names := make(map[string]int)
  276. var i int
  277. query = re.ReplaceAllStringFunc(query, func(src string) string {
  278. names[src[1:]] = i
  279. i += 1
  280. return "?"
  281. })
  282. stmt, err := tx.Tx.Prepare(query)
  283. if err != nil {
  284. return nil, err
  285. }
  286. return &Stmt{stmt, tx.db, names}, nil
  287. }
  288. func (tx *Tx) Stmt(stmt *Stmt) *Stmt {
  289. // TODO:
  290. return stmt
  291. }
  292. func (tx *Tx) ExecMap(query string, mp interface{}) (sql.Result, error) {
  293. query, args, err := MapToSlice(query, mp)
  294. if err != nil {
  295. return nil, err
  296. }
  297. return tx.Tx.Exec(query, args...)
  298. }
  299. func (tx *Tx) ExecStruct(query string, st interface{}) (sql.Result, error) {
  300. query, args, err := StructToSlice(query, st)
  301. if err != nil {
  302. return nil, err
  303. }
  304. return tx.Tx.Exec(query, args...)
  305. }
  306. func (tx *Tx) Query(query string, args ...interface{}) (*Rows, error) {
  307. rows, err := tx.Tx.Query(query, args...)
  308. if err != nil {
  309. return nil, err
  310. }
  311. return &Rows{rows, tx.db}, nil
  312. }
  313. func (tx *Tx) QueryMap(query string, mp interface{}) (*Rows, error) {
  314. query, args, err := MapToSlice(query, mp)
  315. if err != nil {
  316. return nil, err
  317. }
  318. return tx.Query(query, args...)
  319. }
  320. func (tx *Tx) QueryStruct(query string, st interface{}) (*Rows, error) {
  321. query, args, err := StructToSlice(query, st)
  322. if err != nil {
  323. return nil, err
  324. }
  325. return tx.Query(query, args...)
  326. }
  327. func (tx *Tx) QueryRow(query string, args ...interface{}) *Row {
  328. rows, err := tx.Query(query, args...)
  329. return &Row{rows, err}
  330. }
  331. func (tx *Tx) QueryRowMap(query string, mp interface{}) *Row {
  332. query, args, err := MapToSlice(query, mp)
  333. if err != nil {
  334. return &Row{nil, err}
  335. }
  336. return tx.QueryRow(query, args...)
  337. }
  338. func (tx *Tx) QueryRowStruct(query string, st interface{}) *Row {
  339. query, args, err := StructToSlice(query, st)
  340. if err != nil {
  341. return &Row{nil, err}
  342. }
  343. return tx.QueryRow(query, args...)
  344. }