diff options
author | Vsevolod Stakhov <vsevolod@rambler-co.ru> | 2009-07-09 20:45:11 +0400 |
---|---|---|
committer | Vsevolod Stakhov <vsevolod@rambler-co.ru> | 2009-07-09 20:45:11 +0400 |
commit | 2234daebbb352b444b322d43cc6c1093f0ce949c (patch) | |
tree | 320131facabccd4f5aa3eddc465bc50a707b2b00 /src/filter.c | |
parent | 19baadf6a0e6b2554de67b674a2c6f30efda13bb (diff) | |
download | rspamd-2234daebbb352b444b322d43cc6c1093f0ce949c.tar.gz rspamd-2234daebbb352b444b322d43cc6c1093f0ce949c.zip |
* Make autolearn working
Diffstat (limited to 'src/filter.c')
-rw-r--r-- | src/filter.c | 105 |
1 files changed, 100 insertions, 5 deletions
diff --git a/src/filter.c b/src/filter.c index daa9b0e29..1c45f0886 100644 --- a/src/filter.c +++ b/src/filter.c @@ -62,6 +62,7 @@ insert_result (struct worker_task *task, const char *metric_name, const char *sy /* Create new metric chain */ metric_res = memory_pool_alloc (task->task_pool, sizeof (struct metric_result)); metric_res->symbols = g_hash_table_new (g_str_hash, g_str_equal); + metric_res->checked = FALSE; memory_pool_add_destructor (task->task_pool, (pool_destruct_func)g_hash_table_destroy, metric_res->symbols); metric_res->metric = metric; g_hash_table_insert (task->results, (gpointer)metric_name, metric_res); @@ -214,11 +215,19 @@ call_filter_by_name (struct worker_task *task, const char *name, enum filter_typ } static void -metric_process_callback (gpointer key, gpointer value, void *data) +metric_process_callback_common (gpointer key, gpointer value, void *data, gboolean is_forced) { struct worker_task *task = (struct worker_task *)data; struct metric_result *metric_res = (struct metric_result *)value; + if (metric_res->checked && !is_forced) { + /* Already checked */ + return; + } + + /* Set flag */ + metric_res->checked = TRUE; + if (metric_res->metric->func != NULL) { metric_res->score = metric_res->metric->func (task, metric_res->metric->name, metric_res->metric->func_name); } @@ -229,6 +238,18 @@ metric_process_callback (gpointer key, gpointer value, void *data) metric_res->score, metric_res->metric->name); } +static void +metric_process_callback_normal (gpointer key, gpointer value, void *data) +{ + metric_process_callback_common (key, value, data, FALSE); +} + +static void +metric_process_callback_forced (gpointer key, gpointer value, void *data) +{ + metric_process_callback_common (key, value, data, TRUE); +} + static int continue_process_filters (struct worker_task *task) { @@ -359,7 +380,7 @@ process_filters (struct worker_task *task) } /* Process all metrics */ - g_hash_table_foreach (task->results, metric_process_callback, task); + g_hash_table_foreach (task->results, metric_process_callback_forced, task); return 1; } @@ -443,6 +464,75 @@ composites_foreach_callback (gpointer key, gpointer value, void *data) return; } +static gboolean +check_autolearn (struct statfile_autolearn_params *params, struct worker_task *task) +{ + const char *metric_name = DEFAULT_METRIC; + struct metric_result *metric_res; + GList *cur; + + if (params->metric != NULL) { + metric_name = params->metric; + } + + /* First check threshold */ + metric_res = g_hash_table_lookup (task->results, metric_name); + if (metric_res == NULL) { + if (params->symbols == NULL && params->threshold_max > 0) { + /* For ham messages */ + return TRUE; + } + msg_debug ("check_autolearn: metric %s has no results", metric_name); + return FALSE; + } + else { + /* Process score of metric */ + metric_process_callback_normal ((void *)metric_name, metric_res, task); + if ((params->threshold_min != 0 && metric_res->score > params->threshold_min) || + (params->threshold_max != 0 && metric_res->score < params->threshold_max)) { + /* Now check for specific symbols */ + if (params->symbols) { + cur = params->symbols; + while (cur) { + if (g_hash_table_lookup (metric_res->symbols, cur->data) == NULL) { + return FALSE; + } + cur = g_list_next (cur); + } + } + /* Now allow processing of actual autolearn */ + return TRUE; + } + } + + return FALSE; +} + +static void +process_autolearn (struct statfile *st, struct worker_task *task, GTree *tokens, + struct classifier *classifier, char *filename, struct classifier_ctx* ctx) +{ + if (check_autolearn (st->autolearn, task)) { + if (tokens) { + msg_info ("process_autolearn: message with id <%s> autolearned statfile '%s'", task->message_id, filename); + /* Check opened */ + if (! statfile_pool_is_open (task->worker->srv->statfile_pool, filename)) { + /* Try open */ + if (statfile_pool_open (task->worker->srv->statfile_pool, filename) == NULL) { + /* Try create */ + if (statfile_pool_create (task->worker->srv->statfile_pool, + filename, st->size / sizeof (struct stat_file_block)) == -1) { + msg_info ("process_autolearn: error while creating statfile %s", filename); + return; + } + } + } + + classifier->learn_func (ctx, task->worker->srv->statfile_pool, filename, tokens, 1); + } + } +} + static void composites_metric_callback (gpointer key, gpointer value, void *data) { @@ -498,7 +588,7 @@ statfiles_callback (gpointer key, gpointer value, void *arg) filename = resolve_stat_filename (task->task_pool, st->pattern, task->from, ""); } - if (statfile_pool_open (task->worker->srv->statfile_pool, filename) == NULL) { + if (statfile_pool_open (task->worker->srv->statfile_pool, filename) == NULL && !check_autolearn (st->autolearn, task)) { return; } @@ -513,6 +603,7 @@ statfiles_callback (gpointer key, gpointer value, void *arg) msg_info ("statfiles_callback: cannot tokenize input"); return; } + cur = g_list_next (cur); } g_hash_table_insert (data->tokens, st->tokenizer, tokens); } @@ -533,6 +624,10 @@ statfiles_callback (gpointer key, gpointer value, void *arg) classifier->classify_func (res_data->ctx, task->worker->srv->statfile_pool, filename, tokens, st->weight); + if (st->autolearn) { + /* Process autolearn */ + process_autolearn (st, task, tokens, classifier, filename, res_data->ctx); + } } static void @@ -548,7 +643,6 @@ statfiles_results_callback (gpointer key, gpointer value, void *arg) filename = classifier->result_file_func (res->ctx, w); insert_result (task, res->metric->name, classifier->name, *w, NULL); msg_debug ("statfiles_results_callback: got total weight %.2f for metric %s", *w, res->metric->name); - } @@ -566,7 +660,8 @@ process_statfiles (struct worker_task *task) g_hash_table_destroy (cd.tokens); g_hash_table_destroy (cd.classifiers); - g_hash_table_foreach (task->results, metric_process_callback, task); + /* Process results */ + g_hash_table_foreach (task->results, metric_process_callback_forced, task); task->state = WRITE_REPLY; } |