You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

lua_sqlite3.c 8.4KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383
  1. /*-
  2. * Copyright 2016 Vsevolod Stakhov
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "lua_common.h"
  17. #include "sqlite_utils.h"
  18. /***
  19. * @module rspamd_sqlite3
  20. * This module provides routines to query sqlite3 databases
  21. @example
  22. local sqlite3 = require "rspamd_sqlite3"
  23. local db = sqlite3.open("/tmp/db.sqlite")
  24. if db then
  25. db:exec([[ CREATE TABLE x (id INT, value TEXT); ]])
  26. db:exec([[ INSERT INTO x VALUES (?1, ?2); ]], 1, 'test')
  27. for row in db:rows([[ SELECT * FROM x ]]) do
  28. print(string.format('%d -> %s', row.id, row.value))
  29. end
  30. end
  31. */
  32. LUA_FUNCTION_DEF (sqlite3, open);
  33. LUA_FUNCTION_DEF (sqlite3, sql);
  34. LUA_FUNCTION_DEF (sqlite3, rows);
  35. LUA_FUNCTION_DEF (sqlite3, close);
  36. LUA_FUNCTION_DEF (sqlite3_stmt, close);
  37. static const struct luaL_reg sqlitelib_f[] = {
  38. LUA_INTERFACE_DEF (sqlite3, open),
  39. {NULL, NULL}
  40. };
  41. static const struct luaL_reg sqlitelib_m[] = {
  42. LUA_INTERFACE_DEF (sqlite3, sql),
  43. {"query", lua_sqlite3_sql},
  44. {"exec", lua_sqlite3_sql},
  45. LUA_INTERFACE_DEF (sqlite3, rows),
  46. {"__tostring", rspamd_lua_class_tostring},
  47. {"__gc", lua_sqlite3_close},
  48. {NULL, NULL}
  49. };
  50. static const struct luaL_reg sqlitestmtlib_m[] = {
  51. {"__tostring", rspamd_lua_class_tostring},
  52. {"__gc", lua_sqlite3_stmt_close},
  53. {NULL, NULL}
  54. };
  55. static void lua_sqlite3_push_row (lua_State *L, sqlite3_stmt *stmt);
  56. static sqlite3 *
  57. lua_check_sqlite3 (lua_State * L, gint pos)
  58. {
  59. void *ud = rspamd_lua_check_udata (L, pos, "rspamd{sqlite3}");
  60. luaL_argcheck (L, ud != NULL, pos, "'sqlite3' expected");
  61. return ud ? *((sqlite3 **)ud) : NULL;
  62. }
  63. static sqlite3_stmt *
  64. lua_check_sqlite3_stmt (lua_State * L, gint pos)
  65. {
  66. void *ud = rspamd_lua_check_udata (L, pos, "rspamd{sqlite3_stmt}");
  67. luaL_argcheck (L, ud != NULL, pos, "'sqlite3_stmt' expected");
  68. return ud ? *((sqlite3_stmt **)ud) : NULL;
  69. }
  70. /***
  71. * @function rspamd_sqlite3.open(path)
  72. * Opens sqlite3 database at the specified path. DB is created if not exists.
  73. * @param {string} path path to db
  74. * @return {sqlite3} sqlite3 handle
  75. */
  76. static gint
  77. lua_sqlite3_open (lua_State *L)
  78. {
  79. const gchar *path = luaL_checkstring (L, 1);
  80. sqlite3 *db, **pdb;
  81. GError *err = NULL;
  82. if (path == NULL) {
  83. lua_pushnil (L);
  84. return 1;
  85. }
  86. db = rspamd_sqlite3_open_or_create (NULL, path, NULL, 0, &err);
  87. if (db == NULL) {
  88. if (err) {
  89. msg_err ("cannot open db: %e", err);
  90. g_error_free (err);
  91. }
  92. lua_pushnil (L);
  93. return 1;
  94. }
  95. pdb = lua_newuserdata (L, sizeof (db));
  96. *pdb = db;
  97. rspamd_lua_setclass (L, "rspamd{sqlite3}", -1);
  98. return 1;
  99. }
  100. static void
  101. lua_sqlite3_bind_statements (lua_State *L, gint start, gint end,
  102. sqlite3_stmt *stmt)
  103. {
  104. gint i, type, num = 1;
  105. const gchar *str;
  106. gsize slen;
  107. gdouble n;
  108. g_assert (start <= end && start > 0 && end > 0);
  109. for (i = start; i <= end; i ++) {
  110. type = lua_type (L, i);
  111. switch (type) {
  112. case LUA_TNUMBER:
  113. n = lua_tonumber (L, i);
  114. if (n == (gdouble)((gint64)n)) {
  115. sqlite3_bind_int64 (stmt, num, n);
  116. }
  117. else {
  118. sqlite3_bind_double (stmt, num, n);
  119. }
  120. num ++;
  121. break;
  122. case LUA_TSTRING:
  123. str = lua_tolstring (L, i, &slen);
  124. sqlite3_bind_text (stmt, num, str, slen, SQLITE_TRANSIENT);
  125. num ++;
  126. break;
  127. default:
  128. msg_err ("invalid type at position %d: %s", i, lua_typename (L, type));
  129. break;
  130. }
  131. }
  132. }
  133. /***
  134. * @function rspamd_sqlite3:sql(query[, args..])
  135. * Performs sqlite3 query replacing '?1', '?2' and so on with the subsequent args
  136. * of the function
  137. *
  138. * @param {string} query SQL query
  139. * @param {string|number} args... variable number of arguments
  140. * @return {boolean} `true` if a statement has been successfully executed
  141. */
  142. static gint
  143. lua_sqlite3_sql (lua_State *L)
  144. {
  145. LUA_TRACE_POINT;
  146. sqlite3 *db = lua_check_sqlite3 (L, 1);
  147. const gchar *query = luaL_checkstring (L, 2);
  148. sqlite3_stmt *stmt;
  149. gboolean ret = FALSE;
  150. gint top = 1, rc;
  151. if (db && query) {
  152. if (sqlite3_prepare_v2 (db, query, -1, &stmt, NULL) != SQLITE_OK) {
  153. msg_err ("cannot prepare query %s: %s", query, sqlite3_errmsg (db));
  154. return luaL_error (L, sqlite3_errmsg (db));
  155. }
  156. else {
  157. top = lua_gettop (L);
  158. if (top > 2) {
  159. /* Push additional arguments to sqlite3 */
  160. lua_sqlite3_bind_statements (L, 3, top, stmt);
  161. }
  162. rc = sqlite3_step (stmt);
  163. top = 1;
  164. if (rc == SQLITE_ROW || rc == SQLITE_OK || rc == SQLITE_DONE) {
  165. ret = TRUE;
  166. if (rc == SQLITE_ROW) {
  167. lua_sqlite3_push_row (L, stmt);
  168. top = 2;
  169. }
  170. }
  171. else {
  172. msg_warn ("sqlite3 error: %s", sqlite3_errmsg (db));
  173. }
  174. sqlite3_finalize (stmt);
  175. }
  176. }
  177. lua_pushboolean (L, ret);
  178. return top;
  179. }
  180. static void
  181. lua_sqlite3_push_row (lua_State *L, sqlite3_stmt *stmt)
  182. {
  183. const gchar *str;
  184. gsize slen;
  185. gint64 num;
  186. gchar numbuf[32];
  187. gint nresults, i, type;
  188. nresults = sqlite3_column_count (stmt);
  189. lua_createtable (L, 0, nresults);
  190. for (i = 0; i < nresults; i ++) {
  191. lua_pushstring (L, sqlite3_column_name (stmt, i));
  192. type = sqlite3_column_type (stmt, i);
  193. switch (type) {
  194. case SQLITE_INTEGER:
  195. /*
  196. * XXX: we represent int64 as strings, as we can nothing else to do
  197. * about it portably
  198. */
  199. num = sqlite3_column_int64 (stmt, i);
  200. rspamd_snprintf (numbuf, sizeof (numbuf), "%uL", num);
  201. lua_pushstring (L, numbuf);
  202. break;
  203. case SQLITE_FLOAT:
  204. lua_pushnumber (L, sqlite3_column_double (stmt, i));
  205. break;
  206. case SQLITE_TEXT:
  207. slen = sqlite3_column_bytes (stmt, i);
  208. str = sqlite3_column_text (stmt, i);
  209. lua_pushlstring (L, str, slen);
  210. break;
  211. case SQLITE_BLOB:
  212. slen = sqlite3_column_bytes (stmt, i);
  213. str = sqlite3_column_blob (stmt, i);
  214. lua_pushlstring (L, str, slen);
  215. break;
  216. default:
  217. lua_pushboolean (L, 0);
  218. break;
  219. }
  220. lua_settable (L, -3);
  221. }
  222. }
  223. static gint
  224. lua_sqlite3_next_row (lua_State *L)
  225. {
  226. LUA_TRACE_POINT;
  227. sqlite3_stmt *stmt = *(sqlite3_stmt **)lua_touserdata (L, lua_upvalueindex (1));
  228. gint rc;
  229. if (stmt != NULL) {
  230. rc = sqlite3_step (stmt);
  231. if (rc == SQLITE_ROW) {
  232. lua_sqlite3_push_row (L, stmt);
  233. return 1;
  234. }
  235. }
  236. lua_pushnil (L);
  237. return 1;
  238. }
  239. /***
  240. * @function rspamd_sqlite3:rows(query[, args..])
  241. * Performs sqlite3 query replacing '?1', '?2' and so on with the subsequent args
  242. * of the function. This function returns iterator suitable for loop construction:
  243. *
  244. * @param {string} query SQL query
  245. * @param {string|number} args... variable number of arguments
  246. * @return {function} iterator to get all rows
  247. @example
  248. for row in db:rows([[ SELECT * FROM x ]]) do
  249. print(string.format('%d -> %s', row.id, row.value))
  250. end
  251. */
  252. static gint
  253. lua_sqlite3_rows (lua_State *L)
  254. {
  255. LUA_TRACE_POINT;
  256. sqlite3 *db = lua_check_sqlite3 (L, 1);
  257. const gchar *query = luaL_checkstring (L, 2);
  258. sqlite3_stmt *stmt, **pstmt;
  259. gint top;
  260. if (db && query) {
  261. if (sqlite3_prepare_v2 (db, query, -1, &stmt, NULL) != SQLITE_OK) {
  262. msg_err ("cannot prepare query %s: %s", query, sqlite3_errmsg (db));
  263. lua_pushstring (L, sqlite3_errmsg (db));
  264. return lua_error (L);
  265. }
  266. else {
  267. top = lua_gettop (L);
  268. if (top > 2) {
  269. /* Push additional arguments to sqlite3 */
  270. lua_sqlite3_bind_statements (L, 3, top, stmt);
  271. }
  272. /* Create C closure */
  273. pstmt = lua_newuserdata (L, sizeof (stmt));
  274. *pstmt = stmt;
  275. rspamd_lua_setclass (L, "rspamd{sqlite3_stmt}", -1);
  276. lua_pushcclosure (L, lua_sqlite3_next_row, 1);
  277. }
  278. }
  279. else {
  280. lua_pushnil (L);
  281. }
  282. return 1;
  283. }
  284. static gint
  285. lua_sqlite3_close (lua_State *L)
  286. {
  287. LUA_TRACE_POINT;
  288. sqlite3 *db = lua_check_sqlite3 (L, 1);
  289. if (db) {
  290. sqlite3_close (db);
  291. }
  292. return 0;
  293. }
  294. static gint
  295. lua_sqlite3_stmt_close (lua_State *L)
  296. {
  297. sqlite3_stmt *stmt = lua_check_sqlite3_stmt (L, 1);
  298. if (stmt) {
  299. sqlite3_finalize (stmt);
  300. }
  301. return 0;
  302. }
  303. static gint
  304. lua_load_sqlite3 (lua_State * L)
  305. {
  306. lua_newtable (L);
  307. luaL_register (L, NULL, sqlitelib_f);
  308. return 1;
  309. }
  310. /**
  311. * Open redis library
  312. * @param L lua stack
  313. * @return
  314. */
  315. void
  316. luaopen_sqlite3 (lua_State * L)
  317. {
  318. rspamd_lua_new_class (L, "rspamd{sqlite3}", sqlitelib_m);
  319. lua_pop (L, 1);
  320. rspamd_lua_new_class (L, "rspamd{sqlite3_stmt}", sqlitestmtlib_m);
  321. lua_pop (L, 1);
  322. rspamd_lua_add_preload (L, "rspamd_sqlite3", lua_load_sqlite3);
  323. }