g_assert (lua_gettop (L) == level - 1);
}
+static void lua_metric_symbol_callback_return (struct thread_entry *thread_entry,
+ int ret);
+
+static void lua_metric_symbol_callback_error (struct thread_entry *thread_entry,
+ int ret,
+ const char *msg);
+
+static void
+lua_metric_symbol_callback_coro (struct rspamd_task *task,
+ struct rspamd_symcache_item *item,
+ gpointer ud)
+{
+ struct lua_callback_data *cd = ud;
+ struct rspamd_task **ptask;
+ struct thread_entry *thread_entry;
+
+ rspamd_symcache_item_async_inc (task, item, "lua coro symbol");
+ thread_entry = lua_thread_pool_get_for_task (task);
+
+ g_assert(thread_entry->cd == NULL);
+ thread_entry->cd = cd;
+
+ lua_State *thread = thread_entry->lua_state;
+ cd->stack_level = lua_gettop (thread);
+ cd->item = item;
+
+ if (cd->cb_is_ref) {
+ lua_rawgeti (thread, LUA_REGISTRYINDEX, cd->callback.ref);
+ }
+ else {
+ lua_getglobal (thread, cd->callback.name);
+ }
+
+ ptask = lua_newuserdata (thread, sizeof (struct rspamd_task *));
+ rspamd_lua_setclass (thread, "rspamd{task}", -1);
+ *ptask = task;
+
+ thread_entry->finish_callback = lua_metric_symbol_callback_return;
+ thread_entry->error_callback = lua_metric_symbol_callback_error;
+
+ lua_thread_call (thread_entry, 1);
+}
+
+static void
+lua_metric_symbol_callback_error (struct thread_entry *thread_entry,
+ int ret,
+ const char *msg)
+{
+ struct lua_callback_data *cd = thread_entry->cd;
+ struct rspamd_task *task = thread_entry->task;
+ msg_err_task ("call to coroutine (%s) failed (%d): %s", cd->symbol, ret, msg);
+
+ rspamd_symcache_item_async_dec_check (task, cd->item, "lua coro symbol");
+}
+
+static void
+lua_metric_symbol_callback_return (struct thread_entry *thread_entry, int ret)
+{
+ struct lua_callback_data *cd = thread_entry->cd;
+ struct rspamd_task *task = thread_entry->task;
+ int nresults;
+ struct rspamd_symbol_result *s;
+
+ (void)ret;
+
+ lua_State *L = thread_entry->lua_state;
+
+ nresults = lua_gettop (L) - cd->stack_level;
+
+ if (nresults >= 1) {
+ /* Function returned boolean, so maybe we need to insert result? */
+ gint res = 0;
+ gint i;
+ gdouble flag = 1.0;
+ gint type;
+
+ type = lua_type (L, cd->stack_level + 1);
+
+ if (type == LUA_TBOOLEAN) {
+ res = lua_toboolean (L, cd->stack_level + 1);
+ }
+ else if (type == LUA_TFUNCTION) {
+ g_assert_not_reached ();
+ }
+ else {
+ res = lua_tonumber (L, cd->stack_level + 1);
+ }
+
+ if (res) {
+ gint first_opt = 2;
+
+ if (lua_type (L, cd->stack_level + 2) == LUA_TNUMBER) {
+ flag = lua_tonumber (L, cd->stack_level + 2);
+ /* Shift opt index */
+ first_opt = 3;
+ }
+ else {
+ flag = res;
+ }
+
+ s = rspamd_task_insert_result (task, cd->symbol, flag, NULL);
+
+ if (s) {
+ guint last_pos = lua_gettop (L);
+
+ for (i = cd->stack_level + first_opt; i <= last_pos; i++) {
+ if (lua_type (L, i) == LUA_TSTRING) {
+ const char *opt = lua_tostring (L, i);
+
+ rspamd_task_add_result_option (task, s, opt);
+ }
+ else if (lua_type (L, i) == LUA_TTABLE) {
+ lua_pushvalue (L, i);
+
+ for (lua_pushnil (L); lua_next (L, -2); lua_pop (L, 1)) {
+ const char *opt = lua_tostring (L, -1);
+
+ rspamd_task_add_result_option (task, s, opt);
+ }
+
+ lua_pop (L, 1);
+ }
+ }
+ }
+
+ }
+
+ lua_pop (L, nresults);
+ }
+
+ g_assert (lua_gettop (L) == cd->stack_level); /* we properly cleaned up the stack */
+
+ cd->stack_level = 0;
+ rspamd_symcache_item_async_dec_check (task, cd->item, "lua coro symbol");
+}
+
static gint
rspamd_register_symbol_fromlua (lua_State *L,
struct rspamd_config *cfg,
}
if (ref != -1) {
+ if (type & SYMBOL_TYPE_USE_CORO) {
+ /* Coroutines are incompatible with squeezing */
+ no_squeeze = TRUE;
+ }
/*
* We call for routine called lua_squeeze_rules.squeeze_rule if it exists
*/
cd->symbol = rspamd_mempool_strdup (cfg->cfg_pool, name);
}
- ret = rspamd_symcache_add_symbol (cfg->cache,
- name,
- priority,
- lua_metric_symbol_callback,
- cd,
- type,
- parent);
+ if (type & SYMBOL_TYPE_USE_CORO) {
+ ret = rspamd_symcache_add_symbol (cfg->cache,
+ name,
+ priority,
+ lua_metric_symbol_callback_coro,
+ cd,
+ type,
+ parent);
+ }
+ else {
+ ret = rspamd_symcache_add_symbol (cfg->cache,
+ name,
+ priority,
+ lua_metric_symbol_callback,
+ cd,
+ type,
+ parent);
+ }
+
rspamd_mempool_add_destructor (cfg->cfg_pool,
- (rspamd_mempool_destruct_t)lua_destroy_cfg_symbol,
+ (rspamd_mempool_destruct_t) lua_destroy_cfg_symbol,
cd);
}
}
cd->symbol = rspamd_mempool_strdup (cfg->cfg_pool, name);
}
- ret = rspamd_symcache_add_symbol (cfg->cache,
- name,
- priority,
- lua_metric_symbol_callback,
- cd,
- type,
- parent);
+ if (type & SYMBOL_TYPE_USE_CORO) {
+ ret = rspamd_symcache_add_symbol (cfg->cache,
+ name,
+ priority,
+ lua_metric_symbol_callback_coro,
+ cd,
+ type,
+ parent);
+ }
+ else {
+ ret = rspamd_symcache_add_symbol (cfg->cache,
+ name,
+ priority,
+ lua_metric_symbol_callback,
+ cd,
+ type,
+ parent);
+ }
rspamd_mempool_add_destructor (cfg->cfg_pool,
(rspamd_mempool_destruct_t)lua_destroy_cfg_symbol,
cd);
if (strstr (str, "explicit_disable") != NULL) {
ret |= SYMBOL_TYPE_EXPLICIT_DISABLE;
}
+ if (strstr (str, "coro") != NULL) {
+ ret |= SYMBOL_TYPE_USE_CORO;
+ }
}
return ret;
LUA_TRACE_POINT;
struct rspamd_config *cfg = lua_check_config (L, 1);
const gchar *name;
- gint id, nshots;
+ gint id, nshots, flags = 0;
gboolean optional = FALSE, no_squeeze = FALSE;
name = luaL_checkstring (L, 2);
* "weight" - optional weight
* "priority" - optional priority
* "type" - optional type (normal, virtual, callback)
+ * "flags" - optional flags
* -- Metric options
* "score" - optional default score (overridden by metric)
* "group" - optional default group
}
lua_pop (L, 1);
+ lua_pushstring (L, "flags");
+ lua_gettable (L, -2);
+
+ if (lua_type (L, -1) == LUA_TSTRING) {
+ type_str = lua_tostring (L, -1);
+ type |= lua_parse_symbol_flags (type_str);
+ }
+ lua_pop (L, 1);
+
lua_pushstring (L, "condition");
lua_gettable (L, -2);