diff options
Diffstat (limited to 'src/libstat/classifiers/winnow.c')
-rw-r--r-- | src/libstat/classifiers/winnow.c | 694 |
1 files changed, 694 insertions, 0 deletions
diff --git a/src/libstat/classifiers/winnow.c b/src/libstat/classifiers/winnow.c new file mode 100644 index 000000000..68d456968 --- /dev/null +++ b/src/libstat/classifiers/winnow.c @@ -0,0 +1,694 @@ +/* + * Copyright (c) 2009-2012, Vsevolod Stakhov + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY AUTHOR ''AS IS'' AND ANY + * EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL AUTHOR BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ + +/* + * Winnow classifier + */ + +#include "classifiers.h" +#include "tokenizers.h" +#include "main.h" +#include "filter.h" +#include "cfg_file.h" +#include "lua/lua_common.h" + +#define WINNOW_PROMOTION 1.23 +#define WINNOW_DEMOTION 0.83 + +#define MEDIAN_WINDOW_SIZE 5 + +#define MAX_WEIGHT G_MAXDOUBLE / 2. + + + +#define MAX_LEARN_ITERATIONS 100 + +static inline GQuark +winnow_error_quark (void) +{ + return g_quark_from_static_string ("winnow-error"); +} + +struct winnow_callback_data { + statfile_pool_t *pool; + struct classifier_ctx *ctx; + stat_file_t *file; + stat_file_t *learn_file; + long double sum; + long double start; + double multiplier; + guint32 count; + guint32 new_blocks; + gboolean in_class; + gboolean do_demote; + gboolean fresh_run; + time_t now; +}; + +static const double max_common_weight = MAX_WEIGHT * WINNOW_DEMOTION; + + + +static gboolean +winnow_classify_callback (gpointer key, gpointer value, gpointer data) +{ + token_node_t *node = key; + struct winnow_callback_data *cd = data; + double v; + + /* Consider that not found blocks have value 1 */ + v = + statfile_pool_get_block (cd->pool, cd->file, node->h1, node->h2, + cd->now); + if (fabs (v) > ALPHA) { + cd->sum += v; + } + else { + cd->sum += 1.0; + cd->new_blocks++; + } + + cd->count++; + + return FALSE; +} + +static gboolean +winnow_learn_callback (gpointer key, gpointer value, gpointer data) +{ + token_node_t *node = key; + struct winnow_callback_data *cd = data; + double v, c; + + c = (cd->in_class) ? WINNOW_PROMOTION * cd->multiplier : WINNOW_DEMOTION / + cd->multiplier; + + /* Consider that not found blocks have value 1 */ + v = + statfile_pool_get_block (cd->pool, cd->file, node->h1, node->h2, + cd->now); + if (fabs (v) < ALPHA) { + /* Block not found, insert new */ + cd->start += 1; + if (cd->file == cd->learn_file) { + statfile_pool_set_block (cd->pool, + cd->file, + node->h1, + node->h2, + cd->now, + c); + node->value = c; + cd->new_blocks++; + } + } + else { + cd->start += v; + /* Here we just increase the extra value of block */ + if (cd->fresh_run) { + node->extra = 0; + } + else { + node->extra++; + } + node->value = v; + + if (node->extra > 1) { + /* + * Assume that this node is common for several statfiles, so + * decrease its weight proportianally + */ + if (node->value > max_common_weight) { + /* Static fluctuation */ + statfile_pool_set_block (cd->pool, + cd->file, + node->h1, + node->h2, + cd->now, + 0.); + node->value = 0.; + } + else if (node->value > WINNOW_PROMOTION * cd->multiplier) { + /* Try to decrease its value */ + /* XXX: it is more intelligent to add some adaptive filter here */ + if (cd->file == cd->learn_file) { + if (node->value > max_common_weight / 2.) { + node->value *= c; + } + else { + /* + * Too high token value that exists also in other + * statfiles, may be statistic error, so decrease it + * slightly + */ + node->value *= WINNOW_DEMOTION; + } + } + else { + node->value = WINNOW_DEMOTION / cd->multiplier; + } + statfile_pool_set_block (cd->pool, + cd->file, + node->h1, + node->h2, + cd->now, + node->value); + } + } + else if (cd->file == cd->learn_file) { + /* New block or block that is in only one statfile */ + /* Set some limit on growing */ + if (v > MAX_WEIGHT) { + node->value = v; + } + else { + node->value *= c; + } + statfile_pool_set_block (cd->pool, + cd->file, + node->h1, + node->h2, + cd->now, + node->value); + } + else if (cd->do_demote) { + /* Demote blocks in file */ + node->value *= WINNOW_DEMOTION / cd->multiplier; + statfile_pool_set_block (cd->pool, + cd->file, + node->h1, + node->h2, + cd->now, + node->value); + } + } + + + cd->sum += node->value; + + cd->count++; + + return FALSE; +} + +struct classifier_ctx * +winnow_init (rspamd_mempool_t * pool, struct rspamd_classifier_config *cfg) +{ + struct classifier_ctx *ctx = + rspamd_mempool_alloc (pool, sizeof (struct classifier_ctx)); + + ctx->pool = pool; + ctx->cfg = cfg; + + return ctx; +} + +gboolean +winnow_classify (struct classifier_ctx *ctx, + statfile_pool_t * pool, + GTree * input, + struct rspamd_task *task, + lua_State *L) +{ + struct winnow_callback_data data; + char *sumbuf, *value; + long double res = 0., max = 0.; + GList *cur; + struct rspamd_statfile_config *st, *sel = NULL; + int nodes, minnodes; + + g_assert (pool != NULL); + g_assert (ctx != NULL); + + data.pool = pool; + data.now = time (NULL); + data.ctx = ctx; + + if (ctx->cfg->opts && + (value = g_hash_table_lookup (ctx->cfg->opts, "min_tokens")) != NULL) { + minnodes = strtol (value, NULL, 10); + nodes = g_tree_nnodes (input); + if (nodes > FEATURE_WINDOW_SIZE) { + nodes = nodes / FEATURE_WINDOW_SIZE + FEATURE_WINDOW_SIZE; + } + if (nodes < minnodes) { + msg_info ( + "do not classify message as it has too few tokens: %d, while %d min", + nodes, + minnodes); + return FALSE; + } + } + + cur = rspamd_lua_call_cls_pre_callbacks (ctx->cfg, task, FALSE, FALSE, L); + if (cur) { + rspamd_mempool_add_destructor (task->task_pool, + (rspamd_mempool_destruct_t)g_list_free, cur); + } + else { + cur = ctx->cfg->statfiles; + } + + while (cur) { + st = cur->data; + data.sum = 0; + data.count = 0; + data.new_blocks = 0; + if ((data.file = statfile_pool_is_open (pool, st->path)) == NULL) { + if ((data.file = + statfile_pool_open (pool, st->path, st->size, FALSE)) == NULL) { + msg_warn ("cannot open %s, skip it", st->path); + cur = g_list_next (cur); + continue; + } + } + + if (data.file != NULL) { + g_tree_foreach (input, winnow_classify_callback, &data); + } + + if (data.count != 0) { + res = data.sum / (double)data.count; + } + else { + res = 0; + } + if (res > max) { + max = res; + sel = st; + } + cur = g_list_next (cur); + } + + if (sel != NULL) { +#ifdef WITH_LUA + max = rspamd_lua_call_cls_post_callbacks (ctx->cfg, task, max, L); +#endif +#ifdef HAVE_TANHL + max = tanhl (max); +#else + /* + * As some implementations of libm does not support tanhl, try to use + * tanh + */ + max = tanh ((double) max); +#endif + sumbuf = rspamd_mempool_alloc (task->task_pool, 32); + rspamd_snprintf (sumbuf, 32, "%.2F", max); + cur = g_list_prepend (NULL, sumbuf); + rspamd_task_insert_result (task, sel->symbol, max, cur); + } + + return TRUE; +} + +GList * +winnow_weights (struct classifier_ctx *ctx, + statfile_pool_t * pool, + GTree * input, + struct rspamd_task *task) +{ + struct winnow_callback_data data; + long double res = 0.; + GList *cur, *resl = NULL; + struct rspamd_statfile_config *st; + struct classify_weight *w; + char *value; + int nodes, minnodes; + + g_assert (pool != NULL); + g_assert (ctx != NULL); + + data.pool = pool; + data.now = time (NULL); + data.ctx = ctx; + + if (ctx->cfg->opts && + (value = g_hash_table_lookup (ctx->cfg->opts, "min_tokens")) != NULL) { + minnodes = strtol (value, NULL, 10); + nodes = g_tree_nnodes (input); + if (nodes > FEATURE_WINDOW_SIZE) { + nodes = nodes / FEATURE_WINDOW_SIZE + FEATURE_WINDOW_SIZE; + } + if (nodes < minnodes) { + msg_info ( + "do not classify message as it has too few tokens: %d, while %d min", + nodes, + minnodes); + return NULL; + } + } + + cur = ctx->cfg->statfiles; + while (cur) { + st = cur->data; + data.sum = 0; + data.count = 0; + if ((data.file = statfile_pool_is_open (pool, st->path)) == NULL) { + if ((data.file = + statfile_pool_open (pool, st->path, st->size, FALSE)) == NULL) { + msg_warn ("cannot open %s, skip it", st->path); + cur = g_list_next (cur); + continue; + } + } + + if (data.file != NULL) { + g_tree_foreach (input, winnow_classify_callback, &data); + } + + w = + rspamd_mempool_alloc0 (task->task_pool, + sizeof (struct classify_weight)); + if (data.count != 0) { + res = data.sum / (double)data.count; + } + else { + res = 0; + } + w->name = st->symbol; + w->weight = res; + resl = g_list_prepend (resl, w); + cur = g_list_next (cur); + } + + if (resl != NULL) { + rspamd_mempool_add_destructor (task->task_pool, + (rspamd_mempool_destruct_t)g_list_free, resl); + } + + return resl; + +} + + +gboolean +winnow_learn (struct classifier_ctx *ctx, + statfile_pool_t *pool, + const char *symbol, + GTree * input, + int in_class, + double *sum, + double multiplier, + GError **err) +{ + struct winnow_callback_data data = { + .file = NULL, + .multiplier = multiplier + }; + char *value; + int nodes, minnodes, iterations = 0; + struct rspamd_statfile_config *st, *sel_st = NULL; + stat_file_t *sel = NULL, *to_learn; + long double res = 0., max = 0., start_value = 0., end_value = 0.; + double learn_threshold = 0.0; + GList *cur, *to_demote = NULL; + gboolean force_learn = FALSE; + + g_assert (pool != NULL); + g_assert (ctx != NULL); + + data.pool = pool; + data.in_class = in_class; + data.now = time (NULL); + data.ctx = ctx; + + + if (ctx->cfg->opts && + (value = g_hash_table_lookup (ctx->cfg->opts, "min_tokens")) != NULL) { + minnodes = strtol (value, NULL, 10); + nodes = g_tree_nnodes (input); + if (nodes > FEATURE_WINDOW_SIZE) { + nodes = nodes / FEATURE_WINDOW_SIZE + FEATURE_WINDOW_SIZE; + } + if (nodes < minnodes) { + msg_info ( + "do not learn message as it has too few tokens: %d, while %d min", + nodes, + minnodes); + if (sum != NULL) { + *sum = 0; + } + g_set_error (err, + winnow_error_quark (), /* error domain */ + 1, /* error code */ + "message contains too few tokens: %d, while min is %d", + nodes, minnodes); + return FALSE; + } + } + if (ctx->cfg->opts && + (value = + g_hash_table_lookup (ctx->cfg->opts, "learn_threshold")) != NULL) { + learn_threshold = strtod (value, NULL); + } + + if (learn_threshold <= 1.0 && learn_threshold >= 0) { + /* Classify message and check target statfile score */ + cur = ctx->cfg->statfiles; + while (cur) { + /* Open or create all statfiles inside classifier */ + st = cur->data; + if (statfile_pool_is_open (pool, st->path) == NULL) { + if (statfile_pool_open (pool, st->path, st->size, + FALSE) == NULL) { + msg_warn ("cannot open %s", st->path); + if (statfile_pool_create (pool, st->path, st->size) == -1) { + msg_err ("cannot create statfile %s", st->path); + g_set_error (err, + winnow_error_quark (), /* error domain */ + 1, /* error code */ + "cannot create statfile: %s", + st->path); + return FALSE; + } + if (statfile_pool_open (pool, st->path, st->size, + FALSE) == NULL) { + g_set_error (err, + winnow_error_quark (), /* error domain */ + 1, /* error code */ + "open statfile %s after creation", + st->path); + msg_err ("cannot open statfile %s after creation", + st->path); + return FALSE; + } + } + } + if (strcmp (st->symbol, symbol) == 0) { + sel_st = st; + + } + cur = g_list_next (cur); + } + + if (sel_st == NULL) { + g_set_error (err, + winnow_error_quark (), /* error domain */ + 1, /* error code */ + "cannot find statfile for symbol %s", + symbol); + msg_err ("cannot find statfile for symbol %s", symbol); + return FALSE; + } + + to_learn = statfile_pool_is_open (pool, sel_st->path); + if (to_learn == NULL) { + g_set_error (err, + winnow_error_quark (), /* error domain */ + 1, /* error code */ + "statfile %s is not opened this maybe if your statfile pool is too small to handle all statfiles", + sel_st->path); + return FALSE; + } + /* Check target statfile */ + data.file = to_learn; + data.sum = 0; + data.count = 0; + data.new_blocks = 0; + g_tree_foreach (input, winnow_classify_callback, &data); + if (data.count > 0) { + max = data.sum / (double)data.count; + } + else { + max = 0; + } + /* If most of blocks are not presented in targeted statfile do forced learn */ + if (max < 1 + learn_threshold) { + force_learn = TRUE; + } + /* Check other statfiles */ + while (cur) { + st = cur->data; + data.sum = 0; + data.count = 0; + if ((data.file = statfile_pool_is_open (pool, st->path)) == NULL) { + g_set_error (err, + winnow_error_quark (), /* error domain */ + 1, /* error code */ + "statfile %s is not opened this maybe if your statfile pool is too small to handle all statfiles", + st->path); + return FALSE; + } + g_tree_foreach (input, winnow_classify_callback, &data); + if (data.count != 0) { + res = data.sum / data.count; + } + else { + res = 0; + } + if (to_learn != data.file && res - max > 1 - learn_threshold) { + /* Demote tokens in this statfile */ + to_demote = g_list_prepend (to_demote, data.file); + } + cur = g_list_next (cur); + } + } + else { + msg_err ( + "learn threshold is more than 1 or less than 0, so cannot do learn, please check your configuration"); + g_set_error (err, + winnow_error_quark (), /* error domain */ + 1, /* error code */ + "bad learn_threshold setting: %.2f", + learn_threshold); + return FALSE; + } + /* If to_demote list is empty this message is already classified correctly */ + if (max > WINNOW_PROMOTION && to_demote == NULL && !force_learn) { + msg_info ( + "this message is already of class %s with threshold %.2f and weight %.2F", + sel_st->symbol, + learn_threshold, + max); + goto end; + } + data.learn_file = to_learn; + end_value = max; + do { + cur = ctx->cfg->statfiles; + data.fresh_run = TRUE; + while (cur) { + st = cur->data; + data.sum = 0; + data.count = 0; + data.new_blocks = 0; + data.start = 0; + if ((data.file = statfile_pool_is_open (pool, st->path)) == NULL) { + return FALSE; + } + if (to_demote != NULL && + g_list_find (to_demote, data.file) != NULL) { + data.do_demote = TRUE; + } + else { + data.do_demote = FALSE; + } + + statfile_pool_lock_file (pool, data.file); + g_tree_foreach (input, winnow_learn_callback, &data); + statfile_pool_unlock_file (pool, data.file); + if (data.count != 0) { + res = data.sum / data.count; + } + else { + res = 0; + } + if (res > max) { + max = res; + sel = data.file; + } + if (data.file == to_learn) { + if (data.count > 0) { + start_value = data.start / data.count; + } + end_value = res; + } + cur = g_list_next (cur); + data.fresh_run = FALSE; + } + + data.multiplier *= WINNOW_PROMOTION; + msg_info ( + "learn iteration %d for statfile %s: %G -> %G, multiplier: %.2f", + iterations + 1, + symbol, + start_value, + end_value, + data.multiplier); + } while ((in_class ? sel != to_learn : sel == + to_learn) && iterations++ < MAX_LEARN_ITERATIONS); + + if (iterations >= MAX_LEARN_ITERATIONS) { + msg_warn ( + "learning statfile %s was not fully successfull: iterations count is limited to %d, final sum is %G", + sel_st->symbol, + MAX_LEARN_ITERATIONS, + max); + g_set_error (err, + winnow_error_quark (), /* error domain */ + 1, /* error code */ + "learning statfile %s was not fully successfull: iterations count is limited to %d", + sel_st->symbol, MAX_LEARN_ITERATIONS); + return FALSE; + } + else { + msg_info ( + "learned statfile %s successfully with %d iterations and sum %G", + sel_st->symbol, + iterations + 1, + max); + } + + +end: + if (sum) { +#ifdef HAVE_TANHL + *sum = (double)tanhl (max); +#else + /* + * As some implementations of libm does not support tanhl, try to use + * tanh + */ + *sum = tanh ((double) max); +#endif + } + return TRUE; +} + +gboolean +winnow_learn_spam (struct classifier_ctx * ctx, + statfile_pool_t *pool, + GTree *input, + struct rspamd_task *task, + gboolean is_spam, + lua_State *L, + GError **err) +{ + g_set_error (err, + winnow_error_quark (), /* error domain */ + 1, /* error code */ + "learn spam is not supported for winnow" + ); + return FALSE; +} |