diff options
author | Vsevolod Stakhov <vsevolod@rambler-co.ru> | 2010-08-05 21:29:40 +0400 |
---|---|---|
committer | Vsevolod Stakhov <vsevolod@rambler-co.ru> | 2010-08-05 21:29:40 +0400 |
commit | 5d0e4d334fef7f0fe683040d32e2a53b503315f2 (patch) | |
tree | d2f00bbb940564a4e77a287f84c7876e8fd9c009 /src/classifiers | |
parent | 80b5b55a53622875d4973ea1d440dc7fa916f20b (diff) | |
download | rspamd-5d0e4d334fef7f0fe683040d32e2a53b503315f2.tar.gz rspamd-5d0e4d334fef7f0fe683040d32e2a53b503315f2.zip |
* Fixes to winnow learning
Diffstat (limited to 'src/classifiers')
-rw-r--r-- | src/classifiers/classifiers.h | 13 | ||||
-rw-r--r-- | src/classifiers/winnow.c | 148 |
2 files changed, 114 insertions, 47 deletions
diff --git a/src/classifiers/classifiers.h b/src/classifiers/classifiers.h index 02192d795..f69c1284c 100644 --- a/src/classifiers/classifiers.h +++ b/src/classifiers/classifiers.h @@ -24,9 +24,10 @@ struct classify_weight { struct classifier { char *name; struct classifier_ctx* (*init_func)(memory_pool_t *pool, struct classifier_config *cf); - void (*classify_func)(struct classifier_ctx* ctx, statfile_pool_t *pool, GTree *input, struct worker_task *task); - void (*learn_func)(struct classifier_ctx* ctx, statfile_pool_t *pool, - stat_file_t *file, GTree *input, gboolean in_class, double *sum, double multiplier); + gboolean (*classify_func)(struct classifier_ctx* ctx, statfile_pool_t *pool, GTree *input, struct worker_task *task); + gboolean (*learn_func)(struct classifier_ctx* ctx, statfile_pool_t *pool, + const char *symbol, GTree *input, gboolean in_class, + double *sum, double multiplier, GError **err); GList* (*weights_func)(struct classifier_ctx* ctx, statfile_pool_t *pool, GTree *input, struct worker_task *task); }; @@ -35,9 +36,9 @@ struct classifier* get_classifier (char *name); /* Winnow algorithm */ struct classifier_ctx* winnow_init (memory_pool_t *pool, struct classifier_config *cf); -void winnow_classify (struct classifier_ctx* ctx, statfile_pool_t *pool, GTree *input, struct worker_task *task); -void winnow_learn (struct classifier_ctx* ctx, statfile_pool_t *pool, stat_file_t *file, GTree *input, - gboolean in_class, double *sum, double multiplier); +gboolean winnow_classify (struct classifier_ctx* ctx, statfile_pool_t *pool, GTree *input, struct worker_task *task); +gboolean winnow_learn (struct classifier_ctx* ctx, statfile_pool_t *pool, const char *symbol, GTree *input, + gboolean in_class, double *sum, double multiplier, GError **err); GList *winnow_weights (struct classifier_ctx* ctx, statfile_pool_t *pool, GTree *input, struct worker_task *task); diff --git a/src/classifiers/winnow.c b/src/classifiers/winnow.c index 41cb48e89..ab155ff8c 100644 --- a/src/classifiers/winnow.c +++ b/src/classifiers/winnow.c @@ -46,6 +46,12 @@ #define MAX_LEARN_ITERATIONS 100 +G_INLINE_FUNC GQuark +winnow_error_quark (void) +{ + return g_quark_from_static_string ("winnow-error-quark"); +} + struct winnow_callback_data { statfile_pool_t *pool; struct classifier_ctx *ctx; @@ -53,7 +59,8 @@ struct winnow_callback_data { stat_file_t *learn_file; long double sum; double multiplier; - int count; + guint32 count; + guint32 new_blocks; gboolean in_class; gboolean do_demote; gboolean fresh_run; @@ -62,6 +69,8 @@ struct winnow_callback_data { static const double max_common_weight = MAX_WEIGHT * WINNOW_DEMOTION; + + static gboolean classify_callback (gpointer key, gpointer value, gpointer data) { @@ -73,10 +82,10 @@ classify_callback (gpointer key, gpointer value, gpointer data) v = statfile_pool_get_block (cd->pool, cd->file, node->h1, node->h2, cd->now); if (fabs (v) > ALPHA) { cd->sum += v; - cd->in_class++; } else { cd->sum += 1.0; + cd->new_blocks ++; } cd->count++; @@ -100,6 +109,7 @@ learn_callback (gpointer key, gpointer value, gpointer data) if (cd->file == cd->learn_file) { statfile_pool_set_block (cd->pool, cd->file, node->h1, node->h2, cd->now, c); node->value = c; + cd->new_blocks ++; } } else { @@ -181,7 +191,7 @@ winnow_init (memory_pool_t * pool, struct classifier_config *cfg) return ctx; } -void +gboolean winnow_classify (struct classifier_ctx *ctx, statfile_pool_t * pool, GTree * input, struct worker_task *task) { struct winnow_callback_data data; @@ -203,7 +213,7 @@ winnow_classify (struct classifier_ctx *ctx, statfile_pool_t * pool, GTree * inp nodes = g_tree_nnodes (input) / FEATURE_WINDOW_SIZE; if (nodes < minnodes) { msg_info ("do not classify message as it has too few tokens: %d, while %d min", nodes, minnodes); - return; + return FALSE; } } @@ -224,6 +234,7 @@ winnow_classify (struct classifier_ctx *ctx, statfile_pool_t * pool, GTree * inp st = cur->data; data.sum = 0; data.count = 0; + data.new_blocks = 0; if ((data.file = statfile_pool_is_open (pool, st->path)) == NULL) { if ((data.file = statfile_pool_open (pool, st->path, st->size, FALSE)) == NULL) { msg_warn ("cannot open %s, skip it", st->path); @@ -233,9 +244,7 @@ winnow_classify (struct classifier_ctx *ctx, statfile_pool_t * pool, GTree * inp } if (data.file != NULL) { - statfile_pool_lock_file (pool, data.file); g_tree_foreach (input, classify_callback, &data); - statfile_pool_unlock_file (pool, data.file); } if (data.count != 0) { @@ -263,6 +272,8 @@ winnow_classify (struct classifier_ctx *ctx, statfile_pool_t * pool, GTree * inp cur = g_list_prepend (NULL, sumbuf); insert_result (task, sel->symbol, max, cur); } + + return TRUE; } GList * @@ -306,9 +317,7 @@ winnow_weights (struct classifier_ctx *ctx, statfile_pool_t * pool, GTree * inpu } if (data.file != NULL) { - statfile_pool_lock_file (pool, data.file); g_tree_foreach (input, classify_callback, &data); - statfile_pool_unlock_file (pool, data.file); } w = memory_pool_alloc0 (task->task_pool, sizeof (struct classify_weight)); @@ -333,8 +342,9 @@ winnow_weights (struct classifier_ctx *ctx, statfile_pool_t * pool, GTree * inpu } -void -winnow_learn (struct classifier_ctx *ctx, statfile_pool_t *pool, stat_file_t *file, GTree * input, int in_class, double *sum, double multiplier) +gboolean +winnow_learn (struct classifier_ctx *ctx, statfile_pool_t *pool, const char *symbol, + GTree * input, int in_class, double *sum, double multiplier, GError **err) { struct winnow_callback_data data = { .file = NULL, @@ -343,10 +353,11 @@ winnow_learn (struct classifier_ctx *ctx, statfile_pool_t *pool, stat_file_t *fi char *value; int nodes, minnodes, iterations = 0; struct statfile *st, *sel_st; - stat_file_t *sel = NULL; + stat_file_t *sel = NULL, *to_learn; long double res = 0., max = 0.; - double learn_threshold = 1.0; + double learn_threshold = 0.0; GList *cur, *to_demote = NULL; + gboolean force_learn = FALSE; g_assert (pool != NULL); g_assert (ctx != NULL); @@ -355,7 +366,7 @@ winnow_learn (struct classifier_ctx *ctx, statfile_pool_t *pool, stat_file_t *fi data.in_class = in_class; data.now = time (NULL); data.ctx = ctx; - data.learn_file = file; + if (ctx->cfg->opts && (value = g_hash_table_lookup (ctx->cfg->opts, "min_tokens")) != NULL) { minnodes = strtol (value, NULL, 10); @@ -363,70 +374,121 @@ winnow_learn (struct classifier_ctx *ctx, statfile_pool_t *pool, stat_file_t *fi if (nodes < minnodes) { msg_info ("do not learn message as it has too few tokens: %d, while %d min", nodes, minnodes); *sum = 0; - return; + g_set_error (err, + winnow_error_quark(), /* error domain */ + 1, /* error code */ + "message contains too few tokens: %d, while min is %d", + nodes, minnodes); + return FALSE; } } if (ctx->cfg->opts && (value = g_hash_table_lookup (ctx->cfg->opts, "learn_threshold")) != NULL) { learn_threshold = strtod (value, NULL); } - if (learn_threshold >= 1.0) { + if (learn_threshold <= 1.0 && learn_threshold >= 0) { /* Classify message and check target statfile score */ cur = ctx->cfg->statfiles; + while (cur) { + /* Open or create all statfiles inside classifier */ + st = cur->data; + if (statfile_pool_is_open (pool, st->path) == NULL) { + if (statfile_pool_open (pool, st->path, st->size, FALSE) == NULL) { + msg_warn ("cannot open %s", st->path); + if (statfile_pool_create (pool, st->path, st->size) == -1) { + msg_err ("cannot create statfile %s", st->path); + g_set_error (err, + winnow_error_quark(), /* error domain */ + 1, /* error code */ + "cannot create statfile: %s", + st->path); + return FALSE; + } + if (statfile_pool_open (pool, st->path, st->size, FALSE)) { + g_set_error (err, + winnow_error_quark(), /* error domain */ + 1, /* error code */ + "open statfile %s after creation", + st->path); + msg_err ("cannot open statfile %s after creation", st->path); + return FALSE; + } + } + } + if (strcmp (st->symbol, symbol) == 0) { + sel_st = st; + + } + cur = g_list_next (cur); + } + to_learn = statfile_pool_is_open (pool, sel_st->path); + if (to_learn == NULL) { + g_set_error (err, + winnow_error_quark(), /* error domain */ + 1, /* error code */ + "statfile %s is not opened this maybe if your statfile pool is too small to handle all statfiles", + sel_st->path); + return FALSE; + } /* Check target statfile */ - data.file = file; + data.file = to_learn; data.sum = 0; data.count = 0; - data.file = file; - statfile_pool_lock_file (pool, data.file); + data.new_blocks = 0; g_tree_foreach (input, classify_callback, &data); - statfile_pool_unlock_file (pool, data.file); if (data.count > 0) { max = data.sum / (double)data.count; } else { max = 0; } + /* If most of blocks are not presented in targeted statfile do forced learn */ + if ((data.new_blocks > 1 && (double)data.new_blocks / (double)data.count > 0.5) || max < 1 + learn_threshold) { + force_learn = TRUE; + } + /* Check other statfiles */ while (cur) { st = cur->data; data.sum = 0; data.count = 0; if ((data.file = statfile_pool_is_open (pool, st->path)) == NULL) { - if ((data.file = statfile_pool_open (pool, st->path, st->size, FALSE)) == NULL) { - msg_warn ("cannot open %s, skip it", st->path); - cur = g_list_next (cur); - continue; - } + g_set_error (err, + winnow_error_quark(), /* error domain */ + 1, /* error code */ + "statfile %s is not opened this maybe if your statfile pool is too small to handle all statfiles", + st->path); + return FALSE; } - statfile_pool_lock_file (pool, data.file); g_tree_foreach (input, classify_callback, &data); - statfile_pool_unlock_file (pool, data.file); if (data.count != 0) { res = data.sum / data.count; } else { res = 0; } - if (file != data.file && res / max > learn_threshold) { + if (to_learn != data.file && res - max > 1 - learn_threshold) { /* Demote tokens in this statfile */ to_demote = g_list_prepend (to_demote, data.file); } - else if (file == data.file) { - sel_st = st; - } cur = g_list_next (cur); } } else { - msg_err ("learn threshold is less than 1, so cannot do learn, please check your configuration"); - return; + msg_err ("learn threshold is more than 1 or less than 0, so cannot do learn, please check your configuration"); + g_set_error (err, + winnow_error_quark(), /* error domain */ + 1, /* error code */ + "bad learn_threshold setting: %.2f", + learn_threshold); + return FALSE; } /* If to_demote list is empty this message is already classified correctly */ - if (max > ALPHA && to_demote == NULL) { + if (max > ALPHA && to_demote == NULL && !force_learn) { msg_info ("this message is already of class %s with threshold %.2f and weight %.2F", sel_st->symbol, learn_threshold, max); goto end; } + data.learn_file = to_learn; do { cur = ctx->cfg->statfiles; data.fresh_run = TRUE; @@ -434,12 +496,9 @@ winnow_learn (struct classifier_ctx *ctx, statfile_pool_t *pool, stat_file_t *fi st = cur->data; data.sum = 0; data.count = 0; + data.new_blocks = 0; if ((data.file = statfile_pool_is_open (pool, st->path)) == NULL) { - if ((data.file = statfile_pool_open (pool, st->path, st->size, FALSE)) == NULL) { - msg_warn ("cannot open %s, skip it", st->path); - cur = g_list_next (cur); - continue; - } + return FALSE; } if (to_demote != NULL && g_list_find (to_demote, data.file) != NULL) { data.do_demote = TRUE; @@ -470,14 +529,20 @@ winnow_learn (struct classifier_ctx *ctx, statfile_pool_t *pool, stat_file_t *fi else { data.multiplier *= WINNOW_PROMOTION; } - } while ((in_class ? sel != file : sel == file) && iterations ++ < MAX_LEARN_ITERATIONS); + } while ((in_class ? sel != to_learn : sel == to_learn) && iterations ++ < MAX_LEARN_ITERATIONS); if (iterations >= MAX_LEARN_ITERATIONS) { msg_warn ("learning statfile %s was not fully successfull: iterations count is limited to %d, final sum is %G", - file->filename, MAX_LEARN_ITERATIONS, max); + sel_st->symbol, MAX_LEARN_ITERATIONS, max); + g_set_error (err, + winnow_error_quark(), /* error domain */ + 1, /* error code */ + "learning statfile %s was not fully successfull: iterations count is limited to %d", + sel_st->symbol, MAX_LEARN_ITERATIONS); + return FALSE; } else { - msg_info ("learned statfile %s successfully with %d iterations and sum %G", file->filename, iterations + 1, max); + msg_info ("learned statfile %s successfully with %d iterations and sum %G", sel_st->symbol, iterations + 1, max); } @@ -485,4 +550,5 @@ end: if (sum) { *sum = (double)max; } + return TRUE; } |