]> source.dussan.org Git - rspamd.git/commitdiff
[Minor] Allow more sane flags and ids syntax when register symbols
authorVsevolod Stakhov <vsevolod@rspamd.com>
Tue, 16 May 2023 10:01:24 +0000 (11:01 +0100)
committerVsevolod Stakhov <vsevolod@rspamd.com>
Tue, 16 May 2023 10:01:24 +0000 (11:01 +0100)
For example, flags could be defined as:

```lua
rspamd_config.register_symbol{
    flags = ['ignore_passthrough', 'nice']
  }
```

instead of (compatibility is still there):
```lua
rspamd_config.register_symbol{
    flags = 'ignore_passthrough,nice'
  }
```

src/lua/lua_config.c

index c58b35c56136d88874b43691af2c0442ad6e0fbd..b0e2a475c003de7a71fa7a78efb8071c4455531b 100644 (file)
@@ -1504,43 +1504,41 @@ lua_metric_symbol_callback_return (struct thread_entry *thread_entry, int ret)
        rspamd_symcache_item_async_dec_check (task, cd->item, "lua coro symbol");
 }
 
-static guint32*
-rspamd_process_id_list (const gchar *entries, guint32 *plen)
+static GArray *
+rspamd_process_id_list (const gchar *entries)
 {
        gchar **sym_elts;
-       guint32 *ids, nids;
+       GArray *ret;
 
        sym_elts = g_strsplit_set (entries, ",;", -1);
-       nids = g_strv_length (sym_elts);
-
-       ids = g_malloc (nids * sizeof (guint32));
+       guint nids = g_strv_length (sym_elts);
+       ret = g_array_sized_new(FALSE, FALSE, sizeof (guint32), nids);
 
        for (guint i = 0; i < nids; i ++) {
-               ids[i] = rspamd_config_name_to_id (sym_elts[i], strlen (sym_elts[i]));
+               guint32 v = rspamd_config_name_to_id (sym_elts[i], strlen (sym_elts[i]));
+               g_array_append_val(ret, v);
        }
 
-       *plen = nids;
        g_strfreev (sym_elts);
 
-       return ids;
+       return ret;
 }
 
 static gint
 rspamd_register_symbol_fromlua (lua_State *L,
-               struct rspamd_config *cfg,
-               const gchar *name,
-               gint ref,
-               gdouble weight,
-               gint priority,
-               enum rspamd_symbol_type type,
-               gint parent,
-               const gchar *allowed_ids,
-               const gchar *forbidden_ids,
-               gboolean optional)
+                                                               struct rspamd_config *cfg,
+                                                               const gchar *name,
+                                                               gint ref,
+                                                               gdouble weight,
+                                                               gint priority,
+                                                               enum rspamd_symbol_type type,
+                                                               gint parent,
+                                                               GArray *allowed_ids,
+                                                               GArray *forbidden_ids,
+                                                               gboolean optional)
 {
        struct lua_callback_data *cd;
        gint ret = -1;
-       guint32 *ids, nids;
 
        if (priority == 0 && weight < 0) {
                priority = 1;
@@ -1612,47 +1610,13 @@ rspamd_register_symbol_fromlua (lua_State *L,
        }
 
        if (allowed_ids) {
-               ids = rspamd_process_id_list (allowed_ids, &nids);
-
-               if (nids > 0) {
-                       GString *dbg = g_string_new ("");
-
-                       for (guint i = 0; i < nids; i ++) {
-                               rspamd_printf_gstring (dbg, "%ud,", ids[i]);
-                       }
-
-                       dbg->len --;
-
-                       msg_debug_config ("allowed ids for %s are: %v", name, dbg);
-                       g_string_free (dbg, TRUE);
-
-                       rspamd_symcache_set_allowed_settings_ids (cfg->cache, name,
-                                       ids, nids);
-               }
-
-               g_free (ids);
+               rspamd_symcache_set_allowed_settings_ids (cfg->cache, name,
+                       &g_array_index(allowed_ids, guint32, 0), allowed_ids->len);
        }
 
        if (forbidden_ids) {
-               ids = rspamd_process_id_list (forbidden_ids, &nids);
-
-               if (nids > 0) {
-                       GString *dbg = g_string_new ("");
-
-                       for (guint i = 0; i < nids; i ++) {
-                               rspamd_printf_gstring (dbg, "%ud,", ids[i]);
-                       }
-
-                       dbg->len --;
-
-                       msg_debug_config ("forbidden ids for %s are: %v", name, dbg);
-                       g_string_free (dbg, TRUE);
-
-                       rspamd_symcache_set_forbidden_settings_ids (cfg->cache, name,
-                                       ids, nids);
-               }
-
-               g_free (ids);
+               rspamd_symcache_set_forbidden_settings_ids (cfg->cache, name,
+                       &g_array_index(forbidden_ids, guint32, 0), forbidden_ids->len);
        }
 
        return ret;
@@ -1995,26 +1959,25 @@ lua_config_register_symbol (lua_State * L)
 {
        LUA_TRACE_POINT;
        struct rspamd_config *cfg = lua_check_config (L, 1);
-       const gchar *name = NULL, *flags_str = NULL, *type_str = NULL,
-                       *description = NULL, *group = NULL, *allowed_ids = NULL,
-                       *forbidden_ids = NULL;
+       const gchar *name = NULL, *type_str = NULL,
+                       *description = NULL, *group = NULL;
        double weight = 0, score = NAN, parent_float = NAN;
        gboolean one_shot = FALSE;
-       gint ret = -1, cbref = -1, type, flags = 0;
+       gint ret = -1, cbref = -1;
+       guint type = 0, flags = 0;
        gint64 parent = 0, priority = 0, nshots = 0;
+       GArray *allowed_ids = NULL, *forbidden_ids = NULL;
        GError *err = NULL;
        int prev_top = lua_gettop(L);
 
        if (cfg) {
                if (!rspamd_lua_parse_table_arguments (L, 2, &err,
                                RSPAMD_LUA_PARSE_ARGUMENTS_DEFAULT,
-                               "name=S;weight=N;callback=F;flags=S;type=S;priority=I;parent=D;"
-                               "score=D;description=S;group=S;one_shot=B;nshots=I;"
-                               "allowed_ids=S;forbidden_ids=S",
-                               &name, &weight, &cbref, &flags_str, &type_str,
+                               "name=S;weight=N;callback=F;type=S;priority=I;parent=D;"
+                               "score=D;description=S;group=S;one_shot=B;nshots=I",
+                               &name, &weight, &cbref, &type_str,
                                &priority, &parent_float,
-                               &score, &description, &group, &one_shot, &nshots,
-                               &allowed_ids, &forbidden_ids)) {
+                               &score, &description, &group, &one_shot, &nshots)) {
                        msg_err_config ("bad arguments: %e", err);
                        g_error_free (err);
                        lua_settop(L, prev_top);
@@ -2022,6 +1985,49 @@ lua_config_register_symbol (lua_State * L)
                        return luaL_error (L, "invalid arguments");
                }
 
+               /* Deal with flags and ids */
+               lua_pushstring (L, "flags");
+               lua_gettable (L, 2);
+               if (lua_type(L, -1) == LUA_TSTRING) {
+                       flags = lua_parse_symbol_flags (lua_tostring (L, -1));
+               }
+               else if (lua_type(L, -1) == LUA_TTABLE) {
+                       for (lua_pushnil (L); lua_next (L, -2); lua_pop (L, 1)) {
+                               flags |= lua_parse_symbol_flags (lua_tostring (L, -1));
+                       }
+               }
+               lua_pop (L, 1); /* Clean flags */
+
+               lua_pushstring(L, "allowed_ids");
+               lua_gettable (L, 2);
+               if (lua_type(L, -1) == LUA_TSTRING) {
+                       allowed_ids = rspamd_process_id_list(lua_tostring (L, -1));
+               }
+               else if (lua_type(L, -1) == LUA_TTABLE) {
+                       allowed_ids = g_array_sized_new(FALSE, FALSE, sizeof (guint32),
+                                       rspamd_lua_table_size(L, -1));
+                       for (lua_pushnil (L); lua_next (L, -2); lua_pop (L, 1)) {
+                               guint32 v = lua_tointeger(L, -1);
+                               g_array_append_val(allowed_ids, v);
+                       }
+               }
+               lua_pop (L, 1);
+
+               lua_pushstring(L, "forbidden_ids");
+               lua_gettable (L, 2);
+               if (lua_type(L, -1) == LUA_TSTRING) {
+                       forbidden_ids = rspamd_process_id_list(lua_tostring (L, -1));
+               }
+               else if (lua_type(L, -1) == LUA_TTABLE) {
+                       forbidden_ids = g_array_sized_new(FALSE, FALSE, sizeof (guint32),
+                               rspamd_lua_table_size(L, -1));
+                       for (lua_pushnil (L); lua_next (L, -2); lua_pop (L, 1)) {
+                               guint32 v = lua_tointeger(L, -1);
+                               g_array_append_val(forbidden_ids, v);
+                       }
+               }
+               lua_pop (L, 1);
+
                if (nshots == 0) {
                        nshots = cfg->default_max_shots;
                }
@@ -2037,10 +2043,6 @@ lua_config_register_symbol (lua_State * L)
                        return luaL_error (L, "no callback for symbol %s", name);
                }
 
-               if (flags_str) {
-                       type |= lua_parse_symbol_flags (flags_str);
-               }
-
                if (isnan (parent_float)) {
                        parent = -1;
                }
@@ -2054,11 +2056,19 @@ lua_config_register_symbol (lua_State * L)
                                cbref,
                                weight == 0 ? 1.0 : weight,
                                priority,
-                               type,
+                               type | flags,
                                parent,
                                allowed_ids, forbidden_ids,
                                FALSE);
 
+               if (allowed_ids) {
+                       g_array_free(allowed_ids, TRUE);
+               }
+
+               if (forbidden_ids) {
+                       g_array_free(forbidden_ids, TRUE);
+               }
+
                if (ret != -1) {
                        if (!isnan(score) || group) {
                                if (one_shot) {
@@ -2628,8 +2638,10 @@ lua_config_newindex (lua_State *L)
 {
        LUA_TRACE_POINT;
        struct rspamd_config *cfg = lua_check_config (L, 1);
-       const gchar *name, *allowed_ids = NULL, *forbidden_ids = NULL;
-       gint id, nshots, flags = 0;
+       const gchar *name;
+       GArray *allowed_ids = NULL, *forbidden_ids = NULL;
+       gint id, nshots;
+       guint flags = 0;
        gboolean optional = FALSE;
 
        name = luaL_checkstring (L, 2);
@@ -2651,7 +2663,8 @@ lua_config_newindex (lua_State *L)
                                        FALSE);
                }
                else if (lua_type (L, 3) == LUA_TTABLE) {
-                       gint type = SYMBOL_TYPE_NORMAL, priority = 0, idx;
+                       guint type = SYMBOL_TYPE_NORMAL, priority = 0;
+                       gint idx;
                        gdouble weight = 1.0, score = NAN;
                        const char *type_str, *group = NULL, *description = NULL;
 
@@ -2693,7 +2706,7 @@ lua_config_newindex (lua_State *L)
                        lua_gettable (L, -2);
 
                        if (lua_type (L, -1) == LUA_TNUMBER) {
-                               priority = lua_tonumber (L, -1);
+                               priority = lua_tointeger(L, -1);
                        }
                        lua_pop (L, 1);
 
@@ -2714,28 +2727,31 @@ lua_config_newindex (lua_State *L)
                        }
                        lua_pop (L, 1);
 
+                       /* Deal with flags and ids */
                        lua_pushstring (L, "flags");
-                       lua_gettable (L, -2);
-
-                       if (lua_type (L, -1) == LUA_TSTRING) {
-                               type_str = lua_tostring (L, -1);
-                               type |= lua_parse_symbol_flags (type_str);
+                       lua_gettable (L, 2);
+                       if (lua_type(L, -1) == LUA_TSTRING) {
+                               flags = lua_parse_symbol_flags (lua_tostring (L, -1));
                        }
-                       lua_pop (L, 1);
-
-                       lua_pushstring (L, "allowed_ids");
-                       lua_gettable (L, -2);
-
-                       if (lua_type (L, -1) == LUA_TSTRING) {
-                               allowed_ids = lua_tostring (L, -1);
+                       else if (lua_type(L, -1) == LUA_TTABLE) {
+                               for (lua_pushnil (L); lua_next (L, -2); lua_pop (L, 1)) {
+                                       flags |= lua_parse_symbol_flags (lua_tostring (L, -1));
+                               }
                        }
-                       lua_pop (L, 1);
+                       lua_pop (L, 1); /* Clean flags */
 
-                       lua_pushstring (L, "forbidden_ids");
-                       lua_gettable (L, -2);
-
-                       if (lua_type (L, -1) == LUA_TSTRING) {
-                               forbidden_ids = lua_tostring (L, -1);
+                       lua_pushstring(L, "allowed_ids");
+                       lua_gettable (L, 2);
+                       if (lua_type(L, -1) == LUA_TSTRING) {
+                               allowed_ids = rspamd_process_id_list(lua_tostring (L, -1));
+                       }
+                       else if (lua_type(L, -1) == LUA_TTABLE) {
+                               allowed_ids = g_array_sized_new(FALSE, FALSE, sizeof (guint32),
+                                       rspamd_lua_table_size(L, -1));
+                               for (lua_pushnil (L); lua_next (L, -2); lua_pop (L, 1)) {
+                                       guint32 v = lua_tointeger(L, -1);
+                                       g_array_append_val(allowed_ids, v);
+                               }
                        }
                        lua_pop (L, 1);
 
@@ -2745,11 +2761,19 @@ lua_config_newindex (lua_State *L)
                                        idx,
                                        weight,
                                        priority,
-                                       type,
+                                       type | flags,
                                        -1,
                                        allowed_ids, forbidden_ids,
                                        optional);
 
+                       if (allowed_ids) {
+                               g_array_free (allowed_ids, TRUE);
+                       }
+
+                       if (forbidden_ids) {
+                               g_array_free (forbidden_ids, TRUE);
+                       }
+
                        if (id != -1) {
                                /* Check for condition */
                                lua_pushstring (L, "condition");