]> source.dussan.org Git - rspamd.git/commitdiff
[Feature] Allow async events to be registered from LUA rules
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Fri, 21 Apr 2017 09:37:13 +0000 (10:37 +0100)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Fri, 21 Apr 2017 09:37:13 +0000 (10:37 +0100)
Here is an example:

```lua

rspamd_config.ASYNC_RULE = {
  callback = function(task)
    local ret = false

    local function test1(task)
      local function test2(task)
        return ret
      end

      ret = true
      return test2
    end

    return test1
  end
}
```

src/lua/lua_config.c

index 636f4343bad3d6fdafe2c1633062f74fb9a98221..710604cf6c0b5e42c7084b36351be3a24d21c08a 100644 (file)
@@ -846,6 +846,11 @@ struct lua_callback_data {
        gint order;
 };
 
+struct lua_watcher_data {
+       struct lua_callback_data *cbd;
+       gint cb_ref;
+};
+
 /*
  * Unref symbol if it is local reference
  */
@@ -875,6 +880,120 @@ rspamd_compare_order_func (gconstpointer a, gconstpointer b)
        return cb2->order - cb1->order;
 }
 
+static void
+lua_watcher_callback (gpointer session_data, gpointer ud)
+{
+       struct rspamd_task *task = session_data, **ptask;
+       struct lua_watcher_data *wd = ud;
+       lua_State *L;
+       gint level, nresults, err_idx, ret;
+       GString *tb;
+       struct rspamd_symbol_result *s;
+
+       L = wd->cbd->L;
+       level = lua_gettop (L);
+       lua_pushcfunction (L, &rspamd_lua_traceback);
+       err_idx = lua_gettop (L);
+
+       level ++;
+       lua_rawgeti (L, LUA_REGISTRYINDEX, wd->cb_ref);
+
+       ptask = lua_newuserdata (L, sizeof (struct rspamd_task *));
+       rspamd_lua_setclass (L, "rspamd{task}", -1);
+       *ptask = task;
+
+       if ((ret = lua_pcall (L, 1, LUA_MULTRET, err_idx)) != 0) {
+               tb = lua_touserdata (L, -1);
+               msg_err_task ("call to (%s) failed (%d): %v",
+                               wd->cbd->symbol, ret, tb);
+
+               if (tb) {
+                       g_string_free (tb, TRUE);
+                       lua_pop (L, 1);
+               }
+       }
+       else {
+               nresults = lua_gettop (L) - level;
+
+               if (nresults >= 1) {
+                       /* Function returned boolean, so maybe we need to insert result? */
+                       gint res = 0;
+                       gint i;
+                       gdouble flag = 1.0;
+                       gint type;
+                       struct lua_watcher_data *nwd;
+
+                       type = lua_type (L, level + 1);
+
+                       if (type == LUA_TBOOLEAN) {
+                               res = lua_toboolean (L, level + 1);
+                       }
+                       else if (type == LUA_TFUNCTION) {
+                               /* Function returned a closure that should be watched for */
+                               nwd = rspamd_mempool_alloc (task->task_pool, sizeof (*nwd));
+                               lua_pushvalue (L, level + 1);
+                               nwd->cb_ref = luaL_ref (L, LUA_REGISTRYINDEX);
+                               nwd->cbd = wd->cbd;
+                               rspamd_session_watcher_push_callback (task->s,
+                                               rspamd_session_get_watcher (task->s),
+                                               lua_watcher_callback, nwd);
+                               /*
+                                * We immediately pop watcher since we have not registered
+                                * any async events from here
+                                */
+                               rspamd_session_watcher_pop (task->s,
+                                               rspamd_session_get_watcher (task->s));
+                       }
+                       else {
+                               res = lua_tonumber (L, level + 1);
+                       }
+
+                       if (res) {
+                               gint first_opt = 2;
+
+                               if (lua_type (L, level + 2) == LUA_TNUMBER) {
+                                       flag = lua_tonumber (L, level + 2);
+                                       /* Shift opt index */
+                                       first_opt = 3;
+                               }
+                               else {
+                                       flag = res;
+                               }
+
+                               s = rspamd_task_insert_result (task,
+                                               wd->cbd->symbol, flag, NULL);
+
+                               if (s) {
+                                       guint last_pos = lua_gettop (L);
+
+                                       for (i = level + first_opt; i <= last_pos; i++) {
+                                               if (lua_type (L, i) == LUA_TSTRING) {
+                                                       const char *opt = lua_tostring (L, i);
+
+                                                       rspamd_task_add_result_option (task, s, opt);
+                                               }
+                                               else if (lua_type (L, i) == LUA_TTABLE) {
+                                                       lua_pushvalue (L, i);
+
+                                                       for (lua_pushnil (L); lua_next (L, -2); lua_pop (L, 1)) {
+                                                               const char *opt = lua_tostring (L, -1);
+
+                                                               rspamd_task_add_result_option (task, s, opt);
+                                                       }
+
+                                                       lua_pop (L, 1);
+                                               }
+                                       }
+                               }
+                       }
+
+                       lua_pop (L, nresults);
+               }
+       }
+
+       lua_pop (L, 1); /* Error function */
+}
+
 static void
 lua_metric_symbol_callback (struct rspamd_task *task, gpointer ud)
 {
@@ -918,10 +1037,30 @@ lua_metric_symbol_callback (struct rspamd_task *task, gpointer ud)
                        gint res = 0;
                        gint i;
                        gdouble flag = 1.0;
+                       gint type;
+                       struct lua_watcher_data *wd;
 
-                       if (lua_type (cd->L, level + 1) == LUA_TBOOLEAN) {
+                       type = lua_type (cd->L, level + 1);
+
+                       if (type == LUA_TBOOLEAN) {
                                res = lua_toboolean (L, level + 1);
                        }
+                       else if (type == LUA_TFUNCTION) {
+                               /* Function returned a closure that should be watched for */
+                               wd = rspamd_mempool_alloc (task->task_pool, sizeof (*wd));
+                               lua_pushvalue (cd->L, level + 1);
+                               wd->cb_ref = luaL_ref (L, LUA_REGISTRYINDEX);
+                               wd->cbd = cd;
+                               rspamd_session_watcher_push_callback (task->s,
+                                               rspamd_session_get_watcher (task->s),
+                                               lua_watcher_callback, wd);
+                               /*
+                                * We immediately pop watcher since we have not registered
+                                * any async events from here
+                                */
+                               rspamd_session_watcher_pop (task->s,
+                                               rspamd_session_get_watcher (task->s));
+                       }
                        else {
                                res = lua_tonumber (L, level + 1);
                        }
@@ -1029,6 +1168,7 @@ rspamd_register_symbol_fromlua (lua_State *L,
                                type,
                                parent);
        }
+
        rspamd_mempool_add_destructor (cfg->cfg_pool,
                        (rspamd_mempool_destruct_t)lua_destroy_cfg_symbol,
                        cd);