aboutsummaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
authorVsevolod Stakhov <vsevolod@rambler-co.ru>2011-07-12 20:46:55 +0400
committerVsevolod Stakhov <vsevolod@rambler-co.ru>2011-07-12 20:46:55 +0400
commitff4871310ff5b269dcd02ea300cf78092860e1d4 (patch)
treecfa435f5de1dc8efc646a0ca1fc6fd261b2c1aa6 /src
parentc4105fc43199d51af271bc24d3345aa57906d973 (diff)
downloadrspamd-ff4871310ff5b269dcd02ea300cf78092860e1d4.tar.gz
rspamd-ff4871310ff5b269dcd02ea300cf78092860e1d4.zip
* First commit to implement multi-statfile filter system with new learning mechanizm (untested yet)
Diffstat (limited to 'src')
-rw-r--r--src/cfg_file.h5
-rw-r--r--src/cfg_utils.c40
-rw-r--r--src/cfg_xml.c47
-rw-r--r--src/cfg_xml.h1
-rw-r--r--src/classifiers/bayes.c80
-rw-r--r--src/classifiers/classifiers.c2
-rw-r--r--src/classifiers/classifiers.h6
-rw-r--r--src/classifiers/winnow.c33
-rw-r--r--src/controller.c187
-rw-r--r--src/filter.c83
-rw-r--r--src/filter.h10
-rw-r--r--src/html.c27
-rw-r--r--src/lua/lua_classifier.c178
-rw-r--r--src/lua/lua_common.h2
-rw-r--r--src/main.h4
-rw-r--r--src/util.c23
-rw-r--r--src/util.h3
17 files changed, 649 insertions, 82 deletions
diff --git a/src/cfg_file.h b/src/cfg_file.h
index 7e2dfd413..d15923639 100644
--- a/src/cfg_file.h
+++ b/src/cfg_file.h
@@ -422,14 +422,17 @@ gboolean get_config_checksum (struct config_file *cfg);
void unescape_quotes (gchar *line);
GList* parse_comma_list (memory_pool_t *pool, gchar *line);
-struct classifier_config* check_classifier_cfg (struct config_file *cfg, struct classifier_config *c);
+struct classifier_config* check_classifier_conf (struct config_file *cfg, struct classifier_config *c);
struct worker_conf* check_worker_conf (struct config_file *cfg, struct worker_conf *c);
struct metric* check_metric_conf (struct config_file *cfg, struct metric *c);
+struct statfile* check_statfile_conf (struct config_file *cfg, struct statfile *c);
gboolean parse_normalizer (struct config_file *cfg, struct statfile *st, const gchar *line);
gboolean read_xml_config (struct config_file *cfg, const gchar *filename);
gboolean check_modules_config (struct config_file *cfg);
void insert_classifier_symbols (struct config_file *cfg);
+struct classifier_config* find_classifier_conf (struct config_file *cfg, const gchar *name);
+
#endif /* ifdef CFG_FILE_H */
/*
* vi:ts=4
diff --git a/src/cfg_utils.c b/src/cfg_utils.c
index 720d931ef..6bd16d620 100644
--- a/src/cfg_utils.c
+++ b/src/cfg_utils.c
@@ -751,7 +751,7 @@ parse_comma_list (memory_pool_t * pool, gchar *line)
}
struct classifier_config *
-check_classifier_cfg (struct config_file *cfg, struct classifier_config *c)
+check_classifier_conf (struct config_file *cfg, struct classifier_config *c)
{
if (c == NULL) {
c = memory_pool_alloc0 (cfg->cfg_pool, sizeof (struct classifier_config));
@@ -764,6 +764,20 @@ check_classifier_cfg (struct config_file *cfg, struct classifier_config *c)
return c;
}
+struct statfile*
+check_statfile_conf (struct config_file *cfg, struct statfile *c)
+{
+ if (c == NULL) {
+ c = memory_pool_alloc0 (cfg->cfg_pool, sizeof (struct statfile));
+ }
+ if (c->opts == NULL) {
+ c->opts = g_hash_table_new (g_str_hash, g_str_equal);
+ memory_pool_add_destructor (cfg->cfg_pool, (pool_destruct_func) g_hash_table_destroy, c->opts);
+ }
+
+ return c;
+}
+
struct metric *
check_metric_conf (struct config_file *cfg, struct metric *c)
{
@@ -1006,6 +1020,30 @@ insert_classifier_symbols (struct config_file *cfg)
g_hash_table_foreach (cfg->classifiers_symbols, symbols_classifiers_callback, cfg);
}
+struct classifier_config*
+find_classifier_conf (struct config_file *cfg, const gchar *name)
+{
+ GList *cur;
+ struct classifier_config *cf;
+
+ if (name == NULL) {
+ return NULL;
+ }
+
+ cur = cfg->classifiers;
+ while (cur) {
+ cf = cur->data;
+
+ if (g_ascii_strcasecmp (cf->classifier->name, name) == 0) {
+ return cf;
+ }
+
+ cur = g_list_next (cur);
+ }
+
+ return NULL;
+}
+
/*
* vi:ts=4
*/
diff --git a/src/cfg_xml.c b/src/cfg_xml.c
index 09f6574a0..6953edb3a 100644
--- a/src/cfg_xml.c
+++ b/src/cfg_xml.c
@@ -471,6 +471,12 @@ static struct xml_parser_rule grammar[] = {
NULL
},
{
+ "spam",
+ xml_handle_boolean,
+ G_STRUCT_OFFSET (struct statfile, is_spam),
+ NULL
+ },
+ {
"normalizer",
handle_statfile_normalizer,
0,
@@ -496,7 +502,11 @@ static struct xml_parser_rule grammar[] = {
},
NULL_ATTR
},
- NULL_DEF_ATTR
+ {
+ handle_statfile_opt,
+ 0,
+ NULL
+ }
},
{ XML_SECTION_MODULE, {
NULL_ATTR
@@ -1017,9 +1027,9 @@ static void
set_lua_globals (struct config_file *cfg, lua_State *L)
{
struct config_file **pcfg;
+
/* First check for global variable 'config' */
lua_getglobal (L, "config");
-
if (lua_isnil (L, -1)) {
/* Assign global table to set up attributes */
lua_newtable (L);
@@ -1038,13 +1048,19 @@ set_lua_globals (struct config_file *cfg, lua_State *L)
lua_setglobal (L, "composites");
}
+ lua_getglobal (L, "classifiers");
+ if (lua_isnil (L, -1)) {
+ lua_newtable (L);
+ lua_setglobal (L, "classifiers");
+ }
+
pcfg = lua_newuserdata (L, sizeof (struct config_file *));
lua_setclass (L, "rspamd{config}", -1);
*pcfg = cfg;
lua_setglobal (L, "rspamd_config");
/* Clear stack from globals */
- lua_pop (L, 3);
+ lua_pop (L, 4);
}
/* Handle lua tag */
@@ -1402,6 +1418,27 @@ handle_statfile_binlog_master (struct config_file *cfg, struct rspamd_xml_userda
return TRUE;
}
+gboolean
+handle_statfile_opt (struct config_file *cfg, struct rspamd_xml_userdata *ctx, const gchar *tag, GHashTable *attrs, gchar *data, gpointer user_data, gpointer dest_struct, gint offset)
+{
+ struct statfile *st = ctx->section_pointer;
+ const gchar *name;
+
+ if (g_ascii_strcasecmp (tag, "option") == 0 || g_ascii_strcasecmp (tag, "param") == 0) {
+ if (attrs == NULL || (name = g_hash_table_lookup (attrs, "name")) == NULL) {
+ msg_err ("worker param tag must have \"name\" attribute");
+ return FALSE;
+ }
+ }
+ else {
+ name = memory_pool_strdup (cfg->cfg_pool, tag);
+ }
+
+ g_hash_table_insert (st->opts, (char *)name, memory_pool_strdup (cfg->cfg_pool, data));
+
+ return TRUE;
+}
+
/* Common handlers */
gboolean
xml_handle_string (struct config_file *cfg, struct rspamd_xml_userdata *ctx, GHashTable *attrs, gchar *data, gpointer user_data, gpointer dest_struct, gint offset)
@@ -1617,7 +1654,7 @@ rspamd_xml_start_element (GMarkupParseContext *context, const gchar *element_nam
if (extract_attr ("type", attribute_names, attribute_values, &res)) {
ud->state = XML_READ_CLASSIFIER;
/* Create object */
- ccf = check_classifier_cfg (ud->cfg, NULL);
+ ccf = check_classifier_conf (ud->cfg, NULL);
if ((ccf->classifier = get_classifier (res)) == NULL) {
*error = g_error_new (xml_error_quark (), XML_INVALID_ATTR, "invalid classifier type: %s", res);
ud->state = XML_ERROR;
@@ -1665,7 +1702,7 @@ rspamd_xml_start_element (GMarkupParseContext *context, const gchar *element_nam
/* Now section pointer is statfile and parent pointer is classifier */
ud->parent_pointer = ud->section_pointer;
- ud->section_pointer = memory_pool_alloc0 (ud->cfg->cfg_pool, sizeof (struct statfile));
+ ud->section_pointer = check_statfile_conf (ud->cfg, NULL);
}
else {
rspamd_strlcpy (ud->section_name, element_name, sizeof (ud->section_name));
diff --git a/src/cfg_xml.h b/src/cfg_xml.h
index b48f3d96c..659f35c19 100644
--- a/src/cfg_xml.h
+++ b/src/cfg_xml.h
@@ -154,6 +154,7 @@ gboolean handle_statfile_normalizer (struct config_file *cfg, struct rspamd_xml_
gboolean handle_statfile_binlog (struct config_file *cfg, struct rspamd_xml_userdata *ctx, GHashTable *attrs, gchar *data, gpointer user_data, gpointer dest_struct, gint offset);
gboolean handle_statfile_binlog_rotate (struct config_file *cfg, struct rspamd_xml_userdata *ctx, GHashTable *attrs, gchar *data, gpointer user_data, gpointer dest_struct, gint offset);
gboolean handle_statfile_binlog_master (struct config_file *cfg, struct rspamd_xml_userdata *ctx, GHashTable *attrs, gchar *data, gpointer user_data, gpointer dest_struct, gint offset);
+gboolean handle_statfile_opt (struct config_file *cfg, struct rspamd_xml_userdata *ctx, const gchar *tag, GHashTable *attrs, gchar *data, gpointer user_data, gpointer dest_struct, gint offset);
/* Register new module option */
void register_module_opt (const gchar *mname, const gchar *optname, enum module_opt_type type);
diff --git a/src/classifiers/bayes.c b/src/classifiers/bayes.c
index 7363df522..44e9323a2 100644
--- a/src/classifiers/bayes.c
+++ b/src/classifiers/bayes.c
@@ -30,9 +30,8 @@
#include "../main.h"
#include "../filter.h"
#include "../cfg_file.h"
-#ifdef WITH_LUA
+#include "../binlog.h"
#include "../lua/lua_common.h"
-#endif
#define LOCAL_PROB_DENOM 16.0
@@ -194,15 +193,22 @@ bayes_classify (struct classifier_ctx* ctx, statfile_pool_t *pool, GTree *input,
}
}
- data.statfiles_num = g_list_length (ctx->cfg->statfiles);
+ cur = call_classifier_pre_callbacks (ctx->cfg, task, FALSE, FALSE);
+ if (cur) {
+ memory_pool_add_destructor (task->task_pool, (pool_destruct_func)g_list_free, cur);
+ }
+ else {
+ cur = ctx->cfg->statfiles;
+ }
+
+ data.statfiles_num = g_list_length (cur);
data.statfiles = g_new0 (struct bayes_statfile_data, data.statfiles_num);
data.pool = pool;
data.now = time (NULL);
data.ctx = ctx;
- cur = ctx->cfg->statfiles;
while (cur) {
- /* Select statfile to learn */
+ /* Select statfile to classify */
st = cur->data;
if ((file = statfile_pool_is_open (pool, st->path)) == NULL) {
if ((file = statfile_pool_open (pool, st->path, st->size, FALSE)) == NULL) {
@@ -344,6 +350,70 @@ bayes_learn (struct classifier_ctx* ctx, statfile_pool_t *pool, const char *symb
return TRUE;
}
+gboolean
+bayes_learn_spam (struct classifier_ctx* ctx, statfile_pool_t *pool,
+ GTree *input, struct worker_task *task, gboolean is_spam, GError **err)
+{
+ struct bayes_callback_data data;
+ gchar *value;
+ gint nodes, minnodes;
+ struct statfile *st;
+ stat_file_t *file;
+ GList *cur;
+
+ g_assert (pool != NULL);
+ g_assert (ctx != NULL);
+
+ if (ctx->cfg->opts && (value = g_hash_table_lookup (ctx->cfg->opts, "min_tokens")) != NULL) {
+ minnodes = strtol (value, NULL, 10);
+ nodes = g_tree_nnodes (input);
+ if (nodes > FEATURE_WINDOW_SIZE) {
+ nodes = nodes / FEATURE_WINDOW_SIZE + FEATURE_WINDOW_SIZE;
+ }
+ if (nodes < minnodes) {
+ return FALSE;
+ }
+ }
+
+ cur = call_classifier_pre_callbacks (ctx->cfg, task, FALSE, FALSE);
+ if (cur) {
+ memory_pool_add_destructor (task->task_pool, (pool_destruct_func)g_list_free, cur);
+ }
+ else {
+ cur = ctx->cfg->statfiles;
+ }
+
+ data.pool = pool;
+ data.now = time (NULL);
+ data.ctx = ctx;
+
+ while (cur) {
+ /* Select statfiles to learn */
+ st = cur->data;
+ if (st->is_spam != is_spam) {
+ cur = g_list_next (cur);
+ continue;
+ }
+ if ((file = statfile_pool_is_open (pool, st->path)) == NULL) {
+ if ((file = statfile_pool_open (pool, st->path, st->size, FALSE)) == NULL) {
+ msg_warn ("cannot open %s", st->path);
+ cur = g_list_next (cur);
+ continue;
+ }
+ }
+ data.file = file;
+ statfile_pool_lock_file (pool, data.file);
+ g_tree_foreach (input, bayes_learn_callback, &data);
+ statfile_inc_revision (file);
+ statfile_pool_unlock_file (pool, data.file);
+ maybe_write_binlog (ctx->cfg, st, file, input);
+
+ cur = g_list_next (cur);
+ }
+
+ return TRUE;
+}
+
GList *
bayes_weights (struct classifier_ctx* ctx, statfile_pool_t *pool, GTree *input, struct worker_task *task)
{
diff --git a/src/classifiers/classifiers.c b/src/classifiers/classifiers.c
index 6b0554e1b..5e2b9ea88 100644
--- a/src/classifiers/classifiers.c
+++ b/src/classifiers/classifiers.c
@@ -35,6 +35,7 @@ struct classifier classifiers[] = {
.init_func = winnow_init,
.classify_func = winnow_classify,
.learn_func = winnow_learn,
+ .learn_spam_func = winnow_learn_spam,
.weights_func = winnow_weights
},
{
@@ -42,6 +43,7 @@ struct classifier classifiers[] = {
.init_func = bayes_init,
.classify_func = bayes_classify,
.learn_func = bayes_learn,
+ .learn_spam_func = bayes_learn_spam,
.weights_func = bayes_weights
}
};
diff --git a/src/classifiers/classifiers.h b/src/classifiers/classifiers.h
index 601db0205..78ceb196e 100644
--- a/src/classifiers/classifiers.h
+++ b/src/classifiers/classifiers.h
@@ -32,6 +32,8 @@ struct classifier {
gboolean (*learn_func)(struct classifier_ctx* ctx, statfile_pool_t *pool,
const char *symbol, GTree *input, gboolean in_class,
double *sum, double multiplier, GError **err);
+ gboolean (*learn_spam_func)(struct classifier_ctx* ctx, statfile_pool_t *pool,
+ GTree *input, struct worker_task *task, gboolean is_spam, GError **err);
GList* (*weights_func)(struct classifier_ctx* ctx, statfile_pool_t *pool, GTree *input, struct worker_task *task);
};
@@ -43,6 +45,8 @@ struct classifier_ctx* winnow_init (memory_pool_t *pool, struct classifier_confi
gboolean winnow_classify (struct classifier_ctx* ctx, statfile_pool_t *pool, GTree *input, struct worker_task *task);
gboolean winnow_learn (struct classifier_ctx* ctx, statfile_pool_t *pool, const char *symbol, GTree *input,
gboolean in_class, double *sum, double multiplier, GError **err);
+gboolean winnow_learn_spam (struct classifier_ctx* ctx, statfile_pool_t *pool,
+ GTree *input, struct worker_task *task, gboolean is_spam, GError **err);
GList *winnow_weights (struct classifier_ctx* ctx, statfile_pool_t *pool, GTree *input, struct worker_task *task);
/* Bayes algorithm */
@@ -50,6 +54,8 @@ struct classifier_ctx* bayes_init (memory_pool_t *pool, struct classifier_config
gboolean bayes_classify (struct classifier_ctx* ctx, statfile_pool_t *pool, GTree *input, struct worker_task *task);
gboolean bayes_learn (struct classifier_ctx* ctx, statfile_pool_t *pool, const char *symbol, GTree *input,
gboolean in_class, double *sum, double multiplier, GError **err);
+gboolean bayes_learn_spam (struct classifier_ctx* ctx, statfile_pool_t *pool,
+ GTree *input, struct worker_task *task, gboolean is_spam, GError **err);
GList *bayes_weights (struct classifier_ctx* ctx, statfile_pool_t *pool, GTree *input, struct worker_task *task);
/* Array of all defined classifiers */
extern struct classifier classifiers[];
diff --git a/src/classifiers/winnow.c b/src/classifiers/winnow.c
index 2e8b98423..b123ce3e5 100644
--- a/src/classifiers/winnow.c
+++ b/src/classifiers/winnow.c
@@ -223,19 +223,14 @@ winnow_classify (struct classifier_ctx *ctx, statfile_pool_t * pool, GTree * inp
}
}
- if (ctx->cfg->pre_callbacks) {
-#ifdef WITH_LUA
- cur = call_classifier_pre_callbacks (ctx->cfg, task);
- if (cur) {
- memory_pool_add_destructor (task->task_pool, (pool_destruct_func)g_list_free, cur);
- }
-#else
- cur = ctx->cfg->statfiles;
-#endif
- }
- else {
- cur = ctx->cfg->statfiles;
- }
+ cur = call_classifier_pre_callbacks (ctx->cfg, task, FALSE, FALSE);
+ if (cur) {
+ memory_pool_add_destructor (task->task_pool, (pool_destruct_func)g_list_free, cur);
+ }
+ else {
+ cur = ctx->cfg->statfiles;
+ }
+
while (cur) {
st = cur->data;
data.sum = 0;
@@ -597,3 +592,15 @@ end:
}
return TRUE;
}
+
+gboolean
+winnow_learn_spam (struct classifier_ctx* ctx, statfile_pool_t *pool,
+ GTree *input, struct worker_task *task, gboolean is_spam, GError **err)
+{
+ g_set_error (err,
+ winnow_error_quark(), /* error domain */
+ 1, /* error code */
+ "learn spam is not supported for winnow"
+ );
+ return FALSE;
+}
diff --git a/src/controller.c b/src/controller.c
index e7e4c10d8..ce33c53a2 100644
--- a/src/controller.c
+++ b/src/controller.c
@@ -35,6 +35,7 @@
#include "classifiers/classifiers.h"
#include "binlog.h"
#include "statfile_sync.h"
+#include "lua/lua_common.h"
#define END "END" CRLF
@@ -49,6 +50,8 @@ enum command_type {
COMMAND_SHUTDOWN,
COMMAND_UPTIME,
COMMAND_LEARN,
+ COMMAND_LEARN_SPAM,
+ COMMAND_LEARN_HAM,
COMMAND_HELP,
COMMAND_COUNTERS,
COMMAND_SYNC,
@@ -84,7 +87,9 @@ static struct controller_command commands[] = {
{"weights", FALSE, COMMAND_WEIGHTS},
{"help", FALSE, COMMAND_HELP},
{"counters", FALSE, COMMAND_COUNTERS},
- {"sync", FALSE, COMMAND_SYNC}
+ {"sync", FALSE, COMMAND_SYNC},
+ {"learn_spam", TRUE, COMMAND_LEARN_SPAM},
+ {"learn_ham", TRUE, COMMAND_LEARN_HAM}
};
static GList *custom_commands = NULL;
@@ -519,7 +524,7 @@ process_command (struct controller_command *cmd, gchar **cmd_args, struct contro
}
}
break;
- case COMMAND_LEARN:
+ case COMMAND_LEARN_SPAM:
if (check_auth (cmd, session)) {
arg = *cmd_args;
if (!arg || *arg == '\0') {
@@ -532,13 +537,105 @@ process_command (struct controller_command *cmd, gchar **cmd_args, struct contro
}
arg = *(cmd_args + 1);
if (arg == NULL || *arg == '\0') {
- msg_debug ("no statfile size specified in learn command");
+ msg_debug ("no message size specified in learn command");
+ r = rspamd_snprintf (out_buf, sizeof (out_buf), "learn command requires at least two arguments: symbol and message size" CRLF);
+ if (! rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE)) {
+ return FALSE;
+ }
+ return TRUE;
+ }
+ size = strtoul (arg, &err_str, 10);
+ if (err_str && *err_str != '\0') {
+ msg_debug ("message size is invalid: %s", arg);
+ r = rspamd_snprintf (out_buf, sizeof (out_buf), "learn size is invalid" CRLF);
+ if (! rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE)) {
+ return FALSE;
+ }
+ return TRUE;
+ }
+ cl = find_classifier_conf (session->cfg, *cmd_args);
+ if (cl == NULL) {
+ r = rspamd_snprintf (out_buf, sizeof (out_buf), "classifier %s is not defined" CRLF, *cmd_args);
+ if (! rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE)) {
+ return FALSE;
+ }
+ return TRUE;
+
+ }
+ session->learn_classifier = cl;
+
+ /* By default learn positive */
+ session->in_class = TRUE;
+ rspamd_set_dispatcher_policy (session->dispatcher, BUFFER_CHARACTER, size);
+ session->state = STATE_LEARN_SPAM;
+ }
+ break;
+ case COMMAND_LEARN_HAM:
+ if (check_auth (cmd, session)) {
+ arg = *cmd_args;
+ if (!arg || *arg == '\0') {
+ msg_debug ("no statfile specified in learn command");
r = rspamd_snprintf (out_buf, sizeof (out_buf), "learn command requires at least two arguments: stat filename and its size" CRLF);
if (! rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE)) {
return FALSE;
}
return TRUE;
}
+ arg = *(cmd_args + 1);
+ if (arg == NULL || *arg == '\0') {
+ msg_debug ("no message size specified in learn command");
+ r = rspamd_snprintf (out_buf, sizeof (out_buf), "learn command requires at least two arguments: symbol and message size" CRLF);
+ if (! rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE)) {
+ return FALSE;
+ }
+ return TRUE;
+ }
+ size = strtoul (arg, &err_str, 10);
+ if (err_str && *err_str != '\0') {
+ msg_debug ("message size is invalid: %s", arg);
+ r = rspamd_snprintf (out_buf, sizeof (out_buf), "learn size is invalid" CRLF);
+ if (! rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE)) {
+ return FALSE;
+ }
+ return TRUE;
+ }
+ cl = find_classifier_conf (session->cfg, *cmd_args);
+ if (cl == NULL) {
+ r = rspamd_snprintf (out_buf, sizeof (out_buf), "classifier %s is not defined" CRLF, *cmd_args);
+ if (! rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE)) {
+ return FALSE;
+ }
+ return TRUE;
+
+ }
+ session->learn_classifier = cl;
+
+ /* By default learn positive */
+ session->in_class = FALSE;
+ rspamd_set_dispatcher_policy (session->dispatcher, BUFFER_CHARACTER, size);
+ session->state = STATE_LEARN_SPAM;
+ }
+ break;
+ case COMMAND_LEARN:
+ if (check_auth (cmd, session)) {
+ arg = *cmd_args;
+ if (!arg || *arg == '\0') {
+ msg_debug ("no statfile specified in learn command");
+ r = rspamd_snprintf (out_buf, sizeof (out_buf), "learn command requires at least two arguments: stat filename and its size" CRLF);
+ if (! rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE)) {
+ return FALSE;
+ }
+ return TRUE;
+ }
+ arg = *(cmd_args + 1);
+ if (arg == NULL || *arg == '\0') {
+ msg_debug ("no message size specified in learn command");
+ r = rspamd_snprintf (out_buf, sizeof (out_buf), "learn command requires at least two arguments: symbol and message size" CRLF);
+ if (! rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE)) {
+ return FALSE;
+ }
+ return TRUE;
+ }
size = strtoul (arg, &err_str, 10);
if (err_str && *err_str != '\0') {
msg_debug ("message size is invalid: %s", arg);
@@ -817,6 +914,65 @@ controller_read_socket (f_str_t * in, void *arg)
}
session->state = STATE_REPLY;
break;
+ case STATE_LEARN_SPAM_PRE:
+ session->learn_buf = in;
+ task = construct_task (session->worker);
+
+ task->msg = memory_pool_alloc (task->task_pool, sizeof (f_str_t));
+ task->msg->begin = in->begin;
+ task->msg->len = in->len;
+
+
+ r = process_message (task);
+ if (r == -1) {
+ msg_warn ("processing of message failed");
+ free_task (task, FALSE);
+ session->state = STATE_REPLY;
+ r = rspamd_snprintf (out_buf, sizeof (out_buf), "cannot process message" CRLF);
+ if (! rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE)) {
+ return FALSE;
+ }
+ return FALSE;
+ }
+
+ r = process_filters (task);
+ if (r == -1) {
+ session->state = STATE_REPLY;
+ r = rspamd_snprintf (out_buf, sizeof (out_buf), "cannot process message" CRLF);
+ free_task (task, FALSE);
+ if (! rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE)) {
+ return FALSE;
+ }
+ }
+ else if (r == 0) {
+ session->state = STATE_LEARN;
+ task->dispatcher = session->dispatcher;
+ session->learn_task = task;
+ rspamd_dispatcher_pause (session->dispatcher);
+ }
+ else {
+ lua_call_post_filters (task);
+ session->state = STATE_REPLY;
+
+ if (! learn_task_spam (session->learn_classifier, task, session->in_class, &err)) {
+ if (err) {
+ i = rspamd_snprintf (out_buf, sizeof (out_buf), "learn failed, learn classifier error: %s" CRLF END, err->message);
+ g_error_free (err);
+ }
+ else {
+ i = rspamd_snprintf (out_buf, sizeof (out_buf), "learn failed, unknown learn classifier error" CRLF END);
+ }
+ }
+ else {
+ i = rspamd_snprintf (out_buf, sizeof (out_buf), "learn ok" CRLF END);
+ }
+
+ free_task (task, FALSE);
+ if (!rspamd_dispatcher_write (session->dispatcher, out_buf, i, FALSE, FALSE)) {
+ return FALSE;
+ }
+ }
+ break;
case STATE_WEIGHTS:
session->learn_buf = in;
task = construct_task (session->worker);
@@ -926,12 +1082,37 @@ static gboolean
controller_write_socket (void *arg)
{
struct controller_session *session = (struct controller_session *)arg;
+ gint i;
+ gchar out_buf[64];
+ GError *err = NULL;
if (session->state == STATE_QUIT) {
/* Free buffers */
destroy_session (session->s);
return FALSE;
}
+ else if (session->state == STATE_LEARN) {
+ /* Perform actual learn here */
+ if (! learn_task_spam (session->learn_classifier, session->learn_task, session->in_class, &err)) {
+ if (err) {
+ i = rspamd_snprintf (out_buf, sizeof (out_buf), "learn failed, learn classifier error: %s" CRLF END, err->message);
+ g_error_free (err);
+ }
+ else {
+ i = rspamd_snprintf (out_buf, sizeof (out_buf), "learn failed, unknown learn classifier error" CRLF END);
+ }
+ }
+ else {
+ i = rspamd_snprintf (out_buf, sizeof (out_buf), "learn ok" CRLF END);
+ }
+ learn_task_spam (session->learn_classifier, session->learn_task, session->in_class, &err);
+ session->learn_task->dispatcher = NULL;
+ free_task (session->learn_task, FALSE);
+ session->state = STATE_REPLY;
+ if (!rspamd_dispatcher_write (session->dispatcher, out_buf, i, FALSE, FALSE)) {
+ return FALSE;
+ }
+ }
else if (session->state == STATE_REPLY) {
session->state = STATE_COMMAND;
rspamd_set_dispatcher_policy (session->dispatcher, BUFFER_LINE, BUFSIZ);
diff --git a/src/filter.c b/src/filter.c
index fea91125f..66f233115 100644
--- a/src/filter.c
+++ b/src/filter.c
@@ -944,6 +944,89 @@ learn_task (const gchar *statfile, struct worker_task *task, GError **err)
return TRUE;
}
+gboolean
+learn_task_spam (struct classifier_config *cl, struct worker_task *task, gboolean is_spam, GError **err)
+{
+ GList *cur, *ex;
+ struct classifier_ctx *cls_ctx;
+ f_str_t c;
+ GTree *tokens = NULL;
+ struct mime_text_part *part;
+ gboolean is_utf = FALSE, is_twopart = FALSE;
+
+ cur = g_list_first (task->text_parts);
+ if (cur != NULL && cur->next != NULL && cur->next->next == NULL) {
+ is_twopart = TRUE;
+ }
+
+ /* Get tokens from each element */
+ while (cur) {
+ part = cur->data;
+ /* Skip empty parts */
+ if (part->is_empty) {
+ cur = g_list_next (cur);
+ continue;
+ }
+ c.begin = part->content->data;
+ c.len = part->content->len;
+ is_utf = part->is_utf;
+ ex = part->urls_offset;
+ if (is_twopart && cur->next == NULL) {
+ /* Compare part's content */
+ if (fuzzy_compare_parts (cur->data, cur->prev->data) >= COMMON_PART_FACTOR) {
+ msg_info ("message <%s> has two common text parts, ignore the last one", task->message_id);
+ break;
+ }
+ }
+ /* Get tokens */
+ if (!cl->tokenizer->tokenize_func (
+ cl->tokenizer, task->task_pool,
+ &c, &tokens, FALSE, is_utf, ex)) {
+ g_set_error (err, filter_error_quark(), 2, "Cannot tokenize message");
+ return FALSE;
+ }
+ cur = g_list_next (cur);
+ }
+
+ /* Handle messages without text */
+ if (tokens == NULL) {
+ g_set_error (err, filter_error_quark(), 3, "Cannot tokenize message, no text data");
+ msg_info ("learn failed for message <%s>, no tokens to extract", task->message_id);
+ return FALSE;
+ }
+
+ /* Take care of subject */
+ tokenize_subject (task, &tokens);
+
+ /* Init classifier */
+ cls_ctx = cl->classifier->init_func (
+ task->task_pool, cl);
+ /* Learn */
+ if (!cl->classifier->learn_spam_func (
+ cls_ctx, task->worker->srv->statfile_pool,
+ tokens, task, is_spam, err)) {
+ if (*err) {
+ msg_info ("learn failed for message <%s>, learn error: %s", task->message_id, (*err)->message);
+ return FALSE;
+ }
+ else {
+ g_set_error (err, filter_error_quark(), 4, "Learn failed, unknown learn classifier error");
+ msg_info ("learn failed for message <%s>, unknown learn error", task->message_id);
+ return FALSE;
+ }
+ }
+ /* Increase statistics */
+ task->worker->srv->stat->messages_learned++;
+
+ msg_info ("learn success for message <%s>",
+ task->message_id);
+ statfile_pool_plan_invalidate (task->worker->srv->statfile_pool,
+ DEFAULT_STATFILE_INVALIDATE_TIME,
+ DEFAULT_STATFILE_INVALIDATE_JITTER);
+
+ return TRUE;
+}
+
/*
* vi:ts=4
*/
diff --git a/src/filter.h b/src/filter.h
index cfcab7cb5..930594170 100644
--- a/src/filter.h
+++ b/src/filter.h
@@ -11,6 +11,7 @@
struct worker_task;
struct rspamd_settings;
+struct classifier_config;
typedef double (*metric_cons_func)(struct worker_task *task, const gchar *metric_name, const gchar *func_name);
typedef void (*filter_func)(struct worker_task *task);
@@ -136,6 +137,15 @@ double factor_consolidation_func (struct worker_task *task, const gchar *metric_
*/
gboolean learn_task (const gchar *statfile, struct worker_task *task, GError **err);
+/*
+ * Learn specified statfile with message in a task
+ * @param statfile symbol of statfile
+ * @param task worker's task object
+ * @param err pointer to GError
+ * @return true if learn succeed
+ */
+gboolean learn_task_spam (struct classifier_config *cl, struct worker_task *task, gboolean is_spam, GError **err);
+
gboolean check_action_str (const gchar *data, gint *result);
const gchar *str_action_metric (enum rspamd_metric_action action);
gint check_metric_action (double score, double required_score, struct metric *metric);
diff --git a/src/html.c b/src/html.c
index 306e1e700..9a6541610 100644
--- a/src/html.c
+++ b/src/html.c
@@ -655,29 +655,6 @@ decode_entitles (gchar *s, guint * len)
}
}
-/*
- * Find the first occurrence of find in s, ignore case.
- */
-static gchar *
-html_strncasestr (const gchar *s, const gchar *find, gsize len)
-{
- gchar c, sc;
- gsize mlen;
-
- if ((c = *find++) != 0) {
- c = g_ascii_tolower (c);
- mlen = strlen (find);
- do {
- do {
- if ((sc = *s++) == 0 || len -- == 0)
- return (NULL);
- } while (g_ascii_tolower (sc) != c);
- } while (g_ascii_strncasecmp (s, find, mlen) != 0);
- s--;
- }
- return ((gchar *)s);
-}
-
static void
check_phishing (struct worker_task *task, struct uri *href_url, const gchar *url_text, gsize remain, tag_id_t id)
{
@@ -803,11 +780,11 @@ parse_tag_url (struct worker_task *task, struct mime_text_part *part, tag_id_t i
/* For A tags search for href= and for IMG tags search for src= */
if (id == Tag_A) {
- c = html_strncasestr (tag_text, "href=", tag_len);
+ c = rspamd_strncasestr (tag_text, "href=", tag_len);
len = sizeof ("href=") - 1;
}
else if (id == Tag_IMG) {
- c = html_strncasestr (tag_text, "src=", tag_len);
+ c = rspamd_strncasestr (tag_text, "src=", tag_len);
len = sizeof ("src=") - 1;
}
diff --git a/src/lua/lua_classifier.c b/src/lua/lua_classifier.c
index f7ce173a7..511767680 100644
--- a/src/lua/lua_classifier.c
+++ b/src/lua/lua_classifier.c
@@ -41,7 +41,18 @@ static const struct luaL_reg classifierlib_m[] = {
};
+LUA_FUNCTION_DEF (statfile, get_symbol);
+LUA_FUNCTION_DEF (statfile, get_path);
+LUA_FUNCTION_DEF (statfile, get_size);
+LUA_FUNCTION_DEF (statfile, is_spam);
+LUA_FUNCTION_DEF (statfile, get_param);
+
static const struct luaL_reg statfilelib_m[] = {
+ LUA_INTERFACE_DEF (statfile, get_symbol),
+ LUA_INTERFACE_DEF (statfile, get_path),
+ LUA_INTERFACE_DEF (statfile, get_size),
+ LUA_INTERFACE_DEF (statfile, is_spam),
+ LUA_INTERFACE_DEF (statfile, get_param),
{"__tostring", lua_class_tostring},
{NULL, NULL}
};
@@ -64,50 +75,81 @@ lua_check_classifier (lua_State * L)
return *((struct classifier_config **)ud);
}
-/* Return list of statfiles that should be checked for this message */
-GList *
-call_classifier_pre_callbacks (struct classifier_config *ccf, struct worker_task *task)
+static GList *
+call_classifier_pre_callback (struct classifier_config *ccf, struct worker_task *task,
+ lua_State *L, gboolean is_learn, gboolean is_spam)
{
- GList *res = NULL, *cur;
- struct classifier_callback_data *cd;
struct classifier_config **pccf;
struct worker_task **ptask;
struct statfile *st;
gint i, len;
+ GList *res = NULL;
- /* Go throught all callbacks and call them, appending results to list */
- cur = g_list_first (ccf->pre_callbacks);
- while (cur) {
- cd = cur->data;
- lua_getglobal (cd->L, cd->name);
+ pccf = lua_newuserdata (L, sizeof (struct classifier_config *));
+ lua_setclass (L, "rspamd{classifier}", -1);
+ *pccf = ccf;
- pccf = lua_newuserdata (cd->L, sizeof (struct classifier_config *));
- lua_setclass (cd->L, "rspamd{classifier}", -1);
- *pccf = ccf;
+ ptask = lua_newuserdata (L, sizeof (struct worker_task *));
+ lua_setclass (L, "rspamd{task}", -1);
+ *ptask = task;
- ptask = lua_newuserdata (cd->L, sizeof (struct worker_task *));
- lua_setclass (cd->L, "rspamd{task}", -1);
- *ptask = task;
+ lua_pushboolean (L, is_learn);
+ lua_pushboolean (L, is_spam);
- if (lua_pcall (cd->L, 2, 1, 0) != 0) {
- msg_warn ("error running function %s: %s", cd->name, lua_tostring (cd->L, -1));
- }
- else {
- if (lua_istable (cd->L, 1)) {
- len = lua_objlen (cd->L, 1);
- for (i = 1; i <= len; i ++) {
- lua_rawgeti (cd->L, 1, i);
- st = lua_check_statfile (cd->L);
- if (st) {
- res = g_list_prepend (res, st);
- }
+ if (lua_pcall (L, 4, 1, 0) != 0) {
+ msg_warn ("error running pre classifier callback %s", lua_tostring (L, -1));
+ }
+ else {
+ if (lua_istable (L, 1)) {
+ len = lua_objlen (L, 1);
+ for (i = 1; i <= len; i ++) {
+ lua_rawgeti (L, 1, i);
+ st = lua_check_statfile (L);
+ if (st) {
+ res = g_list_prepend (res, st);
}
}
}
+ }
+
+ return res;
+}
+
+/* Return list of statfiles that should be checked for this message */
+GList *
+call_classifier_pre_callbacks (struct classifier_config *ccf, struct worker_task *task,
+ gboolean is_learn, gboolean is_spam)
+{
+ GList *res = NULL, *cur;
+ struct classifier_callback_data *cd;
+ lua_State *L;
+
+
+ /* Go throught all callbacks and call them, appending results to list */
+ cur = g_list_first (ccf->pre_callbacks);
+ while (cur) {
+ cd = cur->data;
+ lua_getglobal (cd->L, cd->name);
+
+ res = g_list_concat (res, call_classifier_pre_callback (ccf, task, cd->L, is_learn, is_spam));
cur = g_list_next (cur);
}
+ if (res == NULL) {
+ L = task->cfg->lua_state;
+ /* Check function from global table 'classifiers' */
+ lua_getglobal (L, "classifiers");
+ if (lua_istable (L, -1)) {
+ lua_pushstring (L, ccf->classifier->name);
+ lua_gettable (L, -2);
+ /* Function is now on top */
+ if (lua_isfunction (L, 1)) {
+ res = call_classifier_pre_callback (ccf, task, L, is_learn, is_spam);
+ }
+ }
+ }
+
return res;
}
@@ -226,6 +268,86 @@ lua_classifier_get_statfiles (lua_State *L)
}
/* Statfile functions */
+static gint
+lua_statfile_get_symbol (lua_State *L)
+{
+ struct statfile *st = lua_check_statfile (L);
+
+ if (st != NULL) {
+ lua_pushstring (L, st->symbol);
+ }
+ else {
+ lua_pushnil (L);
+ }
+
+ return 1;
+}
+
+static gint
+lua_statfile_get_path (lua_State *L)
+{
+ struct statfile *st = lua_check_statfile (L);
+
+ if (st != NULL) {
+ lua_pushstring (L, st->path);
+ }
+ else {
+ lua_pushnil (L);
+ }
+
+ return 1;
+}
+
+static gint
+lua_statfile_get_size (lua_State *L)
+{
+ struct statfile *st = lua_check_statfile (L);
+
+ if (st != NULL) {
+ lua_pushinteger (L, st->size);
+ }
+ else {
+ lua_pushnil (L);
+ }
+
+ return 1;
+}
+
+static gint
+lua_statfile_is_spam (lua_State *L)
+{
+ struct statfile *st = lua_check_statfile (L);
+
+ if (st != NULL) {
+ lua_pushboolean (L, st->is_spam);
+ }
+ else {
+ lua_pushnil (L);
+ }
+
+ return 1;
+}
+
+static gint
+lua_statfile_get_param (lua_State *L)
+{
+ struct statfile *st = lua_check_statfile (L);
+ const gchar *param, *value;
+
+ param = luaL_checkstring (L, 2);
+
+ if (st != NULL && param != NULL) {
+ value = g_hash_table_lookup (st->opts, param);
+ if (param != NULL) {
+ lua_pushstring (L, value);
+ return 1;
+ }
+ }
+ lua_pushnil (L);
+
+ return 1;
+}
+
static struct statfile *
lua_check_statfile (lua_State * L)
{
diff --git a/src/lua/lua_common.h b/src/lua/lua_common.h
index 4427dcae9..3f3fc7a1d 100644
--- a/src/lua/lua_common.h
+++ b/src/lua/lua_common.h
@@ -51,7 +51,7 @@ void lua_call_post_filters (struct worker_task *task);
void add_luabuf (const gchar *line);
/* Classify functions */
-GList *call_classifier_pre_callbacks (struct classifier_config *ccf, struct worker_task *task);
+GList *call_classifier_pre_callbacks (struct classifier_config *ccf, struct worker_task *task, gboolean is_learn, gboolean is_spam);
double call_classifier_post_callbacks (struct classifier_config *ccf, struct worker_task *task, double in);
double lua_normalizer_func (struct config_file *cfg, long double score, void *params);
diff --git a/src/main.h b/src/main.h
index 9a3335d0a..d8761617f 100644
--- a/src/main.h
+++ b/src/main.h
@@ -64,6 +64,7 @@ struct classifier_config;
struct mime_part;
struct rspamd_view;
struct rspamd_dns_resolver;
+struct worker_task;
/**
* Server statistics
@@ -138,6 +139,8 @@ struct controller_session {
enum {
STATE_COMMAND,
STATE_LEARN,
+ STATE_LEARN_SPAM_PRE,
+ STATE_LEARN_SPAM,
STATE_REPLY,
STATE_QUIT,
STATE_OTHER,
@@ -162,6 +165,7 @@ struct controller_session {
f_str_t *in); /**< other command handler to execute at the end of processing */
void *other_data; /**< and its data */
struct rspamd_async_session* s; /**< async session object */
+ struct worker_task *learn_task;
};
typedef void (*controller_func_t)(gchar **args, struct controller_session *session);
diff --git a/src/util.c b/src/util.c
index 6d8cb09e0..16ed43786 100644
--- a/src/util.c
+++ b/src/util.c
@@ -1364,5 +1364,28 @@ escape_braces_addr_fstr (memory_pool_t *pool, f_str_t *in)
}
/*
+ * Find the first occurrence of find in s, ignore case.
+ */
+gchar *
+rspamd_strncasestr (const gchar *s, const gchar *find, gint len)
+{
+ gchar c, sc;
+ gsize mlen;
+
+ if ((c = *find++) != 0) {
+ c = g_ascii_tolower (c);
+ mlen = strlen (find);
+ do {
+ do {
+ if ((sc = *s++) == 0 || len -- == 0)
+ return (NULL);
+ } while (g_ascii_tolower (sc) != c);
+ } while (g_ascii_strncasecmp (s, find, mlen) != 0);
+ s--;
+ }
+ return ((gchar *)s);
+}
+
+/*
* vi:ts=4
*/
diff --git a/src/util.h b/src/util.h
index 73d96c759..b1d4ad2e9 100644
--- a/src/util.h
+++ b/src/util.h
@@ -158,4 +158,7 @@ void free_task (struct worker_task *task, gboolean is_soft);
void free_task_hard (gpointer ud);
void free_task_soft (gpointer ud);
+/* Find string find in string s ignoring case */
+gchar* rspamd_strncasestr (const gchar *s, const gchar *find, gint len);
+
#endif