You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441
  1. // Copyright 2019 The Xorm Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package statements
  5. import (
  6. "errors"
  7. "fmt"
  8. "reflect"
  9. "strings"
  10. "xorm.io/builder"
  11. "xorm.io/xorm/schemas"
  12. )
  13. func (statement *Statement) GenQuerySQL(sqlOrArgs ...interface{}) (string, []interface{}, error) {
  14. if len(sqlOrArgs) > 0 {
  15. return statement.ConvertSQLOrArgs(sqlOrArgs...)
  16. }
  17. if statement.RawSQL != "" {
  18. return statement.GenRawSQL(), statement.RawParams, nil
  19. }
  20. if len(statement.TableName()) <= 0 {
  21. return "", nil, ErrTableNotFound
  22. }
  23. var columnStr = statement.ColumnStr()
  24. if len(statement.SelectStr) > 0 {
  25. columnStr = statement.SelectStr
  26. } else {
  27. if statement.JoinStr == "" {
  28. if columnStr == "" {
  29. if statement.GroupByStr != "" {
  30. columnStr = statement.quoteColumnStr(statement.GroupByStr)
  31. } else {
  32. columnStr = statement.genColumnStr()
  33. }
  34. }
  35. } else {
  36. if columnStr == "" {
  37. if statement.GroupByStr != "" {
  38. columnStr = statement.quoteColumnStr(statement.GroupByStr)
  39. } else {
  40. columnStr = "*"
  41. }
  42. }
  43. }
  44. if columnStr == "" {
  45. columnStr = "*"
  46. }
  47. }
  48. if err := statement.ProcessIDParam(); err != nil {
  49. return "", nil, err
  50. }
  51. sqlStr, condArgs, err := statement.genSelectSQL(columnStr, true, true)
  52. if err != nil {
  53. return "", nil, err
  54. }
  55. args := append(statement.joinArgs, condArgs...)
  56. // for mssql and use limit
  57. qs := strings.Count(sqlStr, "?")
  58. if len(args)*2 == qs {
  59. args = append(args, args...)
  60. }
  61. return sqlStr, args, nil
  62. }
  63. func (statement *Statement) GenSumSQL(bean interface{}, columns ...string) (string, []interface{}, error) {
  64. if statement.RawSQL != "" {
  65. return statement.GenRawSQL(), statement.RawParams, nil
  66. }
  67. statement.SetRefBean(bean)
  68. var sumStrs = make([]string, 0, len(columns))
  69. for _, colName := range columns {
  70. if !strings.Contains(colName, " ") && !strings.Contains(colName, "(") {
  71. colName = statement.quote(colName)
  72. } else {
  73. colName = statement.ReplaceQuote(colName)
  74. }
  75. sumStrs = append(sumStrs, fmt.Sprintf("COALESCE(sum(%s),0)", colName))
  76. }
  77. sumSelect := strings.Join(sumStrs, ", ")
  78. if err := statement.mergeConds(bean); err != nil {
  79. return "", nil, err
  80. }
  81. sqlStr, condArgs, err := statement.genSelectSQL(sumSelect, true, true)
  82. if err != nil {
  83. return "", nil, err
  84. }
  85. return sqlStr, append(statement.joinArgs, condArgs...), nil
  86. }
  87. func (statement *Statement) GenGetSQL(bean interface{}) (string, []interface{}, error) {
  88. v := rValue(bean)
  89. isStruct := v.Kind() == reflect.Struct
  90. if isStruct {
  91. statement.SetRefBean(bean)
  92. }
  93. var columnStr = statement.ColumnStr()
  94. if len(statement.SelectStr) > 0 {
  95. columnStr = statement.SelectStr
  96. } else {
  97. // TODO: always generate column names, not use * even if join
  98. if len(statement.JoinStr) == 0 {
  99. if len(columnStr) == 0 {
  100. if len(statement.GroupByStr) > 0 {
  101. columnStr = statement.quoteColumnStr(statement.GroupByStr)
  102. } else {
  103. columnStr = statement.genColumnStr()
  104. }
  105. }
  106. } else {
  107. if len(columnStr) == 0 {
  108. if len(statement.GroupByStr) > 0 {
  109. columnStr = statement.quoteColumnStr(statement.GroupByStr)
  110. }
  111. }
  112. }
  113. }
  114. if len(columnStr) == 0 {
  115. columnStr = "*"
  116. }
  117. if isStruct {
  118. if err := statement.mergeConds(bean); err != nil {
  119. return "", nil, err
  120. }
  121. } else {
  122. if err := statement.ProcessIDParam(); err != nil {
  123. return "", nil, err
  124. }
  125. }
  126. sqlStr, condArgs, err := statement.genSelectSQL(columnStr, true, true)
  127. if err != nil {
  128. return "", nil, err
  129. }
  130. return sqlStr, append(statement.joinArgs, condArgs...), nil
  131. }
  132. // GenCountSQL generates the SQL for counting
  133. func (statement *Statement) GenCountSQL(beans ...interface{}) (string, []interface{}, error) {
  134. if statement.RawSQL != "" {
  135. return statement.GenRawSQL(), statement.RawParams, nil
  136. }
  137. var condArgs []interface{}
  138. var err error
  139. if len(beans) > 0 {
  140. statement.SetRefBean(beans[0])
  141. if err := statement.mergeConds(beans[0]); err != nil {
  142. return "", nil, err
  143. }
  144. }
  145. var selectSQL = statement.SelectStr
  146. if len(selectSQL) <= 0 {
  147. if statement.IsDistinct {
  148. selectSQL = fmt.Sprintf("count(DISTINCT %s)", statement.ColumnStr())
  149. } else if statement.ColumnStr() != "" {
  150. selectSQL = fmt.Sprintf("count(%s)", statement.ColumnStr())
  151. } else {
  152. selectSQL = "count(*)"
  153. }
  154. }
  155. sqlStr, condArgs, err := statement.genSelectSQL(selectSQL, false, false)
  156. if err != nil {
  157. return "", nil, err
  158. }
  159. return sqlStr, append(statement.joinArgs, condArgs...), nil
  160. }
  161. func (statement *Statement) genSelectSQL(columnStr string, needLimit, needOrderBy bool) (string, []interface{}, error) {
  162. var (
  163. distinct string
  164. dialect = statement.dialect
  165. quote = statement.quote
  166. fromStr = " FROM "
  167. top, mssqlCondi, whereStr string
  168. )
  169. if statement.IsDistinct && !strings.HasPrefix(columnStr, "count") {
  170. distinct = "DISTINCT "
  171. }
  172. condSQL, condArgs, err := statement.GenCondSQL(statement.cond)
  173. if err != nil {
  174. return "", nil, err
  175. }
  176. if len(condSQL) > 0 {
  177. whereStr = " WHERE " + condSQL
  178. }
  179. if dialect.URI().DBType == schemas.MSSQL && strings.Contains(statement.TableName(), "..") {
  180. fromStr += statement.TableName()
  181. } else {
  182. fromStr += quote(statement.TableName())
  183. }
  184. if statement.TableAlias != "" {
  185. if dialect.URI().DBType == schemas.ORACLE {
  186. fromStr += " " + quote(statement.TableAlias)
  187. } else {
  188. fromStr += " AS " + quote(statement.TableAlias)
  189. }
  190. }
  191. if statement.JoinStr != "" {
  192. fromStr = fmt.Sprintf("%v %v", fromStr, statement.JoinStr)
  193. }
  194. pLimitN := statement.LimitN
  195. if dialect.URI().DBType == schemas.MSSQL {
  196. if pLimitN != nil {
  197. LimitNValue := *pLimitN
  198. top = fmt.Sprintf("TOP %d ", LimitNValue)
  199. }
  200. if statement.Start > 0 {
  201. var column string
  202. if len(statement.RefTable.PKColumns()) == 0 {
  203. for _, index := range statement.RefTable.Indexes {
  204. if len(index.Cols) == 1 {
  205. column = index.Cols[0]
  206. break
  207. }
  208. }
  209. if len(column) == 0 {
  210. column = statement.RefTable.ColumnsSeq()[0]
  211. }
  212. } else {
  213. column = statement.RefTable.PKColumns()[0].Name
  214. }
  215. if statement.needTableName() {
  216. if len(statement.TableAlias) > 0 {
  217. column = statement.TableAlias + "." + column
  218. } else {
  219. column = statement.TableName() + "." + column
  220. }
  221. }
  222. var orderStr string
  223. if needOrderBy && len(statement.OrderStr) > 0 {
  224. orderStr = " ORDER BY " + statement.OrderStr
  225. }
  226. var groupStr string
  227. if len(statement.GroupByStr) > 0 {
  228. groupStr = " GROUP BY " + statement.GroupByStr
  229. }
  230. mssqlCondi = fmt.Sprintf("(%s NOT IN (SELECT TOP %d %s%s%s%s%s))",
  231. column, statement.Start, column, fromStr, whereStr, orderStr, groupStr)
  232. }
  233. }
  234. var buf strings.Builder
  235. fmt.Fprintf(&buf, "SELECT %v%v%v%v%v", distinct, top, columnStr, fromStr, whereStr)
  236. if len(mssqlCondi) > 0 {
  237. if len(whereStr) > 0 {
  238. fmt.Fprint(&buf, " AND ", mssqlCondi)
  239. } else {
  240. fmt.Fprint(&buf, " WHERE ", mssqlCondi)
  241. }
  242. }
  243. if statement.GroupByStr != "" {
  244. fmt.Fprint(&buf, " GROUP BY ", statement.GroupByStr)
  245. }
  246. if statement.HavingStr != "" {
  247. fmt.Fprint(&buf, " ", statement.HavingStr)
  248. }
  249. if needOrderBy && statement.OrderStr != "" {
  250. fmt.Fprint(&buf, " ORDER BY ", statement.OrderStr)
  251. }
  252. if needLimit {
  253. if dialect.URI().DBType != schemas.MSSQL && dialect.URI().DBType != schemas.ORACLE {
  254. if statement.Start > 0 {
  255. if pLimitN != nil {
  256. fmt.Fprintf(&buf, " LIMIT %v OFFSET %v", *pLimitN, statement.Start)
  257. } else {
  258. fmt.Fprintf(&buf, "LIMIT 0 OFFSET %v", statement.Start)
  259. }
  260. } else if pLimitN != nil {
  261. fmt.Fprint(&buf, " LIMIT ", *pLimitN)
  262. }
  263. } else if dialect.URI().DBType == schemas.ORACLE {
  264. if statement.Start != 0 || pLimitN != nil {
  265. oldString := buf.String()
  266. buf.Reset()
  267. rawColStr := columnStr
  268. if rawColStr == "*" {
  269. rawColStr = "at.*"
  270. }
  271. fmt.Fprintf(&buf, "SELECT %v FROM (SELECT %v,ROWNUM RN FROM (%v) at WHERE ROWNUM <= %d) aat WHERE RN > %d",
  272. columnStr, rawColStr, oldString, statement.Start+*pLimitN, statement.Start)
  273. }
  274. }
  275. }
  276. if statement.IsForUpdate {
  277. return dialect.ForUpdateSQL(buf.String()), condArgs, nil
  278. }
  279. return buf.String(), condArgs, nil
  280. }
  281. func (statement *Statement) GenExistSQL(bean ...interface{}) (string, []interface{}, error) {
  282. if statement.RawSQL != "" {
  283. return statement.GenRawSQL(), statement.RawParams, nil
  284. }
  285. var sqlStr string
  286. var args []interface{}
  287. var joinStr string
  288. var err error
  289. if len(bean) == 0 {
  290. tableName := statement.TableName()
  291. if len(tableName) <= 0 {
  292. return "", nil, ErrTableNotFound
  293. }
  294. tableName = statement.quote(tableName)
  295. if len(statement.JoinStr) > 0 {
  296. joinStr = statement.JoinStr
  297. }
  298. if statement.Conds().IsValid() {
  299. condSQL, condArgs, err := statement.GenCondSQL(statement.Conds())
  300. if err != nil {
  301. return "", nil, err
  302. }
  303. if statement.dialect.URI().DBType == schemas.MSSQL {
  304. sqlStr = fmt.Sprintf("SELECT TOP 1 * FROM %s %s WHERE %s", tableName, joinStr, condSQL)
  305. } else if statement.dialect.URI().DBType == schemas.ORACLE {
  306. sqlStr = fmt.Sprintf("SELECT * FROM %s WHERE (%s) %s AND ROWNUM=1", tableName, joinStr, condSQL)
  307. } else {
  308. sqlStr = fmt.Sprintf("SELECT * FROM %s %s WHERE %s LIMIT 1", tableName, joinStr, condSQL)
  309. }
  310. args = condArgs
  311. } else {
  312. if statement.dialect.URI().DBType == schemas.MSSQL {
  313. sqlStr = fmt.Sprintf("SELECT TOP 1 * FROM %s %s", tableName, joinStr)
  314. } else if statement.dialect.URI().DBType == schemas.ORACLE {
  315. sqlStr = fmt.Sprintf("SELECT * FROM %s %s WHERE ROWNUM=1", tableName, joinStr)
  316. } else {
  317. sqlStr = fmt.Sprintf("SELECT * FROM %s %s LIMIT 1", tableName, joinStr)
  318. }
  319. args = []interface{}{}
  320. }
  321. } else {
  322. beanValue := reflect.ValueOf(bean[0])
  323. if beanValue.Kind() != reflect.Ptr {
  324. return "", nil, errors.New("needs a pointer")
  325. }
  326. if beanValue.Elem().Kind() == reflect.Struct {
  327. if err := statement.SetRefBean(bean[0]); err != nil {
  328. return "", nil, err
  329. }
  330. }
  331. if len(statement.TableName()) <= 0 {
  332. return "", nil, ErrTableNotFound
  333. }
  334. statement.Limit(1)
  335. sqlStr, args, err = statement.GenGetSQL(bean[0])
  336. if err != nil {
  337. return "", nil, err
  338. }
  339. }
  340. return sqlStr, args, nil
  341. }
  342. func (statement *Statement) GenFindSQL(autoCond builder.Cond) (string, []interface{}, error) {
  343. if statement.RawSQL != "" {
  344. return statement.GenRawSQL(), statement.RawParams, nil
  345. }
  346. var sqlStr string
  347. var args []interface{}
  348. var err error
  349. if len(statement.TableName()) <= 0 {
  350. return "", nil, ErrTableNotFound
  351. }
  352. var columnStr = statement.ColumnStr()
  353. if len(statement.SelectStr) > 0 {
  354. columnStr = statement.SelectStr
  355. } else {
  356. if statement.JoinStr == "" {
  357. if columnStr == "" {
  358. if statement.GroupByStr != "" {
  359. columnStr = statement.quoteColumnStr(statement.GroupByStr)
  360. } else {
  361. columnStr = statement.genColumnStr()
  362. }
  363. }
  364. } else {
  365. if columnStr == "" {
  366. if statement.GroupByStr != "" {
  367. columnStr = statement.quoteColumnStr(statement.GroupByStr)
  368. } else {
  369. columnStr = "*"
  370. }
  371. }
  372. }
  373. if columnStr == "" {
  374. columnStr = "*"
  375. }
  376. }
  377. statement.cond = statement.cond.And(autoCond)
  378. sqlStr, condArgs, err := statement.genSelectSQL(columnStr, true, true)
  379. if err != nil {
  380. return "", nil, err
  381. }
  382. args = append(statement.joinArgs, condArgs...)
  383. // for mssql and use limit
  384. qs := strings.Count(sqlStr, "?")
  385. if len(args)*2 == qs {
  386. args = append(args, args...)
  387. }
  388. return sqlStr, args, nil
  389. }