You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

session_insert.go 17KB


  1. // Copyright 2016 The Xorm Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package xorm
  5. import (
  6. "errors"
  7. "fmt"
  8. "reflect"
  9. "strconv"
  10. "strings"
  11. "github.com/go-xorm/core"
  12. )
  13. // Insert insert one or more beans
  14. func (session *Session) Insert(beans ...interface{}) (int64, error) {
  15. var affected int64
  16. var err error
  17. if session.isAutoClose {
  18. defer session.Close()
  19. }
  20. for _, bean := range beans {
  21. sliceValue := reflect.Indirect(reflect.ValueOf(bean))
  22. if sliceValue.Kind() == reflect.Slice {
  23. size := sliceValue.Len()
  24. if size > 0 {
  25. if session.engine.SupportInsertMany() {
  26. cnt, err := session.innerInsertMulti(bean)
  27. if err != nil {
  28. return affected, err
  29. }
  30. affected += cnt
  31. } else {
  32. for i := 0; i < size; i++ {
  33. cnt, err := session.innerInsert(sliceValue.Index(i).Interface())
  34. if err != nil {
  35. return affected, err
  36. }
  37. affected += cnt
  38. }
  39. }
  40. }
  41. } else {
  42. cnt, err := session.innerInsert(bean)
  43. if err != nil {
  44. return affected, err
  45. }
  46. affected += cnt
  47. }
  48. }
  49. return affected, err
  50. }
  51. func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error) {
  52. sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr))
  53. if sliceValue.Kind() != reflect.Slice {
  54. return 0, errors.New("needs a pointer to a slice")
  55. }
  56. if sliceValue.Len() <= 0 {
  57. return 0, errors.New("could not insert a empty slice")
  58. }
  59. if err := session.statement.setRefBean(sliceValue.Index(0).Interface()); err != nil {
  60. return 0, err
  61. }
  62. tableName := session.statement.TableName()
  63. if len(tableName) <= 0 {
  64. return 0, ErrTableNotFound
  65. }
  66. table := session.statement.RefTable
  67. size := sliceValue.Len()
  68. var colNames []string
  69. var colMultiPlaces []string
  70. var args []interface{}
  71. var cols []*core.Column
  72. for i := 0; i < size; i++ {
  73. v := sliceValue.Index(i)
  74. vv := reflect.Indirect(v)
  75. elemValue := v.Interface()
  76. var colPlaces []string
  77. // handle BeforeInsertProcessor
  78. // !nashtsai! does user expect it's same slice to passed closure when using Before()/After() when insert multi??
  79. for _, closure := range session.beforeClosures {
  80. closure(elemValue)
  81. }
  82. if processor, ok := interface{}(elemValue).(BeforeInsertProcessor); ok {
  83. processor.BeforeInsert()
  84. }
  85. // --
  86. if i == 0 {
  87. for _, col := range table.Columns() {
  88. ptrFieldValue, err := col.ValueOfV(&vv)
  89. if err != nil {
  90. return 0, err
  91. }
  92. fieldValue := *ptrFieldValue
  93. if col.IsAutoIncrement && isZero(fieldValue.Interface()) {
  94. continue
  95. }
  96. if col.MapType == core.ONLYFROMDB {
  97. continue
  98. }
  99. if col.IsDeleted {
  100. continue
  101. }
  102. if session.statement.omitColumnMap.contain(col.Name) {
  103. continue
  104. }
  105. if len(session.statement.columnMap) > 0 && !session.statement.columnMap.contain(col.Name) {
  106. continue
  107. }
  108. if (col.IsCreated || col.IsUpdated) && session.statement.UseAutoTime {
  109. val, t := session.engine.nowTime(col)
  110. args = append(args, val)
  111. var colName = col.Name
  112. session.afterClosures = append(session.afterClosures, func(bean interface{}) {
  113. col := table.GetColumn(colName)
  114. setColumnTime(bean, col, t)
  115. })
  116. } else if col.IsVersion && session.statement.checkVersion {
  117. args = append(args, 1)
  118. var colName = col.Name
  119. session.afterClosures = append(session.afterClosures, func(bean interface{}) {
  120. col := table.GetColumn(colName)
  121. setColumnInt(bean, col, 1)
  122. })
  123. } else {
  124. arg, err := session.value2Interface(col, fieldValue)
  125. if err != nil {
  126. return 0, err
  127. }
  128. args = append(args, arg)
  129. }
  130. colNames = append(colNames, col.Name)
  131. cols = append(cols, col)
  132. colPlaces = append(colPlaces, "?")
  133. }
  134. } else {
  135. for _, col := range cols {
  136. ptrFieldValue, err := col.ValueOfV(&vv)
  137. if err != nil {
  138. return 0, err
  139. }
  140. fieldValue := *ptrFieldValue
  141. if col.IsAutoIncrement && isZero(fieldValue.Interface()) {
  142. continue
  143. }
  144. if col.MapType == core.ONLYFROMDB {
  145. continue
  146. }
  147. if col.IsDeleted {
  148. continue
  149. }
  150. if session.statement.omitColumnMap.contain(col.Name) {
  151. continue
  152. }
  153. if len(session.statement.columnMap) > 0 && !session.statement.columnMap.contain(col.Name) {
  154. continue
  155. }
  156. if (col.IsCreated || col.IsUpdated) && session.statement.UseAutoTime {
  157. val, t := session.engine.nowTime(col)
  158. args = append(args, val)
  159. var colName = col.Name
  160. session.afterClosures = append(session.afterClosures, func(bean interface{}) {
  161. col := table.GetColumn(colName)
  162. setColumnTime(bean, col, t)
  163. })
  164. } else if col.IsVersion && session.statement.checkVersion {
  165. args = append(args, 1)
  166. var colName = col.Name
  167. session.afterClosures = append(session.afterClosures, func(bean interface{}) {
  168. col := table.GetColumn(colName)
  169. setColumnInt(bean, col, 1)
  170. })
  171. } else {
  172. arg, err := session.value2Interface(col, fieldValue)
  173. if err != nil {
  174. return 0, err
  175. }
  176. args = append(args, arg)
  177. }
  178. colPlaces = append(colPlaces, "?")
  179. }
  180. }
  181. colMultiPlaces = append(colMultiPlaces, strings.Join(colPlaces, ", "))
  182. }
  183. cleanupProcessorsClosures(&session.beforeClosures)
  184. var sql string
  185. if session.engine.dialect.DBType() == core.ORACLE {
  186. temp := fmt.Sprintf(") INTO %s (%v%v%v) VALUES (",
  187. session.engine.Quote(tableName),
  188. session.engine.QuoteStr(),
  189. strings.Join(colNames, session.engine.QuoteStr()+", "+session.engine.QuoteStr()),
  190. session.engine.QuoteStr())
  191. sql = fmt.Sprintf("INSERT ALL INTO %s (%v%v%v) VALUES (%v) SELECT 1 FROM DUAL",
  192. session.engine.Quote(tableName),
  193. session.engine.QuoteStr(),
  194. strings.Join(colNames, session.engine.QuoteStr()+", "+session.engine.QuoteStr()),
  195. session.engine.QuoteStr(),
  196. strings.Join(colMultiPlaces, temp))
  197. } else {
  198. sql = fmt.Sprintf("INSERT INTO %s (%v%v%v) VALUES (%v)",
  199. session.engine.Quote(tableName),
  200. session.engine.QuoteStr(),
  201. strings.Join(colNames, session.engine.QuoteStr()+", "+session.engine.QuoteStr()),
  202. session.engine.QuoteStr(),
  203. strings.Join(colMultiPlaces, "),("))
  204. }
  205. res, err := session.exec(sql, args...)
  206. if err != nil {
  207. return 0, err
  208. }
  209. session.cacheInsert(tableName)
  210. lenAfterClosures := len(session.afterClosures)
  211. for i := 0; i < size; i++ {
  212. elemValue := reflect.Indirect(sliceValue.Index(i)).Addr().Interface()
  213. // handle AfterInsertProcessor
  214. if session.isAutoCommit {
  215. // !nashtsai! does user expect it's same slice to passed closure when using Before()/After() when insert multi??
  216. for _, closure := range session.afterClosures {
  217. closure(elemValue)
  218. }
  219. if processor, ok := interface{}(elemValue).(AfterInsertProcessor); ok {
  220. processor.AfterInsert()
  221. }
  222. } else {
  223. if lenAfterClosures > 0 {
  224. if value, has := session.afterInsertBeans[elemValue]; has && value != nil {
  225. *value = append(*value, session.afterClosures...)
  226. } else {
  227. afterClosures := make([]func(interface{}), lenAfterClosures)
  228. copy(afterClosures, session.afterClosures)
  229. session.afterInsertBeans[elemValue] = &afterClosures
  230. }
  231. } else {
  232. if _, ok := interface{}(elemValue).(AfterInsertProcessor); ok {
  233. session.afterInsertBeans[elemValue] = nil
  234. }
  235. }
  236. }
  237. }
  238. cleanupProcessorsClosures(&session.afterClosures)
  239. return res.RowsAffected()
  240. }
  241. // InsertMulti insert multiple records
  242. func (session *Session) InsertMulti(rowsSlicePtr interface{}) (int64, error) {
  243. if session.isAutoClose {
  244. defer session.Close()
  245. }
  246. sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr))
  247. if sliceValue.Kind() != reflect.Slice {
  248. return 0, ErrParamsType
  249. }
  250. if sliceValue.Len() <= 0 {
  251. return 0, nil
  252. }
  253. return session.innerInsertMulti(rowsSlicePtr)
  254. }
  255. func (session *Session) innerInsert(bean interface{}) (int64, error) {
  256. if err := session.statement.setRefBean(bean); err != nil {
  257. return 0, err
  258. }
  259. if len(session.statement.TableName()) <= 0 {
  260. return 0, ErrTableNotFound
  261. }
  262. table := session.statement.RefTable
  263. // handle BeforeInsertProcessor
  264. for _, closure := range session.beforeClosures {
  265. closure(bean)
  266. }
  267. cleanupProcessorsClosures(&session.beforeClosures) // cleanup after used
  268. if processor, ok := interface{}(bean).(BeforeInsertProcessor); ok {
  269. processor.BeforeInsert()
  270. }
  271. colNames, args, err := session.genInsertColumns(bean)
  272. if err != nil {
  273. return 0, err
  274. }
  275. // insert expr columns, override if exists
  276. exprColumns := session.statement.getExpr()
  277. exprColVals := make([]string, 0, len(exprColumns))
  278. for _, v := range exprColumns {
  279. // remove the expr columns
  280. for i, colName := range colNames {
  281. if colName == v.colName {
  282. colNames = append(colNames[:i], colNames[i+1:]...)
  283. args = append(args[:i], args[i+1:]...)
  284. }
  285. }
  286. // append expr column to the end
  287. colNames = append(colNames, v.colName)
  288. exprColVals = append(exprColVals, v.expr)
  289. }
  290. colPlaces := strings.Repeat("?, ", len(colNames)-len(exprColumns))
  291. if len(exprColVals) > 0 {
  292. colPlaces = colPlaces + strings.Join(exprColVals, ", ")
  293. } else {
  294. if len(colPlaces) > 0 {
  295. colPlaces = colPlaces[0 : len(colPlaces)-2]
  296. }
  297. }
  298. var sqlStr string
  299. var tableName = session.statement.TableName()
  300. if len(colPlaces) > 0 {
  301. sqlStr = fmt.Sprintf("INSERT INTO %s (%v%v%v) VALUES (%v)",
  302. session.engine.Quote(tableName),
  303. session.engine.QuoteStr(),
  304. strings.Join(colNames, session.engine.Quote(", ")),
  305. session.engine.QuoteStr(),
  306. colPlaces)
  307. } else {
  308. if session.engine.dialect.DBType() == core.MYSQL {
  309. sqlStr = fmt.Sprintf("INSERT INTO %s VALUES ()", session.engine.Quote(tableName))
  310. } else {
  311. sqlStr = fmt.Sprintf("INSERT INTO %s DEFAULT VALUES", session.engine.Quote(tableName))
  312. }
  313. }
  314. handleAfterInsertProcessorFunc := func(bean interface{}) {
  315. if session.isAutoCommit {
  316. for _, closure := range session.afterClosures {
  317. closure(bean)
  318. }
  319. if processor, ok := interface{}(bean).(AfterInsertProcessor); ok {
  320. processor.AfterInsert()
  321. }
  322. } else {
  323. lenAfterClosures := len(session.afterClosures)
  324. if lenAfterClosures > 0 {
  325. if value, has := session.afterInsertBeans[bean]; has && value != nil {
  326. *value = append(*value, session.afterClosures...)
  327. } else {
  328. afterClosures := make([]func(interface{}), lenAfterClosures)
  329. copy(afterClosures, session.afterClosures)
  330. session.afterInsertBeans[bean] = &afterClosures
  331. }
  332. } else {
  333. if _, ok := interface{}(bean).(AfterInsertProcessor); ok {
  334. session.afterInsertBeans[bean] = nil
  335. }
  336. }
  337. }
  338. cleanupProcessorsClosures(&session.afterClosures) // cleanup after used
  339. }
  340. // for postgres, many of them didn't implement lastInsertId, so we should
  341. // implemented it ourself.
  342. if session.engine.dialect.DBType() == core.ORACLE && len(table.AutoIncrement) > 0 {
  343. res, err := session.queryBytes("select seq_atable.currval from dual", args...)
  344. if err != nil {
  345. return 0, err
  346. }
  347. defer handleAfterInsertProcessorFunc(bean)
  348. session.cacheInsert(tableName)
  349. if table.Version != "" && session.statement.checkVersion {
  350. verValue, err := table.VersionColumn().ValueOf(bean)
  351. if err != nil {
  352. session.engine.logger.Error(err)
  353. } else if verValue.IsValid() && verValue.CanSet() {
  354. session.incrVersionFieldValue(verValue)
  355. }
  356. }
  357. if len(res) < 1 {
  358. return 0, errors.New("insert no error but not returned id")
  359. }
  360. idByte := res[0][table.AutoIncrement]
  361. id, err := strconv.ParseInt(string(idByte), 10, 64)
  362. if err != nil || id <= 0 {
  363. return 1, err
  364. }
  365. aiValue, err := table.AutoIncrColumn().ValueOf(bean)
  366. if err != nil {
  367. session.engine.logger.Error(err)
  368. }
  369. if aiValue == nil || !aiValue.IsValid() || !aiValue.CanSet() {
  370. return 1, nil
  371. }
  372. aiValue.Set(int64ToIntValue(id, aiValue.Type()))
  373. return 1, nil
  374. } else if session.engine.dialect.DBType() == core.POSTGRES && len(table.AutoIncrement) > 0 {
  375. //assert table.AutoIncrement != ""
  376. sqlStr = sqlStr + " RETURNING " + session.engine.Quote(table.AutoIncrement)
  377. res, err := session.queryBytes(sqlStr, args...)
  378. if err != nil {
  379. return 0, err
  380. }
  381. defer handleAfterInsertProcessorFunc(bean)
  382. session.cacheInsert(tableName)
  383. if table.Version != "" && session.statement.checkVersion {
  384. verValue, err := table.VersionColumn().ValueOf(bean)
  385. if err != nil {
  386. session.engine.logger.Error(err)
  387. } else if verValue.IsValid() && verValue.CanSet() {
  388. session.incrVersionFieldValue(verValue)
  389. }
  390. }
  391. if len(res) < 1 {
  392. return 0, errors.New("insert no error but not returned id")
  393. }
  394. idByte := res[0][table.AutoIncrement]
  395. id, err := strconv.ParseInt(string(idByte), 10, 64)
  396. if err != nil || id <= 0 {
  397. return 1, err
  398. }
  399. aiValue, err := table.AutoIncrColumn().ValueOf(bean)
  400. if err != nil {
  401. session.engine.logger.Error(err)
  402. }
  403. if aiValue == nil || !aiValue.IsValid() || !aiValue.CanSet() {
  404. return 1, nil
  405. }
  406. aiValue.Set(int64ToIntValue(id, aiValue.Type()))
  407. return 1, nil
  408. } else {
  409. res, err := session.exec(sqlStr, args...)
  410. if err != nil {
  411. return 0, err
  412. }
  413. defer handleAfterInsertProcessorFunc(bean)
  414. session.cacheInsert(tableName)
  415. if table.Version != "" && session.statement.checkVersion {
  416. verValue, err := table.VersionColumn().ValueOf(bean)
  417. if err != nil {
  418. session.engine.logger.Error(err)
  419. } else if verValue.IsValid() && verValue.CanSet() {
  420. session.incrVersionFieldValue(verValue)
  421. }
  422. }
  423. if table.AutoIncrement == "" {
  424. return res.RowsAffected()
  425. }
  426. var id int64
  427. id, err = res.LastInsertId()
  428. if err != nil || id <= 0 {
  429. return res.RowsAffected()
  430. }
  431. aiValue, err := table.AutoIncrColumn().ValueOf(bean)
  432. if err != nil {
  433. session.engine.logger.Error(err)
  434. }
  435. if aiValue == nil || !aiValue.IsValid() || !aiValue.CanSet() {
  436. return res.RowsAffected()
  437. }
  438. aiValue.Set(int64ToIntValue(id, aiValue.Type()))
  439. return res.RowsAffected()
  440. }
  441. }
  442. // InsertOne insert only one struct into database as a record.
  443. // The in parameter bean must a struct or a point to struct. The return
  444. // parameter is inserted and error
  445. func (session *Session) InsertOne(bean interface{}) (int64, error) {
  446. if session.isAutoClose {
  447. defer session.Close()
  448. }
  449. return session.innerInsert(bean)
  450. }
  451. func (session *Session) cacheInsert(table string) error {
  452. if !session.statement.UseCache {
  453. return nil
  454. }
  455. cacher := session.engine.getCacher(table)
  456. if cacher == nil {
  457. return nil
  458. }
  459. session.engine.logger.Debug("[cache] clear sql:", table)
  460. cacher.ClearIds(table)
  461. return nil
  462. }
  463. // genInsertColumns generates insert needed columns
  464. func (session *Session) genInsertColumns(bean interface{}) ([]string, []interface{}, error) {
  465. table := session.statement.RefTable
  466. colNames := make([]string, 0, len(table.ColumnsSeq()))
  467. args := make([]interface{}, 0, len(table.ColumnsSeq()))
  468. for _, col := range table.Columns() {
  469. if col.MapType == core.ONLYFROMDB {
  470. continue
  471. }
  472. if col.IsDeleted {
  473. continue
  474. }
  475. if session.statement.omitColumnMap.contain(col.Name) {
  476. continue
  477. }
  478. if len(session.statement.columnMap) > 0 && !session.statement.columnMap.contain(col.Name) {
  479. continue
  480. }
  481. if _, ok := session.statement.incrColumns[col.Name]; ok {
  482. continue
  483. } else if _, ok := session.statement.decrColumns[col.Name]; ok {
  484. continue
  485. }
  486. fieldValuePtr, err := col.ValueOf(bean)
  487. if err != nil {
  488. return nil, nil, err
  489. }
  490. fieldValue := *fieldValuePtr
  491. if col.IsAutoIncrement {
  492. switch fieldValue.Type().Kind() {
  493. case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int, reflect.Int64:
  494. if fieldValue.Int() == 0 {
  495. continue
  496. }
  497. case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint, reflect.Uint64:
  498. if fieldValue.Uint() == 0 {
  499. continue
  500. }
  501. case reflect.String:
  502. if len(fieldValue.String()) == 0 {
  503. continue
  504. }
  505. case reflect.Ptr:
  506. if fieldValue.Pointer() == 0 {
  507. continue
  508. }
  509. }
  510. }
  511. // !evalphobia! set fieldValue as nil when column is nullable and zero-value
  512. if _, ok := getFlagForColumn(session.statement.nullableMap, col); ok {
  513. if col.Nullable && isZero(fieldValue.Interface()) {
  514. var nilValue *int
  515. fieldValue = reflect.ValueOf(nilValue)
  516. }
  517. }
  518. if (col.IsCreated || col.IsUpdated) && session.statement.UseAutoTime /*&& isZero(fieldValue.Interface())*/ {
  519. // if time is non-empty, then set to auto time
  520. val, t := session.engine.nowTime(col)
  521. args = append(args, val)
  522. var colName = col.Name
  523. session.afterClosures = append(session.afterClosures, func(bean interface{}) {
  524. col := table.GetColumn(colName)
  525. setColumnTime(bean, col, t)
  526. })
  527. } else if col.IsVersion && session.statement.checkVersion {
  528. args = append(args, 1)
  529. } else {
  530. arg, err := session.value2Interface(col, fieldValue)
  531. if err != nil {
  532. return colNames, args, err
  533. }
  534. args = append(args, arg)
  535. }
  536. colNames = append(colNames, col.Name)
  537. }
  538. return colNames, args, nil
  539. }