diff options
-rw-r--r-- | src/classifiers/classifiers.h | 13 | ||||
-rw-r--r-- | src/classifiers/winnow.c | 148 | ||||
-rw-r--r-- | src/controller.c | 32 | ||||
-rw-r--r-- | src/filter.c | 2 |
4 files changed, 135 insertions, 60 deletions
diff --git a/src/classifiers/classifiers.h b/src/classifiers/classifiers.h index 02192d795..f69c1284c 100644 --- a/src/classifiers/classifiers.h +++ b/src/classifiers/classifiers.h @@ -24,9 +24,10 @@ struct classify_weight { struct classifier { char *name; struct classifier_ctx* (*init_func)(memory_pool_t *pool, struct classifier_config *cf); - void (*classify_func)(struct classifier_ctx* ctx, statfile_pool_t *pool, GTree *input, struct worker_task *task); - void (*learn_func)(struct classifier_ctx* ctx, statfile_pool_t *pool, - stat_file_t *file, GTree *input, gboolean in_class, double *sum, double multiplier); + gboolean (*classify_func)(struct classifier_ctx* ctx, statfile_pool_t *pool, GTree *input, struct worker_task *task); + gboolean (*learn_func)(struct classifier_ctx* ctx, statfile_pool_t *pool, + const char *symbol, GTree *input, gboolean in_class, + double *sum, double multiplier, GError **err); GList* (*weights_func)(struct classifier_ctx* ctx, statfile_pool_t *pool, GTree *input, struct worker_task *task); }; @@ -35,9 +36,9 @@ struct classifier* get_classifier (char *name); /* Winnow algorithm */ struct classifier_ctx* winnow_init (memory_pool_t *pool, struct classifier_config *cf); -void winnow_classify (struct classifier_ctx* ctx, statfile_pool_t *pool, GTree *input, struct worker_task *task); -void winnow_learn (struct classifier_ctx* ctx, statfile_pool_t *pool, stat_file_t *file, GTree *input, - gboolean in_class, double *sum, double multiplier); +gboolean winnow_classify (struct classifier_ctx* ctx, statfile_pool_t *pool, GTree *input, struct worker_task *task); +gboolean winnow_learn (struct classifier_ctx* ctx, statfile_pool_t *pool, const char *symbol, GTree *input, + gboolean in_class, double *sum, double multiplier, GError **err); GList *winnow_weights (struct classifier_ctx* ctx, statfile_pool_t *pool, GTree *input, struct worker_task *task); diff --git a/src/classifiers/winnow.c b/src/classifiers/winnow.c index 41cb48e89..ab155ff8c 100644 --- a/src/classifiers/winnow.c +++ b/src/classifiers/winnow.c @@ -46,6 +46,12 @@ #define MAX_LEARN_ITERATIONS 100 +G_INLINE_FUNC GQuark +winnow_error_quark (void) +{ + return g_quark_from_static_string ("winnow-error-quark"); +} + struct winnow_callback_data { statfile_pool_t *pool; struct classifier_ctx *ctx; @@ -53,7 +59,8 @@ struct winnow_callback_data { stat_file_t *learn_file; long double sum; double multiplier; - int count; + guint32 count; + guint32 new_blocks; gboolean in_class; gboolean do_demote; gboolean fresh_run; @@ -62,6 +69,8 @@ struct winnow_callback_data { static const double max_common_weight = MAX_WEIGHT * WINNOW_DEMOTION; + + static gboolean classify_callback (gpointer key, gpointer value, gpointer data) { @@ -73,10 +82,10 @@ classify_callback (gpointer key, gpointer value, gpointer data) v = statfile_pool_get_block (cd->pool, cd->file, node->h1, node->h2, cd->now); if (fabs (v) > ALPHA) { cd->sum += v; - cd->in_class++; } else { cd->sum += 1.0; + cd->new_blocks ++; } cd->count++; @@ -100,6 +109,7 @@ learn_callback (gpointer key, gpointer value, gpointer data) if (cd->file == cd->learn_file) { statfile_pool_set_block (cd->pool, cd->file, node->h1, node->h2, cd->now, c); node->value = c; + cd->new_blocks ++; } } else { @@ -181,7 +191,7 @@ winnow_init (memory_pool_t * pool, struct classifier_config *cfg) return ctx; } -void +gboolean winnow_classify (struct classifier_ctx *ctx, statfile_pool_t * pool, GTree * input, struct worker_task *task) { struct winnow_callback_data data; @@ -203,7 +213,7 @@ winnow_classify (struct classifier_ctx *ctx, statfile_pool_t * pool, GTree * inp nodes = g_tree_nnodes (input) / FEATURE_WINDOW_SIZE; if (nodes < minnodes) { msg_info ("do not classify message as it has too few tokens: %d, while %d min", nodes, minnodes); - return; + return FALSE; } } @@ -224,6 +234,7 @@ winnow_classify (struct classifier_ctx *ctx, statfile_pool_t * pool, GTree * inp st = cur->data; data.sum = 0; data.count = 0; + data.new_blocks = 0; if ((data.file = statfile_pool_is_open (pool, st->path)) == NULL) { if ((data.file = statfile_pool_open (pool, st->path, st->size, FALSE)) == NULL) { msg_warn ("cannot open %s, skip it", st->path); @@ -233,9 +244,7 @@ winnow_classify (struct classifier_ctx *ctx, statfile_pool_t * pool, GTree * inp } if (data.file != NULL) { - statfile_pool_lock_file (pool, data.file); g_tree_foreach (input, classify_callback, &data); - statfile_pool_unlock_file (pool, data.file); } if (data.count != 0) { @@ -263,6 +272,8 @@ winnow_classify (struct classifier_ctx *ctx, statfile_pool_t * pool, GTree * inp cur = g_list_prepend (NULL, sumbuf); insert_result (task, sel->symbol, max, cur); } + + return TRUE; } GList * @@ -306,9 +317,7 @@ winnow_weights (struct classifier_ctx *ctx, statfile_pool_t * pool, GTree * inpu } if (data.file != NULL) { - statfile_pool_lock_file (pool, data.file); g_tree_foreach (input, classify_callback, &data); - statfile_pool_unlock_file (pool, data.file); } w = memory_pool_alloc0 (task->task_pool, sizeof (struct classify_weight)); @@ -333,8 +342,9 @@ winnow_weights (struct classifier_ctx *ctx, statfile_pool_t * pool, GTree * inpu } -void -winnow_learn (struct classifier_ctx *ctx, statfile_pool_t *pool, stat_file_t *file, GTree * input, int in_class, double *sum, double multiplier) +gboolean +winnow_learn (struct classifier_ctx *ctx, statfile_pool_t *pool, const char *symbol, + GTree * input, int in_class, double *sum, double multiplier, GError **err) { struct winnow_callback_data data = { .file = NULL, @@ -343,10 +353,11 @@ winnow_learn (struct classifier_ctx *ctx, statfile_pool_t *pool, stat_file_t *fi char *value; int nodes, minnodes, iterations = 0; struct statfile *st, *sel_st; - stat_file_t *sel = NULL; + stat_file_t *sel = NULL, *to_learn; long double res = 0., max = 0.; - double learn_threshold = 1.0; + double learn_threshold = 0.0; GList *cur, *to_demote = NULL; + gboolean force_learn = FALSE; g_assert (pool != NULL); g_assert (ctx != NULL); @@ -355,7 +366,7 @@ winnow_learn (struct classifier_ctx *ctx, statfile_pool_t *pool, stat_file_t *fi data.in_class = in_class; data.now = time (NULL); data.ctx = ctx; - data.learn_file = file; + if (ctx->cfg->opts && (value = g_hash_table_lookup (ctx->cfg->opts, "min_tokens")) != NULL) { minnodes = strtol (value, NULL, 10); @@ -363,70 +374,121 @@ winnow_learn (struct classifier_ctx *ctx, statfile_pool_t *pool, stat_file_t *fi if (nodes < minnodes) { msg_info ("do not learn message as it has too few tokens: %d, while %d min", nodes, minnodes); *sum = 0; - return; + g_set_error (err, + winnow_error_quark(), /* error domain */ + 1, /* error code */ + "message contains too few tokens: %d, while min is %d", + nodes, minnodes); + return FALSE; } } if (ctx->cfg->opts && (value = g_hash_table_lookup (ctx->cfg->opts, "learn_threshold")) != NULL) { learn_threshold = strtod (value, NULL); } - if (learn_threshold >= 1.0) { + if (learn_threshold <= 1.0 && learn_threshold >= 0) { /* Classify message and check target statfile score */ cur = ctx->cfg->statfiles; + while (cur) { + /* Open or create all statfiles inside classifier */ + st = cur->data; + if (statfile_pool_is_open (pool, st->path) == NULL) { + if (statfile_pool_open (pool, st->path, st->size, FALSE) == NULL) { + msg_warn ("cannot open %s", st->path); + if (statfile_pool_create (pool, st->path, st->size) == -1) { + msg_err ("cannot create statfile %s", st->path); + g_set_error (err, + winnow_error_quark(), /* error domain */ + 1, /* error code */ + "cannot create statfile: %s", + st->path); + return FALSE; + } + if (statfile_pool_open (pool, st->path, st->size, FALSE)) { + g_set_error (err, + winnow_error_quark(), /* error domain */ + 1, /* error code */ + "open statfile %s after creation", + st->path); + msg_err ("cannot open statfile %s after creation", st->path); + return FALSE; + } + } + } + if (strcmp (st->symbol, symbol) == 0) { + sel_st = st; + + } + cur = g_list_next (cur); + } + to_learn = statfile_pool_is_open (pool, sel_st->path); + if (to_learn == NULL) { + g_set_error (err, + winnow_error_quark(), /* error domain */ + 1, /* error code */ + "statfile %s is not opened this maybe if your statfile pool is too small to handle all statfiles", + sel_st->path); + return FALSE; + } /* Check target statfile */ - data.file = file; + data.file = to_learn; data.sum = 0; data.count = 0; - data.file = file; - statfile_pool_lock_file (pool, data.file); + data.new_blocks = 0; g_tree_foreach (input, classify_callback, &data); - statfile_pool_unlock_file (pool, data.file); if (data.count > 0) { max = data.sum / (double)data.count; } else { max = 0; } + /* If most of blocks are not presented in targeted statfile do forced learn */ + if ((data.new_blocks > 1 && (double)data.new_blocks / (double)data.count > 0.5) || max < 1 + learn_threshold) { + force_learn = TRUE; + } + /* Check other statfiles */ while (cur) { st = cur->data; data.sum = 0; data.count = 0; if ((data.file = statfile_pool_is_open (pool, st->path)) == NULL) { - if ((data.file = statfile_pool_open (pool, st->path, st->size, FALSE)) == NULL) { - msg_warn ("cannot open %s, skip it", st->path); - cur = g_list_next (cur); - continue; - } + g_set_error (err, + winnow_error_quark(), /* error domain */ + 1, /* error code */ + "statfile %s is not opened this maybe if your statfile pool is too small to handle all statfiles", + st->path); + return FALSE; } - statfile_pool_lock_file (pool, data.file); g_tree_foreach (input, classify_callback, &data); - statfile_pool_unlock_file (pool, data.file); if (data.count != 0) { res = data.sum / data.count; } else { res = 0; } - if (file != data.file && res / max > learn_threshold) { + if (to_learn != data.file && res - max > 1 - learn_threshold) { /* Demote tokens in this statfile */ to_demote = g_list_prepend (to_demote, data.file); } - else if (file == data.file) { - sel_st = st; - } cur = g_list_next (cur); } } else { - msg_err ("learn threshold is less than 1, so cannot do learn, please check your configuration"); - return; + msg_err ("learn threshold is more than 1 or less than 0, so cannot do learn, please check your configuration"); + g_set_error (err, + winnow_error_quark(), /* error domain */ + 1, /* error code */ + "bad learn_threshold setting: %.2f", + learn_threshold); + return FALSE; } /* If to_demote list is empty this message is already classified correctly */ - if (max > ALPHA && to_demote == NULL) { + if (max > ALPHA && to_demote == NULL && !force_learn) { msg_info ("this message is already of class %s with threshold %.2f and weight %.2F", sel_st->symbol, learn_threshold, max); goto end; } + data.learn_file = to_learn; do { cur = ctx->cfg->statfiles; data.fresh_run = TRUE; @@ -434,12 +496,9 @@ winnow_learn (struct classifier_ctx *ctx, statfile_pool_t *pool, stat_file_t *fi st = cur->data; data.sum = 0; data.count = 0; + data.new_blocks = 0; if ((data.file = statfile_pool_is_open (pool, st->path)) == NULL) { - if ((data.file = statfile_pool_open (pool, st->path, st->size, FALSE)) == NULL) { - msg_warn ("cannot open %s, skip it", st->path); - cur = g_list_next (cur); - continue; - } + return FALSE; } if (to_demote != NULL && g_list_find (to_demote, data.file) != NULL) { data.do_demote = TRUE; @@ -470,14 +529,20 @@ winnow_learn (struct classifier_ctx *ctx, statfile_pool_t *pool, stat_file_t *fi else { data.multiplier *= WINNOW_PROMOTION; } - } while ((in_class ? sel != file : sel == file) && iterations ++ < MAX_LEARN_ITERATIONS); + } while ((in_class ? sel != to_learn : sel == to_learn) && iterations ++ < MAX_LEARN_ITERATIONS); if (iterations >= MAX_LEARN_ITERATIONS) { msg_warn ("learning statfile %s was not fully successfull: iterations count is limited to %d, final sum is %G", - file->filename, MAX_LEARN_ITERATIONS, max); + sel_st->symbol, MAX_LEARN_ITERATIONS, max); + g_set_error (err, + winnow_error_quark(), /* error domain */ + 1, /* error code */ + "learning statfile %s was not fully successfull: iterations count is limited to %d", + sel_st->symbol, MAX_LEARN_ITERATIONS); + return FALSE; } else { - msg_info ("learned statfile %s successfully with %d iterations and sum %G", file->filename, iterations + 1, max); + msg_info ("learned statfile %s successfully with %d iterations and sum %G", sel_st->symbol, iterations + 1, max); } @@ -485,4 +550,5 @@ end: if (sum) { *sum = (double)max; } + return TRUE; } diff --git a/src/controller.c b/src/controller.c index 6ef9f7e08..15b14cf8d 100644 --- a/src/controller.c +++ b/src/controller.c @@ -704,6 +704,7 @@ controller_read_socket (f_str_t * in, void *arg) struct mime_text_part *part; GList *comp_list, *cur = NULL; GTree *tokens = NULL; + GError *err = NULL; f_str_t c; double sum; @@ -818,26 +819,33 @@ controller_read_socket (f_str_t * in, void *arg) return TRUE; } + + /* Init classifier */ + cls_ctx = session->learn_classifier->classifier->init_func (session->session_pool, session->learn_classifier); /* Get or create statfile */ statfile = get_statfile_by_symbol (session->worker->srv->statfile_pool, session->learn_classifier, - session->learn_symbol, &st, TRUE); - if (statfile == NULL) { - msg_info ("learn failed for message <%s>, no statfile found: %s", task->message_id, session->learn_symbol); + session->learn_symbol, &st, TRUE); + + if (statfile == NULL || + ! session->learn_classifier->classifier->learn_func (cls_ctx, session->worker->srv->statfile_pool, + session->learn_symbol, tokens, session->in_class, &sum, + session->learn_multiplier, &err)) { + if (err) { + i = rspamd_snprintf (out_buf, sizeof (out_buf), "learn failed, learn classifier error: %s" CRLF, err->message); + msg_info ("learn failed for message <%s>, learn error: %s", task->message_id, err->message); + g_error_free (err); + } + else { + i = rspamd_snprintf (out_buf, sizeof (out_buf), "learn failed, unknown learn classifier error" CRLF); + msg_info ("learn failed for message <%s>, unknown learn error", task->message_id); + } free_task (task, FALSE); - i = rspamd_snprintf (out_buf, sizeof (out_buf), "learn failed, invalid symbol" CRLF); if (!rspamd_dispatcher_write (session->dispatcher, out_buf, i, FALSE, FALSE)) { return FALSE; } + session->state = STATE_REPLY; return TRUE; } - - /* Init classifier */ - cls_ctx = session->learn_classifier->classifier->init_func (session->session_pool, session->learn_classifier); - - /* XXX: remove this awful legacy */ - session->learn_classifier->classifier->learn_func (cls_ctx, session->worker->srv->statfile_pool, - statfile, tokens, session->in_class, &sum, - session->learn_multiplier); session->worker->srv->stat->messages_learned++; maybe_write_binlog (session->learn_classifier, st, statfile, tokens); diff --git a/src/filter.c b/src/filter.c index 9e2da0c57..90566ded9 100644 --- a/src/filter.c +++ b/src/filter.c @@ -438,7 +438,7 @@ process_autolearn (struct statfile *st, struct worker_task *task, GTree * tokens return; } - classifier->learn_func (ctx, task->worker->srv->statfile_pool, statfile, tokens, TRUE, NULL, 1.); + classifier->learn_func (ctx, task->worker->srv->statfile_pool, st->symbol, tokens, TRUE, NULL, 1., NULL); maybe_write_binlog (ctx->cfg, st, statfile, tokens); } } |