aboutsummaryrefslogtreecommitdiffstats
path: root/src/controller.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/controller.c')
-rw-r--r--src/controller.c103
1 files changed, 100 insertions, 3 deletions
diff --git a/src/controller.c b/src/controller.c
index 0550ba6b8..6e0e4cac1 100644
--- a/src/controller.c
+++ b/src/controller.c
@@ -53,6 +53,7 @@
#define PATH_HISTORY_RESET "/historyreset"
#define PATH_LEARN_SPAM "/learnspam"
#define PATH_LEARN_HAM "/learnham"
+#define PATH_LEARN_CLASS "/learnclass"
#define PATH_METRICS "/metrics"
#define PATH_READY "/ready"
#define PATH_SAVE_ACTIONS "/saveactions"
@@ -2126,6 +2127,7 @@ rspamd_controller_handle_learn_common(
struct rspamd_controller_worker_ctx *ctx;
struct rspamd_task *task;
const rspamd_ftok_t *cl_header;
+ const char *class_name;
ctx = session->ctx;
@@ -2167,7 +2169,9 @@ rspamd_controller_handle_learn_common(
goto end;
}
- rspamd_learn_task_spam(task, is_spam, session->classifier, NULL);
+ /* Use unified class-based learning approach */
+ class_name = is_spam ? "spam" : "ham";
+ rspamd_task_set_autolearn_class(task, class_name);
if (!rspamd_task_process(task, RSPAMD_TASK_PROCESS_LEARN)) {
msg_warn_session("<%s> message cannot be processed",
@@ -2212,6 +2216,96 @@ rspamd_controller_handle_learnham(
}
/*
+ * Learn class command handler:
+ * request: /learnclass
+ * headers: Password, Class
+ * input: plaintext data
+ * reply: json {"success":true} or {"error":"error message"}
+ */
+static int
+rspamd_controller_handle_learnclass(
+ struct rspamd_http_connection_entry *conn_ent,
+ struct rspamd_http_message *msg)
+{
+ struct rspamd_controller_session *session = conn_ent->ud;
+ struct rspamd_controller_worker_ctx *ctx;
+ struct rspamd_task *task;
+ const rspamd_ftok_t *cl_header, *class_header;
+ char *class_name = NULL;
+
+ ctx = session->ctx;
+
+ if (!rspamd_controller_check_password(conn_ent, session, msg, TRUE)) {
+ return 0;
+ }
+
+ if (rspamd_http_message_get_body(msg, NULL) == NULL) {
+ msg_err_session("got zero length body, cannot continue");
+ rspamd_controller_send_error(conn_ent,
+ 400,
+ "Empty body is not permitted");
+ return 0;
+ }
+
+ class_header = rspamd_http_message_find_header(msg, "Class");
+ if (!class_header) {
+ msg_err_session("missing Class header for multiclass learning");
+ rspamd_controller_send_error(conn_ent,
+ 400,
+ "Class header is required for multiclass learning");
+ return 0;
+ }
+
+ task = rspamd_task_new(session->ctx->worker, session->cfg, session->pool,
+ session->ctx->lang_det, ctx->event_loop, FALSE);
+
+ task->resolver = ctx->resolver;
+ task->s = rspamd_session_create(session->pool,
+ rspamd_controller_learn_fin_task,
+ NULL,
+ (event_finalizer_t) rspamd_task_free,
+ task);
+ task->fin_arg = conn_ent;
+ task->http_conn = rspamd_http_connection_ref(conn_ent->conn);
+ task->sock = -1;
+ session->task = task;
+
+ cl_header = rspamd_http_message_find_header(msg, "classifier");
+ if (cl_header) {
+ session->classifier = rspamd_mempool_ftokdup(session->pool, cl_header);
+ }
+ else {
+ session->classifier = NULL;
+ }
+
+ if (!rspamd_task_load_message(task, msg, msg->body_buf.begin, msg->body_buf.len)) {
+ goto end;
+ }
+
+ /* Set multiclass learning flag and store class name */
+ class_name = rspamd_mempool_ftokdup(task->task_pool, class_header);
+ rspamd_task_set_autolearn_class(task, class_name);
+
+ if (!rspamd_task_process(task, RSPAMD_TASK_PROCESS_LEARN)) {
+ msg_warn_session("<%s> message cannot be processed",
+ MESSAGE_FIELD_CHECK(task, message_id));
+ goto end;
+ }
+
+end:
+ /* Set session spam flag for logging compatibility */
+ if (class_name) {
+ session->is_spam = (strcmp(class_name, "spam") == 0);
+ }
+ else {
+ session->is_spam = FALSE;
+ }
+ rspamd_session_pending(task->s);
+
+ return 0;
+}
+
+/*
* Scan command handler:
* request: /scan
* headers: Password
@@ -3292,7 +3386,7 @@ rspamd_controller_handle_unknown(struct rspamd_http_connection_entry *conn_ent,
rspamd_http_message_add_header(rep, "Access-Control-Allow-Methods",
"POST, GET, OPTIONS");
rspamd_http_message_add_header(rep, "Access-Control-Allow-Headers",
- "Classifier,Content-Type,Password,Map,Weight,Flag,Hash");
+ "Classifier,Class,Content-Type,Password,Map,Weight,Flag,Hash");
rspamd_http_connection_reset(conn_ent->conn);
rspamd_http_router_insert_headers(conn_ent->rt, rep);
rspamd_http_connection_write_message(conn_ent->conn,
@@ -3456,7 +3550,7 @@ rspamd_controller_handle_lua_plugin(struct rspamd_http_connection_entry *conn_en
*/
static int
rspamd_controller_handle_bayes_classifiers(struct rspamd_http_connection_entry *conn_ent,
- struct rspamd_http_message *msg)
+ struct rspamd_http_message *msg)
{
struct rspamd_controller_session *session = conn_ent->ud;
struct rspamd_controller_worker_ctx *ctx = session->ctx;
@@ -4049,6 +4143,9 @@ start_controller_worker(struct rspamd_worker *worker)
PATH_LEARN_HAM,
rspamd_controller_handle_learnham);
rspamd_http_router_add_path(ctx->http,
+ PATH_LEARN_CLASS,
+ rspamd_controller_handle_learnclass);
+ rspamd_http_router_add_path(ctx->http,
PATH_METRICS,
rspamd_controller_handle_metrics);
rspamd_http_router_add_path(ctx->http,