aboutsummaryrefslogtreecommitdiffstats
path: root/src/libserver/task.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/libserver/task.c')
-rw-r--r--src/libserver/task.c43
1 files changed, 31 insertions, 12 deletions
diff --git a/src/libserver/task.c b/src/libserver/task.c
index 9f5b1f00a..f655ab11b 100644
--- a/src/libserver/task.c
+++ b/src/libserver/task.c
@@ -730,7 +730,7 @@ rspamd_task_process(struct rspamd_task *task, unsigned int stages)
if (all_done && (task->flags & RSPAMD_TASK_FLAG_LEARN_AUTO) &&
!RSPAMD_TASK_IS_EMPTY(task) &&
- !(task->flags & (RSPAMD_TASK_FLAG_LEARN_SPAM | RSPAMD_TASK_FLAG_LEARN_HAM))) {
+ !(task->flags & (RSPAMD_TASK_FLAG_LEARN_SPAM | RSPAMD_TASK_FLAG_LEARN_HAM | RSPAMD_TASK_FLAG_LEARN_CLASS))) {
rspamd_stat_check_autolearn(task);
}
break;
@@ -738,12 +738,32 @@ rspamd_task_process(struct rspamd_task *task, unsigned int stages)
case RSPAMD_TASK_STAGE_LEARN:
case RSPAMD_TASK_STAGE_LEARN_PRE:
case RSPAMD_TASK_STAGE_LEARN_POST:
- if (task->flags & (RSPAMD_TASK_FLAG_LEARN_SPAM | RSPAMD_TASK_FLAG_LEARN_HAM)) {
+ if (task->flags & (RSPAMD_TASK_FLAG_LEARN_SPAM | RSPAMD_TASK_FLAG_LEARN_HAM | RSPAMD_TASK_FLAG_LEARN_CLASS)) {
if (task->err == NULL) {
- if (!rspamd_stat_learn(task,
- task->flags & RSPAMD_TASK_FLAG_LEARN_SPAM,
- task->cfg->lua_state, task->classifier,
- st, &stat_error)) {
+ gboolean learn_result = FALSE;
+
+ if (task->flags & RSPAMD_TASK_FLAG_LEARN_CLASS) {
+ /* Multi-class learning */
+ const char *autolearn_class = rspamd_task_get_autolearn_class(task);
+ if (autolearn_class) {
+ learn_result = rspamd_stat_learn_class(task, autolearn_class,
+ task->cfg->lua_state, task->classifier,
+ st, &stat_error);
+ }
+ else {
+ g_set_error(&stat_error, g_quark_from_static_string("stat"), 500,
+ "No autolearn class specified for multi-class learning");
+ }
+ }
+ else {
+ /* Legacy binary learning */
+ learn_result = rspamd_stat_learn(task,
+ task->flags & RSPAMD_TASK_FLAG_LEARN_SPAM,
+ task->cfg->lua_state, task->classifier,
+ st, &stat_error);
+ }
+
+ if (!learn_result) {
if (stat_error == NULL) {
g_set_error(&stat_error,
@@ -922,15 +942,14 @@ rspamd_learn_task_spam(struct rspamd_task *task,
const char *classifier,
GError **err)
{
+ /* Use unified class-based approach internally */
+ const char *class_name = is_spam ? "spam" : "ham";
+
/* Disable learn auto flag to avoid bad learn codes */
task->flags &= ~RSPAMD_TASK_FLAG_LEARN_AUTO;
- if (is_spam) {
- task->flags |= RSPAMD_TASK_FLAG_LEARN_SPAM;
- }
- else {
- task->flags |= RSPAMD_TASK_FLAG_LEARN_HAM;
- }
+ /* Use the unified class-based learning approach */
+ rspamd_task_set_autolearn_class(task, class_name);
task->classifier = classifier;