]> source.dussan.org Git - rspamd.git/commitdiff
[Fix] Neural: Distinguish missing symbols from symbols with low scores
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Tue, 25 Feb 2020 16:04:57 +0000 (16:04 +0000)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Tue, 25 Feb 2020 16:04:57 +0000 (16:04 +0000)
src/lua/lua_task.c
src/plugins/lua/neural.lua

index 82556279bebe760752962e8008f1760dd40fee78..fd1877daf72225bd00d4e3ba1ac685f995ad9dd7 100644 (file)
@@ -673,11 +673,12 @@ LUA_FUNCTION_DEF (task, get_symbols_numeric);
 LUA_FUNCTION_DEF (task, get_symbols_tokens);
 
 /***
- * @method task:process_ann_tokens(symbols, ann_tokens, offset)
+ * @method task:process_ann_tokens(symbols, ann_tokens, offset, [min])
  * Processes ann tokens
  * @param {table|string} symbols list of symbols in this profile
  * @param {table|number} ann_tokens list of tokens (including metatokens)
  * @param {integer} offset offset for symbols token (#metatokens)
+ * @param {number} min minimum value for symbols found (e.g. for 0 score symbols)
  * @return nothing
  */
 LUA_FUNCTION_DEF (task, process_ann_tokens);
@@ -4760,9 +4761,13 @@ lua_task_process_ann_tokens (lua_State *L)
        LUA_TRACE_POINT;
        struct rspamd_task *task = lua_check_task (L, 1);
        gint offset = luaL_checkinteger (L, 4);
+       gdouble min_score = 0.0;
 
        if (task && lua_istable (L, 2) && lua_istable (L, 3)) {
                guint symlen = rspamd_lua_table_size (L, 2);
+               if (lua_isnumber (L, 5)) {
+                       min_score = lua_tonumber (L, 5);
+               }
 
                for (guint i = 1; i <= symlen; i ++, offset ++) {
                        const gchar *sym;
@@ -4778,8 +4783,10 @@ lua_task_process_ann_tokens (lua_State *L)
                                if (!isnan (sres->score) && !isinf (sres->score) &&
                                                (!sres->sym ||
                                                        !(rspamd_symcache_item_flags (sres->sym->cache_item) & SYMBOL_TYPE_NOSTAT))) {
+                                       gdouble norm_score = fabs (tanh (sres->score));
+
 
-                                       lua_pushnumber (L, fabs (tanh (sres->score)));
+                                       lua_pushnumber (L, MAX (min_score , norm_score));
                                        lua_rawseti (L, 3, offset + 1);
                                }
                        }
index e7f126398e35d40a1e7cf7465782c3b8ff007117..9fcb2d8861809305555a9fbf78955b6a5cd9ad19 100644 (file)
@@ -247,7 +247,7 @@ local function result_to_vector(task, profile)
     vec[i] = v
   end
 
-  task:process_ann_tokens(profile.symbols, vec, #mt)
+  task:process_ann_tokens(profile.symbols, vec, #mt, 0.1)
 
   return vec
 end