From 3280d0a385cc84c4f2b44a556c26a73291d59820 Mon Sep 17 00:00:00 2001 From: Vsevolod Stakhov Date: Mon, 23 Nov 2015 18:36:41 +0000 Subject: [PATCH] Restore multiple classifiers support --- src/controller.c | 19 +++++++++--------- src/libserver/cfg_rcl.c | 5 +++++ src/libserver/task.c | 10 +++++++--- src/libserver/task.h | 6 +++--- src/libserver/worker_util.h | 2 +- src/libstat/stat_api.h | 7 ++++++- src/libstat/stat_process.c | 39 +++++++++++++++++++++++++++++++------ src/lua/lua_task.c | 28 +++++++------------------- 8 files changed, 71 insertions(+), 45 deletions(-) diff --git a/src/controller.c b/src/controller.c index c0783ce9d..bb494108f 100644 --- a/src/controller.c +++ b/src/controller.c @@ -1168,7 +1168,7 @@ rspamd_controller_learn_fin_task (void *ud) conn_ent = task->fin_arg; session = conn_ent->ud; - if (rspamd_learn_task_spam (session->cl, task, session->is_spam, &err) == + if (rspamd_learn_task_spam (task, session->is_spam, session->classifier, &err) == RSPAMD_STAT_PROCESS_ERROR) { msg_info_session ("cannot learn <%s>: %e", task->message_id, err); rspamd_controller_send_error (conn_ent, err->code, err->message); @@ -1238,8 +1238,8 @@ rspamd_controller_handle_learn_common ( { struct rspamd_controller_session *session = conn_ent->ud; struct rspamd_controller_worker_ctx *ctx; - struct rspamd_classifier_config *cl; struct rspamd_task *task; + const rspamd_ftok_t *cl_header; ctx = session->ctx; @@ -1255,13 +1255,6 @@ rspamd_controller_handle_learn_common ( return 0; } - /* XXX: now work with only bayes */ - cl = rspamd_config_find_classifier (ctx->cfg, "bayes"); - if (cl == NULL) { - rspamd_controller_send_error (conn_ent, 400, "Classifier not found"); - return 0; - } - task = rspamd_task_new (session->ctx->worker, session->cfg); task->resolver = ctx->resolver; @@ -1277,8 +1270,14 @@ rspamd_controller_handle_learn_common ( task->http_conn = rspamd_http_connection_ref (conn_ent->conn);; task->sock = -1; session->task = task; - session->cl = cl; + cl_header = rspamd_http_message_find_header (msg, "classifier"); + if (cl_header) { + session->classifier = rspamd_mempool_ftokdup (session->pool, cl_header); + } + else { + session->classifier = NULL; + } if (!rspamd_task_load_message (task, msg, msg->body_buf.begin, msg->body_buf.len)) { rspamd_controller_send_error (conn_ent, task->err->code, task->err->message); diff --git a/src/libserver/cfg_rcl.c b/src/libserver/cfg_rcl.c index dc74ece31..e8e0818ca 100644 --- a/src/libserver/cfg_rcl.c +++ b/src/libserver/cfg_rcl.c @@ -1652,6 +1652,11 @@ rspamd_rcl_config_init (void) rspamd_rcl_parse_struct_string, G_STRUCT_OFFSET (struct rspamd_classifier_config, backend), 0); + rspamd_rcl_add_default_handler (sub, + "name", + rspamd_rcl_parse_struct_string, + G_STRUCT_OFFSET (struct rspamd_classifier_config, name), + 0); /* * Statfile defaults diff --git a/src/libserver/task.c b/src/libserver/task.c index c4ae1762c..7d34e830b 100644 --- a/src/libserver/task.c +++ b/src/libserver/task.c @@ -645,12 +645,16 @@ rspamd_task_re_cache_check (struct rspamd_task *task, const gchar *re) } gboolean -rspamd_learn_task_spam (struct rspamd_classifier_config *cl, - struct rspamd_task *task, +rspamd_learn_task_spam (struct rspamd_task *task, gboolean is_spam, + const gchar *classifier, GError **err) { - return rspamd_stat_learn (task, is_spam, task->cfg->lua_state, err); + return rspamd_stat_learn (task, + is_spam, + task->cfg->lua_state, + classifier, + err); } static gboolean diff --git a/src/libserver/task.h b/src/libserver/task.h index b29bcebf6..49357e00b 100644 --- a/src/libserver/task.h +++ b/src/libserver/task.h @@ -260,14 +260,14 @@ guint rspamd_task_re_cache_check (struct rspamd_task *task, const gchar *re); /** * Learn specified statfile with message in a task - * @param statfile symbol of statfile * @param task worker's task object + * @param classifier classifier to learn (or NULL to learn all) * @param err pointer to GError * @return true if learn succeed */ -gboolean rspamd_learn_task_spam (struct rspamd_classifier_config *cl, - struct rspamd_task *task, +gboolean rspamd_learn_task_spam (struct rspamd_task *task, gboolean is_spam, + const gchar *classifier, GError **err); /** diff --git a/src/libserver/worker_util.h b/src/libserver/worker_util.h index 21c86f92e..837e6ac33 100644 --- a/src/libserver/worker_util.h +++ b/src/libserver/worker_util.h @@ -84,7 +84,7 @@ struct rspamd_controller_session { struct rspamd_worker *wrk; rspamd_mempool_t *pool; struct rspamd_task *task; - struct rspamd_classifier_config *cl; + gchar *classifier; rspamd_inet_addr_t *from_addr; struct rspamd_config *cfg; gboolean is_spam; diff --git a/src/libstat/stat_api.h b/src/libstat/stat_api.h index 493354007..ba5dc4a40 100644 --- a/src/libstat/stat_api.h +++ b/src/libstat/stat_api.h @@ -58,6 +58,8 @@ void rspamd_stat_close (void); /** * Classify the task specified and insert symbols if needed * @param task + * @param L lua state + * @param err error returned * @return TRUE if task has been classified */ rspamd_stat_result_t rspamd_stat_classify (struct rspamd_task *task, @@ -68,10 +70,13 @@ rspamd_stat_result_t rspamd_stat_classify (struct rspamd_task *task, * Learn task as spam or ham, task must be processed prior to this call * @param task task to learn * @param spam if TRUE learn spam, otherwise learn ham + * @param L lua state + * @param classifier NULL to learn all classifiers, name to learn a specific one + * @param err error returned * @return TRUE if task has been learned */ rspamd_stat_result_t rspamd_stat_learn (struct rspamd_task *task, - gboolean spam, lua_State *L, + gboolean spam, lua_State *L, const gchar *classifier, GError **err); /** diff --git a/src/libstat/stat_process.c b/src/libstat/stat_process.c index b19663893..952330b49 100644 --- a/src/libstat/stat_process.c +++ b/src/libstat/stat_process.c @@ -353,7 +353,11 @@ preprocess_init_stat_token (gpointer k, gpointer v, gpointer d) static GList* rspamd_stat_preprocess (struct rspamd_stat_ctx *st_ctx, struct rspamd_task *task, - lua_State *L, gint op, gboolean spam, GError **err) + lua_State *L, + gint op, + gboolean spam, + const gchar *classifier, + GError **err) { struct rspamd_classifier_config *clcf; struct rspamd_statfile_config *stcf; @@ -373,6 +377,15 @@ rspamd_stat_preprocess (struct rspamd_stat_ctx *st_ctx, clcf = (struct rspamd_classifier_config *)cur->data; st_list = NULL; + if (classifier != NULL && + (clcf->name == NULL || strcmp (clcf->name, classifier) != 0)) { + /* Skip this classifier */ + msg_debug_task ("skip classifier %s, as we are requested to check %s only", + clcf->name, classifier); + cur = g_list_next (cur); + continue; + } + if (clcf->pre_callbacks != NULL) { st_list = rspamd_lua_call_cls_pre_callbacks (clcf, task, FALSE, FALSE, L); @@ -518,6 +531,11 @@ rspamd_stat_preprocess (struct rspamd_stat_ctx *st_ctx, g_tree_foreach (cbdata.tok->tokens, preprocess_init_stat_token, &cbdata); } + else if (classifier != NULL) { + /* We likely cannot find any classifier with this name */ + g_set_error (err, rspamd_stat_quark (), 404, + "cannot find classifier %s", classifier); + } return cl_runtimes; } @@ -538,7 +556,7 @@ rspamd_stat_classify (struct rspamd_task *task, lua_State *L, GError **err) /* Initialize classifiers and statfiles runtime */ if ((cl_runtimes = rspamd_stat_preprocess (st_ctx, task, L, - RSPAMD_CLASSIFY_OP, FALSE, err)) == NULL) { + RSPAMD_CLASSIFY_OP, FALSE, NULL, err)) == NULL) { return RSPAMD_STAT_PROCESS_OK; } @@ -659,7 +677,10 @@ rspamd_stat_learn_token (gpointer k, gpointer v, gpointer d) } rspamd_stat_result_t -rspamd_stat_learn (struct rspamd_task *task, gboolean spam, lua_State *L, +rspamd_stat_learn (struct rspamd_task *task, + gboolean spam, + lua_State *L, + const gchar *classifier, GError **err) { struct rspamd_stat_ctx *st_ctx; @@ -669,7 +690,8 @@ rspamd_stat_learn (struct rspamd_task *task, gboolean spam, lua_State *L, struct preprocess_cb_data cbdata; GList *cl_runtimes; GList *cur, *curst; - gboolean ret = RSPAMD_STAT_PROCESS_ERROR, unlearn = FALSE; + gboolean unlearn = FALSE; + rspamd_stat_result_t ret = RSPAMD_STAT_PROCESS_ERROR; gulong nrev; rspamd_learn_t learn_res = RSPAMD_LEARN_OK; guint i; @@ -698,8 +720,13 @@ rspamd_stat_learn (struct rspamd_task *task, gboolean spam, lua_State *L, } /* Initialize classifiers and statfiles runtime */ - if ((cl_runtimes = rspamd_stat_preprocess (st_ctx, task, L, - unlearn ? RSPAMD_UNLEARN_OP : RSPAMD_LEARN_OP, spam, err)) == NULL) { + if ((cl_runtimes = rspamd_stat_preprocess (st_ctx, + task, + L, + unlearn ? RSPAMD_UNLEARN_OP : RSPAMD_LEARN_OP, + spam, + classifier, + err)) == NULL) { return RSPAMD_STAT_PROCESS_ERROR; } diff --git a/src/lua/lua_task.c b/src/lua/lua_task.c index d9840275f..0e43b7ab5 100644 --- a/src/lua/lua_task.c +++ b/src/lua/lua_task.c @@ -1945,8 +1945,7 @@ lua_task_learn (lua_State *L) { struct rspamd_task *task = lua_check_task (L, 1); gboolean is_spam = FALSE; - const gchar *clname; - struct rspamd_classifier_config *cl; + const gchar *clname = NULL; GError *err = NULL; int ret = 1; @@ -1954,29 +1953,16 @@ lua_task_learn (lua_State *L) if (lua_gettop (L) > 2) { clname = luaL_checkstring (L, 3); } - else { - clname = "bayes"; - } - - cl = rspamd_config_find_classifier (task->cfg, clname); - if (cl == NULL) { - msg_warn_task ("classifier %s is not found", clname); + if (!rspamd_learn_task_spam (task, is_spam, clname, &err)) { lua_pushboolean (L, FALSE); - lua_pushstring (L, "classifier not found"); - ret = 2; + if (err != NULL) { + lua_pushstring (L, err->message); + ret = 2; + } } else { - if (!rspamd_learn_task_spam (cl, task, is_spam, &err)) { - lua_pushboolean (L, FALSE); - if (err != NULL) { - lua_pushstring (L, err->message); - ret = 2; - } - } - else { - lua_pushboolean (L, TRUE); - } + lua_pushboolean (L, TRUE); } return ret; -- 2.39.5