From 96d887fb8e167278e5406f91882ac262355c9ebe Mon Sep 17 00:00:00 2001 From: Vsevolod Stakhov Date: Wed, 6 Jan 2016 15:08:48 +0000 Subject: [PATCH] Fix learning. --- src/controller.c | 54 +++++++++++++++++++++++++++++++++++--------- src/libserver/task.c | 26 ++++++++++++++++++++- src/libserver/task.h | 7 ++++++ 3 files changed, 75 insertions(+), 12 deletions(-) diff --git a/src/controller.c b/src/controller.c index aa90fc8b6..0a359a9a6 100644 --- a/src/controller.c +++ b/src/controller.c @@ -1163,25 +1163,55 @@ rspamd_controller_learn_fin_task (void *ud) struct rspamd_task *task = ud; struct rspamd_controller_session *session; struct rspamd_http_connection_entry *conn_ent; - GError *err = NULL; conn_ent = task->fin_arg; session = conn_ent->ud; - if (rspamd_learn_task_spam (task, session->is_spam, session->classifier, &err) == - RSPAMD_STAT_PROCESS_ERROR) { - msg_info_session ("cannot learn <%s>: %e", task->message_id, err); - rspamd_controller_send_error (conn_ent, err->code, err->message); + if (task->err != NULL) { + msg_info_session ("cannot learn <%s>: %e", task->message_id, task->err); + rspamd_controller_send_error (conn_ent, task->err->code, + task->err->message); return TRUE; } - /* Successful learn */ - msg_info_session ("<%s> learned message as %s: %s", - rspamd_inet_address_to_string (session->from_addr), - session->is_spam ? "spam" : "ham", - task->message_id); - rspamd_controller_send_string (conn_ent, "{\"success\":true}"); + if (RSPAMD_TASK_IS_PROCESSED (task)) { + /* Successful learn */ + msg_info_session ("<%s> learned message as %s: %s", + rspamd_inet_address_to_string (session->from_addr), + session->is_spam ? "spam" : "ham", + task->message_id); + rspamd_controller_send_string (conn_ent, "{\"success\":true}"); + return TRUE; + } + + if (!rspamd_task_process (task, RSPAMD_TASK_PROCESS_LEARN)) { + msg_info_session ("cannot learn <%s>: %e", task->message_id, task->err); + + if (task->err) { + rspamd_controller_send_error (conn_ent, task->err->code, + task->err->message); + } + else { + rspamd_controller_send_error (conn_ent, 500, + "Internal error"); + } + } + + if (RSPAMD_TASK_IS_PROCESSED (task)) { + msg_info_session ("<%s> learned message as %s: %s", + rspamd_inet_address_to_string (session->from_addr), + session->is_spam ? "spam" : "ham", + task->message_id); + rspamd_controller_send_string (conn_ent, "{\"success\":true}"); + return TRUE; + } + + /* One more iteration */ + return FALSE; + + + return TRUE; } @@ -1284,6 +1314,8 @@ rspamd_controller_handle_learn_common ( return 0; } + rspamd_learn_task_spam (task, is_spam, session->classifier, NULL); + if (!rspamd_task_process (task, RSPAMD_TASK_PROCESS_LEARN)) { msg_warn_session ("<%s> message cannot be processed", task->message_id); rspamd_controller_send_error (conn_ent, task->err->code, task->err->message); diff --git a/src/libserver/task.c b/src/libserver/task.c index 579cc3461..91ed48e86 100644 --- a/src/libserver/task.c +++ b/src/libserver/task.c @@ -457,6 +457,21 @@ rspamd_task_process (struct rspamd_task *task, guint stages) rspamd_lua_call_post_filters (task); break; + case RSPAMD_TASK_STAGE_LEARN: + case RSPAMD_TASK_STAGE_LEARN_PRE: + case RSPAMD_TASK_STAGE_LEARN_POST: + if (task->flags & (RSPAMD_TASK_FLAG_LEARN_SPAM|RSPAMD_TASK_FLAG_LEARN_HAM)) { + if (!rspamd_stat_learn (task, + task->flags & RSPAMD_TASK_FLAG_LEARN_SPAM, + task->cfg->lua_state, task->classifier, + st, &stat_error)) { + msg_err_task ("learn error: %e", stat_error); + task->err = stat_error; + task->processed_stages |= RSPAMD_TASK_STAGE_DONE; + } + } + break; + case RSPAMD_TASK_STAGE_DONE: task->processed_stages |= RSPAMD_TASK_STAGE_DONE; break; @@ -610,7 +625,16 @@ rspamd_learn_task_spam (struct rspamd_task *task, const gchar *classifier, GError **err) { - return FALSE; + if (is_spam) { + task->flags |= RSPAMD_TASK_FLAG_LEARN_SPAM; + } + else { + task->flags |= RSPAMD_TASK_FLAG_LEARN_HAM; + } + + task->classifier = classifier; + + return TRUE; } static gboolean diff --git a/src/libserver/task.h b/src/libserver/task.h index 901067ba4..7ede95b31 100644 --- a/src/libserver/task.h +++ b/src/libserver/task.h @@ -92,6 +92,9 @@ enum rspamd_task_stage { RSPAMD_TASK_STAGE_CLASSIFIERS_PRE | \ RSPAMD_TASK_STAGE_CLASSIFIERS | \ RSPAMD_TASK_STAGE_CLASSIFIERS_POST | \ + RSPAMD_TASK_STAGE_LEARN_PRE | \ + RSPAMD_TASK_STAGE_LEARN | \ + RSPAMD_TASK_STAGE_LEARN_POST | \ RSPAMD_TASK_STAGE_DONE) #define RSPAMD_TASK_FLAG_MIME (1 << 0) @@ -110,6 +113,8 @@ enum rspamd_task_stage { #define RSPAMD_TASK_FLAG_NO_STAT (1 << 13) #define RSPAMD_TASK_FLAG_UNLEARN (1 << 14) #define RSPAMD_TASK_FLAG_ALREADY_LEARNED (1 << 15) +#define RSPAMD_TASK_FLAG_LEARN_SPAM (1 << 16) +#define RSPAMD_TASK_FLAG_LEARN_HAM (1 << 17) #define RSPAMD_TASK_IS_SKIPPED(task) (((task)->flags & RSPAMD_TASK_FLAG_SKIP)) #define RSPAMD_TASK_IS_JSON(task) (((task)->flags & RSPAMD_TASK_FLAG_JSON)) @@ -192,6 +197,8 @@ struct rspamd_task { } pre_result; /**< Result of pre-filters */ ucl_object_t *settings; /**< Settings applied to task */ + + const gchar *classifier; /**< Classifier to learn (if needed) */ }; /** -- 2.39.5