]> source.dussan.org Git - rspamd.git/commitdiff
[Fix] Allow to disable classifiers checks using settings and conditions
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Mon, 27 Feb 2017 17:44:48 +0000 (17:44 +0000)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Mon, 27 Feb 2017 17:44:48 +0000 (17:44 +0000)
src/libserver/symbols_cache.c
src/libserver/symbols_cache.h
src/libstat/stat_process.c

index 03b60a2c231535ee6e48759f8de8113990265ecd..3378cc142b00bea3434266f999b8a8117104d6c5 100644 (file)
@@ -2363,3 +2363,58 @@ rspamd_symbols_cache_get_cksum (struct symbols_cache *cache)
 
        return cache->cksum;
 }
+
+
+gboolean
+rspamd_symbols_cache_is_symbol_enabled (struct rspamd_task *task,
+               struct symbols_cache *cache, const gchar *symbol)
+{
+       gint id;
+       struct cache_savepoint *checkpoint;
+       struct cache_item *item;
+       lua_State *L;
+       struct rspamd_task **ptask;
+       gboolean ret = TRUE;
+
+       g_assert (cache != NULL);
+       g_assert (symbol != NULL);
+
+       id = rspamd_symbols_cache_find_symbol_parent (cache, symbol);
+
+       if (id < 0) {
+               return FALSE;
+       }
+
+       checkpoint = task->checkpoint;
+       item = g_ptr_array_index (cache->items_by_id, id);
+
+       if (checkpoint) {
+               if (isset (checkpoint->processed_bits, id * 2)) {
+                       return FALSE;
+               }
+               else {
+                       if (item->condition_cb != -1) {
+                               /* We also executes condition callback to check if we need this symbol */
+                               L = task->cfg->lua_state;
+                               lua_rawgeti (L, LUA_REGISTRYINDEX, item->condition_cb);
+                               ptask = lua_newuserdata (L, sizeof (struct rspamd_task *));
+                               rspamd_lua_setclass (L, "rspamd{task}", -1);
+                               *ptask = task;
+
+                               if (lua_pcall (L, 1, 1, 0) != 0) {
+                                       msg_info_task ("call to condition for %s failed: %s",
+                                                       item->symbol, lua_tostring (L, -1));
+                                       lua_pop (L, 1);
+                               }
+                               else {
+                                       ret = lua_toboolean (L, -1);
+                                       lua_pop (L, 1);
+                               }
+                       }
+
+                       return ret;
+               }
+       }
+
+       return FALSE;
+}
index eadffded39cb37a2529d333c37ebccc9abb467f8..b003eebab4f8a4876a96b89ebea6c38332e27b03 100644 (file)
@@ -282,4 +282,14 @@ gboolean rspamd_symbols_cache_is_checked (struct rspamd_task *task,
  * @return
  */
 guint64 rspamd_symbols_cache_get_cksum (struct symbols_cache *cache);
+
+/**
+ * Checks if a symbols is enabled (not checked and conditions return true if present)
+ * @param task
+ * @param cache
+ * @param symbol
+ * @return
+ */
+gboolean rspamd_symbols_cache_is_symbol_enabled (struct rspamd_task *task,
+               struct symbols_cache *cache, const gchar *symbol);
 #endif
index 00b26ee2e3072b03d71a401140cfaa313d2fc2fb..7eff21e52e199dbd9af38f86cf943344bd66bac1 100644 (file)
@@ -318,6 +318,12 @@ rspamd_stat_preprocess (struct rspamd_stat_ctx *st_ctx,
                        continue;
                }
 
+               if (!rspamd_symbols_cache_is_symbol_enabled (task, task->cfg->cache,
+                               st->stcf->symbol)) {
+                       g_ptr_array_index (task->stat_runtimes, i) = NULL;
+                       continue;
+               }
+
                bk_run = st->backend->runtime (task, st->stcf, learn, st->bkcf);
 
                if (bk_run == NULL) {
@@ -349,7 +355,6 @@ rspamd_stat_backends_process (struct rspamd_stat_ctx *st_ctx,
                }
 
                bk_run = g_ptr_array_index (task->stat_runtimes, i);
-               g_assert (st != NULL);
 
                if (bk_run != NULL) {
                        st->backend->process_tokens (task, task->tokens, i, bk_run);
@@ -377,7 +382,6 @@ rspamd_stat_backends_post_process (struct rspamd_stat_ctx *st_ctx,
                }
 
                bk_run = g_ptr_array_index (task->stat_runtimes, i);
-               g_assert (st != NULL);
 
                if (bk_run != NULL) {
                        st->backend->finalize_process (task, bk_run, st_ctx);
@@ -389,10 +393,11 @@ static void
 rspamd_stat_classifiers_process (struct rspamd_stat_ctx *st_ctx,
                struct rspamd_task *task)
 {
-       guint i;
+       guint i, j, id;
        struct rspamd_classifier *cl;
        struct rspamd_statfile *st;
        gpointer bk_run;
+       gboolean skip;
 
        if (st_ctx->classifiers->len == 0) {
                return;
@@ -442,28 +447,45 @@ rspamd_stat_classifiers_process (struct rspamd_stat_ctx *st_ctx,
 
                g_assert (cl != NULL);
 
-               if (cl->cfg->min_tokens > 0 && task->tokens->len < cl->cfg->min_tokens) {
-                       msg_debug_task (
-                                       "<%s> contains less tokens than required for %s classifier: "
-                                       "%ud < %ud",
-                                       task->message_id,
-                                       cl->cfg->name,
-                                       task->tokens->len,
-                                       cl->cfg->min_tokens);
-                       continue;
-               }
-               else if (cl->cfg->max_tokens > 0 && task->tokens->len > cl->cfg->max_tokens) {
-                       msg_debug_task (
-                                       "<%s> contains more tokens than allowed for %s classifier: "
-                                       "%ud > %ud",
-                                       task->message_id,
-                                       cl->cfg->name,
-                                       task->tokens->len,
-                                       cl->cfg->max_tokens);
-                       continue;
+               /* Ensure that all symbols enabled */
+               skip = FALSE;
+
+               if (!(cl->cfg->flags & RSPAMD_FLAG_CLASSIFIER_NO_BACKEND)) {
+                       for (j = 0; j < cl->statfiles_ids->len; i++) {
+                               id = g_array_index (cl->statfiles_ids, gint, i);
+                               bk_run =  g_ptr_array_index (task->stat_runtimes, id);
+
+                               if (bk_run == NULL) {
+                                       skip = TRUE;
+                                       break;
+                               }
+                       }
                }
 
-               cl->subrs->classify_func (cl, task->tokens, task);
+               if (!skip) {
+                       if (cl->cfg->min_tokens > 0 && task->tokens->len < cl->cfg->min_tokens) {
+                               msg_debug_task (
+                                               "<%s> contains less tokens than required for %s classifier: "
+                                               "%ud < %ud",
+                                               task->message_id,
+                                               cl->cfg->name,
+                                               task->tokens->len,
+                                               cl->cfg->min_tokens);
+                               continue;
+                       }
+                       else if (cl->cfg->max_tokens > 0 && task->tokens->len > cl->cfg->max_tokens) {
+                               msg_debug_task (
+                                               "<%s> contains more tokens than allowed for %s classifier: "
+                                               "%ud > %ud",
+                                               task->message_id,
+                                               cl->cfg->name,
+                                               task->tokens->len,
+                                               cl->cfg->max_tokens);
+                               continue;
+                       }
+
+                       cl->subrs->classify_func (cl, task->tokens, task);
+               }
        }
 }