aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--src/controller.c54
-rw-r--r--src/libserver/task.c26
-rw-r--r--src/libserver/task.h7
3 files changed, 75 insertions, 12 deletions
diff --git a/src/controller.c b/src/controller.c
index aa90fc8b6..0a359a9a6 100644
--- a/src/controller.c
+++ b/src/controller.c
@@ -1163,25 +1163,55 @@ rspamd_controller_learn_fin_task (void *ud)
struct rspamd_task *task = ud;
struct rspamd_controller_session *session;
struct rspamd_http_connection_entry *conn_ent;
- GError *err = NULL;
conn_ent = task->fin_arg;
session = conn_ent->ud;
- if (rspamd_learn_task_spam (task, session->is_spam, session->classifier, &err) ==
- RSPAMD_STAT_PROCESS_ERROR) {
- msg_info_session ("cannot learn <%s>: %e", task->message_id, err);
- rspamd_controller_send_error (conn_ent, err->code, err->message);
+ if (task->err != NULL) {
+ msg_info_session ("cannot learn <%s>: %e", task->message_id, task->err);
+ rspamd_controller_send_error (conn_ent, task->err->code,
+ task->err->message);
return TRUE;
}
- /* Successful learn */
- msg_info_session ("<%s> learned message as %s: %s",
- rspamd_inet_address_to_string (session->from_addr),
- session->is_spam ? "spam" : "ham",
- task->message_id);
- rspamd_controller_send_string (conn_ent, "{\"success\":true}");
+ if (RSPAMD_TASK_IS_PROCESSED (task)) {
+ /* Successful learn */
+ msg_info_session ("<%s> learned message as %s: %s",
+ rspamd_inet_address_to_string (session->from_addr),
+ session->is_spam ? "spam" : "ham",
+ task->message_id);
+ rspamd_controller_send_string (conn_ent, "{\"success\":true}");
+ return TRUE;
+ }
+
+ if (!rspamd_task_process (task, RSPAMD_TASK_PROCESS_LEARN)) {
+ msg_info_session ("cannot learn <%s>: %e", task->message_id, task->err);
+
+ if (task->err) {
+ rspamd_controller_send_error (conn_ent, task->err->code,
+ task->err->message);
+ }
+ else {
+ rspamd_controller_send_error (conn_ent, 500,
+ "Internal error");
+ }
+ }
+
+ if (RSPAMD_TASK_IS_PROCESSED (task)) {
+ msg_info_session ("<%s> learned message as %s: %s",
+ rspamd_inet_address_to_string (session->from_addr),
+ session->is_spam ? "spam" : "ham",
+ task->message_id);
+ rspamd_controller_send_string (conn_ent, "{\"success\":true}");
+ return TRUE;
+ }
+
+ /* One more iteration */
+ return FALSE;
+
+
+
return TRUE;
}
@@ -1284,6 +1314,8 @@ rspamd_controller_handle_learn_common (
return 0;
}
+ rspamd_learn_task_spam (task, is_spam, session->classifier, NULL);
+
if (!rspamd_task_process (task, RSPAMD_TASK_PROCESS_LEARN)) {
msg_warn_session ("<%s> message cannot be processed", task->message_id);
rspamd_controller_send_error (conn_ent, task->err->code, task->err->message);
diff --git a/src/libserver/task.c b/src/libserver/task.c
index 579cc3461..91ed48e86 100644
--- a/src/libserver/task.c
+++ b/src/libserver/task.c
@@ -457,6 +457,21 @@ rspamd_task_process (struct rspamd_task *task, guint stages)
rspamd_lua_call_post_filters (task);
break;
+ case RSPAMD_TASK_STAGE_LEARN:
+ case RSPAMD_TASK_STAGE_LEARN_PRE:
+ case RSPAMD_TASK_STAGE_LEARN_POST:
+ if (task->flags & (RSPAMD_TASK_FLAG_LEARN_SPAM|RSPAMD_TASK_FLAG_LEARN_HAM)) {
+ if (!rspamd_stat_learn (task,
+ task->flags & RSPAMD_TASK_FLAG_LEARN_SPAM,
+ task->cfg->lua_state, task->classifier,
+ st, &stat_error)) {
+ msg_err_task ("learn error: %e", stat_error);
+ task->err = stat_error;
+ task->processed_stages |= RSPAMD_TASK_STAGE_DONE;
+ }
+ }
+ break;
+
case RSPAMD_TASK_STAGE_DONE:
task->processed_stages |= RSPAMD_TASK_STAGE_DONE;
break;
@@ -610,7 +625,16 @@ rspamd_learn_task_spam (struct rspamd_task *task,
const gchar *classifier,
GError **err)
{
- return FALSE;
+ if (is_spam) {
+ task->flags |= RSPAMD_TASK_FLAG_LEARN_SPAM;
+ }
+ else {
+ task->flags |= RSPAMD_TASK_FLAG_LEARN_HAM;
+ }
+
+ task->classifier = classifier;
+
+ return TRUE;
}
static gboolean
diff --git a/src/libserver/task.h b/src/libserver/task.h
index 901067ba4..7ede95b31 100644
--- a/src/libserver/task.h
+++ b/src/libserver/task.h
@@ -92,6 +92,9 @@ enum rspamd_task_stage {
RSPAMD_TASK_STAGE_CLASSIFIERS_PRE | \
RSPAMD_TASK_STAGE_CLASSIFIERS | \
RSPAMD_TASK_STAGE_CLASSIFIERS_POST | \
+ RSPAMD_TASK_STAGE_LEARN_PRE | \
+ RSPAMD_TASK_STAGE_LEARN | \
+ RSPAMD_TASK_STAGE_LEARN_POST | \
RSPAMD_TASK_STAGE_DONE)
#define RSPAMD_TASK_FLAG_MIME (1 << 0)
@@ -110,6 +113,8 @@ enum rspamd_task_stage {
#define RSPAMD_TASK_FLAG_NO_STAT (1 << 13)
#define RSPAMD_TASK_FLAG_UNLEARN (1 << 14)
#define RSPAMD_TASK_FLAG_ALREADY_LEARNED (1 << 15)
+#define RSPAMD_TASK_FLAG_LEARN_SPAM (1 << 16)
+#define RSPAMD_TASK_FLAG_LEARN_HAM (1 << 17)
#define RSPAMD_TASK_IS_SKIPPED(task) (((task)->flags & RSPAMD_TASK_FLAG_SKIP))
#define RSPAMD_TASK_IS_JSON(task) (((task)->flags & RSPAMD_TASK_FLAG_JSON))
@@ -192,6 +197,8 @@ struct rspamd_task {
} pre_result; /**< Result of pre-filters */
ucl_object_t *settings; /**< Settings applied to task */
+
+ const gchar *classifier; /**< Classifier to learn (if needed) */
};
/**