diff options
Diffstat (limited to 'src/libstat')
-rw-r--r-- | src/libstat/backends/redis_backend.c | 246 | ||||
-rw-r--r-- | src/libstat/classifiers/bayes.c | 145 | ||||
-rw-r--r-- | src/libstat/classifiers/classifiers.h | 38 | ||||
-rw-r--r-- | src/libstat/classifiers/lua_classifier.c | 15 | ||||
-rw-r--r-- | src/libstat/learn_cache/redis_cache.c | 220 | ||||
-rw-r--r-- | src/libstat/learn_cache/sqlite3_cache.c | 2 | ||||
-rw-r--r-- | src/libstat/stat_api.h | 25 | ||||
-rw-r--r-- | src/libstat/stat_config.c | 78 | ||||
-rw-r--r-- | src/libstat/stat_internal.h | 4 | ||||
-rw-r--r-- | src/libstat/stat_process.c | 366 | ||||
-rw-r--r-- | src/libstat/tokenizers/osb.c | 49 | ||||
-rw-r--r-- | src/libstat/tokenizers/tokenizers.c | 521 | ||||
-rw-r--r-- | src/libstat/tokenizers/tokenizers.h | 37 |
13 files changed, 1054 insertions, 692 deletions
diff --git a/src/libstat/backends/redis_backend.c b/src/libstat/backends/redis_backend.c index 74d8c3bf1..b003d5a27 100644 --- a/src/libstat/backends/redis_backend.c +++ b/src/libstat/backends/redis_backend.c @@ -42,9 +42,9 @@ INIT_LOG_MODULE(stat_redis) #define REDIS_STAT_TIMEOUT 30 struct redis_stat_ctx { + lua_State *L; struct rspamd_statfile_config *stcf; - struct upstream_list *read_servers; - struct upstream_list *write_servers; + gint conf_ref; struct rspamd_stat_async_elt *stat_elt; const gchar *redis_object; const gchar *password; @@ -104,12 +104,29 @@ struct rspamd_redis_stat_cbdata { #define GET_TASK_ELT(task, elt) (task == NULL ? NULL : (task)->elt) +static const gchar *M = "redis statistics"; + static GQuark rspamd_redis_stat_quark (void) { - return g_quark_from_static_string ("redis-statistics"); + return g_quark_from_static_string (M); } +static inline struct upstream_list * +rspamd_redis_get_servers (struct redis_stat_ctx *ctx, + const gchar *what) +{ + lua_State *L = ctx->L; + struct upstream_list *res; + + lua_rawgeti (L, LUA_REGISTRYINDEX, ctx->conf_ref); + lua_pushstring (L, what); + lua_gettable (L, -2); + res = *((struct upstream_list**)lua_touserdata (L, -1)); + lua_settop (L, 0); + + return res; +} /* * Non-static for lua unit testing @@ -134,6 +151,7 @@ rspamd_redis_expand_object (const gchar *pattern, GString *tb; const gchar *rcpt = NULL; gint err_idx; + gboolean expansion_errored = FALSE; g_assert (ctx != NULL); stcf = ctx->stcf; @@ -196,6 +214,9 @@ rspamd_redis_expand_object (const gchar *pattern, if (elt) { tlen += strlen (elt); } + else { + expansion_errored = TRUE; + } break; case 'r': @@ -209,11 +230,15 @@ rspamd_redis_expand_object (const gchar *pattern, if (elt) { tlen += strlen (elt); } + else { + expansion_errored = TRUE; + } break; case 'l': if (stcf->label) { tlen += strlen (stcf->label); } + /* Label miss is OK */ break; case 's': if (ctx->new_schema) { @@ -251,7 +276,8 @@ rspamd_redis_expand_object (const gchar *pattern, } } - if (target == NULL || task == NULL) { + + if (target == NULL || task == NULL || expansion_errored) { return tlen; } @@ -501,14 +527,14 @@ rspamd_redis_tokens_to_query (struct rspamd_task *task, "HSET %b_tokens %b %b:%b", prefix, (size_t) prefix_len, n0, (size_t) l0, - tok->t1->begin, tok->t1->len, - tok->t2->begin, tok->t2->len); + tok->t1->stemmed.begin, tok->t1->stemmed.len, + tok->t2->stemmed.begin, tok->t2->stemmed.len); } else if (tok->t1) { redisAsyncCommand (rt->redis, NULL, NULL, "HSET %b_tokens %b %b", prefix, (size_t) prefix_len, n0, (size_t) l0, - tok->t1->begin, tok->t1->len); + tok->t1->stemmed.begin, tok->t1->stemmed.len); } } else { @@ -522,14 +548,14 @@ rspamd_redis_tokens_to_query (struct rspamd_task *task, "HSET %b %s %b:%b", n0, (size_t) l0, "tokens", - tok->t1->begin, tok->t1->len, - tok->t2->begin, tok->t2->len); + tok->t1->stemmed.begin, tok->t1->stemmed.len, + tok->t2->stemmed.begin, tok->t2->stemmed.len); } else if (tok->t1) { redisAsyncCommand (rt->redis, NULL, NULL, "HSET %b %s %b", n0, (size_t) l0, "tokens", - tok->t1->begin, tok->t1->len); + tok->t1->stemmed.begin, tok->t1->stemmed.len); } } @@ -928,6 +954,7 @@ rspamd_redis_async_stat_cb (struct rspamd_stat_async_elt *elt, gpointer d) struct rspamd_redis_stat_elt *redis_elt = elt->ud; struct rspamd_redis_stat_cbdata *cbdata; rspamd_inet_addr_t *addr; + struct upstream_list *ups; g_assert (redis_elt != NULL); @@ -941,8 +968,15 @@ rspamd_redis_async_stat_cb (struct rspamd_stat_async_elt *elt, gpointer d) /* Disable further events unless needed */ elt->enabled = FALSE; + ups = rspamd_redis_get_servers (ctx, "read_servers"); + + if (!ups) { + return; + } + cbdata = g_malloc0 (sizeof (*cbdata)); - cbdata->selected = rspamd_upstream_get (ctx->read_servers, + + cbdata->selected = rspamd_upstream_get (ups, RSPAMD_UPSTREAM_ROUND_ROBIN, NULL, 0); @@ -1231,78 +1265,6 @@ rspamd_redis_learned (redisAsyncContext *c, gpointer r, gpointer priv) rspamd_session_remove_event (task->s, rspamd_redis_fin_learn, rt); } } - -static gboolean -rspamd_redis_try_ucl (struct redis_stat_ctx *backend, - const ucl_object_t *obj, - struct rspamd_config *cfg, - const gchar *symbol) -{ - const ucl_object_t *elt, *relt; - - elt = ucl_object_lookup_any (obj, "read_servers", "servers", NULL); - - if (elt == NULL) { - return FALSE; - } - - backend->read_servers = rspamd_upstreams_create (cfg->ups_ctx); - if (!rspamd_upstreams_from_ucl (backend->read_servers, elt, - REDIS_DEFAULT_PORT, NULL)) { - msg_err ("statfile %s cannot get read servers configuration", - symbol); - return FALSE; - } - - relt = elt; - - elt = ucl_object_lookup (obj, "write_servers"); - if (elt == NULL) { - /* Use read servers as write ones */ - g_assert (relt != NULL); - backend->write_servers = rspamd_upstreams_create (cfg->ups_ctx); - if (!rspamd_upstreams_from_ucl (backend->write_servers, relt, - REDIS_DEFAULT_PORT, NULL)) { - msg_err ("statfile %s cannot get write servers configuration", - symbol); - return FALSE; - } - } - else { - backend->write_servers = rspamd_upstreams_create (cfg->ups_ctx); - if (!rspamd_upstreams_from_ucl (backend->write_servers, elt, - REDIS_DEFAULT_PORT, NULL)) { - msg_err ("statfile %s cannot get write servers configuration", - symbol); - rspamd_upstreams_destroy (backend->write_servers); - backend->write_servers = NULL; - } - } - - elt = ucl_object_lookup_any (obj, "db", "database", "dbname", NULL); - if (elt) { - if (ucl_object_type (elt) == UCL_STRING) { - backend->dbname = ucl_object_tostring (elt); - } - else if (ucl_object_type (elt) == UCL_INT) { - backend->dbname = ucl_object_tostring_forced (elt); - } - } - else { - backend->dbname = NULL; - } - - elt = ucl_object_lookup (obj, "password"); - if (elt) { - backend->password = ucl_object_tostring (elt); - } - else { - backend->password = NULL; - } - - return TRUE; -} - static void rspamd_redis_parse_classifier_opts (struct redis_stat_ctx *backend, const ucl_object_t *obj, @@ -1360,14 +1322,6 @@ rspamd_redis_parse_classifier_opts (struct redis_stat_ctx *backend, backend->redis_object = ucl_object_tostring (elt); } - elt = ucl_object_lookup (obj, "timeout"); - if (elt) { - backend->timeout = ucl_object_todouble (elt); - } - else { - backend->timeout = REDIS_DEFAULT_TIMEOUT; - } - elt = ucl_object_lookup (obj, "store_tokens"); if (elt) { backend->store_tokens = ucl_object_toboolean (elt); @@ -1414,24 +1368,27 @@ rspamd_redis_init (struct rspamd_stat_ctx *ctx, struct rspamd_redis_stat_elt *st_elt; const ucl_object_t *obj; gboolean ret = FALSE; + gint conf_ref = -1; + lua_State *L = (lua_State *)cfg->lua_state; backend = g_malloc0 (sizeof (*backend)); + backend->L = L; + backend->timeout = REDIS_DEFAULT_TIMEOUT; /* First search in backend configuration */ obj = ucl_object_lookup (st->classifier->cfg->opts, "backend"); if (obj != NULL && ucl_object_type (obj) == UCL_OBJECT) { - ret = rspamd_redis_try_ucl (backend, obj, cfg, stf->symbol); + ret = rspamd_lua_try_load_redis (L, obj, cfg, &conf_ref); } /* Now try statfiles config */ - if (!ret) { - ret = rspamd_redis_try_ucl (backend, stf->opts, cfg, stf->symbol); + if (!ret && stf->opts) { + ret = rspamd_lua_try_load_redis (L, stf->opts, cfg, &conf_ref); } /* Now try classifier config */ - if (!ret) { - ret = rspamd_redis_try_ucl (backend, st->classifier->cfg->opts, cfg, - stf->symbol); + if (!ret && st->classifier->cfg->opts) { + ret = rspamd_lua_try_load_redis (L, st->classifier->cfg->opts, cfg, &conf_ref); } /* Now try global redis settings */ @@ -1444,12 +1401,12 @@ rspamd_redis_init (struct rspamd_stat_ctx *ctx, specific_obj = ucl_object_lookup (obj, "statistics"); if (specific_obj) { - ret = rspamd_redis_try_ucl (backend, specific_obj, cfg, - stf->symbol); + ret = rspamd_lua_try_load_redis (L, + specific_obj, cfg, &conf_ref); } else { - ret = rspamd_redis_try_ucl (backend, obj, cfg, - stf->symbol); + ret = rspamd_lua_try_load_redis (L, + obj, cfg, &conf_ref); } } } @@ -1460,6 +1417,36 @@ rspamd_redis_init (struct rspamd_stat_ctx *ctx, return NULL; } + backend->conf_ref = conf_ref; + + /* Check some common table values */ + lua_rawgeti (L, LUA_REGISTRYINDEX, conf_ref); + + lua_pushstring (L, "timeout"); + lua_gettable (L, -2); + if (lua_type (L, -1) == LUA_TNUMBER) { + backend->timeout = lua_tonumber (L, -1); + } + lua_pop (L, 1); + + lua_pushstring (L, "db"); + lua_gettable (L, -2); + if (lua_type (L, -1) == LUA_TSTRING) { + backend->dbname = rspamd_mempool_strdup (cfg->cfg_pool, + lua_tostring (L, -1)); + } + lua_pop (L, 1); + + lua_pushstring (L, "password"); + lua_gettable (L, -2); + if (lua_type (L, -1) == LUA_TSTRING) { + backend->password = rspamd_mempool_strdup (cfg->cfg_pool, + lua_tostring (L, -1)); + } + lua_pop (L, 1); + + lua_settop (L, 0); + rspamd_redis_parse_classifier_opts (backend, st->classifier->cfg->opts, cfg); stf->clcf->flags |= RSPAMD_FLAG_CLASSIFIER_INCREMENTING_BACKEND; backend->stcf = stf; @@ -1485,24 +1472,35 @@ rspamd_redis_runtime (struct rspamd_task *task, struct redis_stat_ctx *ctx = REDIS_CTX (c); struct redis_stat_runtime *rt; struct upstream *up; + struct upstream_list *ups; + char *object_expanded = NULL; rspamd_inet_addr_t *addr; g_assert (ctx != NULL); g_assert (stcf != NULL); - if (learn && ctx->write_servers == NULL) { - msg_err_task ("no write servers defined for %s, cannot learn", stcf->symbol); - return NULL; - } - if (learn) { - up = rspamd_upstream_get (ctx->write_servers, + ups = rspamd_redis_get_servers (ctx, "write_servers"); + + if (!ups) { + msg_err_task ("no write servers defined for %s, cannot learn", + stcf->symbol); + return NULL; + } + up = rspamd_upstream_get (ups, RSPAMD_UPSTREAM_MASTER_SLAVE, NULL, 0); } else { - up = rspamd_upstream_get (ctx->read_servers, + ups = rspamd_redis_get_servers (ctx, "read_servers"); + + if (!ups) { + msg_err_task ("no read servers defined for %s, cannot stat", + stcf->symbol); + return NULL; + } + up = rspamd_upstream_get (ups, RSPAMD_UPSTREAM_ROUND_ROBIN, NULL, 0); @@ -1513,15 +1511,22 @@ rspamd_redis_runtime (struct rspamd_task *task, return NULL; } + if (rspamd_redis_expand_object (ctx->redis_object, ctx, task, + &object_expanded) == 0) { + msg_err_task ("expansion for learning failed for symbol %s " + "(maybe learning per user classifier with no user or recipient)", + stcf->symbol); + return NULL; + } + rt = rspamd_mempool_alloc0 (task->task_pool, sizeof (*rt)); rspamd_mempool_add_destructor (task->task_pool, rspamd_gerror_free_maybe, &rt->err); - rspamd_redis_expand_object (ctx->redis_object, ctx, task, - &rt->redis_object_expanded); rt->selected = up; rt->task = task; rt->ctx = ctx; rt->stcf = stcf; + rt->redis_object_expanded = object_expanded; addr = rspamd_upstream_addr (up); g_assert (addr != NULL); @@ -1549,13 +1554,10 @@ void rspamd_redis_close (gpointer p) { struct redis_stat_ctx *ctx = REDIS_CTX (p); + lua_State *L = ctx->L; - if (ctx->read_servers) { - rspamd_upstreams_destroy (ctx->read_servers); - } - - if (ctx->write_servers) { - rspamd_upstreams_destroy (ctx->write_servers); + if (ctx->conf_ref) { + luaL_unref (L, LUA_REGISTRYINDEX, ctx->conf_ref); } g_free (ctx); @@ -1594,7 +1596,7 @@ rspamd_redis_process_tokens (struct rspamd_task *task, if (redisAsyncCommand (rt->redis, rspamd_redis_connected, rt, "HGET %s %s", rt->redis_object_expanded, learned_key) == REDIS_OK) { - rspamd_session_add_event (task->s, NULL, rspamd_redis_fin, rt, rspamd_redis_stat_quark ()); + rspamd_session_add_event (task->s, rspamd_redis_fin, rt, M); rt->has_event = TRUE; if (rspamd_event_pending (&rt->timeout_event, EV_TIMEOUT)) { @@ -1657,6 +1659,7 @@ rspamd_redis_learn_tokens (struct rspamd_task *task, GPtrArray *tokens, { struct redis_stat_runtime *rt = REDIS_RUNTIME (p); struct upstream *up; + struct upstream_list *ups; rspamd_inet_addr_t *addr; struct timeval tv; rspamd_fstring_t *query; @@ -1670,7 +1673,12 @@ rspamd_redis_learn_tokens (struct rspamd_task *task, GPtrArray *tokens, return FALSE; } - up = rspamd_upstream_get (rt->ctx->write_servers, + ups = rspamd_redis_get_servers (rt->ctx, "write_servers"); + + if (!ups) { + return FALSE; + } + up = rspamd_upstream_get (ups, RSPAMD_UPSTREAM_MASTER_SLAVE, NULL, 0); @@ -1798,7 +1806,7 @@ rspamd_redis_learn_tokens (struct rspamd_task *task, GPtrArray *tokens, "RSIG"); } - rspamd_session_add_event (task->s, NULL, rspamd_redis_fin_learn, rt, rspamd_redis_stat_quark ()); + rspamd_session_add_event (task->s, rspamd_redis_fin_learn, rt, M); rt->has_event = TRUE; /* Set timeout */ diff --git a/src/libstat/classifiers/bayes.c b/src/libstat/classifiers/bayes.c index 5b6b5a0fe..2b0cf21e8 100644 --- a/src/libstat/classifiers/bayes.c +++ b/src/libstat/classifiers/bayes.c @@ -38,7 +38,7 @@ G_STRFUNC, \ __VA_ARGS__) -INIT_LOG_MODULE(bayes) +INIT_LOG_MODULE_PUBLIC(bayes) static inline GQuark bayes_error_quark (void) @@ -80,6 +80,8 @@ inv_chi_square (struct rspamd_task *task, gdouble value, gint freedom_deg) sum = prob; + msg_debug_bayes ("m: %f, prob: %g", m, prob); + /* * m is our confidence in class * prob is e ^ x (small value since x is normally less than zero @@ -89,7 +91,7 @@ inv_chi_square (struct rspamd_task *task, gdouble value, gint freedom_deg) for (i = 1; i < freedom_deg; i++) { prob *= m / (gdouble)i; sum += prob; - msg_debug_bayes ("prob: %.6f, sum: %.6f", prob, sum); + msg_debug_bayes ("i=%d, prob: %g, sum: %g", i, prob, sum); } return MIN (1.0, sum); @@ -109,7 +111,7 @@ struct bayes_task_closure { * Mathematically we use pow(complexity, complexity), where complexity is the * window index */ -static const double feature_weight[] = { 0, 1, 4, 27, 256, 3125, 46656, 823543 }; +static const double feature_weight[] = { 0, 3125, 256, 27, 1, 0, 0, 0 }; #define PROB_COMBINE(prob, cnt, weight, assumed) (((weight) * (assumed) + (cnt) * (prob)) / ((weight) + (cnt))) /* @@ -121,12 +123,12 @@ bayes_classify_token (struct rspamd_classifier *ctx, { guint i; gint id; - guint64 spam_count = 0, ham_count = 0, total_count = 0; + guint spam_count = 0, ham_count = 0, total_count = 0; struct rspamd_statfile *st; struct rspamd_task *task; const gchar *token_type = "txt"; double spam_prob, spam_freq, ham_freq, bayes_spam_prob, bayes_ham_prob, - ham_prob, fw, w, norm_sum, norm_sub, val; + ham_prob, fw, w, val; task = cl->task; @@ -145,8 +147,8 @@ bayes_classify_token (struct rspamd_classifier *ctx, msg_debug_bayes ( "token(meta) %uL <%*s:%*s> probabilistically skipped", tok->data, - (int) tok->t1->len, tok->t1->begin, - (int) tok->t2->len, tok->t2->begin); + (int) tok->t1->original.len, tok->t1->original.begin, + (int) tok->t2->original.len, tok->t2->original.begin); } return; @@ -173,7 +175,7 @@ bayes_classify_token (struct rspamd_classifier *ctx, } /* Probability for this token */ - if (total_count > 0) { + if (total_count >= ctx->cfg->min_token_hits) { spam_freq = ((double)spam_count / MAX (1., (double) ctx->spam_learns)); ham_freq = ((double)ham_count / MAX (1., (double)ctx->ham_learns)); spam_prob = spam_freq / (spam_freq + ham_freq); @@ -187,19 +189,27 @@ bayes_classify_token (struct rspamd_classifier *ctx, G_N_ELEMENTS (feature_weight)]; } - norm_sum = (spam_freq + ham_freq) * (spam_freq + ham_freq); - norm_sub = (spam_freq - ham_freq) * (spam_freq - ham_freq); - w = (norm_sub) / (norm_sum) * - (fw * total_count) / (4.0 * (1.0 + fw * total_count)); + w = (fw * total_count) / (1.0 + fw * total_count); + bayes_spam_prob = PROB_COMBINE (spam_prob, total_count, w, 0.5); - norm_sub = (ham_freq - spam_freq) * (ham_freq - spam_freq); - w = (norm_sub) / (norm_sum) * - (fw * total_count) / (4.0 * (1.0 + fw * total_count)); + + if ((bayes_spam_prob > 0.5 && bayes_spam_prob < 0.5 + ctx->cfg->min_prob_strength) || + (bayes_spam_prob < 0.5 && bayes_spam_prob > 0.5 - ctx->cfg->min_prob_strength)) { + msg_debug_bayes ( + "token %uL <%*s:%*s> skipped, prob not in range: %f", + tok->data, + (int) tok->t1->stemmed.len, tok->t1->stemmed.begin, + (int) tok->t2->stemmed.len, tok->t2->stemmed.begin, + bayes_spam_prob); + + return; + } + bayes_ham_prob = PROB_COMBINE (ham_prob, total_count, w, 0.5); - cl->spam_prob += log2 (bayes_spam_prob); - cl->ham_prob += log2 (bayes_ham_prob); + cl->spam_prob += log (bayes_spam_prob); + cl->ham_prob += log (bayes_ham_prob); cl->processed_tokens ++; if (!(tok->flags & RSPAMD_STAT_TOKEN_FLAG_META)) { @@ -210,29 +220,31 @@ bayes_classify_token (struct rspamd_classifier *ctx, } if (tok->t1 && tok->t2) { - msg_debug_bayes ("token(%s) %uL <%*s:%*s>: weight: %f, total_count: %L, " - "spam_count: %L, ham_count: %L," + msg_debug_bayes ("token(%s) %uL <%*s:%*s>: weight: %f, cf: %f, " + "total_count: %ud, " + "spam_count: %ud, ham_count: %ud," "spam_prob: %.3f, ham_prob: %.3f, " "bayes_spam_prob: %.3f, bayes_ham_prob: %.3f, " "current spam prob: %.3f, current ham prob: %.3f", token_type, tok->data, - (int) tok->t1->len, tok->t1->begin, - (int) tok->t2->len, tok->t2->begin, - fw, total_count, spam_count, ham_count, + (int) tok->t1->stemmed.len, tok->t1->stemmed.begin, + (int) tok->t2->stemmed.len, tok->t2->stemmed.begin, + fw, w, total_count, spam_count, ham_count, spam_prob, ham_prob, bayes_spam_prob, bayes_ham_prob, cl->spam_prob, cl->ham_prob); } else { - msg_debug_bayes ("token(%s) %uL <?:?>: weight: %f, total_count: %L, " - "spam_count: %L, ham_count: %L," + msg_debug_bayes ("token(%s) %uL <?:?>: weight: %f, cf: %f, " + "total_count: %ud, " + "spam_count: %ud, ham_count: %ud," "spam_prob: %.3f, ham_prob: %.3f, " "bayes_spam_prob: %.3f, bayes_ham_prob: %.3f, " "current spam prob: %.3f, current ham prob: %.3f", token_type, tok->data, - fw, total_count, spam_count, ham_count, + fw, w, total_count, spam_count, ham_count, spam_prob, ham_prob, bayes_spam_prob, bayes_ham_prob, cl->spam_prob, cl->ham_prob); @@ -243,13 +255,20 @@ bayes_classify_token (struct rspamd_classifier *ctx, gboolean -bayes_init (rspamd_mempool_t *pool, struct rspamd_classifier *cl) +bayes_init (struct rspamd_config *cfg, + struct event_base *ev_base, + struct rspamd_classifier *cl) { cl->cfg->flags |= RSPAMD_FLAG_CLASSIFIER_INTEGER; return TRUE; } +void +bayes_fin (struct rspamd_classifier *cl) +{ +} + gboolean bayes_classify (struct rspamd_classifier * ctx, GPtrArray *tokens, @@ -318,14 +337,51 @@ bayes_classify (struct rspamd_classifier * ctx, bayes_classify_token (ctx, tok, &cl); } - h = 1 - inv_chi_square (task, cl.spam_prob, cl.processed_tokens); - s = 1 - inv_chi_square (task, cl.ham_prob, cl.processed_tokens); + if (cl.processed_tokens == 0) { + msg_info_bayes ("no tokens found in bayes database " + "(%ud total tokens, %ud text tokens), ignore stats", + tokens->len, text_tokens); + + return TRUE; + } + + if (ctx->cfg->min_tokens > 0 && + cl.text_tokens < (gint)(ctx->cfg->min_tokens * 0.1)) { + msg_info_bayes ("ignore bayes probability since we have " + "found too few text tokens: %uL (of %ud checked), " + "at least %d is required", + cl.text_tokens, + text_tokens, + (gint)(ctx->cfg->min_tokens * 0.1)); + + return TRUE; + } + + if (cl.spam_prob > -300 && cl.ham_prob > -300) { + /* Fisher value is low enough to apply inv_chi_square */ + h = 1 - inv_chi_square (task, cl.spam_prob, cl.processed_tokens); + s = 1 - inv_chi_square (task, cl.ham_prob, cl.processed_tokens); + } + else { + /* Use naive method */ + if (cl.spam_prob < cl.ham_prob) { + h = (1.0 - exp(cl.spam_prob - cl.ham_prob)) / + (1.0 + exp(cl.spam_prob - cl.ham_prob)); + s = 1.0 - h; + } + else { + s = (1.0 - exp(cl.ham_prob - cl.spam_prob)) / + (1.0 + exp(cl.ham_prob - cl.spam_prob)); + h = 1.0 - s; + } + } if (isfinite (s) && isfinite (h)) { final_prob = (s + 1.0 - h) / 2.; msg_debug_bayes ( "<%s> got ham prob %.2f -> %.2f and spam prob %.2f -> %.2f," - " %L tokens processed of %ud total tokens (%uL text tokens)", + " %L tokens processed of %ud total tokens;" + " %uL text tokens found of %ud text tokens)", task->message_id, cl.ham_prob, h, @@ -333,7 +389,8 @@ bayes_classify (struct rspamd_classifier * ctx, s, cl.processed_tokens, tokens->len, - cl.text_tokens); + cl.text_tokens, + text_tokens); } else { /* @@ -357,17 +414,6 @@ bayes_classify (struct rspamd_classifier * ctx, } } - if (ctx->cfg->min_tokens > 0 && - cl.text_tokens < (gint)(ctx->cfg->min_tokens * 0.1)) { - msg_info_bayes ("ignore bayes probability %.2f since we have " - "too few text tokens: %uL, at least %d is required", - final_prob, - cl.text_tokens, - (gint)(ctx->cfg->min_tokens * 0.1)); - - return TRUE; - } - pprob = rspamd_mempool_alloc (task->task_pool, sizeof (*pprob)); *pprob = final_prob; rspamd_mempool_set_variable (task->task_pool, "bayes_prob", pprob, NULL); @@ -399,6 +445,19 @@ bayes_classify (struct rspamd_classifier * ctx, (final_prob - 0.5) * 200.); final_prob = rspamd_normalize_probability (final_prob, 0.5); g_assert (st != NULL); + + if (final_prob > 1 || final_prob < 0) { + msg_err_bayes ("internal error: probability %f is outside of the " + "allowed range [0..1]", final_prob); + + if (final_prob > 1) { + final_prob = 1.0; + } + else { + final_prob = 0.0; + } + } + rspamd_task_insert_result (task, st->stcf->symbol, final_prob, @@ -483,8 +542,8 @@ bayes_learn_spam (struct rspamd_classifier * ctx, msg_debug_bayes ("token %uL <%*s:%*s>: window: %d, total_count: %d, " "spam_count: %d, ham_count: %d", tok->data, - (int) tok->t1->len, tok->t1->begin, - (int) tok->t2->len, tok->t2->begin, + (int) tok->t1->stemmed.len, tok->t1->stemmed.begin, + (int) tok->t2->stemmed.len, tok->t2->stemmed.begin, tok->window_idx, total_cnt, spam_cnt, ham_cnt); } else { diff --git a/src/libstat/classifiers/classifiers.h b/src/libstat/classifiers/classifiers.h index e30f2153a..fd6daf433 100644 --- a/src/libstat/classifiers/classifiers.h +++ b/src/libstat/classifiers/classifiers.h @@ -3,6 +3,7 @@ #include "config.h" #include "mem_pool.h" +#include <event.h> #define RSPAMD_DEFAULT_CLASSIFIER "bayes" /* Consider this value as 0 */ @@ -10,28 +11,32 @@ struct rspamd_classifier_config; struct rspamd_task; +struct rspamd_config; struct rspamd_classifier; struct token_node_s; struct rspamd_stat_classifier { char *name; - gboolean (*init_func)(rspamd_mempool_t *pool, - struct rspamd_classifier *cl); + gboolean (*init_func)(struct rspamd_config *cfg, + struct event_base *ev_base, + struct rspamd_classifier *cl); gboolean (*classify_func)(struct rspamd_classifier * ctx, - GPtrArray *tokens, - struct rspamd_task *task); + GPtrArray *tokens, + struct rspamd_task *task); gboolean (*learn_spam_func)(struct rspamd_classifier * ctx, - GPtrArray *input, - struct rspamd_task *task, - gboolean is_spam, - gboolean unlearn, - GError **err); + GPtrArray *input, + struct rspamd_task *task, + gboolean is_spam, + gboolean unlearn, + GError **err); + void (*fin_func)(struct rspamd_classifier *cl); }; /* Bayes algorithm */ -gboolean bayes_init (rspamd_mempool_t *pool, - struct rspamd_classifier *); +gboolean bayes_init (struct rspamd_config *cfg, + struct event_base *ev_base, + struct rspamd_classifier *); gboolean bayes_classify (struct rspamd_classifier *ctx, GPtrArray *tokens, struct rspamd_task *task); @@ -41,10 +46,12 @@ gboolean bayes_learn_spam (struct rspamd_classifier *ctx, gboolean is_spam, gboolean unlearn, GError **err); +void bayes_fin (struct rspamd_classifier *); /* Generic lua classifier */ -gboolean lua_classifier_init (rspamd_mempool_t *pool, - struct rspamd_classifier *); +gboolean lua_classifier_init (struct rspamd_config *cfg, + struct event_base *ev_base, + struct rspamd_classifier *); gboolean lua_classifier_classify (struct rspamd_classifier *ctx, GPtrArray *tokens, struct rspamd_task *task); @@ -55,6 +62,11 @@ gboolean lua_classifier_learn_spam (struct rspamd_classifier *ctx, gboolean unlearn, GError **err); +extern guint rspamd_bayes_log_id; +#define msg_debug_bayes(...) rspamd_conditional_debug_fast (NULL, task->from_addr, \ + rspamd_bayes_log_id, "bayes", task->task_pool->tag.uid, \ + G_STRFUNC, \ + __VA_ARGS__) #endif /* diff --git a/src/libstat/classifiers/lua_classifier.c b/src/libstat/classifiers/lua_classifier.c index 7b495b165..83ce7b0e1 100644 --- a/src/libstat/classifiers/lua_classifier.c +++ b/src/libstat/classifiers/lua_classifier.c @@ -47,8 +47,9 @@ static GHashTable *lua_classifiers = NULL; INIT_LOG_MODULE(luacl) gboolean -lua_classifier_init (rspamd_mempool_t *pool, - struct rspamd_classifier *cl) +lua_classifier_init (struct rspamd_config *cfg, + struct event_base *ev_base, + struct rspamd_classifier *cl) { struct rspamd_lua_classifier_ctx *ctx; lua_State *L = cl->ctx->cfg->lua_state; @@ -62,7 +63,7 @@ lua_classifier_init (rspamd_mempool_t *pool, ctx = g_hash_table_lookup (lua_classifiers, cl->subrs->name); if (ctx != NULL) { - msg_err_pool ("duplicate lua classifier definition: %s", + msg_err_config ("duplicate lua classifier definition: %s", cl->subrs->name); return FALSE; @@ -70,7 +71,7 @@ lua_classifier_init (rspamd_mempool_t *pool, lua_getglobal (L, "rspamd_classifiers"); if (lua_type (L, -1) != LUA_TTABLE) { - msg_err_pool ("cannot register classifier %s: no rspamd_classifier global", + msg_err_config ("cannot register classifier %s: no rspamd_classifier global", cl->subrs->name); lua_pop (L, 1); @@ -81,7 +82,7 @@ lua_classifier_init (rspamd_mempool_t *pool, lua_gettable (L, -2); if (lua_type (L, -1) != LUA_TTABLE) { - msg_err_pool ("cannot register classifier %s: bad lua type: %s", + msg_err_config ("cannot register classifier %s: bad lua type: %s", cl->subrs->name, lua_typename (L, lua_type (L, -1))); lua_pop (L, 2); @@ -92,7 +93,7 @@ lua_classifier_init (rspamd_mempool_t *pool, lua_gettable (L, -2); if (lua_type (L, -1) != LUA_TFUNCTION) { - msg_err_pool ("cannot register classifier %s: bad lua type for classify: %s", + msg_err_config ("cannot register classifier %s: bad lua type for classify: %s", cl->subrs->name, lua_typename (L, lua_type (L, -1))); lua_pop (L, 3); @@ -105,7 +106,7 @@ lua_classifier_init (rspamd_mempool_t *pool, lua_gettable (L, -2); if (lua_type (L, -1) != LUA_TFUNCTION) { - msg_err_pool ("cannot register classifier %s: bad lua type for learn: %s", + msg_err_config ("cannot register classifier %s: bad lua type for learn: %s", cl->subrs->name, lua_typename (L, lua_type (L, -1))); lua_pop (L, 3); diff --git a/src/libstat/learn_cache/redis_cache.c b/src/libstat/learn_cache/redis_cache.c index 11bc13aae..6a0aa1da7 100644 --- a/src/libstat/learn_cache/redis_cache.c +++ b/src/libstat/learn_cache/redis_cache.c @@ -22,20 +22,23 @@ #include "ucl.h" #include "hiredis.h" #include "adapters/libevent.h" +#include "lua/lua_common.h" #define REDIS_DEFAULT_TIMEOUT 0.5 #define REDIS_STAT_TIMEOUT 30 #define REDIS_DEFAULT_PORT 6379 #define DEFAULT_REDIS_KEY "learned_ids" +static const gchar *M = "redis learn cache"; + struct rspamd_redis_cache_ctx { + lua_State *L; struct rspamd_statfile_config *stcf; - struct upstream_list *read_servers; - struct upstream_list *write_servers; const gchar *password; const gchar *dbname; const gchar *redis_object; gdouble timeout; + gint conf_ref; }; struct rspamd_redis_cache_runtime { @@ -50,7 +53,23 @@ struct rspamd_redis_cache_runtime { static GQuark rspamd_stat_cache_redis_quark (void) { - return g_quark_from_static_string ("redis-statistics"); + return g_quark_from_static_string (M); +} + +static inline struct upstream_list * +rspamd_redis_get_servers (struct rspamd_redis_cache_ctx *ctx, + const gchar *what) +{ + lua_State *L = ctx->L; + struct upstream_list *res; + + lua_rawgeti (L, LUA_REGISTRYINDEX, ctx->conf_ref); + lua_pushstring (L, what); + lua_gettable (L, -2); + res = *((struct upstream_list**)lua_touserdata (L, -1)); + lua_settop (L, 0); + + return res; } static void @@ -208,94 +227,6 @@ rspamd_stat_cache_redis_generate_id (struct rspamd_task *task) rspamd_mempool_set_variable (task->task_pool, "words_hash", b32out, g_free); } -static gboolean -rspamd_redis_cache_try_ucl (struct rspamd_redis_cache_ctx *cache_ctx, - const ucl_object_t *obj, - struct rspamd_config *cfg, - const gchar *symbol) -{ - const ucl_object_t *elt, *relt; - - elt = ucl_object_lookup_any (obj, "read_servers", "servers", NULL); - - if (elt == NULL) { - return FALSE; - } - - cache_ctx->read_servers = rspamd_upstreams_create (cfg->ups_ctx); - if (!rspamd_upstreams_from_ucl (cache_ctx->read_servers, elt, - REDIS_DEFAULT_PORT, NULL)) { - msg_err ("statfile %s cannot get read servers configuration", - symbol); - return FALSE; - } - - relt = elt; - - elt = ucl_object_lookup (obj, "write_servers"); - if (elt == NULL) { - /* Use read servers as write ones */ - g_assert (relt != NULL); - cache_ctx->write_servers = rspamd_upstreams_create (cfg->ups_ctx); - if (!rspamd_upstreams_from_ucl (cache_ctx->write_servers, relt, - REDIS_DEFAULT_PORT, NULL)) { - msg_err ("statfile %s cannot get write servers configuration", - symbol); - return FALSE; - } - } - else { - cache_ctx->write_servers = rspamd_upstreams_create (cfg->ups_ctx); - if (!rspamd_upstreams_from_ucl (cache_ctx->write_servers, elt, - REDIS_DEFAULT_PORT, NULL)) { - msg_err ("statfile %s cannot get write servers configuration", - symbol); - rspamd_upstreams_destroy (cache_ctx->write_servers); - cache_ctx->write_servers = NULL; - } - } - - - elt = ucl_object_lookup (obj, "timeout"); - if (elt) { - cache_ctx->timeout = ucl_object_todouble (elt); - } - else { - cache_ctx->timeout = REDIS_DEFAULT_TIMEOUT; - } - - elt = ucl_object_lookup (obj, "password"); - if (elt) { - cache_ctx->password = ucl_object_tostring (elt); - } - else { - cache_ctx->password = NULL; - } - - elt = ucl_object_lookup_any (obj, "db", "database", "dbname", NULL); - if (elt) { - if (ucl_object_type (elt) == UCL_STRING) { - cache_ctx->dbname = ucl_object_tostring (elt); - } - else if (ucl_object_type (elt) == UCL_INT) { - cache_ctx->dbname = ucl_object_tostring_forced (elt); - } - } - else { - cache_ctx->dbname = NULL; - } - - elt = ucl_object_lookup_any (obj, "cache_key", "key", NULL); - if (elt == NULL || ucl_object_type (elt) != UCL_STRING) { - cache_ctx->redis_object = DEFAULT_REDIS_KEY; - } - else { - cache_ctx->redis_object = ucl_object_tostring (elt); - } - - return TRUE; -} - gpointer rspamd_stat_cache_redis_init (struct rspamd_stat_ctx *ctx, struct rspamd_config *cfg, @@ -306,24 +237,27 @@ rspamd_stat_cache_redis_init (struct rspamd_stat_ctx *ctx, struct rspamd_statfile_config *stf = st->stcf; const ucl_object_t *obj; gboolean ret = FALSE; + lua_State *L = (lua_State *)cfg->lua_state; + gint conf_ref = -1; cache_ctx = g_malloc0 (sizeof (*cache_ctx)); + cache_ctx->timeout = REDIS_DEFAULT_TIMEOUT; + cache_ctx->L = L; /* First search in backend configuration */ obj = ucl_object_lookup (st->classifier->cfg->opts, "backend"); if (obj != NULL && ucl_object_type (obj) == UCL_OBJECT) { - ret = rspamd_redis_cache_try_ucl (cache_ctx, obj, cfg, stf->symbol); + ret = rspamd_lua_try_load_redis (L, obj, cfg, &conf_ref); } /* Now try statfiles config */ - if (!ret) { - ret = rspamd_redis_cache_try_ucl (cache_ctx, stf->opts, cfg, stf->symbol); + if (!ret && stf->opts) { + ret = rspamd_lua_try_load_redis (L, stf->opts, cfg, &conf_ref); } /* Now try classifier config */ - if (!ret) { - ret = rspamd_redis_cache_try_ucl (cache_ctx, st->classifier->cfg->opts, cfg, - stf->symbol); + if (!ret && st->classifier->cfg->opts) { + ret = rspamd_lua_try_load_redis (L, st->classifier->cfg->opts, cfg, &conf_ref); } /* Now try global redis settings */ @@ -336,23 +270,61 @@ rspamd_stat_cache_redis_init (struct rspamd_stat_ctx *ctx, specific_obj = ucl_object_lookup (obj, "statistics"); if (specific_obj) { - ret = rspamd_redis_cache_try_ucl (cache_ctx, specific_obj, cfg, - stf->symbol); + ret = rspamd_lua_try_load_redis (L, + specific_obj, cfg, &conf_ref); } else { - ret = rspamd_redis_cache_try_ucl (cache_ctx, obj, cfg, - stf->symbol); + ret = rspamd_lua_try_load_redis (L, + obj, cfg, &conf_ref); } } } - if (!ret) { msg_err_config ("cannot init redis cache for %s", stf->symbol); g_free (cache_ctx); return NULL; } + obj = ucl_object_lookup (st->classifier->cfg->opts, "cache_key"); + + if (obj) { + cache_ctx->redis_object = ucl_object_tostring (obj); + } + else { + cache_ctx->redis_object = DEFAULT_REDIS_KEY; + } + + cache_ctx->conf_ref = conf_ref; + + /* Check some common table values */ + lua_rawgeti (L, LUA_REGISTRYINDEX, conf_ref); + + lua_pushstring (L, "timeout"); + lua_gettable (L, -2); + if (lua_type (L, -1) == LUA_TNUMBER) { + cache_ctx->timeout = lua_tonumber (L, -1); + } + lua_pop (L, 1); + + lua_pushstring (L, "db"); + lua_gettable (L, -2); + if (lua_type (L, -1) == LUA_TSTRING) { + cache_ctx->dbname = rspamd_mempool_strdup (cfg->cfg_pool, + lua_tostring (L, -1)); + } + lua_pop (L, 1); + + lua_pushstring (L, "password"); + lua_gettable (L, -2); + if (lua_type (L, -1) == LUA_TSTRING) { + cache_ctx->password = rspamd_mempool_strdup (cfg->cfg_pool, + lua_tostring (L, -1)); + } + lua_pop (L, 1); + + lua_settop (L, 0); + cache_ctx->stcf = stf; return (gpointer)cache_ctx; @@ -365,28 +337,39 @@ rspamd_stat_cache_redis_runtime (struct rspamd_task *task, struct rspamd_redis_cache_ctx *ctx = c; struct rspamd_redis_cache_runtime *rt; struct upstream *up; + struct upstream_list *ups; rspamd_inet_addr_t *addr; g_assert (ctx != NULL); - if (learn && ctx->write_servers == NULL) { - msg_err_task ("no write servers defined for %s, cannot learn", - ctx->stcf->symbol); - return NULL; - } - if (task->tokens == NULL || task->tokens->len == 0) { return NULL; } if (learn) { - up = rspamd_upstream_get (ctx->write_servers, + ups = rspamd_redis_get_servers (ctx, "write_servers"); + + if (!ups) { + msg_err_task ("no write servers defined for %s, cannot learn", + ctx->stcf->symbol); + return NULL; + } + + up = rspamd_upstream_get (ups, RSPAMD_UPSTREAM_MASTER_SLAVE, NULL, 0); } else { - up = rspamd_upstream_get (ctx->read_servers, + ups = rspamd_redis_get_servers (ctx, "read_servers"); + + if (!ups) { + msg_err_task ("no read servers defined for %s, cannot check", + ctx->stcf->symbol); + return NULL; + } + + up = rspamd_upstream_get (ups, RSPAMD_UPSTREAM_ROUND_ROBIN, NULL, 0); @@ -453,7 +436,10 @@ rspamd_stat_cache_redis_check (struct rspamd_task *task, if (redisAsyncCommand (rt->redis, rspamd_stat_cache_redis_get, rt, "HGET %s %s", rt->ctx->redis_object, h) == REDIS_OK) { - rspamd_session_add_event (task->s, NULL, rspamd_redis_cache_fin, rt, rspamd_stat_cache_redis_quark ()); + rspamd_session_add_event (task->s, + rspamd_redis_cache_fin, + rt, + M); event_add (&rt->timeout_event, &tv); rt->has_event = TRUE; } @@ -485,7 +471,8 @@ rspamd_stat_cache_redis_learn (struct rspamd_task *task, if (redisAsyncCommand (rt->redis, rspamd_stat_cache_redis_set, rt, "HSET %s %s %d", rt->ctx->redis_object, h, flag) == REDIS_OK) { - rspamd_session_add_event (task->s, NULL, rspamd_redis_cache_fin, rt, rspamd_stat_cache_redis_quark ()); + rspamd_session_add_event (task->s, + rspamd_redis_cache_fin, rt, M); event_add (&rt->timeout_event, &tv); rt->has_event = TRUE; } @@ -497,5 +484,14 @@ rspamd_stat_cache_redis_learn (struct rspamd_task *task, void rspamd_stat_cache_redis_close (gpointer c) { + struct rspamd_redis_cache_ctx *ctx = (struct rspamd_redis_cache_ctx *)c; + lua_State *L; + + L = ctx->L; + + if (ctx->conf_ref) { + luaL_unref (L, LUA_REGISTRYINDEX, ctx->conf_ref); + } + g_free (ctx); } diff --git a/src/libstat/learn_cache/sqlite3_cache.c b/src/libstat/learn_cache/sqlite3_cache.c index 255c835bb..52921326d 100644 --- a/src/libstat/learn_cache/sqlite3_cache.c +++ b/src/libstat/learn_cache/sqlite3_cache.c @@ -221,6 +221,8 @@ rspamd_stat_cache_sqlite3_check (struct rspamd_task *task, /* We have some existing record in the table */ if (!!flag == !!is_spam) { /* Already learned */ + msg_warn_task ("already seen stat hash: %*bs", + rspamd_cryptobox_HASHBYTES, out); return RSPAMD_LEARN_INGORE; } else { diff --git a/src/libstat/stat_api.h b/src/libstat/stat_api.h index 84db8ee01..9dcd6f8e8 100644 --- a/src/libstat/stat_api.h +++ b/src/libstat/stat_api.h @@ -26,16 +26,25 @@ * High level statistics API */ -#define RSPAMD_STAT_TOKEN_FLAG_TEXT (1 << 0) -#define RSPAMD_STAT_TOKEN_FLAG_META (1 << 1) -#define RSPAMD_STAT_TOKEN_FLAG_LUA_META (1 << 2) -#define RSPAMD_STAT_TOKEN_FLAG_EXCEPTION (1 << 3) -#define RSPAMD_STAT_TOKEN_FLAG_SUBJECT (1 << 4) -#define RSPAMD_STAT_TOKEN_FLAG_UNIGRAM (1 << 5) +#define RSPAMD_STAT_TOKEN_FLAG_TEXT (1u << 0) +#define RSPAMD_STAT_TOKEN_FLAG_META (1u << 1) +#define RSPAMD_STAT_TOKEN_FLAG_LUA_META (1u << 2) +#define RSPAMD_STAT_TOKEN_FLAG_EXCEPTION (1u << 3) +#define RSPAMD_STAT_TOKEN_FLAG_HEADER (1u << 4) +#define RSPAMD_STAT_TOKEN_FLAG_UNIGRAM (1u << 5) +#define RSPAMD_STAT_TOKEN_FLAG_UTF (1u << 6) +#define RSPAMD_STAT_TOKEN_FLAG_NORMALISED (1u << 7) +#define RSPAMD_STAT_TOKEN_FLAG_STEMMED (1u << 8) +#define RSPAMD_STAT_TOKEN_FLAG_BROKEN_UNICODE (1u << 9) +#define RSPAMD_STAT_TOKEN_FLAG_STOP_WORD (1u << 9) +#define RSPAMD_STAT_TOKEN_FLAG_SKIPPED (1u << 10) +#define RSPAMD_STAT_TOKEN_FLAG_INVISIBLE_SPACES (1u << 11) typedef struct rspamd_stat_token_s { - const gchar *begin; - gsize len; + rspamd_ftok_t original; /* utf8 raw */ + rspamd_ftok_unicode_t unicode; /* array of unicode characters, normalized, lowercased */ + rspamd_ftok_t normalized; /* normalized and lowercased utf8 */ + rspamd_ftok_t stemmed; /* stemmed utf8 */ guint flags; } rspamd_stat_token_t; diff --git a/src/libstat/stat_config.c b/src/libstat/stat_config.c index 9d1e57f13..101db4fe6 100644 --- a/src/libstat/stat_config.c +++ b/src/libstat/stat_config.c @@ -28,6 +28,7 @@ static struct rspamd_stat_classifier lua_classifier = { .init_func = lua_classifier_init, .classify_func = lua_classifier_classify, .learn_spam_func = lua_classifier_learn_spam, + .fin_func = NULL, }; static struct rspamd_stat_classifier stat_classifiers[] = { @@ -36,6 +37,7 @@ static struct rspamd_stat_classifier stat_classifiers[] = { .init_func = bayes_init, .classify_func = bayes_classify, .learn_spam_func = bayes_learn_spam, + .fin_func = bayes_fin, } }; @@ -162,6 +164,67 @@ rspamd_stat_init (struct rspamd_config *cfg, struct event_base *ev_base) stat_ctx->classifiers = g_ptr_array_new (); stat_ctx->async_elts = g_queue_new (); stat_ctx->ev_base = ev_base; + stat_ctx->lua_stat_tokens_ref = -1; + + /* Interact with lua_stat */ + if (luaL_dostring (L, "return require \"lua_stat\"") != 0) { + msg_err_config ("cannot require lua_stat: %s", + lua_tostring (L, -1)); + } + else { + if (lua_type (L, -1) != LUA_TTABLE) { + msg_err_config ("lua stat must return " + "table and not %s", + lua_typename (L, lua_type (L, -1))); + } + else { + lua_pushstring (L, "gen_stat_tokens"); + lua_gettable (L, -2); + + if (lua_type (L, -1) != LUA_TFUNCTION) { + msg_err_config ("gen_stat_tokens must return " + "function and not %s", + lua_typename (L, lua_type (L, -1))); + } + else { + /* Call this function to obtain closure */ + gint err_idx, ret; + GString *tb; + struct rspamd_config **pcfg; + + lua_pushcfunction (L, &rspamd_lua_traceback); + err_idx = lua_gettop (L); + lua_pushvalue (L, err_idx - 1); + + pcfg = lua_newuserdata (L, sizeof (*pcfg)); + *pcfg = cfg; + rspamd_lua_setclass (L, "rspamd{config}", -1); + + if ((ret = lua_pcall (L, 1, 1, err_idx)) != 0) { + tb = lua_touserdata (L, -1); + msg_err_config ("call to gen_stat_tokens lua " + "script failed (%d): %v", ret, tb); + + if (tb) { + g_string_free (tb, TRUE); + } + } + else { + if (lua_type (L, -1) != LUA_TFUNCTION) { + msg_err_config ("gen_stat_tokens invocation must return " + "function and not %s", + lua_typename (L, lua_type (L, -1))); + } + else { + stat_ctx->lua_stat_tokens_ref = luaL_ref (L, LUA_REGISTRYINDEX); + } + } + } + } + } + + /* Cleanup mess */ + lua_settop (L, 0); /* Create statfiles from the classifiers */ cur = cfg->classifiers; @@ -182,7 +245,7 @@ rspamd_stat_init (struct rspamd_config *cfg, struct event_base *ev_base) continue; } - if (!cl->subrs->init_func (cfg->cfg_pool, cl)) { + if (!cl->subrs->init_func (cfg, ev_base, cl)) { g_free (cl); msg_err_config ("cannot init classifier type %s", clf->name); cur = g_list_next (cur); @@ -328,6 +391,11 @@ rspamd_stat_close (void) } g_array_free (cl->statfiles_ids, TRUE); + + if (cl->subrs->fin_func) { + cl->subrs->fin_func (cl); + } + g_free (cl); } @@ -342,6 +410,12 @@ rspamd_stat_close (void) g_queue_free (stat_ctx->async_elts); g_ptr_array_free (st_ctx->statfiles, TRUE); g_ptr_array_free (st_ctx->classifiers, TRUE); + + if (st_ctx->lua_stat_tokens_ref != -1) { + luaL_unref (st_ctx->cfg->lua_state, LUA_REGISTRYINDEX, + st_ctx->lua_stat_tokens_ref); + } + g_free (st_ctx); /* Set global var to NULL */ @@ -475,11 +549,11 @@ rspamd_stat_ctx_register_async (rspamd_stat_async_handler handler, g_assert (st_ctx != NULL); elt = g_malloc0 (sizeof (*elt)); - REF_INIT_RETAIN (elt, rspamd_async_elt_dtor); elt->handler = handler; elt->cleanup = cleanup; elt->ud = d; elt->timeout = timeout; + REF_INIT_RETAIN (elt, rspamd_async_elt_dtor); /* Enabled by default */ diff --git a/src/libstat/stat_internal.h b/src/libstat/stat_internal.h index 44f48ae5a..a547ca93a 100644 --- a/src/libstat/stat_internal.h +++ b/src/libstat/stat_internal.h @@ -41,6 +41,7 @@ struct rspamd_classifier { gulong ham_learns; struct rspamd_classifier_config *cfg; struct rspamd_stat_classifier *subrs; + gpointer specific; }; struct rspamd_statfile { @@ -85,6 +86,9 @@ struct rspamd_stat_ctx { GPtrArray *classifiers; /* struct rspamd_classifier */ GQueue *async_elts; /* struct rspamd_stat_async_elt */ struct rspamd_config *cfg; + + gint lua_stat_tokens_ref; + /* Global tokenizer */ struct rspamd_stat_tokenizer *tokenizer; gpointer tkcf; diff --git a/src/libstat/stat_process.c b/src/libstat/stat_process.c index ca51d7b02..d097e12e0 100644 --- a/src/libstat/stat_process.c +++ b/src/libstat/stat_process.c @@ -32,273 +32,85 @@ static const gdouble similarity_treshold = 80.0; static void -rspamd_stat_tokenize_header (struct rspamd_task *task, - const gchar *name, const gchar *prefix, GArray *ar) -{ - struct rspamd_mime_header *cur; - GPtrArray *hdrs; - guint i; - rspamd_stat_token_t str; - - hdrs = g_hash_table_lookup (task->raw_headers, name); - str.flags = RSPAMD_STAT_TOKEN_FLAG_META; - - if (hdrs != NULL) { - - PTR_ARRAY_FOREACH (hdrs, i, cur) { - if (cur->name != NULL) { - str.begin = cur->name; - str.len = strlen (cur->name); - g_array_append_val (ar, str); - } - if (cur->decoded != NULL) { - str.begin = cur->decoded; - str.len = strlen (cur->decoded); - g_array_append_val (ar, str); - } - else if (cur->value != NULL) { - str.begin = cur->value; - str.len = strlen (cur->value); - g_array_append_val (ar, str); - } - } - - msg_debug_task ("added stat tokens for header '%s'", name); - } -} - -static void rspamd_stat_tokenize_parts_metadata (struct rspamd_stat_ctx *st_ctx, struct rspamd_task *task) { - struct rspamd_image *img; - struct rspamd_mime_part *part; - struct rspamd_mime_text_part *tp; - GList *cur; GArray *ar; rspamd_stat_token_t elt; guint i; - gchar tmpbuf[128]; lua_State *L = task->cfg->lua_state; - const gchar *headers_hash; - struct rspamd_mime_header *hdr; ar = g_array_sized_new (FALSE, FALSE, sizeof (elt), 16); + memset (&elt, 0, sizeof (elt)); elt.flags = RSPAMD_STAT_TOKEN_FLAG_META; - /* Insert images */ - for (i = 0; i < task->parts->len; i ++) { - part = g_ptr_array_index (task->parts, i); - - if ((part->flags & RSPAMD_MIME_PART_IMAGE) && part->specific.img) { - img = part->specific.img; - - /* If an image has a linked HTML part, then we push its details to the stat */ - if (img->html_image) { - elt.begin = (gchar *)"image"; - elt.len = 5; - g_array_append_val (ar, elt); - elt.begin = (gchar *)&img->html_image->height; - elt.len = sizeof (img->html_image->height); - g_array_append_val (ar, elt); - elt.begin = (gchar *)&img->html_image->width; - elt.len = sizeof (img->html_image->width); - g_array_append_val (ar, elt); - elt.begin = (gchar *)&img->type; - elt.len = sizeof (img->type); - g_array_append_val (ar, elt); - - if (img->filename) { - elt.begin = (gchar *)img->filename; - elt.len = strlen (elt.begin); - g_array_append_val (ar, elt); - } + if (st_ctx->lua_stat_tokens_ref != -1) { + gint err_idx, ret; + GString *tb; + struct rspamd_task **ptask; - msg_debug_task ("added stat tokens for image '%s'", img->html_image->src); - } - } - else if (part->cd && part->cd->filename.len > 0) { - elt.begin = (gchar *)part->cd->filename.begin; - elt.len = part->cd->filename.len; - g_array_append_val (ar, elt); - } - } + lua_pushcfunction (L, &rspamd_lua_traceback); + err_idx = lua_gettop (L); + lua_rawgeti (L, LUA_REGISTRYINDEX, st_ctx->lua_stat_tokens_ref); - /* Process mime parts */ - for (i = 0; i < task->parts->len; i ++) { - part = g_ptr_array_index (task->parts, i); + ptask = lua_newuserdata (L, sizeof (*ptask)); + *ptask = task; + rspamd_lua_setclass (L, "rspamd{task}", -1); - if (IS_CT_MULTIPART (part->ct)) { - elt.begin = (gchar *)part->ct->boundary.begin; - elt.len = part->ct->boundary.len; + if ((ret = lua_pcall (L, 1, 1, err_idx)) != 0) { + tb = lua_touserdata (L, -1); + msg_err_task ("call to stat_tokens lua " + "script failed (%d): %v", ret, tb); - if (elt.len) { - msg_debug_task ("added stat tokens for mime boundary '%*s'", - (gint)elt.len, elt.begin); - g_array_append_val (ar, elt); + if (tb) { + g_string_free (tb, TRUE); } - - if (part->parsed_data.len > 1) { - rspamd_snprintf (tmpbuf, sizeof (tmpbuf), "mime%d:%dlog", - i, (gint)log2 (part->parsed_data.len)); - elt.begin = rspamd_mempool_strdup (task->task_pool, tmpbuf); - elt.len = strlen (elt.begin); - g_array_append_val (ar, elt); - } - } - } - - /* Process text parts metadata */ - for (i = 0; i < task->text_parts->len; i ++) { - tp = g_ptr_array_index (task->text_parts, i); - - if (tp->language != NULL && tp->language[0] != '\0') { - elt.begin = (gchar *)tp->language; - elt.len = strlen (elt.begin); - msg_debug_task ("added stat tokens for part language '%s'", elt.begin); - g_array_append_val (ar, elt); - } - if (tp->real_charset != NULL) { - elt.begin = (gchar *)tp->real_charset; - elt.len = strlen (elt.begin); - msg_debug_task ("added stat tokens for part charset '%s'", elt.begin); - g_array_append_val (ar, elt); - } - } - - cur = g_list_first (task->cfg->classify_headers); - - while (cur) { - rspamd_stat_tokenize_header (task, cur->data, "UA:", ar); - - cur = g_list_next (cur); - } - - /* Use headers order */ - headers_hash = rspamd_mempool_get_variable (task->task_pool, - RSPAMD_MEMPOOL_HEADERS_HASH); - - if (headers_hash) { - elt.begin = (gchar *)headers_hash; - elt.len = 16; - g_array_append_val (ar, elt); - } - - /* Use more precise headers order */ - cur = g_list_first (task->headers_order->head); - while (cur) { - hdr = cur->data; - - if (hdr->name && hdr->type != RSPAMD_HEADER_RECEIVED) { - elt.begin = hdr->name; - elt.len = strlen (hdr->name); - g_array_append_val (ar, elt); } + else { + if (lua_type (L, -1) != LUA_TTABLE) { + msg_err_task ("stat_tokens invocation must return " + "table and not %s", + lua_typename (L, lua_type (L, -1))); + } + else { + guint vlen; + rspamd_ftok_t tok; - cur = g_list_next (cur); - } - - /* Use metatokens plugin from Lua */ - lua_getglobal (L, "rspamd_plugins"); - - if (lua_type (L, -1) == LUA_TTABLE) { - lua_pushstring (L, "stat_metatokens"); - lua_gettable (L, -2); - - if (lua_type (L, -1) == LUA_TTABLE) { - gint old_top; + vlen = rspamd_lua_table_size (L, -1); - old_top = lua_gettop (L); - lua_pushstring (L, "callback"); - lua_gettable (L, -2); + for (i = 0; i < vlen; i ++) { + lua_rawgeti (L, -1, i + 1); + tok.begin = lua_tolstring (L, -1, &tok.len); - if (lua_type (L, -1) == LUA_TFUNCTION) { - struct rspamd_task **ptask; + if (tok.begin && tok.len > 0) { + elt.original.begin = + rspamd_mempool_ftokdup (task->task_pool, &tok); + elt.original.len = tok.len; + elt.stemmed.begin = elt.original.begin; + elt.stemmed.len = elt.original.len; + elt.normalized.begin = elt.original.begin; + elt.normalized.len = elt.original.len; - ptask = lua_newuserdata (L, sizeof (*ptask)); - rspamd_lua_setclass (L, "rspamd{task}", -1); - *ptask = task; + g_array_append_val (ar, elt); + } - if (lua_pcall (L, 1, LUA_MULTRET, 0) != 0) { - msg_err_task ("stat_metatokens failed: %s", - lua_tostring (L, -1)); lua_pop (L, 1); - } else { - if (lua_gettop (L) > old_top && - lua_istable (L, old_top + 1)) { - lua_pushvalue (L, old_top + 1); - /* Iterate over table of tables */ - for (lua_pushnil (L); lua_next (L, -2); - lua_pop (L, 1)) { - elt.flags = RSPAMD_STAT_TOKEN_FLAG_META| - RSPAMD_STAT_TOKEN_FLAG_LUA_META; - - if (lua_isnumber (L, -1)) { - gdouble num = lua_tonumber (L, -1); - guint8 *pnum = rspamd_mempool_alloc ( - task->task_pool, - sizeof (num)); - - msg_debug_task ("got metatoken number: %.2f", - num); - memcpy (pnum, &num, sizeof (num)); - elt.begin = (gchar *) pnum; - elt.len = sizeof (num); - g_array_append_val (ar, elt); - } else if (lua_isstring (L, -1)) { - const gchar *str; - gsize tlen; - - str = lua_tolstring (L, -1, &tlen); - guint8 *pstr = rspamd_mempool_alloc ( - task->task_pool, - tlen); - memcpy (pstr, str, tlen); - - msg_debug_task ("got metatoken string: %*s", - (gint) tlen, str); - elt.begin = (gchar *) pstr; - elt.len = tlen; - g_array_append_val (ar, elt); - } - else if (lua_istable (L, -1)) { - /* Treat that as unigramms */ - for (lua_pushnil (L); lua_next (L, -2); - lua_pop (L, 1)) { - if (lua_isstring (L, -1)) { - const gchar *str; - gsize tlen; - - str = lua_tolstring (L, -1, &tlen); - guint8 *pstr = rspamd_mempool_alloc ( - task->task_pool, - tlen); - memcpy (pstr, str, tlen); - - msg_debug_task ("got unigramm " - "metatoken string: %*s", - (gint) tlen, str); - elt.begin = (gchar *) pstr; - elt.len = tlen; - elt.flags |= RSPAMD_STAT_TOKEN_FLAG_UNIGRAM; - g_array_append_val (ar, elt); - } - } - } - } - } } } } + + lua_settop (L, 0); } - lua_settop (L, 0); - st_ctx->tokenizer->tokenize_func (st_ctx, - task->task_pool, - ar, - TRUE, - "META:", - task->tokens); + + if (ar->len > 0) { + st_ctx->tokenizer->tokenize_func (st_ctx, + task, + ar, + TRUE, + "M", + task->tokens); + } rspamd_mempool_add_destructor (task->task_pool, rspamd_array_free_hard, ar); @@ -313,10 +125,7 @@ rspamd_stat_process_tokenize (struct rspamd_stat_ctx *st_ctx, { struct rspamd_mime_text_part *part; rspamd_cryptobox_hash_state_t hst; - rspamd_stat_token_t *tok; rspamd_token_t *st_tok; - GArray *words; - gchar *sub = NULL; guint i, reserved_len = 0; gdouble *pdiff; guchar hout[rspamd_cryptobox_HASHBYTES]; @@ -347,55 +156,26 @@ rspamd_stat_process_tokenize (struct rspamd_stat_ctx *st_ctx, part = g_ptr_array_index (task->text_parts, i); if (!IS_PART_EMPTY (part) && part->utf_words != NULL) { - st_ctx->tokenizer->tokenize_func (st_ctx, task->task_pool, + st_ctx->tokenizer->tokenize_func (st_ctx, task, part->utf_words, IS_PART_UTF (part), NULL, task->tokens); } if (pdiff != NULL && (1.0 - *pdiff) * 100.0 > similarity_treshold) { - msg_debug_task ("message has two common parts (%.2f), so skip the last one", + msg_debug_bayes ("message has two common parts (%.2f), so skip the last one", *pdiff); break; } } - if (task->subject != NULL) { - sub = task->subject; - } - - if (sub != NULL) { - UText utxt = UTEXT_INITIALIZER; - UErrorCode uc_err = U_ZERO_ERROR; - gsize slen = strlen (sub); - - utext_openUTF8 (&utxt, - sub, - slen, - &uc_err); - - words = rspamd_tokenize_text (sub, slen, &utxt, RSPAMD_TOKENIZE_UTF, - NULL, NULL, NULL); - - if (words != NULL) { - - for (i = 0; i < words->len; i ++) { - tok = &g_array_index (words, rspamd_stat_token_t, i); - tok->flags |= RSPAMD_STAT_TOKEN_FLAG_SUBJECT; - } - - st_ctx->tokenizer->tokenize_func (st_ctx, - task->task_pool, - words, - TRUE, - "SUBJECT", - task->tokens); - - rspamd_mempool_add_destructor (task->task_pool, - rspamd_array_free_hard, words); - } - - utext_close (&utxt); + if (task->meta_words != NULL) { + st_ctx->tokenizer->tokenize_func (st_ctx, + task, + task->meta_words, + TRUE, + "SUBJECT", + task->tokens); } rspamd_stat_tokenize_parts_metadata (st_ctx, task); @@ -445,10 +225,10 @@ rspamd_stat_preprocess (struct rspamd_stat_ctx *st_ctx, continue; } - if (!rspamd_symbols_cache_is_symbol_enabled (task, task->cfg->cache, + if (!rspamd_symcache_is_symbol_enabled (task, task->cfg->cache, st->stcf->symbol)) { g_ptr_array_index (task->stat_runtimes, i) = NULL; - msg_debug_task ("symbol %s is disabled, skip classification", + msg_debug_bayes ("symbol %s is disabled, skip classification", st->stcf->symbol); continue; } @@ -550,6 +330,12 @@ rspamd_stat_classifiers_process (struct rspamd_stat_ctx *st_ctx, return; } + for (i = 0; i < st_ctx->classifiers->len; i++) { + cl = g_ptr_array_index (st_ctx->classifiers, i); + cl->spam_learns = 0; + cl->ham_learns = 0; + } + for (i = 0; i < st_ctx->statfiles->len; i++) { st = g_ptr_array_index (st_ctx->statfiles, i); cl = st->classifier; @@ -591,7 +377,7 @@ rspamd_stat_classifiers_process (struct rspamd_stat_ctx *st_ctx, if (bk_run == NULL) { skip = TRUE; - msg_debug_task ("disable classifier %s as statfile symbol %s is disabled", + msg_debug_bayes ("disable classifier %s as statfile symbol %s is disabled", cl->cfg->name, st->stcf->symbol); break; } @@ -600,7 +386,7 @@ rspamd_stat_classifiers_process (struct rspamd_stat_ctx *st_ctx, if (!skip) { if (cl->cfg->min_tokens > 0 && task->tokens->len < cl->cfg->min_tokens) { - msg_debug_task ( + msg_debug_bayes ( "<%s> contains less tokens than required for %s classifier: " "%ud < %ud", task->message_id, @@ -610,7 +396,7 @@ rspamd_stat_classifiers_process (struct rspamd_stat_ctx *st_ctx, continue; } else if (cl->cfg->max_tokens > 0 && task->tokens->len > cl->cfg->max_tokens) { - msg_debug_task ( + msg_debug_bayes ( "<%s> contains more tokens than allowed for %s classifier: " "%ud > %ud", task->message_id, @@ -740,7 +526,7 @@ rspamd_stat_classifiers_learn (struct rspamd_stat_ctx *st_ctx, if ((task->flags & RSPAMD_TASK_FLAG_ALREADY_LEARNED) && err != NULL && *err == NULL) { /* Do not learn twice */ - g_set_error (err, rspamd_stat_quark (), 404, "<%s> has been already " + g_set_error (err, rspamd_stat_quark (), 208, "<%s> has been already " "learned as %s, ignore it", task->message_id, spam ? "spam" : "ham"); @@ -849,7 +635,7 @@ rspamd_stat_classifiers_learn (struct rspamd_stat_ctx *st_ctx, if (!learned && err && *err == NULL) { if (too_large) { - g_set_error (err, rspamd_stat_quark (), 400, + g_set_error (err, rspamd_stat_quark (), 204, "<%s> contains more tokens than allowed for %s classifier: " "%d > %d", task->message_id, @@ -858,7 +644,7 @@ rspamd_stat_classifiers_learn (struct rspamd_stat_ctx *st_ctx, cl->cfg->max_tokens); } else if (too_small) { - g_set_error (err, rspamd_stat_quark (), 400, + g_set_error (err, rspamd_stat_quark (), 204, "<%s> contains less tokens than required for %s classifier: " "%d < %d", task->message_id, @@ -867,7 +653,7 @@ rspamd_stat_classifiers_learn (struct rspamd_stat_ctx *st_ctx, cl->cfg->min_tokens); } else if (conditionally_skipped) { - g_set_error (err, rspamd_stat_quark (), 410, + g_set_error (err, rspamd_stat_quark (), 204, "<%s> is skipped for %s classifier: " "%s", task->message_id, @@ -1107,7 +893,7 @@ rspamd_stat_has_classifier_symbols (struct rspamd_task *task, if (rspamd_task_find_symbol_result (task, st->stcf->symbol)) { if (is_spam == !!st->stcf->is_spam) { - msg_debug_task ("do not autolearn %s as symbol %s is already " + msg_debug_bayes ("do not autolearn %s as symbol %s is already " "added", is_spam ? "spam" : "ham", st->stcf->symbol); return TRUE; diff --git a/src/libstat/tokenizers/osb.c b/src/libstat/tokenizers/osb.c index f6f46c580..0b53f8af9 100644 --- a/src/libstat/tokenizers/osb.c +++ b/src/libstat/tokenizers/osb.c @@ -17,8 +17,10 @@ * OSB tokenizer */ + #include "tokenizers.h" #include "stat_internal.h" +#include "libmime/lang_detection.h" /* Size for features pipe */ #define DEFAULT_FEATURE_WINDOW_SIZE 5 @@ -259,11 +261,11 @@ struct token_pipe_entry { gint rspamd_tokenizer_osb (struct rspamd_stat_ctx *ctx, - rspamd_mempool_t *pool, - GArray *words, - gboolean is_utf, - const gchar *prefix, - GPtrArray *result) + struct rspamd_task *task, + GArray *words, + gboolean is_utf, + const gchar *prefix, + GPtrArray *result) { rspamd_token_t *new_tok = NULL; rspamd_stat_token_t *token; @@ -302,23 +304,40 @@ rspamd_tokenizer_osb (struct rspamd_stat_ctx *ctx, for (w = 0; w < words->len; w ++) { token = &g_array_index (words, rspamd_stat_token_t, w); token_flags = token->flags; + const gchar *begin; + gsize len; + + if (token->flags & + (RSPAMD_STAT_TOKEN_FLAG_STOP_WORD|RSPAMD_STAT_TOKEN_FLAG_SKIPPED)) { + /* Skip stop/skipped words */ + continue; + } + + if (token->flags & RSPAMD_STAT_TOKEN_FLAG_TEXT) { + begin = token->stemmed.begin; + len = token->stemmed.len; + } + else { + begin = token->original.begin; + len = token->original.len; + } if (osb_cf->ht == RSPAMD_OSB_HASH_COMPAT) { rspamd_ftok_t ftok; - ftok.begin = token->begin; - ftok.len = token->len; + ftok.begin = begin; + ftok.len = len; cur = rspamd_fstrhash_lc (&ftok, is_utf); } else { /* We know that the words are normalized */ if (osb_cf->ht == RSPAMD_OSB_HASH_XXHASH) { cur = rspamd_cryptobox_fast_hash_specific (RSPAMD_CRYPTOBOX_XXHASH64, - token->begin, token->len, osb_cf->seed); + begin, len, osb_cf->seed); } else { - rspamd_cryptobox_siphash ((guchar *)&cur, token->begin, - token->len, osb_cf->sk); + rspamd_cryptobox_siphash ((guchar *)&cur, begin, + len, osb_cf->sk); if (prefix) { cur ^= seed; @@ -327,7 +346,7 @@ rspamd_tokenizer_osb (struct rspamd_stat_ctx *ctx, } if (token_flags & RSPAMD_STAT_TOKEN_FLAG_UNIGRAM) { - new_tok = rspamd_mempool_alloc0 (pool, token_size); + new_tok = rspamd_mempool_alloc0 (task->task_pool, token_size); new_tok->flags = token_flags; new_tok->t1 = token; new_tok->t2 = token; @@ -339,7 +358,7 @@ rspamd_tokenizer_osb (struct rspamd_stat_ctx *ctx, } #define ADD_TOKEN do {\ - new_tok = rspamd_mempool_alloc0 (pool, token_size); \ + new_tok = rspamd_mempool_alloc0 (task->task_pool, token_size); \ new_tok->flags = token_flags; \ new_tok->t1 = hashpipe[0].t; \ new_tok->t2 = hashpipe[i].t; \ @@ -354,7 +373,7 @@ rspamd_tokenizer_osb (struct rspamd_stat_ctx *ctx, else { \ new_tok->data = hashpipe[0].h * primes[0] + hashpipe[i].h * primes[i << 1]; \ } \ - new_tok->window_idx = i + 1; \ + new_tok->window_idx = i; \ g_ptr_array_add (result, new_tok); \ } while(0) @@ -375,7 +394,9 @@ rspamd_tokenizer_osb (struct rspamd_stat_ctx *ctx, processed++; for (i = 1; i < window_size; i++) { - ADD_TOKEN; + if (!(hashpipe[i].t->flags & RSPAMD_STAT_TOKEN_FLAG_EXCEPTION)) { + ADD_TOKEN; + } } } } diff --git a/src/libstat/tokenizers/tokenizers.c b/src/libstat/tokenizers/tokenizers.c index c8e8e44df..acbbcf2f0 100644 --- a/src/libstat/tokenizers/tokenizers.c +++ b/src/libstat/tokenizers/tokenizers.c @@ -20,11 +20,19 @@ #include "rspamd.h" #include "tokenizers.h" #include "stat_internal.h" -#include "../../../contrib/mumhash/mum.h" +#include "contrib/mumhash/mum.h" +#include "libmime/lang_detection.h" +#include "libstemmer.h" + #include <unicode/utf8.h> #include <unicode/uchar.h> #include <unicode/uiter.h> #include <unicode/ubrk.h> +#include <unicode/ucnv.h> +#if U_ICU_VERSION_MAJOR_NUM >= 44 +#include <unicode/unorm2.h> +#endif + #include <math.h> typedef gboolean (*token_get_function) (rspamd_stat_token_t * buf, gchar const **pos, @@ -80,33 +88,33 @@ rspamd_tokenizer_get_word_raw (rspamd_stat_token_t * buf, ex = (*exceptions)->data; } - if (token->begin == NULL || *cur == NULL) { + if (token->original.begin == NULL || *cur == NULL) { if (ex != NULL) { if (ex->pos == 0) { - token->begin = buf->begin + ex->len; - token->len = ex->len; + token->original.begin = buf->original.begin + ex->len; + token->original.len = ex->len; token->flags = RSPAMD_STAT_TOKEN_FLAG_EXCEPTION; } else { - token->begin = buf->begin; - token->len = 0; + token->original.begin = buf->original.begin; + token->original.len = 0; } } else { - token->begin = buf->begin; - token->len = 0; + token->original.begin = buf->original.begin; + token->original.len = 0; } - *cur = token->begin; + *cur = token->original.begin; } - token->len = 0; + token->original.len = 0; - pos = *cur - buf->begin; - if (pos >= buf->len) { + pos = *cur - buf->original.begin; + if (pos >= buf->original.len) { return FALSE; } - remain = buf->len - pos; + remain = buf->original.len - pos; p = *cur; /* Skip non delimiters symbols */ @@ -122,7 +130,7 @@ rspamd_tokenizer_get_word_raw (rspamd_stat_token_t * buf, remain--; } while (remain > 0 && t_delimiters[(guchar)*p]); - token->begin = p; + token->original.begin = p; while (remain > 0 && !t_delimiters[(guchar)*p]) { if (ex != NULL && ex->pos == pos) { @@ -130,7 +138,7 @@ rspamd_tokenizer_get_word_raw (rspamd_stat_token_t * buf, *cur = p + ex->len; return TRUE; } - token->len++; + token->original.len++; pos++; remain--; p++; @@ -141,7 +149,7 @@ rspamd_tokenizer_get_word_raw (rspamd_stat_token_t * buf, } if (rl) { - *rl = token->len; + *rl = token->original.len; } token->flags = RSPAMD_STAT_TOKEN_FLAG_TEXT; @@ -164,12 +172,12 @@ rspamd_tokenize_check_limit (gboolean decay, static const gdouble avg_word_len = 6.0; if (!decay) { - if (token->len >= sizeof (guint64)) { + if (token->original.len >= sizeof (guint64)) { #ifdef _MUM_UNALIGNED_ACCESS - *hv = mum_hash_step (*hv, *(guint64 *)token->begin); + *hv = mum_hash_step (*hv, *(guint64 *)token->original.begin); #else guint64 tmp; - memcpy (&tmp, token->begin, sizeof (tmp)); + memcpy (&tmp, token->original.begin, sizeof (tmp)); *hv = mum_hash_step (*hv, tmp); #endif } @@ -221,7 +229,7 @@ rspamd_utf_word_valid (const gchar *text, const gchar *end, U8_NEXT (text, start, finish, c); - if (u_isalnum (c)) { + if (u_isJavaIDPart (c)) { return TRUE; } @@ -237,13 +245,51 @@ rspamd_utf_word_valid (const gchar *text, const gchar *end, } \ } while(0) +static inline void +rspamd_tokenize_exception (struct rspamd_process_exception *ex, GArray *res) +{ + rspamd_stat_token_t token; + + memset (&token, 0, sizeof (token)); + + if (ex->type == RSPAMD_EXCEPTION_GENERIC) { + token.original.begin = "!!EX!!"; + token.original.len = sizeof ("!!EX!!") - 1; + token.flags = RSPAMD_STAT_TOKEN_FLAG_EXCEPTION; + + g_array_append_val (res, token); + token.flags = 0; + } + else if (ex->type == RSPAMD_EXCEPTION_URL) { + struct rspamd_url *uri; + + uri = ex->ptr; + + if (uri && uri->tldlen > 0) { + token.original.begin = uri->tld; + token.original.len = uri->tldlen; + + } + else { + token.original.begin = "!!EX!!"; + token.original.len = sizeof ("!!EX!!") - 1; + } + + token.flags = RSPAMD_STAT_TOKEN_FLAG_EXCEPTION; + g_array_append_val (res, token); + token.flags = 0; + } +} + + GArray * rspamd_tokenize_text (const gchar *text, gsize len, const UText *utxt, enum rspamd_tokenize_type how, struct rspamd_config *cfg, GList *exceptions, - guint64 *hash) + guint64 *hash, + GArray *cur_words) { rspamd_stat_token_t token, buf; const gchar *pos = NULL; @@ -257,15 +303,14 @@ rspamd_tokenize_text (const gchar *text, gsize len, static UBreakIterator* bi = NULL; if (text == NULL) { - return NULL; + return cur_words; } - buf.begin = text; - buf.len = len; + buf.original.begin = text; + buf.original.len = len; buf.flags = 0; - token.begin = NULL; - token.len = 0; - token.flags = 0; + + memset (&token, 0, sizeof (token)); if (cfg != NULL) { min_len = cfg->min_word_len; @@ -274,31 +319,36 @@ rspamd_tokenize_text (const gchar *text, gsize len, initial_size = word_decay * 2; } - res = g_array_sized_new (FALSE, FALSE, sizeof (rspamd_stat_token_t), - initial_size); + if (!cur_words) { + res = g_array_sized_new (FALSE, FALSE, sizeof (rspamd_stat_token_t), + initial_size); + } + else { + res = cur_words; + } if (G_UNLIKELY (how == RSPAMD_TOKENIZE_RAW || utxt == NULL)) { while (rspamd_tokenizer_get_word_raw (&buf, &pos, &token, &cur, &l, FALSE)) { if (l == 0 || (min_len > 0 && l < min_len) || (max_len > 0 && l > max_len)) { - token.begin = pos; + token.original.begin = pos; continue; } - if (token.len > 0 && + if (token.original.len > 0 && rspamd_tokenize_check_limit (decay, word_decay, res->len, &hv, &prob, &token, pos - text, len)) { if (!decay) { decay = TRUE; } else { - token.begin = pos; + token.original.begin = pos; continue; } } g_array_append_val (res, token); - token.begin = pos; + token.original.begin = pos; } } else { @@ -323,7 +373,7 @@ rspamd_tokenize_text (const gchar *text, gsize len, while (p != UBRK_DONE) { start_over: - token.len = 0; + token.original.len = 0; if (p > last) { if (ex && cur) { @@ -334,15 +384,7 @@ start_over: while (cur && ex->pos <= last) { /* We have an exception at the beginning, skip those */ last += ex->len; - - if (ex->type == RSPAMD_EXCEPTION_URL) { - token.begin = "!!EX!!"; - token.len = sizeof ("!!EX!!") - 1; - token.flags = RSPAMD_STAT_TOKEN_FLAG_EXCEPTION; - - g_array_append_val (res, token); - token.flags = 0; - } + rspamd_tokenize_exception (ex, res); if (last > p) { /* Exception spread over the boundaries */ @@ -363,8 +405,8 @@ start_over: /* Append the first part */ if (rspamd_utf_word_valid (text, text + len, last, ex->pos)) { - token.begin = text + last; - token.len = ex->pos - last; + token.original.begin = text + last; + token.original.len = ex->pos - last; token.flags = 0; g_array_append_val (res, token); } @@ -372,13 +414,7 @@ start_over: /* Process the current exception */ last += ex->len + (ex->pos - last); - if (ex->type == RSPAMD_EXCEPTION_URL) { - token.begin = "!!EX!!"; - token.len = sizeof ("!!EX!!") - 1; - token.flags = RSPAMD_STAT_TOKEN_FLAG_EXCEPTION; - - g_array_append_val (res, token); - } + rspamd_tokenize_exception (ex, res); if (last > p) { /* Exception spread over the boundaries */ @@ -394,9 +430,10 @@ start_over: } else if (p > last) { if (rspamd_utf_word_valid (text, text + len, last, p)) { - token.begin = text + last; - token.len = p - last; - token.flags = RSPAMD_STAT_TOKEN_FLAG_TEXT; + token.original.begin = text + last; + token.original.len = p - last; + token.flags = RSPAMD_STAT_TOKEN_FLAG_TEXT | + RSPAMD_STAT_TOKEN_FLAG_UTF; } } } @@ -408,40 +445,43 @@ start_over: } if (rspamd_utf_word_valid (text, text + len, last, p)) { - token.begin = text + last; - token.len = p - last; - token.flags = RSPAMD_STAT_TOKEN_FLAG_TEXT; + token.original.begin = text + last; + token.original.len = p - last; + token.flags = RSPAMD_STAT_TOKEN_FLAG_TEXT | + RSPAMD_STAT_TOKEN_FLAG_UTF; } } else { /* No exceptions within boundary */ if (rspamd_utf_word_valid (text, text + len, last, p)) { - token.begin = text + last; - token.len = p - last; - token.flags = RSPAMD_STAT_TOKEN_FLAG_TEXT; + token.original.begin = text + last; + token.original.len = p - last; + token.flags = RSPAMD_STAT_TOKEN_FLAG_TEXT | + RSPAMD_STAT_TOKEN_FLAG_UTF; } } } else { if (rspamd_utf_word_valid (text, text + len, last, p)) { - token.begin = text + last; - token.len = p - last; - token.flags = RSPAMD_STAT_TOKEN_FLAG_TEXT; + token.original.begin = text + last; + token.original.len = p - last; + token.flags = RSPAMD_STAT_TOKEN_FLAG_TEXT | + RSPAMD_STAT_TOKEN_FLAG_UTF; } } - if (token.len > 0 && + if (token.original.len > 0 && rspamd_tokenize_check_limit (decay, word_decay, res->len, &hv, &prob, &token, p, len)) { if (!decay) { decay = TRUE; } else { - token.len = 0; + token.flags |= RSPAMD_STAT_TOKEN_FLAG_SKIPPED; } } } - if (token.len > 0) { + if (token.original.len > 0) { g_array_append_val (res, token); } @@ -463,6 +503,347 @@ start_over: #undef SHIFT_EX -/* - * vi:ts=4 - */ +static void +rspamd_add_metawords_from_str (const gchar *beg, gsize len, + struct rspamd_task *task) +{ + UText utxt = UTEXT_INITIALIZER; + UErrorCode uc_err = U_ZERO_ERROR; + guint i = 0; + UChar32 uc; + gboolean valid_utf = TRUE; + + while (i < len) { + U8_NEXT (beg, i, len, uc); + + if (((gint32) uc) < 0) { + valid_utf = FALSE; + break; + } + +#if U_ICU_VERSION_MAJOR_NUM < 50 + if (u_isalpha (uc)) { + gint32 sc = ublock_getCode (uc); + + if (sc == UBLOCK_THAI) { + valid_utf = FALSE; + msg_info_task ("enable workaround for Thai characters for old libicu"); + break; + } + } +#endif + } + + if (valid_utf) { + utext_openUTF8 (&utxt, + beg, + len, + &uc_err); + + task->meta_words = rspamd_tokenize_text (beg, len, + &utxt, RSPAMD_TOKENIZE_UTF, + task->cfg, NULL, NULL, task->meta_words); + + utext_close (&utxt); + } + else { + task->meta_words = rspamd_tokenize_text (beg, len, + NULL, RSPAMD_TOKENIZE_RAW, + task->cfg, NULL, NULL, task->meta_words); + } +} + +void +rspamd_tokenize_meta_words (struct rspamd_task *task) +{ + guint i = 0; + rspamd_stat_token_t *tok; + + if (task->subject) { + rspamd_add_metawords_from_str (task->subject, strlen (task->subject), task); + } + + if (task->from_mime && task->from_mime->len > 0) { + struct rspamd_email_address *addr; + + addr = g_ptr_array_index (task->from_mime, 0); + + if (addr->name) { + rspamd_add_metawords_from_str (addr->name, strlen (addr->name), task); + } + } + + if (task->meta_words != NULL) { + const gchar *language = NULL; + + if (task->text_parts && task->text_parts->len > 0) { + struct rspamd_mime_text_part *tp = g_ptr_array_index (task->text_parts, 0); + + if (tp->language) { + language = tp->language; + } + } + + rspamd_normalize_words (task->meta_words, task->task_pool); + rspamd_stem_words (task->meta_words, task->task_pool, language, + task->lang_det); + + for (i = 0; i < task->meta_words->len; i++) { + tok = &g_array_index (task->meta_words, rspamd_stat_token_t, i); + tok->flags |= RSPAMD_STAT_TOKEN_FLAG_HEADER; + } + } +} + +static inline void +rspamd_uchars_to_ucs32 (const UChar *src, gsize srclen, + rspamd_stat_token_t *tok, + rspamd_mempool_t *pool) +{ + UChar32 *dest, t, *d; + gint32 i = 0; + + dest = rspamd_mempool_alloc (pool, srclen * sizeof (UChar32)); + d = dest; + + while (i < srclen) { + U16_NEXT_UNSAFE (src, i, t); + + if (u_isgraph (t)) { + *d++ = u_tolower (t); + } + else { + /* Invisible spaces ! */ + tok->flags |= RSPAMD_STAT_TOKEN_FLAG_INVISIBLE_SPACES; + } + } + + tok->unicode.begin = dest; + tok->unicode.len = d - dest; +} + +static inline void +rspamd_ucs32_to_normalised (rspamd_stat_token_t *tok, + rspamd_mempool_t *pool) +{ + guint i, doff = 0; + gsize utflen = 0; + gchar *dest; + UChar32 t; + + for (i = 0; i < tok->unicode.len; i ++) { + utflen += U8_LENGTH (tok->unicode.begin[i]); + } + + dest = rspamd_mempool_alloc (pool, utflen + 1); + + for (i = 0; i < tok->unicode.len; i ++) { + t = tok->unicode.begin[i]; + U8_APPEND_UNSAFE (dest, doff, t); + } + + g_assert (doff <= utflen); + dest[doff] = '\0'; + + tok->normalized.len = doff; + tok->normalized.begin = dest; +} + +void +rspamd_normalize_single_word (rspamd_stat_token_t *tok, rspamd_mempool_t *pool) +{ + UErrorCode uc_err = U_ZERO_ERROR; + UConverter *utf8_converter; + UChar tmpbuf[1024]; /* Assume that we have no longer words... */ + gsize ulen; + + utf8_converter = rspamd_get_utf8_converter (); + + if (tok->flags & RSPAMD_STAT_TOKEN_FLAG_UTF) { + ulen = ucnv_toUChars (utf8_converter, + tmpbuf, + G_N_ELEMENTS (tmpbuf), + tok->original.begin, + tok->original.len, + &uc_err); + + /* Now, we need to understand if we need to normalise the word */ + if (!U_SUCCESS (uc_err)) { + tok->flags |= RSPAMD_STAT_TOKEN_FLAG_BROKEN_UNICODE; + tok->unicode.begin = NULL; + tok->unicode.len = 0; + tok->normalized.begin = NULL; + tok->normalized.len = 0; + } + else { +#if U_ICU_VERSION_MAJOR_NUM >= 44 + const UNormalizer2 *norm = rspamd_get_unicode_normalizer (); + gint32 end; + + /* We can now check if we need to decompose */ + end = unorm2_spanQuickCheckYes (norm, tmpbuf, ulen, &uc_err); + + if (!U_SUCCESS (uc_err)) { + rspamd_uchars_to_ucs32 (tmpbuf, ulen, tok, pool); + tok->normalized.begin = NULL; + tok->normalized.len = 0; + tok->flags |= RSPAMD_STAT_TOKEN_FLAG_BROKEN_UNICODE; + } + else { + if (end == ulen) { + /* Already normalised, just lowercase */ + rspamd_uchars_to_ucs32 (tmpbuf, ulen, tok, pool); + rspamd_ucs32_to_normalised (tok, pool); + } + else { + /* Perform normalization */ + UChar normbuf[1024]; + + g_assert (end < G_N_ELEMENTS (normbuf)); + /* First part */ + memcpy (normbuf, tmpbuf, end * sizeof (UChar)); + /* Second part */ + ulen = unorm2_normalizeSecondAndAppend (norm, + normbuf, end, + G_N_ELEMENTS (normbuf), + tmpbuf + end, + ulen - end, + &uc_err); + + if (!U_SUCCESS (uc_err)) { + if (uc_err != U_BUFFER_OVERFLOW_ERROR) { + msg_warn_pool_check ("cannot normalise text '%*s': %s", + (gint)tok->original.len, tok->original.begin, + u_errorName (uc_err)); + rspamd_uchars_to_ucs32 (tmpbuf, ulen, tok, pool); + rspamd_ucs32_to_normalised (tok, pool); + tok->flags |= RSPAMD_STAT_TOKEN_FLAG_BROKEN_UNICODE; + } + } + else { + /* Copy normalised back */ + rspamd_uchars_to_ucs32 (normbuf, ulen, tok, pool); + tok->flags |= RSPAMD_STAT_TOKEN_FLAG_NORMALISED; + rspamd_ucs32_to_normalised (tok, pool); + } + } + } +#else + /* Legacy version with no unorm2 interface */ + rspamd_uchars_to_ucs32 (tmpbuf, ulen, tok, pool); + rspamd_ucs32_to_normalised (tok, pool); +#endif + } + } + else { + if (tok->flags & RSPAMD_STAT_TOKEN_FLAG_TEXT) { + /* Simple lowercase */ + gchar *dest; + + dest = rspamd_mempool_alloc (pool, tok->original.len + 1); + rspamd_strlcpy (dest, tok->original.begin, tok->original.len + 1); + rspamd_str_lc (dest, tok->original.len); + tok->normalized.len = tok->original.len; + tok->normalized.begin = dest; + } + } +} + +void +rspamd_normalize_words (GArray *words, rspamd_mempool_t *pool) +{ + rspamd_stat_token_t *tok; + guint i; + + for (i = 0; i < words->len; i++) { + tok = &g_array_index (words, rspamd_stat_token_t, i); + rspamd_normalize_single_word (tok, pool); + } +} + +void +rspamd_stem_words (GArray *words, rspamd_mempool_t *pool, + const gchar *language, + struct rspamd_lang_detector *d) +{ + static GHashTable *stemmers = NULL; + struct sb_stemmer *stem = NULL; + guint i; + rspamd_stat_token_t *tok; + gchar *dest; + gsize dlen; + + if (!stemmers) { + stemmers = g_hash_table_new (rspamd_strcase_hash, + rspamd_strcase_equal); + } + + if (language && language[0] != '\0') { + stem = g_hash_table_lookup (stemmers, language); + + if (stem == NULL) { + + stem = sb_stemmer_new (language, "UTF_8"); + + if (stem == NULL) { + msg_debug_pool ( + "<%s> cannot create lemmatizer for %s language", + language); + g_hash_table_insert (stemmers, g_strdup (language), + GINT_TO_POINTER (-1)); + } + else { + g_hash_table_insert (stemmers, g_strdup (language), + stem); + } + } + else if (stem == GINT_TO_POINTER (-1)) { + /* Negative cache */ + stem = NULL; + } + } + for (i = 0; i < words->len; i++) { + tok = &g_array_index (words, rspamd_stat_token_t, i); + + if (tok->flags & RSPAMD_STAT_TOKEN_FLAG_UTF) { + if (stem) { + const gchar *stemmed = NULL; + + stemmed = sb_stemmer_stem (stem, + tok->normalized.begin, tok->normalized.len); + + dlen = stemmed ? strlen (stemmed) : 0; + + if (dlen > 0) { + dest = rspamd_mempool_alloc (pool, dlen + 1); + memcpy (dest, stemmed, dlen); + dest[dlen] = '\0'; + tok->stemmed.len = dlen; + tok->stemmed.begin = dest; + tok->flags |= RSPAMD_STAT_TOKEN_FLAG_STEMMED; + } + else { + /* Fallback */ + tok->stemmed.len = tok->normalized.len; + tok->stemmed.begin = tok->normalized.begin; + } + } + else { + tok->stemmed.len = tok->normalized.len; + tok->stemmed.begin = tok->normalized.begin; + } + + if (tok->stemmed.len > 0 && d != NULL && + rspamd_language_detector_is_stop_word (d, tok->stemmed.begin, tok->stemmed.len)) { + tok->flags |= RSPAMD_STAT_TOKEN_FLAG_STOP_WORD; + } + } + else { + if (tok->flags & RSPAMD_STAT_TOKEN_FLAG_TEXT) { + /* Raw text, lowercase */ + tok->stemmed.len = tok->normalized.len; + tok->stemmed.begin = tok->normalized.begin; + } + } + } +}
\ No newline at end of file diff --git a/src/libstat/tokenizers/tokenizers.h b/src/libstat/tokenizers/tokenizers.h index 6c538eafc..784426d31 100644 --- a/src/libstat/tokenizers/tokenizers.h +++ b/src/libstat/tokenizers/tokenizers.h @@ -18,13 +18,13 @@ struct rspamd_stat_ctx; struct rspamd_stat_tokenizer { gchar *name; gpointer (*get_config) (rspamd_mempool_t *pool, - struct rspamd_tokenizer_config *cf, gsize *len); + struct rspamd_tokenizer_config *cf, gsize *len); gint (*tokenize_func)(struct rspamd_stat_ctx *ctx, - rspamd_mempool_t *pool, - GArray *words, - gboolean is_utf, - const gchar *prefix, - GPtrArray *result); + struct rspamd_task *task, + GArray *words, + gboolean is_utf, + const gchar *prefix, + GPtrArray *result); }; enum rspamd_tokenize_type { @@ -43,20 +43,29 @@ GArray * rspamd_tokenize_text (const gchar *text, gsize len, enum rspamd_tokenize_type how, struct rspamd_config *cfg, GList *exceptions, - guint64 *hash); + guint64 *hash, + GArray *cur_words); /* OSB tokenize function */ gint rspamd_tokenizer_osb (struct rspamd_stat_ctx *ctx, - rspamd_mempool_t *pool, - GArray *words, - gboolean is_utf, - const gchar *prefix, - GPtrArray *result); + struct rspamd_task *task, + GArray *words, + gboolean is_utf, + const gchar *prefix, + GPtrArray *result); gpointer rspamd_tokenizer_osb_get_config (rspamd_mempool_t *pool, - struct rspamd_tokenizer_config *cf, - gsize *len); + struct rspamd_tokenizer_config *cf, + gsize *len); +struct rspamd_lang_detector; +void rspamd_normalize_single_word (rspamd_stat_token_t *tok, rspamd_mempool_t *pool); +void rspamd_normalize_words (GArray *words, rspamd_mempool_t *pool); +void rspamd_stem_words (GArray *words, rspamd_mempool_t *pool, + const gchar *language, + struct rspamd_lang_detector *d); + +void rspamd_tokenize_meta_words (struct rspamd_task *task); #endif /* * vi:ts=4 |