]> source.dussan.org Git - rspamd.git/commitdiff
[Project] Implement some conditions checks
authorVsevolod Stakhov <vsevolod@rspamd.com>
Sun, 24 Apr 2022 10:59:00 +0000 (11:59 +0100)
committerVsevolod Stakhov <vsevolod@rspamd.com>
Sun, 24 Apr 2022 10:59:00 +0000 (11:59 +0100)
src/libserver/symcache/symcache_id_list.hxx
src/libserver/symcache/symcache_item.cxx
src/libserver/symcache/symcache_item.hxx
src/libserver/symcache/symcache_runtime.cxx

index b42cb9183ad665da7fde266ad14aedd7d90078ce..444ac2079fae3b7f4597872ac46e7c994ce5ce35 100644 (file)
@@ -20,7 +20,7 @@
 
 #include <cstdint>
 #include <cstring> // for memset
-#include <algorithm> // for sort
+#include <algorithm> // for sort/bsearch
 
 #include "config.h"
 #include "libutil/mem_pool.h"
@@ -43,7 +43,9 @@ struct id_list {
        } data;
 
        id_list() = default;
-       auto reset() {
+
+       auto reset()
+       {
                std::memset(&data, 0, sizeof(data));
        }
 
@@ -121,7 +123,8 @@ struct id_list {
                }
        }
 
-       auto set_ids(const std::uint32_t *ids, std::size_t nids, rspamd_mempool_t *pool) -> void {
+       auto set_ids(const std::uint32_t *ids, std::size_t nids, rspamd_mempool_t *pool) -> void
+       {
                if (nids <= G_N_ELEMENTS(data.st)) {
                        /* Use static version */
                        reset();
@@ -145,6 +148,25 @@ struct id_list {
                        std::sort(data.dyn.n, data.dyn.n + data.dyn.len);
                }
        }
+
+       auto check_id(unsigned int id) const -> bool
+       {
+               if (data.dyn.e == -1) {
+                       return std::binary_search(data.dyn.n, data.dyn.n + data.dyn.len, id);
+               }
+               else {
+                       for (auto elt: data.st) {
+                               if (elt == id) {
+                                       return true;
+                               }
+                               else if (elt == 0) {
+                                       return false;
+                               }
+                       }
+               }
+
+               return false;
+       }
 };
 
 static_assert(std::is_trivial_v<id_list>);
index 0ca080ac0da417ad0c2fdb926e9654b3bca66334..99e3cfb5da9418fd959a78f6b41a189ff081b8ad 100644 (file)
@@ -17,6 +17,8 @@
 #include "symcache_internal.hxx"
 #include "symcache_item.hxx"
 #include "fmt/core.h"
+#include "libserver/task.h"
+#include "lua/lua_common.h"
 
 namespace rspamd::symcache {
 
@@ -205,7 +207,7 @@ auto cache_item::update_counters_check_peak(lua_State *L,
 
 auto cache_item::get_type_str() const -> const char *
 {
-       switch(type) {
+       switch (type) {
        case symcache_item_type::CONNFILTER:
                return "connfilter";
        case symcache_item_type::FILTER:
@@ -227,6 +229,100 @@ auto cache_item::get_type_str() const -> const char *
        RSPAMD_UNREACHABLE;
 }
 
+auto cache_item::is_item_allowed(struct rspamd_task *task, bool exec_only) -> bool
+{
+       const auto *what = "execution";
+
+       if (!exec_only) {
+               what = "symbol insertion";
+       }
+
+       /* Static checks */
+       if (!enabled ||
+               (RSPAMD_TASK_IS_EMPTY(task) && !(flags & SYMBOL_TYPE_EMPTY)) ||
+               (flags & SYMBOL_TYPE_MIME_ONLY && !RSPAMD_TASK_IS_MIME(task))) {
+
+               if (!enabled) {
+                       msg_debug_cache_task("skipping %s of %s as it is permanently disabled",
+                                       what, symbol.c_str());
+
+                       return false;
+               }
+               else {
+                       /*
+                        * If we check merely execution (not insertion), then we disallow
+                        * mime symbols for non mime tasks and vice versa
+                        */
+                       if (exec_only) {
+                               msg_debug_cache_task("skipping check of %s as it cannot be "
+                                                                        "executed for this task type",
+                                               symbol.c_str());
+
+                               return FALSE;
+                       }
+               }
+       }
+
+       /* Settings checks */
+       if (task->settings_elt != nullptr) {
+               if (forbidden_ids.check_id(task->settings_elt->id)) {
+                       msg_debug_cache_task ("deny %s of %s as it is forbidden for "
+                                                                 "settings id %ud",
+                                       what,
+                                       symbol.c_str(),
+                                       task->settings_elt->id);
+
+                       return false;
+               }
+
+               if (!(flags & SYMBOL_TYPE_EXPLICIT_DISABLE)) {
+                       if (allowed_ids.check_id(task->settings_elt->id)) {
+
+                               if (task->settings_elt->policy == RSPAMD_SETTINGS_POLICY_IMPLICIT_ALLOW) {
+                                       msg_debug_cache_task("allow execution of %s settings id %ud "
+                                                                                "allows implicit execution of the symbols;",
+                                                       symbol.c_str(),
+                                                       id);
+
+                                       return true;
+                               }
+
+                               if (exec_only) {
+                                       /*
+                                        * Special case if any of our virtual children are enabled
+                                        */
+                                       if (exec_only_ids.check_id(task->settings_elt->id)) {
+                                               return true;
+                                       }
+                               }
+
+                               msg_debug_cache_task ("deny %s of %s as it is not listed "
+                                                                         "as allowed for settings id %ud",
+                                               what,
+                                               symbol.c_str(),
+                                               task->settings_elt->id);
+                               return false;
+                       }
+               }
+               else {
+                       msg_debug_cache_task ("allow %s of %s for "
+                                                                 "settings id %ud as it can be only disabled explicitly",
+                                       what,
+                                       symbol.c_str(),
+                                       task->settings_elt->id);
+               }
+       }
+       else if (flags & SYMBOL_TYPE_EXPLICIT_ENABLE) {
+               msg_debug_cache_task ("deny %s of %s as it must be explicitly enabled",
+                               what,
+                               symbol.c_str());
+               return false;
+       }
+
+       /* Allow all symbols with no settings id */
+       return true;
+}
+
 auto virtual_item::get_parent(const symcache &cache) const -> const cache_item *
 {
        if (parent) {
@@ -331,4 +427,41 @@ bool operator<(symcache_item_type lhs, symcache_item_type rhs)
        return ret;
 }
 
+item_condition::~item_condition()
+{
+       if (cb != -1 && L != nullptr) {
+               luaL_unref(L, LUA_REGISTRYINDEX, cb);
+       }
+}
+
+auto item_condition::check(std::string_view sym_name, struct rspamd_task *task) const -> bool
+{
+       if (cb != -1 && L != nullptr) {
+               auto ret = false;
+
+               lua_rawgeti(L, LUA_REGISTRYINDEX, cb);
+
+               lua_pushcfunction (L, &rspamd_lua_traceback);
+               auto err_idx = lua_gettop(L);
+
+               auto **ptask = (struct rspamd_task **) lua_newuserdata(L, sizeof(struct rspamd_task *));
+               rspamd_lua_setclass(L, "rspamd{task}", -1);
+               *ptask = task;
+
+               if (lua_pcall(L, 1, 1, err_idx) != 0) {
+                       msg_info_task("call to condition for %s failed: %s",
+                                       sym_name.data(), lua_tostring(L, -1));
+               }
+               else {
+                       ret = lua_toboolean(L, -1);
+               }
+
+               lua_settop(L, err_idx - 1);
+
+               return ret;
+       }
+
+       return true;
+}
+
 }
index 18a46d3176e3e5d25c7223accd9572f97c7fac31..484065cc2785c7ffc7fd0ba89bc046ff59425980 100644 (file)
 #include <string_view>
 #include <memory>
 #include <variant>
+#include <algorithm>
 
 #include "rspamd_symcache.h"
 #include "symcache_id_list.hxx"
 #include "contrib/expected/expected.hpp"
 #include "contrib/libev/ev.h"
-#include "lua/lua_common.h"
 
 namespace rspamd::symcache {
 
@@ -67,16 +67,10 @@ private:
        lua_State *L;
        int cb;
 public:
-       item_condition(lua_State *_L, int _cb) : L(_L), cb(_cb)
-       {
-       }
+       item_condition(lua_State *_L, int _cb) : L(_L), cb(_cb) {}
+       virtual ~item_condition();
 
-       virtual ~item_condition()
-       {
-               if (cb != -1 && L != nullptr) {
-                       luaL_unref(L, LUA_REGISTRYINDEX, cb);
-               }
-       }
+       auto check(std::string_view sym_name, struct rspamd_task *task) const -> bool;
 };
 
 class normal_item {
@@ -98,6 +92,11 @@ public:
        {
                // TODO
        }
+
+       auto check_conditions(std::string_view sym_name, struct rspamd_task *task) -> bool {
+               return std::all_of(std::begin(conditions), std::end(conditions),
+                                                  [&](const auto &cond) { return cond.check(sym_name, task); });
+       }
 };
 
 class virtual_item {
@@ -288,10 +287,21 @@ public:
                                                                        double cur_time,
                                                                        double last_resort) -> bool;
 
+       /**
+        * Increase frequency for a symbol
+        */
        auto inc_frequency() -> void {
                g_atomic_int_inc(&st->hits);
        }
 
+       /**
+        * Check if an item is allowed to be executed not checking item conditions
+        * @param task
+        * @param exec_only
+        * @return
+        */
+       auto is_item_allowed(struct rspamd_task *task, bool exec_only) -> bool;
+
 private:
        /**
         * Constructor for a normal symbols with callback
index c86f79087bc0d3543b5c889a2ec89c96cb1df31a..74dd4a18329a0d22fe2985218cf595e8eaf52bee 100644 (file)
@@ -18,6 +18,7 @@
 #include "symcache_item.hxx"
 #include "symcache_runtime.hxx"
 #include "libutil/cxx/util.hxx"
+#include "libserver/task.h"
 
 namespace rspamd::symcache {
 
@@ -31,13 +32,11 @@ constexpr static const auto PROFILE_PROBABILITY = 0.01;
 auto
 symcache_runtime::create_savepoint(struct rspamd_task *task, symcache &cache) -> symcache_runtime *
 {
-       struct symcache_runtime *checkpoint;
-
        cache.maybe_resort();
 
        auto &&cur_order = cache.get_cache_order();
-       checkpoint = (struct symcache_runtime *) rspamd_mempool_alloc0 (task->task_pool,
-                       sizeof(*checkpoint) +
+       auto *checkpoint = (symcache_runtime *) rspamd_mempool_alloc0 (task->task_pool,
+                       sizeof(symcache_runtime) +
                        sizeof(struct cache_dynamic_item) * cur_order->size());
 
        checkpoint->order = cache.get_cache_order();