summaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
authorVsevolod Stakhov <vsevolod@highsecure.ru>2015-11-23 18:36:41 +0000
committerVsevolod Stakhov <vsevolod@highsecure.ru>2015-11-23 18:36:41 +0000
commit3280d0a385cc84c4f2b44a556c26a73291d59820 (patch)
tree7354c0d009e8df9166aa7d379c4ff03fcf1569a0 /src
parent171031923755ae89ed2672130b3fb3665c5c8796 (diff)
downloadrspamd-3280d0a385cc84c4f2b44a556c26a73291d59820.tar.gz
rspamd-3280d0a385cc84c4f2b44a556c26a73291d59820.zip
Restore multiple classifiers support
Diffstat (limited to 'src')
-rw-r--r--src/controller.c19
-rw-r--r--src/libserver/cfg_rcl.c5
-rw-r--r--src/libserver/task.c10
-rw-r--r--src/libserver/task.h6
-rw-r--r--src/libserver/worker_util.h2
-rw-r--r--src/libstat/stat_api.h7
-rw-r--r--src/libstat/stat_process.c39
-rw-r--r--src/lua/lua_task.c28
8 files changed, 71 insertions, 45 deletions
diff --git a/src/controller.c b/src/controller.c
index c0783ce9d..bb494108f 100644
--- a/src/controller.c
+++ b/src/controller.c
@@ -1168,7 +1168,7 @@ rspamd_controller_learn_fin_task (void *ud)
conn_ent = task->fin_arg;
session = conn_ent->ud;
- if (rspamd_learn_task_spam (session->cl, task, session->is_spam, &err) ==
+ if (rspamd_learn_task_spam (task, session->is_spam, session->classifier, &err) ==
RSPAMD_STAT_PROCESS_ERROR) {
msg_info_session ("cannot learn <%s>: %e", task->message_id, err);
rspamd_controller_send_error (conn_ent, err->code, err->message);
@@ -1238,8 +1238,8 @@ rspamd_controller_handle_learn_common (
{
struct rspamd_controller_session *session = conn_ent->ud;
struct rspamd_controller_worker_ctx *ctx;
- struct rspamd_classifier_config *cl;
struct rspamd_task *task;
+ const rspamd_ftok_t *cl_header;
ctx = session->ctx;
@@ -1255,13 +1255,6 @@ rspamd_controller_handle_learn_common (
return 0;
}
- /* XXX: now work with only bayes */
- cl = rspamd_config_find_classifier (ctx->cfg, "bayes");
- if (cl == NULL) {
- rspamd_controller_send_error (conn_ent, 400, "Classifier not found");
- return 0;
- }
-
task = rspamd_task_new (session->ctx->worker, session->cfg);
task->resolver = ctx->resolver;
@@ -1277,8 +1270,14 @@ rspamd_controller_handle_learn_common (
task->http_conn = rspamd_http_connection_ref (conn_ent->conn);;
task->sock = -1;
session->task = task;
- session->cl = cl;
+ cl_header = rspamd_http_message_find_header (msg, "classifier");
+ if (cl_header) {
+ session->classifier = rspamd_mempool_ftokdup (session->pool, cl_header);
+ }
+ else {
+ session->classifier = NULL;
+ }
if (!rspamd_task_load_message (task, msg, msg->body_buf.begin, msg->body_buf.len)) {
rspamd_controller_send_error (conn_ent, task->err->code, task->err->message);
diff --git a/src/libserver/cfg_rcl.c b/src/libserver/cfg_rcl.c
index dc74ece31..e8e0818ca 100644
--- a/src/libserver/cfg_rcl.c
+++ b/src/libserver/cfg_rcl.c
@@ -1652,6 +1652,11 @@ rspamd_rcl_config_init (void)
rspamd_rcl_parse_struct_string,
G_STRUCT_OFFSET (struct rspamd_classifier_config, backend),
0);
+ rspamd_rcl_add_default_handler (sub,
+ "name",
+ rspamd_rcl_parse_struct_string,
+ G_STRUCT_OFFSET (struct rspamd_classifier_config, name),
+ 0);
/*
* Statfile defaults
diff --git a/src/libserver/task.c b/src/libserver/task.c
index c4ae1762c..7d34e830b 100644
--- a/src/libserver/task.c
+++ b/src/libserver/task.c
@@ -645,12 +645,16 @@ rspamd_task_re_cache_check (struct rspamd_task *task, const gchar *re)
}
gboolean
-rspamd_learn_task_spam (struct rspamd_classifier_config *cl,
- struct rspamd_task *task,
+rspamd_learn_task_spam (struct rspamd_task *task,
gboolean is_spam,
+ const gchar *classifier,
GError **err)
{
- return rspamd_stat_learn (task, is_spam, task->cfg->lua_state, err);
+ return rspamd_stat_learn (task,
+ is_spam,
+ task->cfg->lua_state,
+ classifier,
+ err);
}
static gboolean
diff --git a/src/libserver/task.h b/src/libserver/task.h
index b29bcebf6..49357e00b 100644
--- a/src/libserver/task.h
+++ b/src/libserver/task.h
@@ -260,14 +260,14 @@ guint rspamd_task_re_cache_check (struct rspamd_task *task, const gchar *re);
/**
* Learn specified statfile with message in a task
- * @param statfile symbol of statfile
* @param task worker's task object
+ * @param classifier classifier to learn (or NULL to learn all)
* @param err pointer to GError
* @return true if learn succeed
*/
-gboolean rspamd_learn_task_spam (struct rspamd_classifier_config *cl,
- struct rspamd_task *task,
+gboolean rspamd_learn_task_spam (struct rspamd_task *task,
gboolean is_spam,
+ const gchar *classifier,
GError **err);
/**
diff --git a/src/libserver/worker_util.h b/src/libserver/worker_util.h
index 21c86f92e..837e6ac33 100644
--- a/src/libserver/worker_util.h
+++ b/src/libserver/worker_util.h
@@ -84,7 +84,7 @@ struct rspamd_controller_session {
struct rspamd_worker *wrk;
rspamd_mempool_t *pool;
struct rspamd_task *task;
- struct rspamd_classifier_config *cl;
+ gchar *classifier;
rspamd_inet_addr_t *from_addr;
struct rspamd_config *cfg;
gboolean is_spam;
diff --git a/src/libstat/stat_api.h b/src/libstat/stat_api.h
index 493354007..ba5dc4a40 100644
--- a/src/libstat/stat_api.h
+++ b/src/libstat/stat_api.h
@@ -58,6 +58,8 @@ void rspamd_stat_close (void);
/**
* Classify the task specified and insert symbols if needed
* @param task
+ * @param L lua state
+ * @param err error returned
* @return TRUE if task has been classified
*/
rspamd_stat_result_t rspamd_stat_classify (struct rspamd_task *task,
@@ -68,10 +70,13 @@ rspamd_stat_result_t rspamd_stat_classify (struct rspamd_task *task,
* Learn task as spam or ham, task must be processed prior to this call
* @param task task to learn
* @param spam if TRUE learn spam, otherwise learn ham
+ * @param L lua state
+ * @param classifier NULL to learn all classifiers, name to learn a specific one
+ * @param err error returned
* @return TRUE if task has been learned
*/
rspamd_stat_result_t rspamd_stat_learn (struct rspamd_task *task,
- gboolean spam, lua_State *L,
+ gboolean spam, lua_State *L, const gchar *classifier,
GError **err);
/**
diff --git a/src/libstat/stat_process.c b/src/libstat/stat_process.c
index b19663893..952330b49 100644
--- a/src/libstat/stat_process.c
+++ b/src/libstat/stat_process.c
@@ -353,7 +353,11 @@ preprocess_init_stat_token (gpointer k, gpointer v, gpointer d)
static GList*
rspamd_stat_preprocess (struct rspamd_stat_ctx *st_ctx,
struct rspamd_task *task,
- lua_State *L, gint op, gboolean spam, GError **err)
+ lua_State *L,
+ gint op,
+ gboolean spam,
+ const gchar *classifier,
+ GError **err)
{
struct rspamd_classifier_config *clcf;
struct rspamd_statfile_config *stcf;
@@ -373,6 +377,15 @@ rspamd_stat_preprocess (struct rspamd_stat_ctx *st_ctx,
clcf = (struct rspamd_classifier_config *)cur->data;
st_list = NULL;
+ if (classifier != NULL &&
+ (clcf->name == NULL || strcmp (clcf->name, classifier) != 0)) {
+ /* Skip this classifier */
+ msg_debug_task ("skip classifier %s, as we are requested to check %s only",
+ clcf->name, classifier);
+ cur = g_list_next (cur);
+ continue;
+ }
+
if (clcf->pre_callbacks != NULL) {
st_list = rspamd_lua_call_cls_pre_callbacks (clcf, task, FALSE,
FALSE, L);
@@ -518,6 +531,11 @@ rspamd_stat_preprocess (struct rspamd_stat_ctx *st_ctx,
g_tree_foreach (cbdata.tok->tokens, preprocess_init_stat_token,
&cbdata);
}
+ else if (classifier != NULL) {
+ /* We likely cannot find any classifier with this name */
+ g_set_error (err, rspamd_stat_quark (), 404,
+ "cannot find classifier %s", classifier);
+ }
return cl_runtimes;
}
@@ -538,7 +556,7 @@ rspamd_stat_classify (struct rspamd_task *task, lua_State *L, GError **err)
/* Initialize classifiers and statfiles runtime */
if ((cl_runtimes = rspamd_stat_preprocess (st_ctx, task, L,
- RSPAMD_CLASSIFY_OP, FALSE, err)) == NULL) {
+ RSPAMD_CLASSIFY_OP, FALSE, NULL, err)) == NULL) {
return RSPAMD_STAT_PROCESS_OK;
}
@@ -659,7 +677,10 @@ rspamd_stat_learn_token (gpointer k, gpointer v, gpointer d)
}
rspamd_stat_result_t
-rspamd_stat_learn (struct rspamd_task *task, gboolean spam, lua_State *L,
+rspamd_stat_learn (struct rspamd_task *task,
+ gboolean spam,
+ lua_State *L,
+ const gchar *classifier,
GError **err)
{
struct rspamd_stat_ctx *st_ctx;
@@ -669,7 +690,8 @@ rspamd_stat_learn (struct rspamd_task *task, gboolean spam, lua_State *L,
struct preprocess_cb_data cbdata;
GList *cl_runtimes;
GList *cur, *curst;
- gboolean ret = RSPAMD_STAT_PROCESS_ERROR, unlearn = FALSE;
+ gboolean unlearn = FALSE;
+ rspamd_stat_result_t ret = RSPAMD_STAT_PROCESS_ERROR;
gulong nrev;
rspamd_learn_t learn_res = RSPAMD_LEARN_OK;
guint i;
@@ -698,8 +720,13 @@ rspamd_stat_learn (struct rspamd_task *task, gboolean spam, lua_State *L,
}
/* Initialize classifiers and statfiles runtime */
- if ((cl_runtimes = rspamd_stat_preprocess (st_ctx, task, L,
- unlearn ? RSPAMD_UNLEARN_OP : RSPAMD_LEARN_OP, spam, err)) == NULL) {
+ if ((cl_runtimes = rspamd_stat_preprocess (st_ctx,
+ task,
+ L,
+ unlearn ? RSPAMD_UNLEARN_OP : RSPAMD_LEARN_OP,
+ spam,
+ classifier,
+ err)) == NULL) {
return RSPAMD_STAT_PROCESS_ERROR;
}
diff --git a/src/lua/lua_task.c b/src/lua/lua_task.c
index d9840275f..0e43b7ab5 100644
--- a/src/lua/lua_task.c
+++ b/src/lua/lua_task.c
@@ -1945,8 +1945,7 @@ lua_task_learn (lua_State *L)
{
struct rspamd_task *task = lua_check_task (L, 1);
gboolean is_spam = FALSE;
- const gchar *clname;
- struct rspamd_classifier_config *cl;
+ const gchar *clname = NULL;
GError *err = NULL;
int ret = 1;
@@ -1954,29 +1953,16 @@ lua_task_learn (lua_State *L)
if (lua_gettop (L) > 2) {
clname = luaL_checkstring (L, 3);
}
- else {
- clname = "bayes";
- }
-
- cl = rspamd_config_find_classifier (task->cfg, clname);
- if (cl == NULL) {
- msg_warn_task ("classifier %s is not found", clname);
+ if (!rspamd_learn_task_spam (task, is_spam, clname, &err)) {
lua_pushboolean (L, FALSE);
- lua_pushstring (L, "classifier not found");
- ret = 2;
+ if (err != NULL) {
+ lua_pushstring (L, err->message);
+ ret = 2;
+ }
}
else {
- if (!rspamd_learn_task_spam (cl, task, is_spam, &err)) {
- lua_pushboolean (L, FALSE);
- if (err != NULL) {
- lua_pushstring (L, err->message);
- ret = 2;
- }
- }
- else {
- lua_pushboolean (L, TRUE);
- }
+ lua_pushboolean (L, TRUE);
}
return ret;