diff options
Diffstat (limited to 'src/controller.c')
-rw-r--r-- | src/controller.c | 103 |
1 files changed, 100 insertions, 3 deletions
diff --git a/src/controller.c b/src/controller.c index 0550ba6b8..6e0e4cac1 100644 --- a/src/controller.c +++ b/src/controller.c @@ -53,6 +53,7 @@ #define PATH_HISTORY_RESET "/historyreset" #define PATH_LEARN_SPAM "/learnspam" #define PATH_LEARN_HAM "/learnham" +#define PATH_LEARN_CLASS "/learnclass" #define PATH_METRICS "/metrics" #define PATH_READY "/ready" #define PATH_SAVE_ACTIONS "/saveactions" @@ -2126,6 +2127,7 @@ rspamd_controller_handle_learn_common( struct rspamd_controller_worker_ctx *ctx; struct rspamd_task *task; const rspamd_ftok_t *cl_header; + const char *class_name; ctx = session->ctx; @@ -2167,7 +2169,9 @@ rspamd_controller_handle_learn_common( goto end; } - rspamd_learn_task_spam(task, is_spam, session->classifier, NULL); + /* Use unified class-based learning approach */ + class_name = is_spam ? "spam" : "ham"; + rspamd_task_set_autolearn_class(task, class_name); if (!rspamd_task_process(task, RSPAMD_TASK_PROCESS_LEARN)) { msg_warn_session("<%s> message cannot be processed", @@ -2212,6 +2216,96 @@ rspamd_controller_handle_learnham( } /* + * Learn class command handler: + * request: /learnclass + * headers: Password, Class + * input: plaintext data + * reply: json {"success":true} or {"error":"error message"} + */ +static int +rspamd_controller_handle_learnclass( + struct rspamd_http_connection_entry *conn_ent, + struct rspamd_http_message *msg) +{ + struct rspamd_controller_session *session = conn_ent->ud; + struct rspamd_controller_worker_ctx *ctx; + struct rspamd_task *task; + const rspamd_ftok_t *cl_header, *class_header; + char *class_name = NULL; + + ctx = session->ctx; + + if (!rspamd_controller_check_password(conn_ent, session, msg, TRUE)) { + return 0; + } + + if (rspamd_http_message_get_body(msg, NULL) == NULL) { + msg_err_session("got zero length body, cannot continue"); + rspamd_controller_send_error(conn_ent, + 400, + "Empty body is not permitted"); + return 0; + } + + class_header = rspamd_http_message_find_header(msg, "Class"); + if (!class_header) { + msg_err_session("missing Class header for multiclass learning"); + rspamd_controller_send_error(conn_ent, + 400, + "Class header is required for multiclass learning"); + return 0; + } + + task = rspamd_task_new(session->ctx->worker, session->cfg, session->pool, + session->ctx->lang_det, ctx->event_loop, FALSE); + + task->resolver = ctx->resolver; + task->s = rspamd_session_create(session->pool, + rspamd_controller_learn_fin_task, + NULL, + (event_finalizer_t) rspamd_task_free, + task); + task->fin_arg = conn_ent; + task->http_conn = rspamd_http_connection_ref(conn_ent->conn); + task->sock = -1; + session->task = task; + + cl_header = rspamd_http_message_find_header(msg, "classifier"); + if (cl_header) { + session->classifier = rspamd_mempool_ftokdup(session->pool, cl_header); + } + else { + session->classifier = NULL; + } + + if (!rspamd_task_load_message(task, msg, msg->body_buf.begin, msg->body_buf.len)) { + goto end; + } + + /* Set multiclass learning flag and store class name */ + class_name = rspamd_mempool_ftokdup(task->task_pool, class_header); + rspamd_task_set_autolearn_class(task, class_name); + + if (!rspamd_task_process(task, RSPAMD_TASK_PROCESS_LEARN)) { + msg_warn_session("<%s> message cannot be processed", + MESSAGE_FIELD_CHECK(task, message_id)); + goto end; + } + +end: + /* Set session spam flag for logging compatibility */ + if (class_name) { + session->is_spam = (strcmp(class_name, "spam") == 0); + } + else { + session->is_spam = FALSE; + } + rspamd_session_pending(task->s); + + return 0; +} + +/* * Scan command handler: * request: /scan * headers: Password @@ -3292,7 +3386,7 @@ rspamd_controller_handle_unknown(struct rspamd_http_connection_entry *conn_ent, rspamd_http_message_add_header(rep, "Access-Control-Allow-Methods", "POST, GET, OPTIONS"); rspamd_http_message_add_header(rep, "Access-Control-Allow-Headers", - "Classifier,Content-Type,Password,Map,Weight,Flag,Hash"); + "Classifier,Class,Content-Type,Password,Map,Weight,Flag,Hash"); rspamd_http_connection_reset(conn_ent->conn); rspamd_http_router_insert_headers(conn_ent->rt, rep); rspamd_http_connection_write_message(conn_ent->conn, @@ -3456,7 +3550,7 @@ rspamd_controller_handle_lua_plugin(struct rspamd_http_connection_entry *conn_en */ static int rspamd_controller_handle_bayes_classifiers(struct rspamd_http_connection_entry *conn_ent, - struct rspamd_http_message *msg) + struct rspamd_http_message *msg) { struct rspamd_controller_session *session = conn_ent->ud; struct rspamd_controller_worker_ctx *ctx = session->ctx; @@ -4049,6 +4143,9 @@ start_controller_worker(struct rspamd_worker *worker) PATH_LEARN_HAM, rspamd_controller_handle_learnham); rspamd_http_router_add_path(ctx->http, + PATH_LEARN_CLASS, + rspamd_controller_handle_learnclass); + rspamd_http_router_add_path(ctx->http, PATH_METRICS, rspamd_controller_handle_metrics); rspamd_http_router_add_path(ctx->http, |