diff options
Diffstat (limited to 'src/controller.c')
-rw-r--r-- | src/controller.c | 153 |
1 files changed, 144 insertions, 9 deletions
diff --git a/src/controller.c b/src/controller.c index 22423e999..6e0e4cac1 100644 --- a/src/controller.c +++ b/src/controller.c @@ -53,6 +53,7 @@ #define PATH_HISTORY_RESET "/historyreset" #define PATH_LEARN_SPAM "/learnspam" #define PATH_LEARN_HAM "/learnham" +#define PATH_LEARN_CLASS "/learnclass" #define PATH_METRICS "/metrics" #define PATH_READY "/ready" #define PATH_SAVE_ACTIONS "/saveactions" @@ -68,6 +69,7 @@ #define PATH_NEIGHBOURS "/neighbours" #define PATH_PLUGINS "/plugins" #define PATH_PING "/ping" +#define PATH_BAYES_CLASSIFIERS "/bayes/classifiers" #define msg_err_session(...) rspamd_default_log_function(G_LOG_LEVEL_CRITICAL, \ session->pool->tag.tagname, session->pool->tag.uid, \ @@ -992,9 +994,9 @@ rspamd_controller_handle_maps(struct rspamd_http_connection_entry *conn_ent, "type", 0, false); ucl_object_insert_key(obj, ucl_object_frombool(editable), "editable", 0, false); - ucl_object_insert_key(obj, ucl_object_frombool(bk->shared->loaded), + ucl_object_insert_key(obj, ucl_object_frombool(map->shared->loaded), "loaded", 0, false); - ucl_object_insert_key(obj, ucl_object_frombool(bk->shared->cached), + ucl_object_insert_key(obj, ucl_object_frombool(map->shared->cached), "cached", 0, false); ucl_array_append(top, obj); } @@ -1012,9 +1014,9 @@ rspamd_controller_handle_maps(struct rspamd_http_connection_entry *conn_ent, "type", 0, false); ucl_object_insert_key(obj, ucl_object_frombool(false), "editable", 0, false); - ucl_object_insert_key(obj, ucl_object_frombool(bk->shared->loaded), + ucl_object_insert_key(obj, ucl_object_frombool(map->shared->loaded), "loaded", 0, false); - ucl_object_insert_key(obj, ucl_object_frombool(bk->shared->cached), + ucl_object_insert_key(obj, ucl_object_frombool(map->shared->cached), "cached", 0, false); ucl_array_append(top, obj); } @@ -1141,7 +1143,7 @@ rspamd_controller_handle_get_map(struct rspamd_http_connection_entry *conn_ent, rspamd_map_traverse(bk->map, rspamd_controller_map_traverse_callback, &map_body, FALSE); rspamd_http_message_set_body_from_fstring_steal(reply, map_body); } - else if (bk->shared->loaded) { + else if (map->shared->loaded) { reply = rspamd_http_new_message(HTTP_RESPONSE); reply->code = 200; rspamd_fstring_t *map_body = rspamd_fstring_new(); @@ -2125,6 +2127,7 @@ rspamd_controller_handle_learn_common( struct rspamd_controller_worker_ctx *ctx; struct rspamd_task *task; const rspamd_ftok_t *cl_header; + const char *class_name; ctx = session->ctx; @@ -2166,7 +2169,9 @@ rspamd_controller_handle_learn_common( goto end; } - rspamd_learn_task_spam(task, is_spam, session->classifier, NULL); + /* Use unified class-based learning approach */ + class_name = is_spam ? "spam" : "ham"; + rspamd_task_set_autolearn_class(task, class_name); if (!rspamd_task_process(task, RSPAMD_TASK_PROCESS_LEARN)) { msg_warn_session("<%s> message cannot be processed", @@ -2211,6 +2216,96 @@ rspamd_controller_handle_learnham( } /* + * Learn class command handler: + * request: /learnclass + * headers: Password, Class + * input: plaintext data + * reply: json {"success":true} or {"error":"error message"} + */ +static int +rspamd_controller_handle_learnclass( + struct rspamd_http_connection_entry *conn_ent, + struct rspamd_http_message *msg) +{ + struct rspamd_controller_session *session = conn_ent->ud; + struct rspamd_controller_worker_ctx *ctx; + struct rspamd_task *task; + const rspamd_ftok_t *cl_header, *class_header; + char *class_name = NULL; + + ctx = session->ctx; + + if (!rspamd_controller_check_password(conn_ent, session, msg, TRUE)) { + return 0; + } + + if (rspamd_http_message_get_body(msg, NULL) == NULL) { + msg_err_session("got zero length body, cannot continue"); + rspamd_controller_send_error(conn_ent, + 400, + "Empty body is not permitted"); + return 0; + } + + class_header = rspamd_http_message_find_header(msg, "Class"); + if (!class_header) { + msg_err_session("missing Class header for multiclass learning"); + rspamd_controller_send_error(conn_ent, + 400, + "Class header is required for multiclass learning"); + return 0; + } + + task = rspamd_task_new(session->ctx->worker, session->cfg, session->pool, + session->ctx->lang_det, ctx->event_loop, FALSE); + + task->resolver = ctx->resolver; + task->s = rspamd_session_create(session->pool, + rspamd_controller_learn_fin_task, + NULL, + (event_finalizer_t) rspamd_task_free, + task); + task->fin_arg = conn_ent; + task->http_conn = rspamd_http_connection_ref(conn_ent->conn); + task->sock = -1; + session->task = task; + + cl_header = rspamd_http_message_find_header(msg, "classifier"); + if (cl_header) { + session->classifier = rspamd_mempool_ftokdup(session->pool, cl_header); + } + else { + session->classifier = NULL; + } + + if (!rspamd_task_load_message(task, msg, msg->body_buf.begin, msg->body_buf.len)) { + goto end; + } + + /* Set multiclass learning flag and store class name */ + class_name = rspamd_mempool_ftokdup(task->task_pool, class_header); + rspamd_task_set_autolearn_class(task, class_name); + + if (!rspamd_task_process(task, RSPAMD_TASK_PROCESS_LEARN)) { + msg_warn_session("<%s> message cannot be processed", + MESSAGE_FIELD_CHECK(task, message_id)); + goto end; + } + +end: + /* Set session spam flag for logging compatibility */ + if (class_name) { + session->is_spam = (strcmp(class_name, "spam") == 0); + } + else { + session->is_spam = FALSE; + } + rspamd_session_pending(task->s); + + return 0; +} + +/* * Scan command handler: * request: /scan * headers: Password @@ -2311,7 +2406,7 @@ rspamd_controller_handle_saveactions( return 0; } - parser = ucl_parser_new(0); + parser = ucl_parser_new(UCL_PARSER_SAFE_FLAGS); if (!ucl_parser_add_chunk(parser, msg->body_buf.begin, msg->body_buf.len)) { if ((error = ucl_parser_get_error(parser)) != NULL) { msg_err_session("cannot parse input: %s", error); @@ -2434,7 +2529,7 @@ rspamd_controller_handle_savesymbols( return 0; } - parser = ucl_parser_new(0); + parser = ucl_parser_new(UCL_PARSER_SAFE_FLAGS); if (!ucl_parser_add_chunk(parser, msg->body_buf.begin, msg->body_buf.len)) { if ((error = ucl_parser_get_error(parser)) != NULL) { msg_err_session("cannot parse input: %s", error); @@ -3291,7 +3386,7 @@ rspamd_controller_handle_unknown(struct rspamd_http_connection_entry *conn_ent, rspamd_http_message_add_header(rep, "Access-Control-Allow-Methods", "POST, GET, OPTIONS"); rspamd_http_message_add_header(rep, "Access-Control-Allow-Headers", - "Content-Type,Password,Map,Weight,Flag"); + "Classifier,Class,Content-Type,Password,Map,Weight,Flag,Hash"); rspamd_http_connection_reset(conn_ent->conn); rspamd_http_router_insert_headers(conn_ent->rt, rep); rspamd_http_connection_write_message(conn_ent->conn, @@ -3446,6 +3541,40 @@ rspamd_controller_handle_lua_plugin(struct rspamd_http_connection_entry *conn_en return 0; } +/* + * Bayes classifier list command handler: + * request: /bayes/classifiers + * headers: Password + * reply: JSON array of Bayes classifier names + * Note: list is in reverse of declaration order (GList prepend). + */ +static int +rspamd_controller_handle_bayes_classifiers(struct rspamd_http_connection_entry *conn_ent, + struct rspamd_http_message *msg) +{ + struct rspamd_controller_session *session = conn_ent->ud; + struct rspamd_controller_worker_ctx *ctx = session->ctx; + ucl_object_t *arr; + struct rspamd_classifier_config *clc; + GList *cur; + + if (!rspamd_controller_check_password(conn_ent, session, msg, FALSE)) { + return 0; + } + + arr = ucl_object_typed_new(UCL_ARRAY); + cur = g_list_last(ctx->cfg->classifiers); + while (cur) { + clc = cur->data; + ucl_array_append(arr, ucl_object_fromstring(clc->name)); + cur = g_list_previous(cur); + } + + rspamd_controller_send_ucl(conn_ent, arr); + ucl_object_unref(arr); + return 0; +} + static void rspamd_controller_error_handler(struct rspamd_http_connection_entry *conn_ent, @@ -4014,6 +4143,9 @@ start_controller_worker(struct rspamd_worker *worker) PATH_LEARN_HAM, rspamd_controller_handle_learnham); rspamd_http_router_add_path(ctx->http, + PATH_LEARN_CLASS, + rspamd_controller_handle_learnclass); + rspamd_http_router_add_path(ctx->http, PATH_METRICS, rspamd_controller_handle_metrics); rspamd_http_router_add_path(ctx->http, @@ -4055,6 +4187,9 @@ start_controller_worker(struct rspamd_worker *worker) rspamd_http_router_add_path(ctx->http, PATH_PING, rspamd_controller_handle_ping); + rspamd_http_router_add_path(ctx->http, + PATH_BAYES_CLASSIFIERS, + rspamd_controller_handle_bayes_classifiers); rspamd_controller_register_plugins_paths(ctx); #if 0 |