aboutsummaryrefslogtreecommitdiffstats
path: root/src/classifiers
diff options
context:
space:
mode:
authorVsevolod Stakhov <vsevolod@rambler-co.ru>2010-08-05 21:29:40 +0400
committerVsevolod Stakhov <vsevolod@rambler-co.ru>2010-08-05 21:29:40 +0400
commit5d0e4d334fef7f0fe683040d32e2a53b503315f2 (patch)
treed2f00bbb940564a4e77a287f84c7876e8fd9c009 /src/classifiers
parent80b5b55a53622875d4973ea1d440dc7fa916f20b (diff)
downloadrspamd-5d0e4d334fef7f0fe683040d32e2a53b503315f2.tar.gz
rspamd-5d0e4d334fef7f0fe683040d32e2a53b503315f2.zip
* Fixes to winnow learning
Diffstat (limited to 'src/classifiers')
-rw-r--r--src/classifiers/classifiers.h13
-rw-r--r--src/classifiers/winnow.c148
2 files changed, 114 insertions, 47 deletions
diff --git a/src/classifiers/classifiers.h b/src/classifiers/classifiers.h
index 02192d795..f69c1284c 100644
--- a/src/classifiers/classifiers.h
+++ b/src/classifiers/classifiers.h
@@ -24,9 +24,10 @@ struct classify_weight {
struct classifier {
char *name;
struct classifier_ctx* (*init_func)(memory_pool_t *pool, struct classifier_config *cf);
- void (*classify_func)(struct classifier_ctx* ctx, statfile_pool_t *pool, GTree *input, struct worker_task *task);
- void (*learn_func)(struct classifier_ctx* ctx, statfile_pool_t *pool,
- stat_file_t *file, GTree *input, gboolean in_class, double *sum, double multiplier);
+ gboolean (*classify_func)(struct classifier_ctx* ctx, statfile_pool_t *pool, GTree *input, struct worker_task *task);
+ gboolean (*learn_func)(struct classifier_ctx* ctx, statfile_pool_t *pool,
+ const char *symbol, GTree *input, gboolean in_class,
+ double *sum, double multiplier, GError **err);
GList* (*weights_func)(struct classifier_ctx* ctx, statfile_pool_t *pool, GTree *input, struct worker_task *task);
};
@@ -35,9 +36,9 @@ struct classifier* get_classifier (char *name);
/* Winnow algorithm */
struct classifier_ctx* winnow_init (memory_pool_t *pool, struct classifier_config *cf);
-void winnow_classify (struct classifier_ctx* ctx, statfile_pool_t *pool, GTree *input, struct worker_task *task);
-void winnow_learn (struct classifier_ctx* ctx, statfile_pool_t *pool, stat_file_t *file, GTree *input,
- gboolean in_class, double *sum, double multiplier);
+gboolean winnow_classify (struct classifier_ctx* ctx, statfile_pool_t *pool, GTree *input, struct worker_task *task);
+gboolean winnow_learn (struct classifier_ctx* ctx, statfile_pool_t *pool, const char *symbol, GTree *input,
+ gboolean in_class, double *sum, double multiplier, GError **err);
GList *winnow_weights (struct classifier_ctx* ctx, statfile_pool_t *pool, GTree *input, struct worker_task *task);
diff --git a/src/classifiers/winnow.c b/src/classifiers/winnow.c
index 41cb48e89..ab155ff8c 100644
--- a/src/classifiers/winnow.c
+++ b/src/classifiers/winnow.c
@@ -46,6 +46,12 @@
#define MAX_LEARN_ITERATIONS 100
+G_INLINE_FUNC GQuark
+winnow_error_quark (void)
+{
+ return g_quark_from_static_string ("winnow-error-quark");
+}
+
struct winnow_callback_data {
statfile_pool_t *pool;
struct classifier_ctx *ctx;
@@ -53,7 +59,8 @@ struct winnow_callback_data {
stat_file_t *learn_file;
long double sum;
double multiplier;
- int count;
+ guint32 count;
+ guint32 new_blocks;
gboolean in_class;
gboolean do_demote;
gboolean fresh_run;
@@ -62,6 +69,8 @@ struct winnow_callback_data {
static const double max_common_weight = MAX_WEIGHT * WINNOW_DEMOTION;
+
+
static gboolean
classify_callback (gpointer key, gpointer value, gpointer data)
{
@@ -73,10 +82,10 @@ classify_callback (gpointer key, gpointer value, gpointer data)
v = statfile_pool_get_block (cd->pool, cd->file, node->h1, node->h2, cd->now);
if (fabs (v) > ALPHA) {
cd->sum += v;
- cd->in_class++;
}
else {
cd->sum += 1.0;
+ cd->new_blocks ++;
}
cd->count++;
@@ -100,6 +109,7 @@ learn_callback (gpointer key, gpointer value, gpointer data)
if (cd->file == cd->learn_file) {
statfile_pool_set_block (cd->pool, cd->file, node->h1, node->h2, cd->now, c);
node->value = c;
+ cd->new_blocks ++;
}
}
else {
@@ -181,7 +191,7 @@ winnow_init (memory_pool_t * pool, struct classifier_config *cfg)
return ctx;
}
-void
+gboolean
winnow_classify (struct classifier_ctx *ctx, statfile_pool_t * pool, GTree * input, struct worker_task *task)
{
struct winnow_callback_data data;
@@ -203,7 +213,7 @@ winnow_classify (struct classifier_ctx *ctx, statfile_pool_t * pool, GTree * inp
nodes = g_tree_nnodes (input) / FEATURE_WINDOW_SIZE;
if (nodes < minnodes) {
msg_info ("do not classify message as it has too few tokens: %d, while %d min", nodes, minnodes);
- return;
+ return FALSE;
}
}
@@ -224,6 +234,7 @@ winnow_classify (struct classifier_ctx *ctx, statfile_pool_t * pool, GTree * inp
st = cur->data;
data.sum = 0;
data.count = 0;
+ data.new_blocks = 0;
if ((data.file = statfile_pool_is_open (pool, st->path)) == NULL) {
if ((data.file = statfile_pool_open (pool, st->path, st->size, FALSE)) == NULL) {
msg_warn ("cannot open %s, skip it", st->path);
@@ -233,9 +244,7 @@ winnow_classify (struct classifier_ctx *ctx, statfile_pool_t * pool, GTree * inp
}
if (data.file != NULL) {
- statfile_pool_lock_file (pool, data.file);
g_tree_foreach (input, classify_callback, &data);
- statfile_pool_unlock_file (pool, data.file);
}
if (data.count != 0) {
@@ -263,6 +272,8 @@ winnow_classify (struct classifier_ctx *ctx, statfile_pool_t * pool, GTree * inp
cur = g_list_prepend (NULL, sumbuf);
insert_result (task, sel->symbol, max, cur);
}
+
+ return TRUE;
}
GList *
@@ -306,9 +317,7 @@ winnow_weights (struct classifier_ctx *ctx, statfile_pool_t * pool, GTree * inpu
}
if (data.file != NULL) {
- statfile_pool_lock_file (pool, data.file);
g_tree_foreach (input, classify_callback, &data);
- statfile_pool_unlock_file (pool, data.file);
}
w = memory_pool_alloc0 (task->task_pool, sizeof (struct classify_weight));
@@ -333,8 +342,9 @@ winnow_weights (struct classifier_ctx *ctx, statfile_pool_t * pool, GTree * inpu
}
-void
-winnow_learn (struct classifier_ctx *ctx, statfile_pool_t *pool, stat_file_t *file, GTree * input, int in_class, double *sum, double multiplier)
+gboolean
+winnow_learn (struct classifier_ctx *ctx, statfile_pool_t *pool, const char *symbol,
+ GTree * input, int in_class, double *sum, double multiplier, GError **err)
{
struct winnow_callback_data data = {
.file = NULL,
@@ -343,10 +353,11 @@ winnow_learn (struct classifier_ctx *ctx, statfile_pool_t *pool, stat_file_t *fi
char *value;
int nodes, minnodes, iterations = 0;
struct statfile *st, *sel_st;
- stat_file_t *sel = NULL;
+ stat_file_t *sel = NULL, *to_learn;
long double res = 0., max = 0.;
- double learn_threshold = 1.0;
+ double learn_threshold = 0.0;
GList *cur, *to_demote = NULL;
+ gboolean force_learn = FALSE;
g_assert (pool != NULL);
g_assert (ctx != NULL);
@@ -355,7 +366,7 @@ winnow_learn (struct classifier_ctx *ctx, statfile_pool_t *pool, stat_file_t *fi
data.in_class = in_class;
data.now = time (NULL);
data.ctx = ctx;
- data.learn_file = file;
+
if (ctx->cfg->opts && (value = g_hash_table_lookup (ctx->cfg->opts, "min_tokens")) != NULL) {
minnodes = strtol (value, NULL, 10);
@@ -363,70 +374,121 @@ winnow_learn (struct classifier_ctx *ctx, statfile_pool_t *pool, stat_file_t *fi
if (nodes < minnodes) {
msg_info ("do not learn message as it has too few tokens: %d, while %d min", nodes, minnodes);
*sum = 0;
- return;
+ g_set_error (err,
+ winnow_error_quark(), /* error domain */
+ 1, /* error code */
+ "message contains too few tokens: %d, while min is %d",
+ nodes, minnodes);
+ return FALSE;
}
}
if (ctx->cfg->opts && (value = g_hash_table_lookup (ctx->cfg->opts, "learn_threshold")) != NULL) {
learn_threshold = strtod (value, NULL);
}
- if (learn_threshold >= 1.0) {
+ if (learn_threshold <= 1.0 && learn_threshold >= 0) {
/* Classify message and check target statfile score */
cur = ctx->cfg->statfiles;
+ while (cur) {
+ /* Open or create all statfiles inside classifier */
+ st = cur->data;
+ if (statfile_pool_is_open (pool, st->path) == NULL) {
+ if (statfile_pool_open (pool, st->path, st->size, FALSE) == NULL) {
+ msg_warn ("cannot open %s", st->path);
+ if (statfile_pool_create (pool, st->path, st->size) == -1) {
+ msg_err ("cannot create statfile %s", st->path);
+ g_set_error (err,
+ winnow_error_quark(), /* error domain */
+ 1, /* error code */
+ "cannot create statfile: %s",
+ st->path);
+ return FALSE;
+ }
+ if (statfile_pool_open (pool, st->path, st->size, FALSE)) {
+ g_set_error (err,
+ winnow_error_quark(), /* error domain */
+ 1, /* error code */
+ "open statfile %s after creation",
+ st->path);
+ msg_err ("cannot open statfile %s after creation", st->path);
+ return FALSE;
+ }
+ }
+ }
+ if (strcmp (st->symbol, symbol) == 0) {
+ sel_st = st;
+
+ }
+ cur = g_list_next (cur);
+ }
+ to_learn = statfile_pool_is_open (pool, sel_st->path);
+ if (to_learn == NULL) {
+ g_set_error (err,
+ winnow_error_quark(), /* error domain */
+ 1, /* error code */
+ "statfile %s is not opened this maybe if your statfile pool is too small to handle all statfiles",
+ sel_st->path);
+ return FALSE;
+ }
/* Check target statfile */
- data.file = file;
+ data.file = to_learn;
data.sum = 0;
data.count = 0;
- data.file = file;
- statfile_pool_lock_file (pool, data.file);
+ data.new_blocks = 0;
g_tree_foreach (input, classify_callback, &data);
- statfile_pool_unlock_file (pool, data.file);
if (data.count > 0) {
max = data.sum / (double)data.count;
}
else {
max = 0;
}
+ /* If most of blocks are not presented in targeted statfile do forced learn */
+ if ((data.new_blocks > 1 && (double)data.new_blocks / (double)data.count > 0.5) || max < 1 + learn_threshold) {
+ force_learn = TRUE;
+ }
+ /* Check other statfiles */
while (cur) {
st = cur->data;
data.sum = 0;
data.count = 0;
if ((data.file = statfile_pool_is_open (pool, st->path)) == NULL) {
- if ((data.file = statfile_pool_open (pool, st->path, st->size, FALSE)) == NULL) {
- msg_warn ("cannot open %s, skip it", st->path);
- cur = g_list_next (cur);
- continue;
- }
+ g_set_error (err,
+ winnow_error_quark(), /* error domain */
+ 1, /* error code */
+ "statfile %s is not opened this maybe if your statfile pool is too small to handle all statfiles",
+ st->path);
+ return FALSE;
}
- statfile_pool_lock_file (pool, data.file);
g_tree_foreach (input, classify_callback, &data);
- statfile_pool_unlock_file (pool, data.file);
if (data.count != 0) {
res = data.sum / data.count;
}
else {
res = 0;
}
- if (file != data.file && res / max > learn_threshold) {
+ if (to_learn != data.file && res - max > 1 - learn_threshold) {
/* Demote tokens in this statfile */
to_demote = g_list_prepend (to_demote, data.file);
}
- else if (file == data.file) {
- sel_st = st;
- }
cur = g_list_next (cur);
}
}
else {
- msg_err ("learn threshold is less than 1, so cannot do learn, please check your configuration");
- return;
+ msg_err ("learn threshold is more than 1 or less than 0, so cannot do learn, please check your configuration");
+ g_set_error (err,
+ winnow_error_quark(), /* error domain */
+ 1, /* error code */
+ "bad learn_threshold setting: %.2f",
+ learn_threshold);
+ return FALSE;
}
/* If to_demote list is empty this message is already classified correctly */
- if (max > ALPHA && to_demote == NULL) {
+ if (max > ALPHA && to_demote == NULL && !force_learn) {
msg_info ("this message is already of class %s with threshold %.2f and weight %.2F",
sel_st->symbol, learn_threshold, max);
goto end;
}
+ data.learn_file = to_learn;
do {
cur = ctx->cfg->statfiles;
data.fresh_run = TRUE;
@@ -434,12 +496,9 @@ winnow_learn (struct classifier_ctx *ctx, statfile_pool_t *pool, stat_file_t *fi
st = cur->data;
data.sum = 0;
data.count = 0;
+ data.new_blocks = 0;
if ((data.file = statfile_pool_is_open (pool, st->path)) == NULL) {
- if ((data.file = statfile_pool_open (pool, st->path, st->size, FALSE)) == NULL) {
- msg_warn ("cannot open %s, skip it", st->path);
- cur = g_list_next (cur);
- continue;
- }
+ return FALSE;
}
if (to_demote != NULL && g_list_find (to_demote, data.file) != NULL) {
data.do_demote = TRUE;
@@ -470,14 +529,20 @@ winnow_learn (struct classifier_ctx *ctx, statfile_pool_t *pool, stat_file_t *fi
else {
data.multiplier *= WINNOW_PROMOTION;
}
- } while ((in_class ? sel != file : sel == file) && iterations ++ < MAX_LEARN_ITERATIONS);
+ } while ((in_class ? sel != to_learn : sel == to_learn) && iterations ++ < MAX_LEARN_ITERATIONS);
if (iterations >= MAX_LEARN_ITERATIONS) {
msg_warn ("learning statfile %s was not fully successfull: iterations count is limited to %d, final sum is %G",
- file->filename, MAX_LEARN_ITERATIONS, max);
+ sel_st->symbol, MAX_LEARN_ITERATIONS, max);
+ g_set_error (err,
+ winnow_error_quark(), /* error domain */
+ 1, /* error code */
+ "learning statfile %s was not fully successfull: iterations count is limited to %d",
+ sel_st->symbol, MAX_LEARN_ITERATIONS);
+ return FALSE;
}
else {
- msg_info ("learned statfile %s successfully with %d iterations and sum %G", file->filename, iterations + 1, max);
+ msg_info ("learned statfile %s successfully with %d iterations and sum %G", sel_st->symbol, iterations + 1, max);
}
@@ -485,4 +550,5 @@ end:
if (sum) {
*sum = (double)max;
}
+ return TRUE;
}