aboutsummaryrefslogtreecommitdiffstats
path: root/src/filter.c
diff options
context:
space:
mode:
authorVsevolod Stakhov <vsevolod@rambler-co.ru>2009-07-09 20:45:11 +0400
committerVsevolod Stakhov <vsevolod@rambler-co.ru>2009-07-09 20:45:11 +0400
commit2234daebbb352b444b322d43cc6c1093f0ce949c (patch)
tree320131facabccd4f5aa3eddc465bc50a707b2b00 /src/filter.c
parent19baadf6a0e6b2554de67b674a2c6f30efda13bb (diff)
downloadrspamd-2234daebbb352b444b322d43cc6c1093f0ce949c.tar.gz
rspamd-2234daebbb352b444b322d43cc6c1093f0ce949c.zip
* Make autolearn working
Diffstat (limited to 'src/filter.c')
-rw-r--r--src/filter.c105
1 files changed, 100 insertions, 5 deletions
diff --git a/src/filter.c b/src/filter.c
index daa9b0e29..1c45f0886 100644
--- a/src/filter.c
+++ b/src/filter.c
@@ -62,6 +62,7 @@ insert_result (struct worker_task *task, const char *metric_name, const char *sy
/* Create new metric chain */
metric_res = memory_pool_alloc (task->task_pool, sizeof (struct metric_result));
metric_res->symbols = g_hash_table_new (g_str_hash, g_str_equal);
+ metric_res->checked = FALSE;
memory_pool_add_destructor (task->task_pool, (pool_destruct_func)g_hash_table_destroy, metric_res->symbols);
metric_res->metric = metric;
g_hash_table_insert (task->results, (gpointer)metric_name, metric_res);
@@ -214,11 +215,19 @@ call_filter_by_name (struct worker_task *task, const char *name, enum filter_typ
}
static void
-metric_process_callback (gpointer key, gpointer value, void *data)
+metric_process_callback_common (gpointer key, gpointer value, void *data, gboolean is_forced)
{
struct worker_task *task = (struct worker_task *)data;
struct metric_result *metric_res = (struct metric_result *)value;
+ if (metric_res->checked && !is_forced) {
+ /* Already checked */
+ return;
+ }
+
+ /* Set flag */
+ metric_res->checked = TRUE;
+
if (metric_res->metric->func != NULL) {
metric_res->score = metric_res->metric->func (task, metric_res->metric->name, metric_res->metric->func_name);
}
@@ -229,6 +238,18 @@ metric_process_callback (gpointer key, gpointer value, void *data)
metric_res->score, metric_res->metric->name);
}
+static void
+metric_process_callback_normal (gpointer key, gpointer value, void *data)
+{
+ metric_process_callback_common (key, value, data, FALSE);
+}
+
+static void
+metric_process_callback_forced (gpointer key, gpointer value, void *data)
+{
+ metric_process_callback_common (key, value, data, TRUE);
+}
+
static int
continue_process_filters (struct worker_task *task)
{
@@ -359,7 +380,7 @@ process_filters (struct worker_task *task)
}
/* Process all metrics */
- g_hash_table_foreach (task->results, metric_process_callback, task);
+ g_hash_table_foreach (task->results, metric_process_callback_forced, task);
return 1;
}
@@ -443,6 +464,75 @@ composites_foreach_callback (gpointer key, gpointer value, void *data)
return;
}
+static gboolean
+check_autolearn (struct statfile_autolearn_params *params, struct worker_task *task)
+{
+ const char *metric_name = DEFAULT_METRIC;
+ struct metric_result *metric_res;
+ GList *cur;
+
+ if (params->metric != NULL) {
+ metric_name = params->metric;
+ }
+
+ /* First check threshold */
+ metric_res = g_hash_table_lookup (task->results, metric_name);
+ if (metric_res == NULL) {
+ if (params->symbols == NULL && params->threshold_max > 0) {
+ /* For ham messages */
+ return TRUE;
+ }
+ msg_debug ("check_autolearn: metric %s has no results", metric_name);
+ return FALSE;
+ }
+ else {
+ /* Process score of metric */
+ metric_process_callback_normal ((void *)metric_name, metric_res, task);
+ if ((params->threshold_min != 0 && metric_res->score > params->threshold_min) ||
+ (params->threshold_max != 0 && metric_res->score < params->threshold_max)) {
+ /* Now check for specific symbols */
+ if (params->symbols) {
+ cur = params->symbols;
+ while (cur) {
+ if (g_hash_table_lookup (metric_res->symbols, cur->data) == NULL) {
+ return FALSE;
+ }
+ cur = g_list_next (cur);
+ }
+ }
+ /* Now allow processing of actual autolearn */
+ return TRUE;
+ }
+ }
+
+ return FALSE;
+}
+
+static void
+process_autolearn (struct statfile *st, struct worker_task *task, GTree *tokens,
+ struct classifier *classifier, char *filename, struct classifier_ctx* ctx)
+{
+ if (check_autolearn (st->autolearn, task)) {
+ if (tokens) {
+ msg_info ("process_autolearn: message with id <%s> autolearned statfile '%s'", task->message_id, filename);
+ /* Check opened */
+ if (! statfile_pool_is_open (task->worker->srv->statfile_pool, filename)) {
+ /* Try open */
+ if (statfile_pool_open (task->worker->srv->statfile_pool, filename) == NULL) {
+ /* Try create */
+ if (statfile_pool_create (task->worker->srv->statfile_pool,
+ filename, st->size / sizeof (struct stat_file_block)) == -1) {
+ msg_info ("process_autolearn: error while creating statfile %s", filename);
+ return;
+ }
+ }
+ }
+
+ classifier->learn_func (ctx, task->worker->srv->statfile_pool, filename, tokens, 1);
+ }
+ }
+}
+
static void
composites_metric_callback (gpointer key, gpointer value, void *data)
{
@@ -498,7 +588,7 @@ statfiles_callback (gpointer key, gpointer value, void *arg)
filename = resolve_stat_filename (task->task_pool, st->pattern, task->from, "");
}
- if (statfile_pool_open (task->worker->srv->statfile_pool, filename) == NULL) {
+ if (statfile_pool_open (task->worker->srv->statfile_pool, filename) == NULL && !check_autolearn (st->autolearn, task)) {
return;
}
@@ -513,6 +603,7 @@ statfiles_callback (gpointer key, gpointer value, void *arg)
msg_info ("statfiles_callback: cannot tokenize input");
return;
}
+ cur = g_list_next (cur);
}
g_hash_table_insert (data->tokens, st->tokenizer, tokens);
}
@@ -533,6 +624,10 @@ statfiles_callback (gpointer key, gpointer value, void *arg)
classifier->classify_func (res_data->ctx, task->worker->srv->statfile_pool, filename, tokens, st->weight);
+ if (st->autolearn) {
+ /* Process autolearn */
+ process_autolearn (st, task, tokens, classifier, filename, res_data->ctx);
+ }
}
static void
@@ -548,7 +643,6 @@ statfiles_results_callback (gpointer key, gpointer value, void *arg)
filename = classifier->result_file_func (res->ctx, w);
insert_result (task, res->metric->name, classifier->name, *w, NULL);
msg_debug ("statfiles_results_callback: got total weight %.2f for metric %s", *w, res->metric->name);
-
}
@@ -566,7 +660,8 @@ process_statfiles (struct worker_task *task)
g_hash_table_destroy (cd.tokens);
g_hash_table_destroy (cd.classifiers);
- g_hash_table_foreach (task->results, metric_process_callback, task);
+ /* Process results */
+ g_hash_table_foreach (task->results, metric_process_callback_forced, task);
task->state = WRITE_REPLY;
}