/* * Copyright 2024 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "lua_common.h" #include "sqlite_utils.h" /*** * @module rspamd_sqlite3 * This module provides routines to query sqlite3 databases @example local sqlite3 = require "rspamd_sqlite3" local db = sqlite3.open("/tmp/db.sqlite") if db then db:exec([[ CREATE TABLE x (id INT, value TEXT); ]]) db:exec([[ INSERT INTO x VALUES (?1, ?2); ]], 1, 'test') for row in db:rows([[ SELECT * FROM x ]]) do print(string.format('%d -> %s', row.id, row.value)) end end */ LUA_FUNCTION_DEF(sqlite3, open); LUA_FUNCTION_DEF(sqlite3, sql); LUA_FUNCTION_DEF(sqlite3, rows); LUA_FUNCTION_DEF(sqlite3, close); LUA_FUNCTION_DEF(sqlite3_stmt, close); static const struct luaL_reg sqlitelib_f[] = { LUA_INTERFACE_DEF(sqlite3, open), {NULL, NULL}}; static const struct luaL_reg sqlitelib_m[] = { LUA_INTERFACE_DEF(sqlite3, sql), {"query", lua_sqlite3_sql}, {"exec", lua_sqlite3_sql}, LUA_INTERFACE_DEF(sqlite3, rows), {"__tostring", rspamd_lua_class_tostring}, {"__gc", lua_sqlite3_close}, {NULL, NULL}}; static const struct luaL_reg sqlitestmtlib_m[] = { {"__tostring", rspamd_lua_class_tostring}, {"__gc", lua_sqlite3_stmt_close}, {NULL, NULL}}; static void lua_sqlite3_push_row(lua_State *L, sqlite3_stmt *stmt); static sqlite3 * lua_check_sqlite3(lua_State *L, int pos) { void *ud = rspamd_lua_check_udata(L, pos, rspamd_sqlite3_classname); luaL_argcheck(L, ud != NULL, pos, "'sqlite3' expected"); return ud ? *((sqlite3 **) ud) : NULL; } static sqlite3_stmt * lua_check_sqlite3_stmt(lua_State *L, int pos) { void *ud = rspamd_lua_check_udata(L, pos, rspamd_sqlite3_stmt_classname); luaL_argcheck(L, ud != NULL, pos, "'sqlite3_stmt' expected"); return ud ? *((sqlite3_stmt **) ud) : NULL; } /*** * @function rspamd_sqlite3.open(path) * Opens sqlite3 database at the specified path. DB is created if not exists. * @param {string} path path to db * @return {sqlite3} sqlite3 handle */ static int lua_sqlite3_open(lua_State *L) { const char *path = luaL_checkstring(L, 1); sqlite3 *db, **pdb; GError *err = NULL; if (path == NULL) { lua_pushnil(L); return 1; } db = rspamd_sqlite3_open_or_create(NULL, path, NULL, 0, &err); if (db == NULL) { if (err) { msg_err("cannot open db: %e", err); g_error_free(err); } lua_pushnil(L); return 1; } pdb = lua_newuserdata(L, sizeof(db)); *pdb = db; rspamd_lua_setclass(L, rspamd_sqlite3_classname, -1); return 1; } static void lua_sqlite3_bind_statements(lua_State *L, int start, int end, sqlite3_stmt *stmt) { int i, type, num = 1; const char *str; gsize slen; double n; g_assert(start <= end && start > 0 && end > 0); for (i = start; i <= end; i++) { type = lua_type(L, i); switch (type) { case LUA_TNUMBER: n = lua_tonumber(L, i); if (n == (double) ((int64_t) n)) { sqlite3_bind_int64(stmt, num, n); } else { sqlite3_bind_double(stmt, num, n); } num++; break; case LUA_TSTRING: str = lua_tolstring(L, i, &slen); sqlite3_bind_text(stmt, num, str, slen, SQLITE_TRANSIENT); num++; break; default: msg_err("invalid type at position %d: %s", i, lua_typename(L, type)); break; } } } /*** * @function rspamd_sqlite3:sql(query[, args..]) * Performs sqlite3 query replacing '?1', '?2' and so on with the subsequent args * of the function * * @param {string} query SQL query * @param {string|number} args... variable number of arguments * @return {boolean} `true` if a statement has been successfully executed */ static int lua_sqlite3_sql(lua_State *L) { LUA_TRACE_POINT; sqlite3 *db = lua_check_sqlite3(L, 1); const char *query = luaL_checkstring(L, 2); sqlite3_stmt *stmt; gboolean ret = FALSE; int top = 1, rc; if (db && query) { if (sqlite3_prepare_v2(db, query, -1, &stmt, NULL) != SQLITE_OK) { msg_err("cannot prepare query %s: %s", query, sqlite3_errmsg(db)); return luaL_error(L, sqlite3_errmsg(db)); } else { top = lua_gettop(L); if (top > 2) { /* Push additional arguments to sqlite3 */ lua_sqlite3_bind_statements(L, 3, top, stmt); } rc = sqlite3_step(stmt); top = 1; if (rc == SQLITE_ROW || rc == SQLITE_OK || rc == SQLITE_DONE) { ret = TRUE; if (rc == SQLITE_ROW) { lua_sqlite3_push_row(L, stmt); top = 2; } } else { msg_warn("sqlite3 error: %s", sqlite3_errmsg(db)); } sqlite3_finalize(stmt); } } lua_pushboolean(L, ret); return top; } static void lua_sqlite3_push_row(lua_State *L, sqlite3_stmt *stmt) { const char *str; gsize slen; int64_t num; char numbuf[32]; int nresults, i, type; nresults = sqlite3_column_count(stmt); lua_createtable(L, 0, nresults); for (i = 0; i < nresults; i++) { lua_pushstring(L, sqlite3_column_name(stmt, i)); type = sqlite3_column_type(stmt, i); switch (type) { case SQLITE_INTEGER: /* * XXX: we represent int64 as strings, as we can nothing else to do * about it portably */ num = sqlite3_column_int64(stmt, i); rspamd_snprintf(numbuf, sizeof(numbuf), "%uL", num); lua_pushstring(L, numbuf); break; case SQLITE_FLOAT: lua_pushnumber(L, sqlite3_column_double(stmt, i)); break; case SQLITE_TEXT: slen = sqlite3_column_bytes(stmt, i); str = sqlite3_column_text(stmt, i); lua_pushlstring(L, str, slen); break; case SQLITE_BLOB: slen = sqlite3_column_bytes(stmt, i); str = sqlite3_column_blob(stmt, i); lua_pushlstring(L, str, slen); break; default: lua_pushboolean(L, 0); break; } lua_settable(L, -3); } } static int lua_sqlite3_next_row(lua_State *L) { LUA_TRACE_POINT; sqlite3_stmt *stmt = *(sqlite3_stmt **) lua_touserdata(L, lua_upvalueindex(1)); int rc; if (stmt != NULL) { rc = sqlite3_step(stmt); if (rc == SQLITE_ROW) { lua_sqlite3_push_row(L, stmt); return 1; } } lua_pushnil(L); return 1; } /*** * @function rspamd_sqlite3:rows(query[, args..]) * Performs sqlite3 query replacing '?1', '?2' and so on with the subsequent args * of the function. This function returns iterator suitable for loop construction: * * @param {string} query SQL query * @param {string|number} args... variable number of arguments * @return {function} iterator to get all rows @example for row in db:rows([[ SELECT * FROM x ]]) do print(string.format('%d -> %s', row.id, row.value)) end */ static int lua_sqlite3_rows(lua_State *L) { LUA_TRACE_POINT; sqlite3 *db = lua_check_sqlite3(L, 1); const char *query = luaL_checkstring(L, 2); sqlite3_stmt *stmt, **pstmt; int top; if (db && query) { if (sqlite3_prepare_v2(db, query, -1, &stmt, NULL) != SQLITE_OK) { msg_err("cannot prepare query %s: %s", query, sqlite3_errmsg(db)); lua_pushstring(L, sqlite3_errmsg(db)); return lua_error(L); } else { top = lua_gettop(L); if (top > 2) { /* Push additional arguments to sqlite3 */ lua_sqlite3_bind_statements(L, 3, top, stmt); } /* Create C closure */ pstmt = lua_newuserdata(L, sizeof(stmt)); *pstmt = stmt; rspamd_lua_setclass(L, rspamd_sqlite3_stmt_classname, -1); lua_pushcclosure(L, lua_sqlite3_next_row, 1); } } else { lua_pushnil(L); } return 1; } static int lua_sqlite3_close(lua_State *L) { LUA_TRACE_POINT; sqlite3 *db = lua_check_sqlite3(L, 1); if (db) { sqlite3_close(db); } return 0; } static int lua_sqlite3_stmt_close(lua_State *L) { sqlite3_stmt *stmt = lua_check_sqlite3_stmt(L, 1); if (stmt) { sqlite3_finalize(stmt); } return 0; } static int lua_load_sqlite3(lua_State *L) { lua_newtable(L); luaL_register(L, NULL, sqlitelib_f); return 1; } /** * Open redis library * @param L lua stack * @return */ void luaopen_sqlite3(lua_State *L) { rspamd_lua_new_class(L, rspamd_sqlite3_classname, sqlitelib_m); lua_pop(L, 1); rspamd_lua_new_class(L, rspamd_sqlite3_stmt_classname, sqlitestmtlib_m); lua_pop(L, 1); rspamd_lua_add_preload(L, "rspamd_sqlite3", lua_load_sqlite3); }