aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVsevolod Stakhov <vsevolod@rspamd.com>2024-04-09 17:51:12 +0600
committerGitHub <noreply@github.com>2024-04-09 17:51:12 +0600
commit3c525c1f73fcc294375842996ab30751f704e4ca (patch)
treed0d04fb91d55ece3f2d3afb9b2d7c619f127e6de
parent748f80fdfe4076488442aecf418fa8ac30bce72e (diff)
parenta96ff38346fea032d4800e14b254aa5708ba7843 (diff)
downloadrspamd-3c525c1f73fcc294375842996ab30751f704e4ca.tar.gz
rspamd-3c525c1f73fcc294375842996ab30751f704e4ca.zip
Merge pull request #4912 from rspamd/vstakhov-grow-factor-rework
Rework grow factor
-rw-r--r--src/libmime/scan_result.c74
-rw-r--r--src/libmime/scan_result.h18
-rw-r--r--src/libserver/protocol.c8
-rw-r--r--src/libserver/task.c4
4 files changed, 60 insertions, 44 deletions
diff --git a/src/libmime/scan_result.c b/src/libmime/scan_result.c
index b9660071b..09c3208cf 100644
--- a/src/libmime/scan_result.c
+++ b/src/libmime/scan_result.c
@@ -231,7 +231,7 @@ insert_metric_result(struct rspamd_task *task,
bool *new_sym)
{
struct rspamd_symbol_result *symbol_result = NULL;
- double final_score, *gr_score = NULL, next_gf = 1.0, diff;
+ double final_score, *gr_score = NULL, diff;
struct rspamd_symbol *sdef;
struct rspamd_symbols_group *gr = NULL;
const ucl_object_t *mobj, *sobj;
@@ -368,17 +368,6 @@ insert_metric_result(struct rspamd_task *task,
}
if (diff) {
- /* Handle grow factor */
- if (metric_res->grow_factor && diff > 0) {
- diff *= metric_res->grow_factor;
- next_gf *= task->cfg->grow_factor;
- }
- else if (diff > 0) {
- next_gf = task->cfg->grow_factor;
- }
-
- msg_debug_metric("adjust grow factor to %.2f for symbol %s (%.2f final)",
- next_gf, symbol, diff);
if (sdef) {
PTR_ARRAY_FOREACH(sdef->groups, i, gr)
@@ -418,8 +407,6 @@ insert_metric_result(struct rspamd_task *task,
}
if (!isnan(diff)) {
- metric_res->score += diff;
- metric_res->grow_factor = next_gf;
if (single) {
msg_debug_metric("final score for single symbol %s = %.2f; %.2f diff",
@@ -447,18 +434,6 @@ insert_metric_result(struct rspamd_task *task,
symbol_result = rspamd_mempool_alloc0(task->task_pool, sizeof(*symbol_result));
kh_value(metric_res->symbols, k) = symbol_result;
- /* Handle grow factor */
- if (metric_res->grow_factor && final_score > 0) {
- final_score *= metric_res->grow_factor;
- next_gf *= task->cfg->grow_factor;
- }
- else if (final_score > 0) {
- next_gf = task->cfg->grow_factor;
- }
-
- msg_debug_metric("adjust grow factor to %.2f for symbol %s (%.2f final)",
- next_gf, symbol, final_score);
-
symbol_result->name = sym_cpy;
symbol_result->sym = sdef;
symbol_result->nshots = 1;
@@ -503,7 +478,6 @@ insert_metric_result(struct rspamd_task *task,
const double epsilon = DBL_EPSILON;
metric_res->score += final_score;
- metric_res->grow_factor = next_gf;
symbol_result->score = final_score;
if (final_score > epsilon) {
@@ -1104,3 +1078,49 @@ rspamd_find_metric_result(struct rspamd_task *task,
return NULL;
}
+
+void rspamd_task_result_adjust_grow_factor(struct rspamd_task *task,
+ struct rspamd_scan_result *result,
+ double grow_factor)
+{
+ const char *kk;
+ struct rspamd_symbol_result *res;
+ double final_grow_factor = grow_factor;
+ double max_limit = G_MINDOUBLE;
+
+ if (grow_factor > 1.0) {
+
+ for (unsigned int i = 0; i < result->nactions; i++) {
+ struct rspamd_action_config *cur = &result->actions_config[i];
+
+ if (cur->cur_limit > 0 && max_limit < cur->cur_limit) {
+ max_limit = cur->cur_limit;
+ }
+ }
+
+ /* Adjust factor by selecting all symbols and checking those with positive scores */
+ kh_foreach(result->symbols, kk, res, {
+ if (res->score > 0) {
+ double mult = 1.0 - grow_factor;
+ /* We adjust the factor by the ratio of the score to the max limit */
+ if (max_limit > 0 && !isnan(res->score)) {
+ mult *= res->score / max_limit;
+ final_grow_factor *= 1.0 + mult;
+ }
+ }
+ });
+
+ /* At this stage we know that we have some grow factor to apply */
+ if (final_grow_factor > 1.0) {
+ msg_info_task("calculated final grow factor for task: %.3f (%.2f the original one)",
+ final_grow_factor, grow_factor);
+ kh_foreach(result->symbols, kk, res, {
+ if (res->score > 0) {
+ result->score -= res->score;
+ res->score *= final_grow_factor;
+ result->score += res->score;
+ }
+ });
+ }
+ }
+}
diff --git a/src/libmime/scan_result.h b/src/libmime/scan_result.h
index d4572e1d8..12fdb9459 100644
--- a/src/libmime/scan_result.h
+++ b/src/libmime/scan_result.h
@@ -1,5 +1,5 @@
/*
- * Copyright 2023 Vsevolod Stakhov
+ * Copyright 2024 Vsevolod Stakhov
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -99,8 +99,7 @@ struct kh_rspamd_symbols_group_hash_s;
struct rspamd_scan_result {
- double score; /**< total score */
- double grow_factor; /**< current grow factor */
+ double score; /**< total score */
struct rspamd_passthrough_result *passthrough_result;
double positive_score;
double negative_score;
@@ -220,16 +219,11 @@ void rspamd_task_symbol_result_foreach(struct rspamd_task *task,
gpointer ud);
/**
- * Default consolidation function for metric, it get all symbols and multiply symbol
- * weight by some factor that is specified in config. Default factor is 1.
- * @param task worker's task that present message from user
- * @param metric_name name of metric
- * @return result metric weight
+ * Adjust symbol results to the grow factor for a specific task; should be called after postfilters
*/
-double rspamd_factor_consolidation_func(struct rspamd_task *task,
- const char *metric_name,
- const char *unused);
-
+void rspamd_task_result_adjust_grow_factor(struct rspamd_task *task,
+ struct rspamd_scan_result *result,
+ double grow_factor);
/**
* Check thresholds and return action for a task
diff --git a/src/libserver/protocol.c b/src/libserver/protocol.c
index 8da246bdb..5de980352 100644
--- a/src/libserver/protocol.c
+++ b/src/libserver/protocol.c
@@ -1322,15 +1322,13 @@ rspamd_scan_result_ucl(struct rspamd_task *task,
sobj = rspamd_metric_symbol_ucl(task, sym);
ucl_object_insert_key(obj, sobj, sym->name, 0, false);
}
- })
+ });
- if (task->cmd != CMD_CHECK)
- {
+ if (task->cmd != CMD_CHECK) {
/* For checkv2 we insert symbols as a separate object */
ucl_object_insert_key(top, obj, "symbols", 0, false);
}
- else
- {
+ else {
/* For legacy check we just insert it as "default" all together */
ucl_object_insert_key(top, obj, DEFAULT_METRIC, 0, false);
}
diff --git a/src/libserver/task.c b/src/libserver/task.c
index f81f34e47..270bb80ea 100644
--- a/src/libserver/task.c
+++ b/src/libserver/task.c
@@ -758,6 +758,10 @@ rspamd_task_process(struct rspamd_task *task, unsigned int stages)
all_done = rspamd_symcache_process_symbols(task, task->cfg->cache,
st);
+ if (all_done) {
+ rspamd_task_result_adjust_grow_factor(task, task->result, task->cfg->grow_factor);
+ }
+
if (all_done && (task->flags & RSPAMD_TASK_FLAG_LEARN_AUTO) &&
!RSPAMD_TASK_IS_EMPTY(task) &&
!(task->flags & (RSPAMD_TASK_FLAG_LEARN_SPAM | RSPAMD_TASK_FLAG_LEARN_HAM))) {