aboutsummaryrefslogtreecommitdiffstats
path: root/src/libstat/stat_process.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/libstat/stat_process.c')
-rw-r--r--src/libstat/stat_process.c620
1 files changed, 531 insertions, 89 deletions
diff --git a/src/libstat/stat_process.c b/src/libstat/stat_process.c
index 176064087..11b31decc 100644
--- a/src/libstat/stat_process.c
+++ b/src/libstat/stat_process.c
@@ -32,6 +32,78 @@
static const double similarity_threshold = 80.0;
+void rspamd_task_set_multiclass_result(struct rspamd_task *task, rspamd_multiclass_result_t *result)
+{
+ g_assert(task != NULL);
+ g_assert(result != NULL);
+
+ rspamd_mempool_set_variable(task->task_pool, "multiclass_bayes_result", result,
+ (rspamd_mempool_destruct_t) rspamd_multiclass_result_free);
+}
+
+rspamd_multiclass_result_t *
+rspamd_task_get_multiclass_result(struct rspamd_task *task)
+{
+ g_assert(task != NULL);
+
+ return (rspamd_multiclass_result_t *) rspamd_mempool_get_variable(task->task_pool,
+ "multiclass_bayes_result");
+}
+
+void rspamd_multiclass_result_free(rspamd_multiclass_result_t *result)
+{
+ if (result == NULL) {
+ return;
+ }
+
+ g_free(result->class_names);
+ g_free(result->probabilities);
+ /* winning_class is a reference, not owned - don't free */
+ g_free(result);
+}
+
+void rspamd_task_set_autolearn_class(struct rspamd_task *task, const char *class_name)
+{
+ g_assert(task != NULL);
+ g_assert(class_name != NULL);
+
+ /* Store the class name in the mempool */
+ const char *class_name_copy = rspamd_mempool_strdup(task->task_pool, class_name);
+ rspamd_mempool_set_variable(task->task_pool, "autolearn_class",
+ (gpointer) class_name_copy, NULL);
+
+ /* Set the appropriate flags */
+ task->flags |= RSPAMD_TASK_FLAG_LEARN_CLASS;
+
+ /* For backward compatibility, also set binary flags */
+ if (strcmp(class_name, "spam") == 0 || strcmp(class_name, "S") == 0) {
+ task->flags |= RSPAMD_TASK_FLAG_LEARN_SPAM;
+ }
+ else if (strcmp(class_name, "ham") == 0 || strcmp(class_name, "H") == 0) {
+ task->flags |= RSPAMD_TASK_FLAG_LEARN_HAM;
+ }
+}
+
+const char *
+rspamd_task_get_autolearn_class(struct rspamd_task *task)
+{
+ g_assert(task != NULL);
+
+ if (task->flags & RSPAMD_TASK_FLAG_LEARN_CLASS) {
+ return (const char *) rspamd_mempool_get_variable(task->task_pool, "autolearn_class");
+ }
+
+ /* Fallback to binary flags for backward compatibility */
+ if (task->flags & RSPAMD_TASK_FLAG_LEARN_SPAM) {
+ return "spam";
+ }
+ else if (task->flags & RSPAMD_TASK_FLAG_LEARN_HAM) {
+ return "ham";
+ }
+
+ return NULL;
+}
+
static void
rspamd_stat_tokenize_parts_metadata(struct rspamd_stat_ctx *st_ctx,
struct rspamd_task *task)
@@ -394,18 +466,9 @@ rspamd_stat_classifiers_process(struct rspamd_stat_ctx *st_ctx,
}
/*
- * Do not classify a message if some class is missing
+ * Multi-class approach: don't check for missing classes
+ * Missing tokens naturally result in 0 probability
*/
- if (!(task->flags & RSPAMD_TASK_FLAG_HAS_SPAM_TOKENS)) {
- msg_info_task("skip statistics as SPAM class is missing");
-
- return;
- }
- if (!(task->flags & RSPAMD_TASK_FLAG_HAS_HAM_TOKENS)) {
- msg_info_task("skip statistics as HAM class is missing");
-
- return;
- }
for (i = 0; i < st_ctx->classifiers->len; i++) {
cl = g_ptr_array_index(st_ctx->classifiers, i);
@@ -565,7 +628,24 @@ rspamd_stat_cache_check(struct rspamd_stat_ctx *st_ctx,
if (sel->cache && sel->cachecf) {
rt = cl->cache->runtime(task, sel->cachecf, FALSE);
- learn_res = cl->cache->check(task, spam, rt);
+
+ /* For multi-class learning, determine spam boolean from class name if available */
+ gboolean cache_spam = spam; /* Default to original spam parameter */
+ const char *autolearn_class = rspamd_task_get_autolearn_class(task);
+ if (autolearn_class) {
+ if (strcmp(autolearn_class, "spam") == 0 || strcmp(autolearn_class, "S") == 0) {
+ cache_spam = TRUE;
+ }
+ else if (strcmp(autolearn_class, "ham") == 0 || strcmp(autolearn_class, "H") == 0) {
+ cache_spam = FALSE;
+ }
+ else {
+ /* For other classes, use a heuristic or default to spam for cache purposes */
+ cache_spam = TRUE; /* Non-ham classes are treated as spam for cache */
+ }
+ }
+
+ learn_res = cl->cache->check(task, cache_spam, rt);
}
if (learn_res == RSPAMD_LEARN_IGNORE) {
@@ -658,9 +738,63 @@ rspamd_stat_classifiers_learn(struct rspamd_stat_ctx *st_ctx,
continue;
}
- if (cl->subrs->learn_spam_func(cl, task->tokens, task, spam,
- task->flags & RSPAMD_TASK_FLAG_UNLEARN, err)) {
- learned = TRUE;
+ /* Check if classifier supports multi-class learning and if we should use it */
+ if (cl->subrs->learn_class_func && cl->cfg->class_names && cl->cfg->class_names->len > 2) {
+ /* Multi-class learning: determine class name from task flags or autolearn result */
+ const char *class_name = NULL;
+
+ if (task->flags & RSPAMD_TASK_FLAG_LEARN_SPAM) {
+ /* Find spam class name */
+ for (unsigned int k = 0; k < cl->cfg->class_names->len; k++) {
+ const char *check_class = (const char *) g_ptr_array_index(cl->cfg->class_names, k);
+ /* Look for statfile with this class that is spam */
+ GList *cur = cl->cfg->statfiles;
+ while (cur) {
+ struct rspamd_statfile_config *stcf = (struct rspamd_statfile_config *) cur->data;
+ if (stcf->class_name && strcmp(stcf->class_name, check_class) == 0 && stcf->is_spam) {
+ class_name = check_class;
+ break;
+ }
+ cur = g_list_next(cur);
+ }
+ if (class_name) break;
+ }
+ if (!class_name) class_name = "spam"; /* fallback */
+ }
+ else if (task->flags & RSPAMD_TASK_FLAG_LEARN_HAM) {
+ /* Find ham class name */
+ for (unsigned int k = 0; k < cl->cfg->class_names->len; k++) {
+ const char *check_class = (const char *) g_ptr_array_index(cl->cfg->class_names, k);
+ /* Look for statfile with this class that is ham */
+ GList *cur = cl->cfg->statfiles;
+ while (cur) {
+ struct rspamd_statfile_config *stcf = (struct rspamd_statfile_config *) cur->data;
+ if (stcf->class_name && strcmp(stcf->class_name, check_class) == 0 && !stcf->is_spam) {
+ class_name = check_class;
+ break;
+ }
+ cur = g_list_next(cur);
+ }
+ if (class_name) break;
+ }
+ if (!class_name) class_name = "ham"; /* fallback */
+ }
+ else {
+ /* Fallback to spam/ham based on the spam parameter */
+ class_name = spam ? "spam" : "ham";
+ }
+
+ if (cl->subrs->learn_class_func(cl, task->tokens, task, class_name,
+ task->flags & RSPAMD_TASK_FLAG_UNLEARN, err)) {
+ learned = TRUE;
+ }
+ }
+ else {
+ /* Binary learning: use existing function */
+ if (cl->subrs->learn_spam_func(cl, task->tokens, task, spam,
+ task->flags & RSPAMD_TASK_FLAG_UNLEARN, err)) {
+ learned = TRUE;
+ }
}
}
@@ -759,9 +893,26 @@ rspamd_stat_backends_learn(struct rspamd_stat_ctx *st_ctx,
backend_found = TRUE;
if (!(task->flags & RSPAMD_TASK_FLAG_UNLEARN)) {
- if (!!spam != !!st->stcf->is_spam) {
- /* If we are not unlearning, then do not touch another class */
- continue;
+ /* For multiclass learning, check if this statfile has any tokens to learn */
+ if (task->flags & RSPAMD_TASK_FLAG_LEARN_CLASS) {
+ /* Multiclass learning: only process statfiles that have tokens set up by the classifier */
+ gboolean has_tokens = FALSE;
+ for (unsigned int k = 0; k < task->tokens->len && !has_tokens; k++) {
+ rspamd_token_t *tok = (rspamd_token_t *) g_ptr_array_index(task->tokens, k);
+ if (tok->values[id] != 0) {
+ has_tokens = TRUE;
+ }
+ }
+ if (!has_tokens) {
+ continue;
+ }
+ }
+ else {
+ /* Binary learning: use traditional spam/ham check */
+ if (!!spam != !!st->stcf->is_spam) {
+ /* If we are not unlearning, then do not touch another class */
+ continue;
+ }
}
}
@@ -870,7 +1021,24 @@ rspamd_stat_backends_post_learn(struct rspamd_stat_ctx *st_ctx,
if (cl->cache) {
cache_run = cl->cache->runtime(task, cl->cachecf, TRUE);
- cl->cache->learn(task, spam, cache_run);
+
+ /* For multi-class learning, determine spam boolean from class name if available */
+ gboolean cache_spam = spam; /* Default to original spam parameter */
+ const char *autolearn_class = rspamd_task_get_autolearn_class(task);
+ if (autolearn_class) {
+ if (strcmp(autolearn_class, "spam") == 0 || strcmp(autolearn_class, "S") == 0) {
+ cache_spam = TRUE;
+ }
+ else if (strcmp(autolearn_class, "ham") == 0 || strcmp(autolearn_class, "H") == 0) {
+ cache_spam = FALSE;
+ }
+ else {
+ /* For other classes, use a heuristic or default to spam for cache purposes */
+ cache_spam = TRUE; /* Non-ham classes are treated as spam for cache */
+ }
+ }
+
+ cl->cache->learn(task, cache_spam, cache_run);
}
}
@@ -879,6 +1047,218 @@ rspamd_stat_backends_post_learn(struct rspamd_stat_ctx *st_ctx,
return res;
}
+static gboolean
+rspamd_stat_classifiers_learn_class(struct rspamd_stat_ctx *st_ctx,
+ struct rspamd_task *task,
+ const char *classifier,
+ const char *class_name,
+ GError **err)
+{
+ struct rspamd_classifier *cl, *sel = NULL;
+ unsigned int i;
+ gboolean learned = FALSE, too_small = FALSE, too_large = FALSE;
+
+ if ((task->flags & RSPAMD_TASK_FLAG_ALREADY_LEARNED) && err != NULL &&
+ *err == NULL) {
+ /* Do not learn twice */
+ g_set_error(err, rspamd_stat_quark(), 208, "<%s> has been already "
+ "learned as %s, ignore it",
+ MESSAGE_FIELD(task, message_id),
+ class_name);
+
+ return FALSE;
+ }
+
+ /* Check whether we have learned that file */
+ for (i = 0; i < st_ctx->classifiers->len; i++) {
+ cl = g_ptr_array_index(st_ctx->classifiers, i);
+
+ /* Skip other classifiers if they are not needed */
+ if (classifier != NULL && (cl->cfg->name == NULL ||
+ g_ascii_strcasecmp(classifier, cl->cfg->name) != 0)) {
+ continue;
+ }
+
+ sel = cl;
+
+ /* Now check max and min tokens */
+ if (cl->cfg->min_tokens > 0 && task->tokens->len < cl->cfg->min_tokens) {
+ msg_info_task(
+ "<%s> contains less tokens than required for %s classifier: "
+ "%ud < %ud",
+ MESSAGE_FIELD(task, message_id),
+ cl->cfg->name,
+ task->tokens->len,
+ cl->cfg->min_tokens);
+ too_small = TRUE;
+ continue;
+ }
+ else if (cl->cfg->max_tokens > 0 && task->tokens->len > cl->cfg->max_tokens) {
+ msg_info_task(
+ "<%s> contains more tokens than allowed for %s classifier: "
+ "%ud > %ud",
+ MESSAGE_FIELD(task, message_id),
+ cl->cfg->name,
+ task->tokens->len,
+ cl->cfg->max_tokens);
+ too_large = TRUE;
+ continue;
+ }
+
+ /* Use the new multi-class learning function if available */
+ if (cl->subrs->learn_class_func) {
+ if (cl->subrs->learn_class_func(cl, task->tokens, task, class_name,
+ task->flags & RSPAMD_TASK_FLAG_UNLEARN, err)) {
+ learned = TRUE;
+ }
+ }
+ else {
+ /* Fallback to binary learning with class name mapping */
+ gboolean is_spam;
+ if (strcmp(class_name, "spam") == 0 || strcmp(class_name, "S") == 0) {
+ is_spam = TRUE;
+ }
+ else if (strcmp(class_name, "ham") == 0 || strcmp(class_name, "H") == 0) {
+ is_spam = FALSE;
+ }
+ else {
+ /* For unknown classes with binary classifier, skip */
+ msg_info_task("skipping class '%s' for binary classifier %s",
+ class_name, cl->cfg->name);
+ continue;
+ }
+
+ if (cl->subrs->learn_spam_func(cl, task->tokens, task, is_spam,
+ task->flags & RSPAMD_TASK_FLAG_UNLEARN, err)) {
+ learned = TRUE;
+ }
+ }
+ }
+
+ if (sel == NULL) {
+ if (classifier) {
+ g_set_error(err, rspamd_stat_quark(), 404, "cannot find classifier "
+ "with name %s",
+ classifier);
+ }
+ else {
+ g_set_error(err, rspamd_stat_quark(), 404, "no classifiers defined");
+ }
+
+ return FALSE;
+ }
+
+ if (!learned && err && *err == NULL) {
+ if (too_large) {
+ g_set_error(err, rspamd_stat_quark(), 204,
+ "<%s> contains more tokens than allowed for %s classifier: "
+ "%d > %d",
+ MESSAGE_FIELD(task, message_id),
+ sel->cfg->name,
+ task->tokens->len,
+ sel->cfg->max_tokens);
+ }
+ else if (too_small) {
+ g_set_error(err, rspamd_stat_quark(), 204,
+ "<%s> contains less tokens than required for %s classifier: "
+ "%d < %d",
+ MESSAGE_FIELD(task, message_id),
+ sel->cfg->name,
+ task->tokens->len,
+ sel->cfg->min_tokens);
+ }
+ }
+
+ return learned;
+}
+
+rspamd_stat_result_t
+rspamd_stat_learn_class(struct rspamd_task *task,
+ const char *class_name,
+ lua_State *L,
+ const char *classifier,
+ unsigned int stage,
+ GError **err)
+{
+ struct rspamd_stat_ctx *st_ctx;
+ rspamd_stat_result_t ret = RSPAMD_STAT_PROCESS_OK;
+
+ /*
+ * We assume now that a task has been already classified before
+ * coming to learn
+ */
+ g_assert(RSPAMD_TASK_IS_CLASSIFIED(task));
+
+ st_ctx = rspamd_stat_get_ctx();
+ g_assert(st_ctx != NULL);
+
+ msg_debug_bayes("learn class stage %d has been called for class '%s'", stage, class_name);
+
+ if (st_ctx->classifiers->len == 0) {
+ msg_debug_bayes("no classifiers defined");
+ task->processed_stages |= stage;
+ return ret;
+ }
+
+ if (task->message == NULL) {
+ ret = RSPAMD_STAT_PROCESS_ERROR;
+ if (err && *err == NULL) {
+ g_set_error(err, rspamd_stat_quark(), 500,
+ "Trying to learn an empty message");
+ }
+
+ task->processed_stages |= stage;
+ return ret;
+ }
+
+ if (stage == RSPAMD_TASK_STAGE_LEARN_PRE) {
+ /* Process classifiers - determine spam boolean for compatibility */
+ gboolean spam = (strcmp(class_name, "spam") == 0 || strcmp(class_name, "S") == 0);
+ rspamd_stat_preprocess(st_ctx, task, TRUE, spam);
+
+ if (!rspamd_stat_cache_check(st_ctx, task, classifier, spam, err)) {
+ msg_debug_bayes("cache check failed, skip learning");
+ return RSPAMD_STAT_PROCESS_ERROR;
+ }
+ }
+ else if (stage == RSPAMD_TASK_STAGE_LEARN) {
+ /* Process classifiers */
+ if (!rspamd_stat_classifiers_learn_class(st_ctx, task, classifier,
+ class_name, err)) {
+ if (err && *err == NULL) {
+ g_set_error(err, rspamd_stat_quark(), 500,
+ "Unknown statistics error, found when learning classifiers;"
+ " classifier: %s",
+ task->classifier);
+ }
+ return RSPAMD_STAT_PROCESS_ERROR;
+ }
+
+ /* Process backends - determine spam boolean for compatibility */
+ gboolean spam = (strcmp(class_name, "spam") == 0 || strcmp(class_name, "S") == 0);
+ if (!rspamd_stat_backends_learn(st_ctx, task, classifier, spam, err)) {
+ if (err && *err == NULL) {
+ g_set_error(err, rspamd_stat_quark(), 500,
+ "Unknown statistics error, found when storing data on backend;"
+ " classifier: %s",
+ task->classifier);
+ }
+ return RSPAMD_STAT_PROCESS_ERROR;
+ }
+ }
+ else if (stage == RSPAMD_TASK_STAGE_LEARN_POST) {
+ /* Process backends - determine spam boolean for compatibility */
+ gboolean spam = (strcmp(class_name, "spam") == 0 || strcmp(class_name, "S") == 0);
+ if (!rspamd_stat_backends_post_learn(st_ctx, task, classifier, spam, err)) {
+ return RSPAMD_STAT_PROCESS_ERROR;
+ }
+ }
+
+ task->processed_stages |= stage;
+
+ return ret;
+}
+
rspamd_stat_result_t
rspamd_stat_learn(struct rspamd_task *task,
gboolean spam, lua_State *L, const char *classifier, unsigned int stage,
@@ -1039,12 +1419,11 @@ rspamd_stat_check_autolearn(struct rspamd_task *task)
if (mres) {
if (mres->score > rspamd_task_get_required_score(task, mres)) {
- task->flags |= RSPAMD_TASK_FLAG_LEARN_SPAM;
-
+ rspamd_task_set_autolearn_class(task, "spam");
ret = TRUE;
}
else if (mres->score < 0) {
- task->flags |= RSPAMD_TASK_FLAG_LEARN_HAM;
+ rspamd_task_set_autolearn_class(task, "ham");
ret = TRUE;
}
}
@@ -1076,12 +1455,11 @@ rspamd_stat_check_autolearn(struct rspamd_task *task)
if (mres) {
if (mres->score >= spam_score) {
- task->flags |= RSPAMD_TASK_FLAG_LEARN_SPAM;
-
+ rspamd_task_set_autolearn_class(task, "spam");
ret = TRUE;
}
else if (mres->score <= ham_score) {
- task->flags |= RSPAMD_TASK_FLAG_LEARN_HAM;
+ rspamd_task_set_autolearn_class(task, "ham");
ret = TRUE;
}
}
@@ -1117,11 +1495,16 @@ rspamd_stat_check_autolearn(struct rspamd_task *task)
/* We can have immediate results */
if (lua_ret) {
if (strcmp(lua_ret, "ham") == 0) {
- task->flags |= RSPAMD_TASK_FLAG_LEARN_HAM;
+ rspamd_task_set_autolearn_class(task, "ham");
ret = TRUE;
}
else if (strcmp(lua_ret, "spam") == 0) {
- task->flags |= RSPAMD_TASK_FLAG_LEARN_SPAM;
+ rspamd_task_set_autolearn_class(task, "spam");
+ ret = TRUE;
+ }
+ else {
+ /* Multi-class: any other class name */
+ rspamd_task_set_autolearn_class(task, lua_ret);
ret = TRUE;
}
}
@@ -1139,79 +1522,138 @@ rspamd_stat_check_autolearn(struct rspamd_task *task)
}
}
else if (ucl_object_type(obj) == UCL_OBJECT) {
- /* Try to find autolearn callback */
- if (cl->autolearn_cbref == 0) {
- /* We don't have preprocessed cb id, so try to get it */
- if (!rspamd_lua_require_function(L, "lua_bayes_learn",
- "autolearn")) {
- msg_err_task("cannot get autolearn library from "
- "`lua_bayes_learn`");
- }
- else {
- cl->autolearn_cbref = luaL_ref(L, LUA_REGISTRYINDEX);
+ /* Check if this is a multi-class autolearn configuration */
+ const ucl_object_t *multiclass_obj = ucl_object_lookup(obj, "multiclass");
+
+ if (multiclass_obj && ucl_object_type(multiclass_obj) == UCL_OBJECT) {
+ /* Multi-class threshold-based autolearn */
+ const ucl_object_t *thresholds_obj = ucl_object_lookup(multiclass_obj, "thresholds");
+
+ if (thresholds_obj && ucl_object_type(thresholds_obj) == UCL_OBJECT) {
+ /* Iterate through class thresholds */
+ ucl_object_iter_t it = NULL;
+ const ucl_object_t *class_obj;
+ const char *class_name;
+
+ while ((class_obj = ucl_object_iterate(thresholds_obj, &it, true))) {
+ class_name = ucl_object_key(class_obj);
+
+ if (class_name && ucl_object_type(class_obj) == UCL_ARRAY && class_obj->len == 2) {
+ /* [min_score, max_score] for this class */
+ const ucl_object_t *min_elt = ucl_array_find_index(class_obj, 0);
+ const ucl_object_t *max_elt = ucl_array_find_index(class_obj, 1);
+
+ if ((ucl_object_type(min_elt) == UCL_FLOAT || ucl_object_type(min_elt) == UCL_INT) &&
+ (ucl_object_type(max_elt) == UCL_FLOAT || ucl_object_type(max_elt) == UCL_INT)) {
+
+ double min_score = ucl_object_todouble(min_elt);
+ double max_score = ucl_object_todouble(max_elt);
+
+ if (mres && mres->score >= min_score && mres->score <= max_score) {
+ rspamd_task_set_autolearn_class(task, class_name);
+ ret = TRUE;
+ msg_debug_bayes("multiclass autolearn: score %.2f matches class '%s' [%.2f, %.2f]",
+ mres->score, class_name, min_score, max_score);
+ break; /* Stop at first matching class */
+ }
+ }
+ }
+ }
}
}
-
- if (cl->autolearn_cbref != -1) {
- lua_pushcfunction(L, &rspamd_lua_traceback);
- err_idx = lua_gettop(L);
- lua_rawgeti(L, LUA_REGISTRYINDEX, cl->autolearn_cbref);
-
- ptask = lua_newuserdata(L, sizeof(struct rspamd_task *));
- *ptask = task;
- rspamd_lua_setclass(L, rspamd_task_classname, -1);
- /* Push the whole object as well */
- ucl_object_push_lua(L, obj, true);
-
- if (lua_pcall(L, 2, 1, err_idx) != 0) {
- msg_err_task("call to autolearn script failed: "
- "%s",
- lua_tostring(L, -1));
+ else {
+ /* Try to find autolearn callback */
+ if (cl->autolearn_cbref == 0) {
+ /* We don't have preprocessed cb id, so try to get it */
+ if (!rspamd_lua_require_function(L, "lua_bayes_learn",
+ "autolearn")) {
+ msg_err_task("cannot get autolearn library from "
+ "`lua_bayes_learn`");
+ }
+ else {
+ cl->autolearn_cbref = luaL_ref(L, LUA_REGISTRYINDEX);
+ }
}
- else {
- lua_ret = lua_tostring(L, -1);
- if (lua_ret) {
- if (strcmp(lua_ret, "ham") == 0) {
- task->flags |= RSPAMD_TASK_FLAG_LEARN_HAM;
- ret = TRUE;
- }
- else if (strcmp(lua_ret, "spam") == 0) {
- task->flags |= RSPAMD_TASK_FLAG_LEARN_SPAM;
- ret = TRUE;
+ if (cl->autolearn_cbref != -1) {
+ lua_pushcfunction(L, &rspamd_lua_traceback);
+ err_idx = lua_gettop(L);
+ lua_rawgeti(L, LUA_REGISTRYINDEX, cl->autolearn_cbref);
+
+ ptask = lua_newuserdata(L, sizeof(struct rspamd_task *));
+ *ptask = task;
+ rspamd_lua_setclass(L, rspamd_task_classname, -1);
+ /* Push the whole object as well */
+ ucl_object_push_lua(L, obj, true);
+
+ if (lua_pcall(L, 2, 1, err_idx) != 0) {
+ msg_err_task("call to autolearn script failed: "
+ "%s",
+ lua_tostring(L, -1));
+ }
+ else {
+ lua_ret = lua_tostring(L, -1);
+
+ if (lua_ret) {
+ if (strcmp(lua_ret, "ham") == 0) {
+ rspamd_task_set_autolearn_class(task, "ham");
+ ret = TRUE;
+ }
+ else if (strcmp(lua_ret, "spam") == 0) {
+ rspamd_task_set_autolearn_class(task, "spam");
+ ret = TRUE;
+ }
+ else {
+ /* Multi-class: any other class name */
+ rspamd_task_set_autolearn_class(task, lua_ret);
+ ret = TRUE;
+ }
}
}
- }
- lua_settop(L, err_idx - 1);
+ lua_settop(L, err_idx - 1);
+ }
}
- }
- if (ret) {
- /* Do not autolearn if we have this symbol already */
- if (rspamd_stat_has_classifier_symbols(task, mres, cl)) {
- ret = FALSE;
- task->flags &= ~(RSPAMD_TASK_FLAG_LEARN_HAM |
- RSPAMD_TASK_FLAG_LEARN_SPAM);
- }
- else if (mres != NULL) {
- if (task->flags & RSPAMD_TASK_FLAG_LEARN_HAM) {
- msg_info_task("<%s>: autolearn ham for classifier "
- "'%s' as message's "
- "score is negative: %.2f",
- MESSAGE_FIELD(task, message_id), cl->cfg->name,
- mres->score);
- }
- else {
- msg_info_task("<%s>: autolearn spam for classifier "
- "'%s' as message's "
- "action is reject, score: %.2f",
- MESSAGE_FIELD(task, message_id), cl->cfg->name,
- mres->score);
+ if (ret) {
+ /* Do not autolearn if we have this symbol already */
+ if (rspamd_stat_has_classifier_symbols(task, mres, cl)) {
+ ret = FALSE;
+ task->flags &= ~(RSPAMD_TASK_FLAG_LEARN_HAM |
+ RSPAMD_TASK_FLAG_LEARN_SPAM |
+ RSPAMD_TASK_FLAG_LEARN_CLASS);
+ /* Clear the autolearn class from mempool */
+ rspamd_mempool_set_variable(task->task_pool, "autolearn_class", NULL, NULL);
}
+ else if (mres != NULL) {
+ const char *autolearn_class = rspamd_task_get_autolearn_class(task);
+
+ if (autolearn_class) {
+ if (strcmp(autolearn_class, "ham") == 0) {
+ msg_info_task("<%s>: autolearn ham for classifier "
+ "'%s' as message's "
+ "score is negative: %.2f",
+ MESSAGE_FIELD(task, message_id), cl->cfg->name,
+ mres->score);
+ }
+ else if (strcmp(autolearn_class, "spam") == 0) {
+ msg_info_task("<%s>: autolearn spam for classifier "
+ "'%s' as message's "
+ "action is reject, score: %.2f",
+ MESSAGE_FIELD(task, message_id), cl->cfg->name,
+ mres->score);
+ }
+ else {
+ msg_info_task("<%s>: autolearn class '%s' for classifier "
+ "'%s', score: %.2f",
+ MESSAGE_FIELD(task, message_id), autolearn_class,
+ cl->cfg->name, mres->score);
+ }
+ }
- task->classifier = cl->cfg->name;
- break;
+ task->classifier = cl->cfg->name;
+ break;
+ }
}
}
}