/* Create new metric chain */
metric_res = memory_pool_alloc (task->task_pool, sizeof (struct metric_result));
metric_res->symbols = g_hash_table_new (g_str_hash, g_str_equal);
+ metric_res->checked = FALSE;
memory_pool_add_destructor (task->task_pool, (pool_destruct_func)g_hash_table_destroy, metric_res->symbols);
metric_res->metric = metric;
g_hash_table_insert (task->results, (gpointer)metric_name, metric_res);
}
static void
-metric_process_callback (gpointer key, gpointer value, void *data)
+metric_process_callback_common (gpointer key, gpointer value, void *data, gboolean is_forced)
{
struct worker_task *task = (struct worker_task *)data;
struct metric_result *metric_res = (struct metric_result *)value;
+ if (metric_res->checked && !is_forced) {
+ /* Already checked */
+ return;
+ }
+
+ /* Set flag */
+ metric_res->checked = TRUE;
+
if (metric_res->metric->func != NULL) {
metric_res->score = metric_res->metric->func (task, metric_res->metric->name, metric_res->metric->func_name);
}
metric_res->score, metric_res->metric->name);
}
+static void
+metric_process_callback_normal (gpointer key, gpointer value, void *data)
+{
+ metric_process_callback_common (key, value, data, FALSE);
+}
+
+static void
+metric_process_callback_forced (gpointer key, gpointer value, void *data)
+{
+ metric_process_callback_common (key, value, data, TRUE);
+}
+
static int
continue_process_filters (struct worker_task *task)
{
}
/* Process all metrics */
- g_hash_table_foreach (task->results, metric_process_callback, task);
+ g_hash_table_foreach (task->results, metric_process_callback_forced, task);
return 1;
}
return;
}
+static gboolean
+check_autolearn (struct statfile_autolearn_params *params, struct worker_task *task)
+{
+ const char *metric_name = DEFAULT_METRIC;
+ struct metric_result *metric_res;
+ GList *cur;
+
+ if (params->metric != NULL) {
+ metric_name = params->metric;
+ }
+
+ /* First check threshold */
+ metric_res = g_hash_table_lookup (task->results, metric_name);
+ if (metric_res == NULL) {
+ if (params->symbols == NULL && params->threshold_max > 0) {
+ /* For ham messages */
+ return TRUE;
+ }
+ msg_debug ("check_autolearn: metric %s has no results", metric_name);
+ return FALSE;
+ }
+ else {
+ /* Process score of metric */
+ metric_process_callback_normal ((void *)metric_name, metric_res, task);
+ if ((params->threshold_min != 0 && metric_res->score > params->threshold_min) ||
+ (params->threshold_max != 0 && metric_res->score < params->threshold_max)) {
+ /* Now check for specific symbols */
+ if (params->symbols) {
+ cur = params->symbols;
+ while (cur) {
+ if (g_hash_table_lookup (metric_res->symbols, cur->data) == NULL) {
+ return FALSE;
+ }
+ cur = g_list_next (cur);
+ }
+ }
+ /* Now allow processing of actual autolearn */
+ return TRUE;
+ }
+ }
+
+ return FALSE;
+}
+
+static void
+process_autolearn (struct statfile *st, struct worker_task *task, GTree *tokens,
+ struct classifier *classifier, char *filename, struct classifier_ctx* ctx)
+{
+ if (check_autolearn (st->autolearn, task)) {
+ if (tokens) {
+ msg_info ("process_autolearn: message with id <%s> autolearned statfile '%s'", task->message_id, filename);
+ /* Check opened */
+ if (! statfile_pool_is_open (task->worker->srv->statfile_pool, filename)) {
+ /* Try open */
+ if (statfile_pool_open (task->worker->srv->statfile_pool, filename) == NULL) {
+ /* Try create */
+ if (statfile_pool_create (task->worker->srv->statfile_pool,
+ filename, st->size / sizeof (struct stat_file_block)) == -1) {
+ msg_info ("process_autolearn: error while creating statfile %s", filename);
+ return;
+ }
+ }
+ }
+
+ classifier->learn_func (ctx, task->worker->srv->statfile_pool, filename, tokens, 1);
+ }
+ }
+}
+
static void
composites_metric_callback (gpointer key, gpointer value, void *data)
{
filename = resolve_stat_filename (task->task_pool, st->pattern, task->from, "");
}
- if (statfile_pool_open (task->worker->srv->statfile_pool, filename) == NULL) {
+ if (statfile_pool_open (task->worker->srv->statfile_pool, filename) == NULL && !check_autolearn (st->autolearn, task)) {
return;
}
msg_info ("statfiles_callback: cannot tokenize input");
return;
}
+ cur = g_list_next (cur);
}
g_hash_table_insert (data->tokens, st->tokenizer, tokens);
}
classifier->classify_func (res_data->ctx, task->worker->srv->statfile_pool, filename, tokens, st->weight);
+ if (st->autolearn) {
+ /* Process autolearn */
+ process_autolearn (st, task, tokens, classifier, filename, res_data->ctx);
+ }
}
static void
filename = classifier->result_file_func (res->ctx, w);
insert_result (task, res->metric->name, classifier->name, *w, NULL);
msg_debug ("statfiles_results_callback: got total weight %.2f for metric %s", *w, res->metric->name);
-
}
g_hash_table_destroy (cd.tokens);
g_hash_table_destroy (cd.classifiers);
- g_hash_table_foreach (task->results, metric_process_callback, task);
+ /* Process results */
+ g_hash_table_foreach (task->results, metric_process_callback_forced, task);
task->state = WRITE_REPLY;
}