diff options
Diffstat (limited to 'src/libserver/task.c')
-rw-r--r-- | src/libserver/task.c | 43 |
1 files changed, 31 insertions, 12 deletions
diff --git a/src/libserver/task.c b/src/libserver/task.c index 9f5b1f00a..f655ab11b 100644 --- a/src/libserver/task.c +++ b/src/libserver/task.c @@ -730,7 +730,7 @@ rspamd_task_process(struct rspamd_task *task, unsigned int stages) if (all_done && (task->flags & RSPAMD_TASK_FLAG_LEARN_AUTO) && !RSPAMD_TASK_IS_EMPTY(task) && - !(task->flags & (RSPAMD_TASK_FLAG_LEARN_SPAM | RSPAMD_TASK_FLAG_LEARN_HAM))) { + !(task->flags & (RSPAMD_TASK_FLAG_LEARN_SPAM | RSPAMD_TASK_FLAG_LEARN_HAM | RSPAMD_TASK_FLAG_LEARN_CLASS))) { rspamd_stat_check_autolearn(task); } break; @@ -738,12 +738,32 @@ rspamd_task_process(struct rspamd_task *task, unsigned int stages) case RSPAMD_TASK_STAGE_LEARN: case RSPAMD_TASK_STAGE_LEARN_PRE: case RSPAMD_TASK_STAGE_LEARN_POST: - if (task->flags & (RSPAMD_TASK_FLAG_LEARN_SPAM | RSPAMD_TASK_FLAG_LEARN_HAM)) { + if (task->flags & (RSPAMD_TASK_FLAG_LEARN_SPAM | RSPAMD_TASK_FLAG_LEARN_HAM | RSPAMD_TASK_FLAG_LEARN_CLASS)) { if (task->err == NULL) { - if (!rspamd_stat_learn(task, - task->flags & RSPAMD_TASK_FLAG_LEARN_SPAM, - task->cfg->lua_state, task->classifier, - st, &stat_error)) { + gboolean learn_result = FALSE; + + if (task->flags & RSPAMD_TASK_FLAG_LEARN_CLASS) { + /* Multi-class learning */ + const char *autolearn_class = rspamd_task_get_autolearn_class(task); + if (autolearn_class) { + learn_result = rspamd_stat_learn_class(task, autolearn_class, + task->cfg->lua_state, task->classifier, + st, &stat_error); + } + else { + g_set_error(&stat_error, g_quark_from_static_string("stat"), 500, + "No autolearn class specified for multi-class learning"); + } + } + else { + /* Legacy binary learning */ + learn_result = rspamd_stat_learn(task, + task->flags & RSPAMD_TASK_FLAG_LEARN_SPAM, + task->cfg->lua_state, task->classifier, + st, &stat_error); + } + + if (!learn_result) { if (stat_error == NULL) { g_set_error(&stat_error, @@ -922,15 +942,14 @@ rspamd_learn_task_spam(struct rspamd_task *task, const char *classifier, GError **err) { + /* Use unified class-based approach internally */ + const char *class_name = is_spam ? "spam" : "ham"; + /* Disable learn auto flag to avoid bad learn codes */ task->flags &= ~RSPAMD_TASK_FLAG_LEARN_AUTO; - if (is_spam) { - task->flags |= RSPAMD_TASK_FLAG_LEARN_SPAM; - } - else { - task->flags |= RSPAMD_TASK_FLAG_LEARN_HAM; - } + /* Use the unified class-based learning approach */ + rspamd_task_set_autolearn_class(task, class_name); task->classifier = classifier; |