選択できるのは25トピックまでです。 トピックは、先頭が英数字で、英数字とダッシュ('-')を使用した35文字以内のものにしてください。

session_insert.go 21KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811
  1. // Copyright 2016 The Xorm Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package xorm
  5. import (
  6. "errors"
  7. "fmt"
  8. "reflect"
  9. "sort"
  10. "strconv"
  11. "strings"
  12. "xorm.io/builder"
  13. "xorm.io/core"
  14. )
  15. // Insert insert one or more beans
  16. func (session *Session) Insert(beans ...interface{}) (int64, error) {
  17. var affected int64
  18. var err error
  19. if session.isAutoClose {
  20. defer session.Close()
  21. }
  22. for _, bean := range beans {
  23. switch bean.(type) {
  24. case map[string]interface{}:
  25. cnt, err := session.insertMapInterface(bean.(map[string]interface{}))
  26. if err != nil {
  27. return affected, err
  28. }
  29. affected += cnt
  30. case []map[string]interface{}:
  31. s := bean.([]map[string]interface{})
  32. session.autoResetStatement = false
  33. for i := 0; i < len(s); i++ {
  34. cnt, err := session.insertMapInterface(s[i])
  35. if err != nil {
  36. return affected, err
  37. }
  38. affected += cnt
  39. }
  40. case map[string]string:
  41. cnt, err := session.insertMapString(bean.(map[string]string))
  42. if err != nil {
  43. return affected, err
  44. }
  45. affected += cnt
  46. case []map[string]string:
  47. s := bean.([]map[string]string)
  48. session.autoResetStatement = false
  49. for i := 0; i < len(s); i++ {
  50. cnt, err := session.insertMapString(s[i])
  51. if err != nil {
  52. return affected, err
  53. }
  54. affected += cnt
  55. }
  56. default:
  57. sliceValue := reflect.Indirect(reflect.ValueOf(bean))
  58. if sliceValue.Kind() == reflect.Slice {
  59. size := sliceValue.Len()
  60. if size > 0 {
  61. if session.engine.SupportInsertMany() {
  62. cnt, err := session.innerInsertMulti(bean)
  63. if err != nil {
  64. return affected, err
  65. }
  66. affected += cnt
  67. } else {
  68. for i := 0; i < size; i++ {
  69. cnt, err := session.innerInsert(sliceValue.Index(i).Interface())
  70. if err != nil {
  71. return affected, err
  72. }
  73. affected += cnt
  74. }
  75. }
  76. }
  77. } else {
  78. cnt, err := session.innerInsert(bean)
  79. if err != nil {
  80. return affected, err
  81. }
  82. affected += cnt
  83. }
  84. }
  85. }
  86. return affected, err
  87. }
  88. func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error) {
  89. sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr))
  90. if sliceValue.Kind() != reflect.Slice {
  91. return 0, errors.New("needs a pointer to a slice")
  92. }
  93. if sliceValue.Len() <= 0 {
  94. return 0, errors.New("could not insert a empty slice")
  95. }
  96. if err := session.statement.setRefBean(sliceValue.Index(0).Interface()); err != nil {
  97. return 0, err
  98. }
  99. tableName := session.statement.TableName()
  100. if len(tableName) <= 0 {
  101. return 0, ErrTableNotFound
  102. }
  103. table := session.statement.RefTable
  104. size := sliceValue.Len()
  105. var colNames []string
  106. var colMultiPlaces []string
  107. var args []interface{}
  108. var cols []*core.Column
  109. for i := 0; i < size; i++ {
  110. v := sliceValue.Index(i)
  111. vv := reflect.Indirect(v)
  112. elemValue := v.Interface()
  113. var colPlaces []string
  114. // handle BeforeInsertProcessor
  115. // !nashtsai! does user expect it's same slice to passed closure when using Before()/After() when insert multi??
  116. for _, closure := range session.beforeClosures {
  117. closure(elemValue)
  118. }
  119. if processor, ok := interface{}(elemValue).(BeforeInsertProcessor); ok {
  120. processor.BeforeInsert()
  121. }
  122. // --
  123. if i == 0 {
  124. for _, col := range table.Columns() {
  125. ptrFieldValue, err := col.ValueOfV(&vv)
  126. if err != nil {
  127. return 0, err
  128. }
  129. fieldValue := *ptrFieldValue
  130. if col.IsAutoIncrement && isZero(fieldValue.Interface()) {
  131. continue
  132. }
  133. if col.MapType == core.ONLYFROMDB {
  134. continue
  135. }
  136. if col.IsDeleted {
  137. continue
  138. }
  139. if session.statement.omitColumnMap.contain(col.Name) {
  140. continue
  141. }
  142. if len(session.statement.columnMap) > 0 && !session.statement.columnMap.contain(col.Name) {
  143. continue
  144. }
  145. if (col.IsCreated || col.IsUpdated) && session.statement.UseAutoTime {
  146. val, t := session.engine.nowTime(col)
  147. args = append(args, val)
  148. var colName = col.Name
  149. session.afterClosures = append(session.afterClosures, func(bean interface{}) {
  150. col := table.GetColumn(colName)
  151. setColumnTime(bean, col, t)
  152. })
  153. } else if col.IsVersion && session.statement.checkVersion {
  154. args = append(args, 1)
  155. var colName = col.Name
  156. session.afterClosures = append(session.afterClosures, func(bean interface{}) {
  157. col := table.GetColumn(colName)
  158. setColumnInt(bean, col, 1)
  159. })
  160. } else {
  161. arg, err := session.value2Interface(col, fieldValue)
  162. if err != nil {
  163. return 0, err
  164. }
  165. args = append(args, arg)
  166. }
  167. colNames = append(colNames, col.Name)
  168. cols = append(cols, col)
  169. colPlaces = append(colPlaces, "?")
  170. }
  171. } else {
  172. for _, col := range cols {
  173. ptrFieldValue, err := col.ValueOfV(&vv)
  174. if err != nil {
  175. return 0, err
  176. }
  177. fieldValue := *ptrFieldValue
  178. if col.IsAutoIncrement && isZero(fieldValue.Interface()) {
  179. continue
  180. }
  181. if col.MapType == core.ONLYFROMDB {
  182. continue
  183. }
  184. if col.IsDeleted {
  185. continue
  186. }
  187. if session.statement.omitColumnMap.contain(col.Name) {
  188. continue
  189. }
  190. if len(session.statement.columnMap) > 0 && !session.statement.columnMap.contain(col.Name) {
  191. continue
  192. }
  193. if (col.IsCreated || col.IsUpdated) && session.statement.UseAutoTime {
  194. val, t := session.engine.nowTime(col)
  195. args = append(args, val)
  196. var colName = col.Name
  197. session.afterClosures = append(session.afterClosures, func(bean interface{}) {
  198. col := table.GetColumn(colName)
  199. setColumnTime(bean, col, t)
  200. })
  201. } else if col.IsVersion && session.statement.checkVersion {
  202. args = append(args, 1)
  203. var colName = col.Name
  204. session.afterClosures = append(session.afterClosures, func(bean interface{}) {
  205. col := table.GetColumn(colName)
  206. setColumnInt(bean, col, 1)
  207. })
  208. } else {
  209. arg, err := session.value2Interface(col, fieldValue)
  210. if err != nil {
  211. return 0, err
  212. }
  213. args = append(args, arg)
  214. }
  215. colPlaces = append(colPlaces, "?")
  216. }
  217. }
  218. colMultiPlaces = append(colMultiPlaces, strings.Join(colPlaces, ", "))
  219. }
  220. cleanupProcessorsClosures(&session.beforeClosures)
  221. var sql string
  222. if session.engine.dialect.DBType() == core.ORACLE {
  223. temp := fmt.Sprintf(") INTO %s (%v) VALUES (",
  224. session.engine.Quote(tableName),
  225. quoteColumns(colNames, session.engine.Quote, ","))
  226. sql = fmt.Sprintf("INSERT ALL INTO %s (%v) VALUES (%v) SELECT 1 FROM DUAL",
  227. session.engine.Quote(tableName),
  228. quoteColumns(colNames, session.engine.Quote, ","),
  229. strings.Join(colMultiPlaces, temp))
  230. } else {
  231. sql = fmt.Sprintf("INSERT INTO %s (%v) VALUES (%v)",
  232. session.engine.Quote(tableName),
  233. quoteColumns(colNames, session.engine.Quote, ","),
  234. strings.Join(colMultiPlaces, "),("))
  235. }
  236. res, err := session.exec(sql, args...)
  237. if err != nil {
  238. return 0, err
  239. }
  240. session.cacheInsert(tableName)
  241. lenAfterClosures := len(session.afterClosures)
  242. for i := 0; i < size; i++ {
  243. elemValue := reflect.Indirect(sliceValue.Index(i)).Addr().Interface()
  244. // handle AfterInsertProcessor
  245. if session.isAutoCommit {
  246. // !nashtsai! does user expect it's same slice to passed closure when using Before()/After() when insert multi??
  247. for _, closure := range session.afterClosures {
  248. closure(elemValue)
  249. }
  250. if processor, ok := interface{}(elemValue).(AfterInsertProcessor); ok {
  251. processor.AfterInsert()
  252. }
  253. } else {
  254. if lenAfterClosures > 0 {
  255. if value, has := session.afterInsertBeans[elemValue]; has && value != nil {
  256. *value = append(*value, session.afterClosures...)
  257. } else {
  258. afterClosures := make([]func(interface{}), lenAfterClosures)
  259. copy(afterClosures, session.afterClosures)
  260. session.afterInsertBeans[elemValue] = &afterClosures
  261. }
  262. } else {
  263. if _, ok := interface{}(elemValue).(AfterInsertProcessor); ok {
  264. session.afterInsertBeans[elemValue] = nil
  265. }
  266. }
  267. }
  268. }
  269. cleanupProcessorsClosures(&session.afterClosures)
  270. return res.RowsAffected()
  271. }
  272. // InsertMulti insert multiple records
  273. func (session *Session) InsertMulti(rowsSlicePtr interface{}) (int64, error) {
  274. if session.isAutoClose {
  275. defer session.Close()
  276. }
  277. sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr))
  278. if sliceValue.Kind() != reflect.Slice {
  279. return 0, ErrParamsType
  280. }
  281. if sliceValue.Len() <= 0 {
  282. return 0, nil
  283. }
  284. return session.innerInsertMulti(rowsSlicePtr)
  285. }
  286. func (session *Session) innerInsert(bean interface{}) (int64, error) {
  287. if err := session.statement.setRefBean(bean); err != nil {
  288. return 0, err
  289. }
  290. if len(session.statement.TableName()) <= 0 {
  291. return 0, ErrTableNotFound
  292. }
  293. table := session.statement.RefTable
  294. // handle BeforeInsertProcessor
  295. for _, closure := range session.beforeClosures {
  296. closure(bean)
  297. }
  298. cleanupProcessorsClosures(&session.beforeClosures) // cleanup after used
  299. if processor, ok := interface{}(bean).(BeforeInsertProcessor); ok {
  300. processor.BeforeInsert()
  301. }
  302. colNames, args, err := session.genInsertColumns(bean)
  303. if err != nil {
  304. return 0, err
  305. }
  306. // insert expr columns, override if exists
  307. exprColumns := session.statement.getExpr()
  308. exprColVals := make([]string, 0, len(exprColumns))
  309. for _, v := range exprColumns {
  310. // remove the expr columns
  311. for i, colName := range colNames {
  312. if colName == strings.Trim(v.colName, "`") {
  313. colNames = append(colNames[:i], colNames[i+1:]...)
  314. args = append(args[:i], args[i+1:]...)
  315. }
  316. }
  317. // append expr column to the end
  318. colNames = append(colNames, v.colName)
  319. exprColVals = append(exprColVals, v.expr)
  320. }
  321. colPlaces := strings.Repeat("?, ", len(colNames)-len(exprColumns))
  322. if len(exprColVals) > 0 {
  323. colPlaces = colPlaces + strings.Join(exprColVals, ", ")
  324. } else {
  325. if len(colPlaces) > 0 {
  326. colPlaces = colPlaces[0 : len(colPlaces)-2]
  327. }
  328. }
  329. var sqlStr string
  330. var tableName = session.statement.TableName()
  331. var output string
  332. if session.engine.dialect.DBType() == core.MSSQL && len(table.AutoIncrement) > 0 {
  333. output = fmt.Sprintf(" OUTPUT Inserted.%s", table.AutoIncrement)
  334. }
  335. if len(colPlaces) > 0 {
  336. if session.statement.cond.IsValid() {
  337. condSQL, condArgs, err := builder.ToSQL(session.statement.cond)
  338. if err != nil {
  339. return 0, err
  340. }
  341. sqlStr = fmt.Sprintf("INSERT INTO %s (%v)%s SELECT %v FROM %v WHERE %v",
  342. session.engine.Quote(tableName),
  343. quoteColumns(colNames, session.engine.Quote, ","),
  344. output,
  345. colPlaces,
  346. session.engine.Quote(tableName),
  347. condSQL,
  348. )
  349. args = append(args, condArgs...)
  350. } else {
  351. sqlStr = fmt.Sprintf("INSERT INTO %s (%v)%s VALUES (%v)",
  352. session.engine.Quote(tableName),
  353. quoteColumns(colNames, session.engine.Quote, ","),
  354. output,
  355. colPlaces)
  356. }
  357. } else {
  358. if session.engine.dialect.DBType() == core.MYSQL {
  359. sqlStr = fmt.Sprintf("INSERT INTO %s VALUES ()", session.engine.Quote(tableName))
  360. } else {
  361. sqlStr = fmt.Sprintf("INSERT INTO %s%s DEFAULT VALUES", session.engine.Quote(tableName), output)
  362. }
  363. }
  364. if len(table.AutoIncrement) > 0 && session.engine.dialect.DBType() == core.POSTGRES {
  365. sqlStr = sqlStr + " RETURNING " + session.engine.Quote(table.AutoIncrement)
  366. }
  367. handleAfterInsertProcessorFunc := func(bean interface{}) {
  368. if session.isAutoCommit {
  369. for _, closure := range session.afterClosures {
  370. closure(bean)
  371. }
  372. if processor, ok := interface{}(bean).(AfterInsertProcessor); ok {
  373. processor.AfterInsert()
  374. }
  375. } else {
  376. lenAfterClosures := len(session.afterClosures)
  377. if lenAfterClosures > 0 {
  378. if value, has := session.afterInsertBeans[bean]; has && value != nil {
  379. *value = append(*value, session.afterClosures...)
  380. } else {
  381. afterClosures := make([]func(interface{}), lenAfterClosures)
  382. copy(afterClosures, session.afterClosures)
  383. session.afterInsertBeans[bean] = &afterClosures
  384. }
  385. } else {
  386. if _, ok := interface{}(bean).(AfterInsertProcessor); ok {
  387. session.afterInsertBeans[bean] = nil
  388. }
  389. }
  390. }
  391. cleanupProcessorsClosures(&session.afterClosures) // cleanup after used
  392. }
  393. // for postgres, many of them didn't implement lastInsertId, so we should
  394. // implemented it ourself.
  395. if session.engine.dialect.DBType() == core.ORACLE && len(table.AutoIncrement) > 0 {
  396. res, err := session.queryBytes("select seq_atable.currval from dual", args...)
  397. if err != nil {
  398. return 0, err
  399. }
  400. defer handleAfterInsertProcessorFunc(bean)
  401. session.cacheInsert(tableName)
  402. if table.Version != "" && session.statement.checkVersion {
  403. verValue, err := table.VersionColumn().ValueOf(bean)
  404. if err != nil {
  405. session.engine.logger.Error(err)
  406. } else if verValue.IsValid() && verValue.CanSet() {
  407. session.incrVersionFieldValue(verValue)
  408. }
  409. }
  410. if len(res) < 1 {
  411. return 0, errors.New("insert no error but not returned id")
  412. }
  413. idByte := res[0][table.AutoIncrement]
  414. id, err := strconv.ParseInt(string(idByte), 10, 64)
  415. if err != nil || id <= 0 {
  416. return 1, err
  417. }
  418. aiValue, err := table.AutoIncrColumn().ValueOf(bean)
  419. if err != nil {
  420. session.engine.logger.Error(err)
  421. }
  422. if aiValue == nil || !aiValue.IsValid() || !aiValue.CanSet() {
  423. return 1, nil
  424. }
  425. aiValue.Set(int64ToIntValue(id, aiValue.Type()))
  426. return 1, nil
  427. } else if len(table.AutoIncrement) > 0 && (session.engine.dialect.DBType() == core.POSTGRES || session.engine.dialect.DBType() == core.MSSQL) {
  428. res, err := session.queryBytes(sqlStr, args...)
  429. if err != nil {
  430. return 0, err
  431. }
  432. defer handleAfterInsertProcessorFunc(bean)
  433. session.cacheInsert(tableName)
  434. if table.Version != "" && session.statement.checkVersion {
  435. verValue, err := table.VersionColumn().ValueOf(bean)
  436. if err != nil {
  437. session.engine.logger.Error(err)
  438. } else if verValue.IsValid() && verValue.CanSet() {
  439. session.incrVersionFieldValue(verValue)
  440. }
  441. }
  442. if len(res) < 1 {
  443. return 0, errors.New("insert successfully but not returned id")
  444. }
  445. idByte := res[0][table.AutoIncrement]
  446. id, err := strconv.ParseInt(string(idByte), 10, 64)
  447. if err != nil || id <= 0 {
  448. return 1, err
  449. }
  450. aiValue, err := table.AutoIncrColumn().ValueOf(bean)
  451. if err != nil {
  452. session.engine.logger.Error(err)
  453. }
  454. if aiValue == nil || !aiValue.IsValid() || !aiValue.CanSet() {
  455. return 1, nil
  456. }
  457. aiValue.Set(int64ToIntValue(id, aiValue.Type()))
  458. return 1, nil
  459. } else {
  460. res, err := session.exec(sqlStr, args...)
  461. if err != nil {
  462. return 0, err
  463. }
  464. defer handleAfterInsertProcessorFunc(bean)
  465. session.cacheInsert(tableName)
  466. if table.Version != "" && session.statement.checkVersion {
  467. verValue, err := table.VersionColumn().ValueOf(bean)
  468. if err != nil {
  469. session.engine.logger.Error(err)
  470. } else if verValue.IsValid() && verValue.CanSet() {
  471. session.incrVersionFieldValue(verValue)
  472. }
  473. }
  474. if table.AutoIncrement == "" {
  475. return res.RowsAffected()
  476. }
  477. var id int64
  478. id, err = res.LastInsertId()
  479. if err != nil || id <= 0 {
  480. return res.RowsAffected()
  481. }
  482. aiValue, err := table.AutoIncrColumn().ValueOf(bean)
  483. if err != nil {
  484. session.engine.logger.Error(err)
  485. }
  486. if aiValue == nil || !aiValue.IsValid() || !aiValue.CanSet() {
  487. return res.RowsAffected()
  488. }
  489. aiValue.Set(int64ToIntValue(id, aiValue.Type()))
  490. return res.RowsAffected()
  491. }
  492. }
  493. // InsertOne insert only one struct into database as a record.
  494. // The in parameter bean must a struct or a point to struct. The return
  495. // parameter is inserted and error
  496. func (session *Session) InsertOne(bean interface{}) (int64, error) {
  497. if session.isAutoClose {
  498. defer session.Close()
  499. }
  500. return session.innerInsert(bean)
  501. }
  502. func (session *Session) cacheInsert(table string) error {
  503. if !session.statement.UseCache {
  504. return nil
  505. }
  506. cacher := session.engine.getCacher(table)
  507. if cacher == nil {
  508. return nil
  509. }
  510. session.engine.logger.Debug("[cache] clear sql:", table)
  511. cacher.ClearIds(table)
  512. return nil
  513. }
  514. // genInsertColumns generates insert needed columns
  515. func (session *Session) genInsertColumns(bean interface{}) ([]string, []interface{}, error) {
  516. table := session.statement.RefTable
  517. colNames := make([]string, 0, len(table.ColumnsSeq()))
  518. args := make([]interface{}, 0, len(table.ColumnsSeq()))
  519. for _, col := range table.Columns() {
  520. if col.MapType == core.ONLYFROMDB {
  521. continue
  522. }
  523. if col.IsDeleted {
  524. continue
  525. }
  526. if session.statement.omitColumnMap.contain(col.Name) {
  527. continue
  528. }
  529. if len(session.statement.columnMap) > 0 && !session.statement.columnMap.contain(col.Name) {
  530. continue
  531. }
  532. if _, ok := session.statement.incrColumns[col.Name]; ok {
  533. continue
  534. } else if _, ok := session.statement.decrColumns[col.Name]; ok {
  535. continue
  536. }
  537. fieldValuePtr, err := col.ValueOf(bean)
  538. if err != nil {
  539. return nil, nil, err
  540. }
  541. fieldValue := *fieldValuePtr
  542. if col.IsAutoIncrement {
  543. switch fieldValue.Type().Kind() {
  544. case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int, reflect.Int64:
  545. if fieldValue.Int() == 0 {
  546. continue
  547. }
  548. case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint, reflect.Uint64:
  549. if fieldValue.Uint() == 0 {
  550. continue
  551. }
  552. case reflect.String:
  553. if len(fieldValue.String()) == 0 {
  554. continue
  555. }
  556. case reflect.Ptr:
  557. if fieldValue.Pointer() == 0 {
  558. continue
  559. }
  560. }
  561. }
  562. // !evalphobia! set fieldValue as nil when column is nullable and zero-value
  563. if _, ok := getFlagForColumn(session.statement.nullableMap, col); ok {
  564. if col.Nullable && isZero(fieldValue.Interface()) {
  565. var nilValue *int
  566. fieldValue = reflect.ValueOf(nilValue)
  567. }
  568. }
  569. if (col.IsCreated || col.IsUpdated) && session.statement.UseAutoTime /*&& isZero(fieldValue.Interface())*/ {
  570. // if time is non-empty, then set to auto time
  571. val, t := session.engine.nowTime(col)
  572. args = append(args, val)
  573. var colName = col.Name
  574. session.afterClosures = append(session.afterClosures, func(bean interface{}) {
  575. col := table.GetColumn(colName)
  576. setColumnTime(bean, col, t)
  577. })
  578. } else if col.IsVersion && session.statement.checkVersion {
  579. args = append(args, 1)
  580. } else {
  581. arg, err := session.value2Interface(col, fieldValue)
  582. if err != nil {
  583. return colNames, args, err
  584. }
  585. args = append(args, arg)
  586. }
  587. colNames = append(colNames, col.Name)
  588. }
  589. return colNames, args, nil
  590. }
  591. func (session *Session) insertMapInterface(m map[string]interface{}) (int64, error) {
  592. if len(m) == 0 {
  593. return 0, ErrParamsType
  594. }
  595. tableName := session.statement.TableName()
  596. if len(tableName) <= 0 {
  597. return 0, ErrTableNotFound
  598. }
  599. var columns = make([]string, 0, len(m))
  600. for k := range m {
  601. columns = append(columns, k)
  602. }
  603. sort.Strings(columns)
  604. qm := strings.Repeat("?,", len(columns))
  605. var args = make([]interface{}, 0, len(m))
  606. for _, colName := range columns {
  607. args = append(args, m[colName])
  608. }
  609. // insert expr columns, override if exists
  610. exprColumns := session.statement.getExpr()
  611. for _, col := range exprColumns {
  612. columns = append(columns, strings.Trim(col.colName, "`"))
  613. qm = qm + col.expr + ","
  614. }
  615. qm = qm[:len(qm)-1]
  616. var sql string
  617. if session.statement.cond.IsValid() {
  618. condSQL, condArgs, err := builder.ToSQL(session.statement.cond)
  619. if err != nil {
  620. return 0, err
  621. }
  622. sql = fmt.Sprintf("INSERT INTO %s (`%s`) SELECT %s FROM %s WHERE %s",
  623. session.engine.Quote(tableName),
  624. strings.Join(columns, "`,`"),
  625. qm,
  626. session.engine.Quote(tableName),
  627. condSQL,
  628. )
  629. args = append(args, condArgs...)
  630. } else {
  631. sql = fmt.Sprintf("INSERT INTO %s (`%s`) VALUES (%s)", session.engine.Quote(tableName), strings.Join(columns, "`,`"), qm)
  632. }
  633. if err := session.cacheInsert(tableName); err != nil {
  634. return 0, err
  635. }
  636. res, err := session.exec(sql, args...)
  637. if err != nil {
  638. return 0, err
  639. }
  640. affected, err := res.RowsAffected()
  641. if err != nil {
  642. return 0, err
  643. }
  644. return affected, nil
  645. }
  646. func (session *Session) insertMapString(m map[string]string) (int64, error) {
  647. if len(m) == 0 {
  648. return 0, ErrParamsType
  649. }
  650. tableName := session.statement.TableName()
  651. if len(tableName) <= 0 {
  652. return 0, ErrTableNotFound
  653. }
  654. var columns = make([]string, 0, len(m))
  655. for k := range m {
  656. columns = append(columns, k)
  657. }
  658. sort.Strings(columns)
  659. var args = make([]interface{}, 0, len(m))
  660. for _, colName := range columns {
  661. args = append(args, m[colName])
  662. }
  663. qm := strings.Repeat("?,", len(columns))
  664. // insert expr columns, override if exists
  665. exprColumns := session.statement.getExpr()
  666. for _, col := range exprColumns {
  667. columns = append(columns, strings.Trim(col.colName, "`"))
  668. qm = qm + col.expr + ","
  669. }
  670. qm = qm[:len(qm)-1]
  671. var sql string
  672. if session.statement.cond.IsValid() {
  673. qm = "(" + qm[:len(qm)-1] + ")"
  674. condSQL, condArgs, err := builder.ToSQL(session.statement.cond)
  675. if err != nil {
  676. return 0, err
  677. }
  678. sql = fmt.Sprintf("INSERT INTO %s (`%s`) SELECT %s FROM %s WHERE %s",
  679. session.engine.Quote(tableName),
  680. strings.Join(columns, "`,`"),
  681. qm,
  682. session.engine.Quote(tableName),
  683. condSQL,
  684. )
  685. args = append(args, condArgs...)
  686. } else {
  687. sql = fmt.Sprintf("INSERT INTO %s (`%s`) VALUES (%s)", session.engine.Quote(tableName), strings.Join(columns, "`,`"), qm)
  688. }
  689. if err := session.cacheInsert(tableName); err != nil {
  690. return 0, err
  691. }
  692. res, err := session.exec(sql, args...)
  693. if err != nil {
  694. return 0, err
  695. }
  696. affected, err := res.RowsAffected()
  697. if err != nil {
  698. return 0, err
  699. }
  700. return affected, nil
  701. }