aboutsummaryrefslogtreecommitdiffstats
path: root/src/libstat
diff options
context:
space:
mode:
Diffstat (limited to 'src/libstat')
-rw-r--r--src/libstat/backends/redis_backend.c246
-rw-r--r--src/libstat/classifiers/bayes.c145
-rw-r--r--src/libstat/classifiers/classifiers.h38
-rw-r--r--src/libstat/classifiers/lua_classifier.c15
-rw-r--r--src/libstat/learn_cache/redis_cache.c220
-rw-r--r--src/libstat/learn_cache/sqlite3_cache.c2
-rw-r--r--src/libstat/stat_api.h25
-rw-r--r--src/libstat/stat_config.c78
-rw-r--r--src/libstat/stat_internal.h4
-rw-r--r--src/libstat/stat_process.c366
-rw-r--r--src/libstat/tokenizers/osb.c49
-rw-r--r--src/libstat/tokenizers/tokenizers.c521
-rw-r--r--src/libstat/tokenizers/tokenizers.h37
13 files changed, 1054 insertions, 692 deletions
diff --git a/src/libstat/backends/redis_backend.c b/src/libstat/backends/redis_backend.c
index 74d8c3bf1..b003d5a27 100644
--- a/src/libstat/backends/redis_backend.c
+++ b/src/libstat/backends/redis_backend.c
@@ -42,9 +42,9 @@ INIT_LOG_MODULE(stat_redis)
#define REDIS_STAT_TIMEOUT 30
struct redis_stat_ctx {
+ lua_State *L;
struct rspamd_statfile_config *stcf;
- struct upstream_list *read_servers;
- struct upstream_list *write_servers;
+ gint conf_ref;
struct rspamd_stat_async_elt *stat_elt;
const gchar *redis_object;
const gchar *password;
@@ -104,12 +104,29 @@ struct rspamd_redis_stat_cbdata {
#define GET_TASK_ELT(task, elt) (task == NULL ? NULL : (task)->elt)
+static const gchar *M = "redis statistics";
+
static GQuark
rspamd_redis_stat_quark (void)
{
- return g_quark_from_static_string ("redis-statistics");
+ return g_quark_from_static_string (M);
}
+static inline struct upstream_list *
+rspamd_redis_get_servers (struct redis_stat_ctx *ctx,
+ const gchar *what)
+{
+ lua_State *L = ctx->L;
+ struct upstream_list *res;
+
+ lua_rawgeti (L, LUA_REGISTRYINDEX, ctx->conf_ref);
+ lua_pushstring (L, what);
+ lua_gettable (L, -2);
+ res = *((struct upstream_list**)lua_touserdata (L, -1));
+ lua_settop (L, 0);
+
+ return res;
+}
/*
* Non-static for lua unit testing
@@ -134,6 +151,7 @@ rspamd_redis_expand_object (const gchar *pattern,
GString *tb;
const gchar *rcpt = NULL;
gint err_idx;
+ gboolean expansion_errored = FALSE;
g_assert (ctx != NULL);
stcf = ctx->stcf;
@@ -196,6 +214,9 @@ rspamd_redis_expand_object (const gchar *pattern,
if (elt) {
tlen += strlen (elt);
}
+ else {
+ expansion_errored = TRUE;
+ }
break;
case 'r':
@@ -209,11 +230,15 @@ rspamd_redis_expand_object (const gchar *pattern,
if (elt) {
tlen += strlen (elt);
}
+ else {
+ expansion_errored = TRUE;
+ }
break;
case 'l':
if (stcf->label) {
tlen += strlen (stcf->label);
}
+ /* Label miss is OK */
break;
case 's':
if (ctx->new_schema) {
@@ -251,7 +276,8 @@ rspamd_redis_expand_object (const gchar *pattern,
}
}
- if (target == NULL || task == NULL) {
+
+ if (target == NULL || task == NULL || expansion_errored) {
return tlen;
}
@@ -501,14 +527,14 @@ rspamd_redis_tokens_to_query (struct rspamd_task *task,
"HSET %b_tokens %b %b:%b",
prefix, (size_t) prefix_len,
n0, (size_t) l0,
- tok->t1->begin, tok->t1->len,
- tok->t2->begin, tok->t2->len);
+ tok->t1->stemmed.begin, tok->t1->stemmed.len,
+ tok->t2->stemmed.begin, tok->t2->stemmed.len);
} else if (tok->t1) {
redisAsyncCommand (rt->redis, NULL, NULL,
"HSET %b_tokens %b %b",
prefix, (size_t) prefix_len,
n0, (size_t) l0,
- tok->t1->begin, tok->t1->len);
+ tok->t1->stemmed.begin, tok->t1->stemmed.len);
}
}
else {
@@ -522,14 +548,14 @@ rspamd_redis_tokens_to_query (struct rspamd_task *task,
"HSET %b %s %b:%b",
n0, (size_t) l0,
"tokens",
- tok->t1->begin, tok->t1->len,
- tok->t2->begin, tok->t2->len);
+ tok->t1->stemmed.begin, tok->t1->stemmed.len,
+ tok->t2->stemmed.begin, tok->t2->stemmed.len);
} else if (tok->t1) {
redisAsyncCommand (rt->redis, NULL, NULL,
"HSET %b %s %b",
n0, (size_t) l0,
"tokens",
- tok->t1->begin, tok->t1->len);
+ tok->t1->stemmed.begin, tok->t1->stemmed.len);
}
}
@@ -928,6 +954,7 @@ rspamd_redis_async_stat_cb (struct rspamd_stat_async_elt *elt, gpointer d)
struct rspamd_redis_stat_elt *redis_elt = elt->ud;
struct rspamd_redis_stat_cbdata *cbdata;
rspamd_inet_addr_t *addr;
+ struct upstream_list *ups;
g_assert (redis_elt != NULL);
@@ -941,8 +968,15 @@ rspamd_redis_async_stat_cb (struct rspamd_stat_async_elt *elt, gpointer d)
/* Disable further events unless needed */
elt->enabled = FALSE;
+ ups = rspamd_redis_get_servers (ctx, "read_servers");
+
+ if (!ups) {
+ return;
+ }
+
cbdata = g_malloc0 (sizeof (*cbdata));
- cbdata->selected = rspamd_upstream_get (ctx->read_servers,
+
+ cbdata->selected = rspamd_upstream_get (ups,
RSPAMD_UPSTREAM_ROUND_ROBIN,
NULL,
0);
@@ -1231,78 +1265,6 @@ rspamd_redis_learned (redisAsyncContext *c, gpointer r, gpointer priv)
rspamd_session_remove_event (task->s, rspamd_redis_fin_learn, rt);
}
}
-
-static gboolean
-rspamd_redis_try_ucl (struct redis_stat_ctx *backend,
- const ucl_object_t *obj,
- struct rspamd_config *cfg,
- const gchar *symbol)
-{
- const ucl_object_t *elt, *relt;
-
- elt = ucl_object_lookup_any (obj, "read_servers", "servers", NULL);
-
- if (elt == NULL) {
- return FALSE;
- }
-
- backend->read_servers = rspamd_upstreams_create (cfg->ups_ctx);
- if (!rspamd_upstreams_from_ucl (backend->read_servers, elt,
- REDIS_DEFAULT_PORT, NULL)) {
- msg_err ("statfile %s cannot get read servers configuration",
- symbol);
- return FALSE;
- }
-
- relt = elt;
-
- elt = ucl_object_lookup (obj, "write_servers");
- if (elt == NULL) {
- /* Use read servers as write ones */
- g_assert (relt != NULL);
- backend->write_servers = rspamd_upstreams_create (cfg->ups_ctx);
- if (!rspamd_upstreams_from_ucl (backend->write_servers, relt,
- REDIS_DEFAULT_PORT, NULL)) {
- msg_err ("statfile %s cannot get write servers configuration",
- symbol);
- return FALSE;
- }
- }
- else {
- backend->write_servers = rspamd_upstreams_create (cfg->ups_ctx);
- if (!rspamd_upstreams_from_ucl (backend->write_servers, elt,
- REDIS_DEFAULT_PORT, NULL)) {
- msg_err ("statfile %s cannot get write servers configuration",
- symbol);
- rspamd_upstreams_destroy (backend->write_servers);
- backend->write_servers = NULL;
- }
- }
-
- elt = ucl_object_lookup_any (obj, "db", "database", "dbname", NULL);
- if (elt) {
- if (ucl_object_type (elt) == UCL_STRING) {
- backend->dbname = ucl_object_tostring (elt);
- }
- else if (ucl_object_type (elt) == UCL_INT) {
- backend->dbname = ucl_object_tostring_forced (elt);
- }
- }
- else {
- backend->dbname = NULL;
- }
-
- elt = ucl_object_lookup (obj, "password");
- if (elt) {
- backend->password = ucl_object_tostring (elt);
- }
- else {
- backend->password = NULL;
- }
-
- return TRUE;
-}
-
static void
rspamd_redis_parse_classifier_opts (struct redis_stat_ctx *backend,
const ucl_object_t *obj,
@@ -1360,14 +1322,6 @@ rspamd_redis_parse_classifier_opts (struct redis_stat_ctx *backend,
backend->redis_object = ucl_object_tostring (elt);
}
- elt = ucl_object_lookup (obj, "timeout");
- if (elt) {
- backend->timeout = ucl_object_todouble (elt);
- }
- else {
- backend->timeout = REDIS_DEFAULT_TIMEOUT;
- }
-
elt = ucl_object_lookup (obj, "store_tokens");
if (elt) {
backend->store_tokens = ucl_object_toboolean (elt);
@@ -1414,24 +1368,27 @@ rspamd_redis_init (struct rspamd_stat_ctx *ctx,
struct rspamd_redis_stat_elt *st_elt;
const ucl_object_t *obj;
gboolean ret = FALSE;
+ gint conf_ref = -1;
+ lua_State *L = (lua_State *)cfg->lua_state;
backend = g_malloc0 (sizeof (*backend));
+ backend->L = L;
+ backend->timeout = REDIS_DEFAULT_TIMEOUT;
/* First search in backend configuration */
obj = ucl_object_lookup (st->classifier->cfg->opts, "backend");
if (obj != NULL && ucl_object_type (obj) == UCL_OBJECT) {
- ret = rspamd_redis_try_ucl (backend, obj, cfg, stf->symbol);
+ ret = rspamd_lua_try_load_redis (L, obj, cfg, &conf_ref);
}
/* Now try statfiles config */
- if (!ret) {
- ret = rspamd_redis_try_ucl (backend, stf->opts, cfg, stf->symbol);
+ if (!ret && stf->opts) {
+ ret = rspamd_lua_try_load_redis (L, stf->opts, cfg, &conf_ref);
}
/* Now try classifier config */
- if (!ret) {
- ret = rspamd_redis_try_ucl (backend, st->classifier->cfg->opts, cfg,
- stf->symbol);
+ if (!ret && st->classifier->cfg->opts) {
+ ret = rspamd_lua_try_load_redis (L, st->classifier->cfg->opts, cfg, &conf_ref);
}
/* Now try global redis settings */
@@ -1444,12 +1401,12 @@ rspamd_redis_init (struct rspamd_stat_ctx *ctx,
specific_obj = ucl_object_lookup (obj, "statistics");
if (specific_obj) {
- ret = rspamd_redis_try_ucl (backend, specific_obj, cfg,
- stf->symbol);
+ ret = rspamd_lua_try_load_redis (L,
+ specific_obj, cfg, &conf_ref);
}
else {
- ret = rspamd_redis_try_ucl (backend, obj, cfg,
- stf->symbol);
+ ret = rspamd_lua_try_load_redis (L,
+ obj, cfg, &conf_ref);
}
}
}
@@ -1460,6 +1417,36 @@ rspamd_redis_init (struct rspamd_stat_ctx *ctx,
return NULL;
}
+ backend->conf_ref = conf_ref;
+
+ /* Check some common table values */
+ lua_rawgeti (L, LUA_REGISTRYINDEX, conf_ref);
+
+ lua_pushstring (L, "timeout");
+ lua_gettable (L, -2);
+ if (lua_type (L, -1) == LUA_TNUMBER) {
+ backend->timeout = lua_tonumber (L, -1);
+ }
+ lua_pop (L, 1);
+
+ lua_pushstring (L, "db");
+ lua_gettable (L, -2);
+ if (lua_type (L, -1) == LUA_TSTRING) {
+ backend->dbname = rspamd_mempool_strdup (cfg->cfg_pool,
+ lua_tostring (L, -1));
+ }
+ lua_pop (L, 1);
+
+ lua_pushstring (L, "password");
+ lua_gettable (L, -2);
+ if (lua_type (L, -1) == LUA_TSTRING) {
+ backend->password = rspamd_mempool_strdup (cfg->cfg_pool,
+ lua_tostring (L, -1));
+ }
+ lua_pop (L, 1);
+
+ lua_settop (L, 0);
+
rspamd_redis_parse_classifier_opts (backend, st->classifier->cfg->opts, cfg);
stf->clcf->flags |= RSPAMD_FLAG_CLASSIFIER_INCREMENTING_BACKEND;
backend->stcf = stf;
@@ -1485,24 +1472,35 @@ rspamd_redis_runtime (struct rspamd_task *task,
struct redis_stat_ctx *ctx = REDIS_CTX (c);
struct redis_stat_runtime *rt;
struct upstream *up;
+ struct upstream_list *ups;
+ char *object_expanded = NULL;
rspamd_inet_addr_t *addr;
g_assert (ctx != NULL);
g_assert (stcf != NULL);
- if (learn && ctx->write_servers == NULL) {
- msg_err_task ("no write servers defined for %s, cannot learn", stcf->symbol);
- return NULL;
- }
-
if (learn) {
- up = rspamd_upstream_get (ctx->write_servers,
+ ups = rspamd_redis_get_servers (ctx, "write_servers");
+
+ if (!ups) {
+ msg_err_task ("no write servers defined for %s, cannot learn",
+ stcf->symbol);
+ return NULL;
+ }
+ up = rspamd_upstream_get (ups,
RSPAMD_UPSTREAM_MASTER_SLAVE,
NULL,
0);
}
else {
- up = rspamd_upstream_get (ctx->read_servers,
+ ups = rspamd_redis_get_servers (ctx, "read_servers");
+
+ if (!ups) {
+ msg_err_task ("no read servers defined for %s, cannot stat",
+ stcf->symbol);
+ return NULL;
+ }
+ up = rspamd_upstream_get (ups,
RSPAMD_UPSTREAM_ROUND_ROBIN,
NULL,
0);
@@ -1513,15 +1511,22 @@ rspamd_redis_runtime (struct rspamd_task *task,
return NULL;
}
+ if (rspamd_redis_expand_object (ctx->redis_object, ctx, task,
+ &object_expanded) == 0) {
+ msg_err_task ("expansion for learning failed for symbol %s "
+ "(maybe learning per user classifier with no user or recipient)",
+ stcf->symbol);
+ return NULL;
+ }
+
rt = rspamd_mempool_alloc0 (task->task_pool, sizeof (*rt));
rspamd_mempool_add_destructor (task->task_pool,
rspamd_gerror_free_maybe, &rt->err);
- rspamd_redis_expand_object (ctx->redis_object, ctx, task,
- &rt->redis_object_expanded);
rt->selected = up;
rt->task = task;
rt->ctx = ctx;
rt->stcf = stcf;
+ rt->redis_object_expanded = object_expanded;
addr = rspamd_upstream_addr (up);
g_assert (addr != NULL);
@@ -1549,13 +1554,10 @@ void
rspamd_redis_close (gpointer p)
{
struct redis_stat_ctx *ctx = REDIS_CTX (p);
+ lua_State *L = ctx->L;
- if (ctx->read_servers) {
- rspamd_upstreams_destroy (ctx->read_servers);
- }
-
- if (ctx->write_servers) {
- rspamd_upstreams_destroy (ctx->write_servers);
+ if (ctx->conf_ref) {
+ luaL_unref (L, LUA_REGISTRYINDEX, ctx->conf_ref);
}
g_free (ctx);
@@ -1594,7 +1596,7 @@ rspamd_redis_process_tokens (struct rspamd_task *task,
if (redisAsyncCommand (rt->redis, rspamd_redis_connected, rt, "HGET %s %s",
rt->redis_object_expanded, learned_key) == REDIS_OK) {
- rspamd_session_add_event (task->s, NULL, rspamd_redis_fin, rt, rspamd_redis_stat_quark ());
+ rspamd_session_add_event (task->s, rspamd_redis_fin, rt, M);
rt->has_event = TRUE;
if (rspamd_event_pending (&rt->timeout_event, EV_TIMEOUT)) {
@@ -1657,6 +1659,7 @@ rspamd_redis_learn_tokens (struct rspamd_task *task, GPtrArray *tokens,
{
struct redis_stat_runtime *rt = REDIS_RUNTIME (p);
struct upstream *up;
+ struct upstream_list *ups;
rspamd_inet_addr_t *addr;
struct timeval tv;
rspamd_fstring_t *query;
@@ -1670,7 +1673,12 @@ rspamd_redis_learn_tokens (struct rspamd_task *task, GPtrArray *tokens,
return FALSE;
}
- up = rspamd_upstream_get (rt->ctx->write_servers,
+ ups = rspamd_redis_get_servers (rt->ctx, "write_servers");
+
+ if (!ups) {
+ return FALSE;
+ }
+ up = rspamd_upstream_get (ups,
RSPAMD_UPSTREAM_MASTER_SLAVE,
NULL,
0);
@@ -1798,7 +1806,7 @@ rspamd_redis_learn_tokens (struct rspamd_task *task, GPtrArray *tokens,
"RSIG");
}
- rspamd_session_add_event (task->s, NULL, rspamd_redis_fin_learn, rt, rspamd_redis_stat_quark ());
+ rspamd_session_add_event (task->s, rspamd_redis_fin_learn, rt, M);
rt->has_event = TRUE;
/* Set timeout */
diff --git a/src/libstat/classifiers/bayes.c b/src/libstat/classifiers/bayes.c
index 5b6b5a0fe..2b0cf21e8 100644
--- a/src/libstat/classifiers/bayes.c
+++ b/src/libstat/classifiers/bayes.c
@@ -38,7 +38,7 @@
G_STRFUNC, \
__VA_ARGS__)
-INIT_LOG_MODULE(bayes)
+INIT_LOG_MODULE_PUBLIC(bayes)
static inline GQuark
bayes_error_quark (void)
@@ -80,6 +80,8 @@ inv_chi_square (struct rspamd_task *task, gdouble value, gint freedom_deg)
sum = prob;
+ msg_debug_bayes ("m: %f, prob: %g", m, prob);
+
/*
* m is our confidence in class
* prob is e ^ x (small value since x is normally less than zero
@@ -89,7 +91,7 @@ inv_chi_square (struct rspamd_task *task, gdouble value, gint freedom_deg)
for (i = 1; i < freedom_deg; i++) {
prob *= m / (gdouble)i;
sum += prob;
- msg_debug_bayes ("prob: %.6f, sum: %.6f", prob, sum);
+ msg_debug_bayes ("i=%d, prob: %g, sum: %g", i, prob, sum);
}
return MIN (1.0, sum);
@@ -109,7 +111,7 @@ struct bayes_task_closure {
* Mathematically we use pow(complexity, complexity), where complexity is the
* window index
*/
-static const double feature_weight[] = { 0, 1, 4, 27, 256, 3125, 46656, 823543 };
+static const double feature_weight[] = { 0, 3125, 256, 27, 1, 0, 0, 0 };
#define PROB_COMBINE(prob, cnt, weight, assumed) (((weight) * (assumed) + (cnt) * (prob)) / ((weight) + (cnt)))
/*
@@ -121,12 +123,12 @@ bayes_classify_token (struct rspamd_classifier *ctx,
{
guint i;
gint id;
- guint64 spam_count = 0, ham_count = 0, total_count = 0;
+ guint spam_count = 0, ham_count = 0, total_count = 0;
struct rspamd_statfile *st;
struct rspamd_task *task;
const gchar *token_type = "txt";
double spam_prob, spam_freq, ham_freq, bayes_spam_prob, bayes_ham_prob,
- ham_prob, fw, w, norm_sum, norm_sub, val;
+ ham_prob, fw, w, val;
task = cl->task;
@@ -145,8 +147,8 @@ bayes_classify_token (struct rspamd_classifier *ctx,
msg_debug_bayes (
"token(meta) %uL <%*s:%*s> probabilistically skipped",
tok->data,
- (int) tok->t1->len, tok->t1->begin,
- (int) tok->t2->len, tok->t2->begin);
+ (int) tok->t1->original.len, tok->t1->original.begin,
+ (int) tok->t2->original.len, tok->t2->original.begin);
}
return;
@@ -173,7 +175,7 @@ bayes_classify_token (struct rspamd_classifier *ctx,
}
/* Probability for this token */
- if (total_count > 0) {
+ if (total_count >= ctx->cfg->min_token_hits) {
spam_freq = ((double)spam_count / MAX (1., (double) ctx->spam_learns));
ham_freq = ((double)ham_count / MAX (1., (double)ctx->ham_learns));
spam_prob = spam_freq / (spam_freq + ham_freq);
@@ -187,19 +189,27 @@ bayes_classify_token (struct rspamd_classifier *ctx,
G_N_ELEMENTS (feature_weight)];
}
- norm_sum = (spam_freq + ham_freq) * (spam_freq + ham_freq);
- norm_sub = (spam_freq - ham_freq) * (spam_freq - ham_freq);
- w = (norm_sub) / (norm_sum) *
- (fw * total_count) / (4.0 * (1.0 + fw * total_count));
+ w = (fw * total_count) / (1.0 + fw * total_count);
+
bayes_spam_prob = PROB_COMBINE (spam_prob, total_count, w, 0.5);
- norm_sub = (ham_freq - spam_freq) * (ham_freq - spam_freq);
- w = (norm_sub) / (norm_sum) *
- (fw * total_count) / (4.0 * (1.0 + fw * total_count));
+
+ if ((bayes_spam_prob > 0.5 && bayes_spam_prob < 0.5 + ctx->cfg->min_prob_strength) ||
+ (bayes_spam_prob < 0.5 && bayes_spam_prob > 0.5 - ctx->cfg->min_prob_strength)) {
+ msg_debug_bayes (
+ "token %uL <%*s:%*s> skipped, prob not in range: %f",
+ tok->data,
+ (int) tok->t1->stemmed.len, tok->t1->stemmed.begin,
+ (int) tok->t2->stemmed.len, tok->t2->stemmed.begin,
+ bayes_spam_prob);
+
+ return;
+ }
+
bayes_ham_prob = PROB_COMBINE (ham_prob, total_count, w, 0.5);
- cl->spam_prob += log2 (bayes_spam_prob);
- cl->ham_prob += log2 (bayes_ham_prob);
+ cl->spam_prob += log (bayes_spam_prob);
+ cl->ham_prob += log (bayes_ham_prob);
cl->processed_tokens ++;
if (!(tok->flags & RSPAMD_STAT_TOKEN_FLAG_META)) {
@@ -210,29 +220,31 @@ bayes_classify_token (struct rspamd_classifier *ctx,
}
if (tok->t1 && tok->t2) {
- msg_debug_bayes ("token(%s) %uL <%*s:%*s>: weight: %f, total_count: %L, "
- "spam_count: %L, ham_count: %L,"
+ msg_debug_bayes ("token(%s) %uL <%*s:%*s>: weight: %f, cf: %f, "
+ "total_count: %ud, "
+ "spam_count: %ud, ham_count: %ud,"
"spam_prob: %.3f, ham_prob: %.3f, "
"bayes_spam_prob: %.3f, bayes_ham_prob: %.3f, "
"current spam prob: %.3f, current ham prob: %.3f",
token_type,
tok->data,
- (int) tok->t1->len, tok->t1->begin,
- (int) tok->t2->len, tok->t2->begin,
- fw, total_count, spam_count, ham_count,
+ (int) tok->t1->stemmed.len, tok->t1->stemmed.begin,
+ (int) tok->t2->stemmed.len, tok->t2->stemmed.begin,
+ fw, w, total_count, spam_count, ham_count,
spam_prob, ham_prob,
bayes_spam_prob, bayes_ham_prob,
cl->spam_prob, cl->ham_prob);
}
else {
- msg_debug_bayes ("token(%s) %uL <?:?>: weight: %f, total_count: %L, "
- "spam_count: %L, ham_count: %L,"
+ msg_debug_bayes ("token(%s) %uL <?:?>: weight: %f, cf: %f, "
+ "total_count: %ud, "
+ "spam_count: %ud, ham_count: %ud,"
"spam_prob: %.3f, ham_prob: %.3f, "
"bayes_spam_prob: %.3f, bayes_ham_prob: %.3f, "
"current spam prob: %.3f, current ham prob: %.3f",
token_type,
tok->data,
- fw, total_count, spam_count, ham_count,
+ fw, w, total_count, spam_count, ham_count,
spam_prob, ham_prob,
bayes_spam_prob, bayes_ham_prob,
cl->spam_prob, cl->ham_prob);
@@ -243,13 +255,20 @@ bayes_classify_token (struct rspamd_classifier *ctx,
gboolean
-bayes_init (rspamd_mempool_t *pool, struct rspamd_classifier *cl)
+bayes_init (struct rspamd_config *cfg,
+ struct event_base *ev_base,
+ struct rspamd_classifier *cl)
{
cl->cfg->flags |= RSPAMD_FLAG_CLASSIFIER_INTEGER;
return TRUE;
}
+void
+bayes_fin (struct rspamd_classifier *cl)
+{
+}
+
gboolean
bayes_classify (struct rspamd_classifier * ctx,
GPtrArray *tokens,
@@ -318,14 +337,51 @@ bayes_classify (struct rspamd_classifier * ctx,
bayes_classify_token (ctx, tok, &cl);
}
- h = 1 - inv_chi_square (task, cl.spam_prob, cl.processed_tokens);
- s = 1 - inv_chi_square (task, cl.ham_prob, cl.processed_tokens);
+ if (cl.processed_tokens == 0) {
+ msg_info_bayes ("no tokens found in bayes database "
+ "(%ud total tokens, %ud text tokens), ignore stats",
+ tokens->len, text_tokens);
+
+ return TRUE;
+ }
+
+ if (ctx->cfg->min_tokens > 0 &&
+ cl.text_tokens < (gint)(ctx->cfg->min_tokens * 0.1)) {
+ msg_info_bayes ("ignore bayes probability since we have "
+ "found too few text tokens: %uL (of %ud checked), "
+ "at least %d is required",
+ cl.text_tokens,
+ text_tokens,
+ (gint)(ctx->cfg->min_tokens * 0.1));
+
+ return TRUE;
+ }
+
+ if (cl.spam_prob > -300 && cl.ham_prob > -300) {
+ /* Fisher value is low enough to apply inv_chi_square */
+ h = 1 - inv_chi_square (task, cl.spam_prob, cl.processed_tokens);
+ s = 1 - inv_chi_square (task, cl.ham_prob, cl.processed_tokens);
+ }
+ else {
+ /* Use naive method */
+ if (cl.spam_prob < cl.ham_prob) {
+ h = (1.0 - exp(cl.spam_prob - cl.ham_prob)) /
+ (1.0 + exp(cl.spam_prob - cl.ham_prob));
+ s = 1.0 - h;
+ }
+ else {
+ s = (1.0 - exp(cl.ham_prob - cl.spam_prob)) /
+ (1.0 + exp(cl.ham_prob - cl.spam_prob));
+ h = 1.0 - s;
+ }
+ }
if (isfinite (s) && isfinite (h)) {
final_prob = (s + 1.0 - h) / 2.;
msg_debug_bayes (
"<%s> got ham prob %.2f -> %.2f and spam prob %.2f -> %.2f,"
- " %L tokens processed of %ud total tokens (%uL text tokens)",
+ " %L tokens processed of %ud total tokens;"
+ " %uL text tokens found of %ud text tokens)",
task->message_id,
cl.ham_prob,
h,
@@ -333,7 +389,8 @@ bayes_classify (struct rspamd_classifier * ctx,
s,
cl.processed_tokens,
tokens->len,
- cl.text_tokens);
+ cl.text_tokens,
+ text_tokens);
}
else {
/*
@@ -357,17 +414,6 @@ bayes_classify (struct rspamd_classifier * ctx,
}
}
- if (ctx->cfg->min_tokens > 0 &&
- cl.text_tokens < (gint)(ctx->cfg->min_tokens * 0.1)) {
- msg_info_bayes ("ignore bayes probability %.2f since we have "
- "too few text tokens: %uL, at least %d is required",
- final_prob,
- cl.text_tokens,
- (gint)(ctx->cfg->min_tokens * 0.1));
-
- return TRUE;
- }
-
pprob = rspamd_mempool_alloc (task->task_pool, sizeof (*pprob));
*pprob = final_prob;
rspamd_mempool_set_variable (task->task_pool, "bayes_prob", pprob, NULL);
@@ -399,6 +445,19 @@ bayes_classify (struct rspamd_classifier * ctx,
(final_prob - 0.5) * 200.);
final_prob = rspamd_normalize_probability (final_prob, 0.5);
g_assert (st != NULL);
+
+ if (final_prob > 1 || final_prob < 0) {
+ msg_err_bayes ("internal error: probability %f is outside of the "
+ "allowed range [0..1]", final_prob);
+
+ if (final_prob > 1) {
+ final_prob = 1.0;
+ }
+ else {
+ final_prob = 0.0;
+ }
+ }
+
rspamd_task_insert_result (task,
st->stcf->symbol,
final_prob,
@@ -483,8 +542,8 @@ bayes_learn_spam (struct rspamd_classifier * ctx,
msg_debug_bayes ("token %uL <%*s:%*s>: window: %d, total_count: %d, "
"spam_count: %d, ham_count: %d",
tok->data,
- (int) tok->t1->len, tok->t1->begin,
- (int) tok->t2->len, tok->t2->begin,
+ (int) tok->t1->stemmed.len, tok->t1->stemmed.begin,
+ (int) tok->t2->stemmed.len, tok->t2->stemmed.begin,
tok->window_idx, total_cnt, spam_cnt, ham_cnt);
}
else {
diff --git a/src/libstat/classifiers/classifiers.h b/src/libstat/classifiers/classifiers.h
index e30f2153a..fd6daf433 100644
--- a/src/libstat/classifiers/classifiers.h
+++ b/src/libstat/classifiers/classifiers.h
@@ -3,6 +3,7 @@
#include "config.h"
#include "mem_pool.h"
+#include <event.h>
#define RSPAMD_DEFAULT_CLASSIFIER "bayes"
/* Consider this value as 0 */
@@ -10,28 +11,32 @@
struct rspamd_classifier_config;
struct rspamd_task;
+struct rspamd_config;
struct rspamd_classifier;
struct token_node_s;
struct rspamd_stat_classifier {
char *name;
- gboolean (*init_func)(rspamd_mempool_t *pool,
- struct rspamd_classifier *cl);
+ gboolean (*init_func)(struct rspamd_config *cfg,
+ struct event_base *ev_base,
+ struct rspamd_classifier *cl);
gboolean (*classify_func)(struct rspamd_classifier * ctx,
- GPtrArray *tokens,
- struct rspamd_task *task);
+ GPtrArray *tokens,
+ struct rspamd_task *task);
gboolean (*learn_spam_func)(struct rspamd_classifier * ctx,
- GPtrArray *input,
- struct rspamd_task *task,
- gboolean is_spam,
- gboolean unlearn,
- GError **err);
+ GPtrArray *input,
+ struct rspamd_task *task,
+ gboolean is_spam,
+ gboolean unlearn,
+ GError **err);
+ void (*fin_func)(struct rspamd_classifier *cl);
};
/* Bayes algorithm */
-gboolean bayes_init (rspamd_mempool_t *pool,
- struct rspamd_classifier *);
+gboolean bayes_init (struct rspamd_config *cfg,
+ struct event_base *ev_base,
+ struct rspamd_classifier *);
gboolean bayes_classify (struct rspamd_classifier *ctx,
GPtrArray *tokens,
struct rspamd_task *task);
@@ -41,10 +46,12 @@ gboolean bayes_learn_spam (struct rspamd_classifier *ctx,
gboolean is_spam,
gboolean unlearn,
GError **err);
+void bayes_fin (struct rspamd_classifier *);
/* Generic lua classifier */
-gboolean lua_classifier_init (rspamd_mempool_t *pool,
- struct rspamd_classifier *);
+gboolean lua_classifier_init (struct rspamd_config *cfg,
+ struct event_base *ev_base,
+ struct rspamd_classifier *);
gboolean lua_classifier_classify (struct rspamd_classifier *ctx,
GPtrArray *tokens,
struct rspamd_task *task);
@@ -55,6 +62,11 @@ gboolean lua_classifier_learn_spam (struct rspamd_classifier *ctx,
gboolean unlearn,
GError **err);
+extern guint rspamd_bayes_log_id;
+#define msg_debug_bayes(...) rspamd_conditional_debug_fast (NULL, task->from_addr, \
+ rspamd_bayes_log_id, "bayes", task->task_pool->tag.uid, \
+ G_STRFUNC, \
+ __VA_ARGS__)
#endif
/*
diff --git a/src/libstat/classifiers/lua_classifier.c b/src/libstat/classifiers/lua_classifier.c
index 7b495b165..83ce7b0e1 100644
--- a/src/libstat/classifiers/lua_classifier.c
+++ b/src/libstat/classifiers/lua_classifier.c
@@ -47,8 +47,9 @@ static GHashTable *lua_classifiers = NULL;
INIT_LOG_MODULE(luacl)
gboolean
-lua_classifier_init (rspamd_mempool_t *pool,
- struct rspamd_classifier *cl)
+lua_classifier_init (struct rspamd_config *cfg,
+ struct event_base *ev_base,
+ struct rspamd_classifier *cl)
{
struct rspamd_lua_classifier_ctx *ctx;
lua_State *L = cl->ctx->cfg->lua_state;
@@ -62,7 +63,7 @@ lua_classifier_init (rspamd_mempool_t *pool,
ctx = g_hash_table_lookup (lua_classifiers, cl->subrs->name);
if (ctx != NULL) {
- msg_err_pool ("duplicate lua classifier definition: %s",
+ msg_err_config ("duplicate lua classifier definition: %s",
cl->subrs->name);
return FALSE;
@@ -70,7 +71,7 @@ lua_classifier_init (rspamd_mempool_t *pool,
lua_getglobal (L, "rspamd_classifiers");
if (lua_type (L, -1) != LUA_TTABLE) {
- msg_err_pool ("cannot register classifier %s: no rspamd_classifier global",
+ msg_err_config ("cannot register classifier %s: no rspamd_classifier global",
cl->subrs->name);
lua_pop (L, 1);
@@ -81,7 +82,7 @@ lua_classifier_init (rspamd_mempool_t *pool,
lua_gettable (L, -2);
if (lua_type (L, -1) != LUA_TTABLE) {
- msg_err_pool ("cannot register classifier %s: bad lua type: %s",
+ msg_err_config ("cannot register classifier %s: bad lua type: %s",
cl->subrs->name, lua_typename (L, lua_type (L, -1)));
lua_pop (L, 2);
@@ -92,7 +93,7 @@ lua_classifier_init (rspamd_mempool_t *pool,
lua_gettable (L, -2);
if (lua_type (L, -1) != LUA_TFUNCTION) {
- msg_err_pool ("cannot register classifier %s: bad lua type for classify: %s",
+ msg_err_config ("cannot register classifier %s: bad lua type for classify: %s",
cl->subrs->name, lua_typename (L, lua_type (L, -1)));
lua_pop (L, 3);
@@ -105,7 +106,7 @@ lua_classifier_init (rspamd_mempool_t *pool,
lua_gettable (L, -2);
if (lua_type (L, -1) != LUA_TFUNCTION) {
- msg_err_pool ("cannot register classifier %s: bad lua type for learn: %s",
+ msg_err_config ("cannot register classifier %s: bad lua type for learn: %s",
cl->subrs->name, lua_typename (L, lua_type (L, -1)));
lua_pop (L, 3);
diff --git a/src/libstat/learn_cache/redis_cache.c b/src/libstat/learn_cache/redis_cache.c
index 11bc13aae..6a0aa1da7 100644
--- a/src/libstat/learn_cache/redis_cache.c
+++ b/src/libstat/learn_cache/redis_cache.c
@@ -22,20 +22,23 @@
#include "ucl.h"
#include "hiredis.h"
#include "adapters/libevent.h"
+#include "lua/lua_common.h"
#define REDIS_DEFAULT_TIMEOUT 0.5
#define REDIS_STAT_TIMEOUT 30
#define REDIS_DEFAULT_PORT 6379
#define DEFAULT_REDIS_KEY "learned_ids"
+static const gchar *M = "redis learn cache";
+
struct rspamd_redis_cache_ctx {
+ lua_State *L;
struct rspamd_statfile_config *stcf;
- struct upstream_list *read_servers;
- struct upstream_list *write_servers;
const gchar *password;
const gchar *dbname;
const gchar *redis_object;
gdouble timeout;
+ gint conf_ref;
};
struct rspamd_redis_cache_runtime {
@@ -50,7 +53,23 @@ struct rspamd_redis_cache_runtime {
static GQuark
rspamd_stat_cache_redis_quark (void)
{
- return g_quark_from_static_string ("redis-statistics");
+ return g_quark_from_static_string (M);
+}
+
+static inline struct upstream_list *
+rspamd_redis_get_servers (struct rspamd_redis_cache_ctx *ctx,
+ const gchar *what)
+{
+ lua_State *L = ctx->L;
+ struct upstream_list *res;
+
+ lua_rawgeti (L, LUA_REGISTRYINDEX, ctx->conf_ref);
+ lua_pushstring (L, what);
+ lua_gettable (L, -2);
+ res = *((struct upstream_list**)lua_touserdata (L, -1));
+ lua_settop (L, 0);
+
+ return res;
}
static void
@@ -208,94 +227,6 @@ rspamd_stat_cache_redis_generate_id (struct rspamd_task *task)
rspamd_mempool_set_variable (task->task_pool, "words_hash", b32out, g_free);
}
-static gboolean
-rspamd_redis_cache_try_ucl (struct rspamd_redis_cache_ctx *cache_ctx,
- const ucl_object_t *obj,
- struct rspamd_config *cfg,
- const gchar *symbol)
-{
- const ucl_object_t *elt, *relt;
-
- elt = ucl_object_lookup_any (obj, "read_servers", "servers", NULL);
-
- if (elt == NULL) {
- return FALSE;
- }
-
- cache_ctx->read_servers = rspamd_upstreams_create (cfg->ups_ctx);
- if (!rspamd_upstreams_from_ucl (cache_ctx->read_servers, elt,
- REDIS_DEFAULT_PORT, NULL)) {
- msg_err ("statfile %s cannot get read servers configuration",
- symbol);
- return FALSE;
- }
-
- relt = elt;
-
- elt = ucl_object_lookup (obj, "write_servers");
- if (elt == NULL) {
- /* Use read servers as write ones */
- g_assert (relt != NULL);
- cache_ctx->write_servers = rspamd_upstreams_create (cfg->ups_ctx);
- if (!rspamd_upstreams_from_ucl (cache_ctx->write_servers, relt,
- REDIS_DEFAULT_PORT, NULL)) {
- msg_err ("statfile %s cannot get write servers configuration",
- symbol);
- return FALSE;
- }
- }
- else {
- cache_ctx->write_servers = rspamd_upstreams_create (cfg->ups_ctx);
- if (!rspamd_upstreams_from_ucl (cache_ctx->write_servers, elt,
- REDIS_DEFAULT_PORT, NULL)) {
- msg_err ("statfile %s cannot get write servers configuration",
- symbol);
- rspamd_upstreams_destroy (cache_ctx->write_servers);
- cache_ctx->write_servers = NULL;
- }
- }
-
-
- elt = ucl_object_lookup (obj, "timeout");
- if (elt) {
- cache_ctx->timeout = ucl_object_todouble (elt);
- }
- else {
- cache_ctx->timeout = REDIS_DEFAULT_TIMEOUT;
- }
-
- elt = ucl_object_lookup (obj, "password");
- if (elt) {
- cache_ctx->password = ucl_object_tostring (elt);
- }
- else {
- cache_ctx->password = NULL;
- }
-
- elt = ucl_object_lookup_any (obj, "db", "database", "dbname", NULL);
- if (elt) {
- if (ucl_object_type (elt) == UCL_STRING) {
- cache_ctx->dbname = ucl_object_tostring (elt);
- }
- else if (ucl_object_type (elt) == UCL_INT) {
- cache_ctx->dbname = ucl_object_tostring_forced (elt);
- }
- }
- else {
- cache_ctx->dbname = NULL;
- }
-
- elt = ucl_object_lookup_any (obj, "cache_key", "key", NULL);
- if (elt == NULL || ucl_object_type (elt) != UCL_STRING) {
- cache_ctx->redis_object = DEFAULT_REDIS_KEY;
- }
- else {
- cache_ctx->redis_object = ucl_object_tostring (elt);
- }
-
- return TRUE;
-}
-
gpointer
rspamd_stat_cache_redis_init (struct rspamd_stat_ctx *ctx,
struct rspamd_config *cfg,
@@ -306,24 +237,27 @@ rspamd_stat_cache_redis_init (struct rspamd_stat_ctx *ctx,
struct rspamd_statfile_config *stf = st->stcf;
const ucl_object_t *obj;
gboolean ret = FALSE;
+ lua_State *L = (lua_State *)cfg->lua_state;
+ gint conf_ref = -1;
cache_ctx = g_malloc0 (sizeof (*cache_ctx));
+ cache_ctx->timeout = REDIS_DEFAULT_TIMEOUT;
+ cache_ctx->L = L;
/* First search in backend configuration */
obj = ucl_object_lookup (st->classifier->cfg->opts, "backend");
if (obj != NULL && ucl_object_type (obj) == UCL_OBJECT) {
- ret = rspamd_redis_cache_try_ucl (cache_ctx, obj, cfg, stf->symbol);
+ ret = rspamd_lua_try_load_redis (L, obj, cfg, &conf_ref);
}
/* Now try statfiles config */
- if (!ret) {
- ret = rspamd_redis_cache_try_ucl (cache_ctx, stf->opts, cfg, stf->symbol);
+ if (!ret && stf->opts) {
+ ret = rspamd_lua_try_load_redis (L, stf->opts, cfg, &conf_ref);
}
/* Now try classifier config */
- if (!ret) {
- ret = rspamd_redis_cache_try_ucl (cache_ctx, st->classifier->cfg->opts, cfg,
- stf->symbol);
+ if (!ret && st->classifier->cfg->opts) {
+ ret = rspamd_lua_try_load_redis (L, st->classifier->cfg->opts, cfg, &conf_ref);
}
/* Now try global redis settings */
@@ -336,23 +270,61 @@ rspamd_stat_cache_redis_init (struct rspamd_stat_ctx *ctx,
specific_obj = ucl_object_lookup (obj, "statistics");
if (specific_obj) {
- ret = rspamd_redis_cache_try_ucl (cache_ctx, specific_obj, cfg,
- stf->symbol);
+ ret = rspamd_lua_try_load_redis (L,
+ specific_obj, cfg, &conf_ref);
}
else {
- ret = rspamd_redis_cache_try_ucl (cache_ctx, obj, cfg,
- stf->symbol);
+ ret = rspamd_lua_try_load_redis (L,
+ obj, cfg, &conf_ref);
}
}
}
-
if (!ret) {
msg_err_config ("cannot init redis cache for %s", stf->symbol);
g_free (cache_ctx);
return NULL;
}
+ obj = ucl_object_lookup (st->classifier->cfg->opts, "cache_key");
+
+ if (obj) {
+ cache_ctx->redis_object = ucl_object_tostring (obj);
+ }
+ else {
+ cache_ctx->redis_object = DEFAULT_REDIS_KEY;
+ }
+
+ cache_ctx->conf_ref = conf_ref;
+
+ /* Check some common table values */
+ lua_rawgeti (L, LUA_REGISTRYINDEX, conf_ref);
+
+ lua_pushstring (L, "timeout");
+ lua_gettable (L, -2);
+ if (lua_type (L, -1) == LUA_TNUMBER) {
+ cache_ctx->timeout = lua_tonumber (L, -1);
+ }
+ lua_pop (L, 1);
+
+ lua_pushstring (L, "db");
+ lua_gettable (L, -2);
+ if (lua_type (L, -1) == LUA_TSTRING) {
+ cache_ctx->dbname = rspamd_mempool_strdup (cfg->cfg_pool,
+ lua_tostring (L, -1));
+ }
+ lua_pop (L, 1);
+
+ lua_pushstring (L, "password");
+ lua_gettable (L, -2);
+ if (lua_type (L, -1) == LUA_TSTRING) {
+ cache_ctx->password = rspamd_mempool_strdup (cfg->cfg_pool,
+ lua_tostring (L, -1));
+ }
+ lua_pop (L, 1);
+
+ lua_settop (L, 0);
+
cache_ctx->stcf = stf;
return (gpointer)cache_ctx;
@@ -365,28 +337,39 @@ rspamd_stat_cache_redis_runtime (struct rspamd_task *task,
struct rspamd_redis_cache_ctx *ctx = c;
struct rspamd_redis_cache_runtime *rt;
struct upstream *up;
+ struct upstream_list *ups;
rspamd_inet_addr_t *addr;
g_assert (ctx != NULL);
- if (learn && ctx->write_servers == NULL) {
- msg_err_task ("no write servers defined for %s, cannot learn",
- ctx->stcf->symbol);
- return NULL;
- }
-
if (task->tokens == NULL || task->tokens->len == 0) {
return NULL;
}
if (learn) {
- up = rspamd_upstream_get (ctx->write_servers,
+ ups = rspamd_redis_get_servers (ctx, "write_servers");
+
+ if (!ups) {
+ msg_err_task ("no write servers defined for %s, cannot learn",
+ ctx->stcf->symbol);
+ return NULL;
+ }
+
+ up = rspamd_upstream_get (ups,
RSPAMD_UPSTREAM_MASTER_SLAVE,
NULL,
0);
}
else {
- up = rspamd_upstream_get (ctx->read_servers,
+ ups = rspamd_redis_get_servers (ctx, "read_servers");
+
+ if (!ups) {
+ msg_err_task ("no read servers defined for %s, cannot check",
+ ctx->stcf->symbol);
+ return NULL;
+ }
+
+ up = rspamd_upstream_get (ups,
RSPAMD_UPSTREAM_ROUND_ROBIN,
NULL,
0);
@@ -453,7 +436,10 @@ rspamd_stat_cache_redis_check (struct rspamd_task *task,
if (redisAsyncCommand (rt->redis, rspamd_stat_cache_redis_get, rt,
"HGET %s %s",
rt->ctx->redis_object, h) == REDIS_OK) {
- rspamd_session_add_event (task->s, NULL, rspamd_redis_cache_fin, rt, rspamd_stat_cache_redis_quark ());
+ rspamd_session_add_event (task->s,
+ rspamd_redis_cache_fin,
+ rt,
+ M);
event_add (&rt->timeout_event, &tv);
rt->has_event = TRUE;
}
@@ -485,7 +471,8 @@ rspamd_stat_cache_redis_learn (struct rspamd_task *task,
if (redisAsyncCommand (rt->redis, rspamd_stat_cache_redis_set, rt,
"HSET %s %s %d",
rt->ctx->redis_object, h, flag) == REDIS_OK) {
- rspamd_session_add_event (task->s, NULL, rspamd_redis_cache_fin, rt, rspamd_stat_cache_redis_quark ());
+ rspamd_session_add_event (task->s,
+ rspamd_redis_cache_fin, rt, M);
event_add (&rt->timeout_event, &tv);
rt->has_event = TRUE;
}
@@ -497,5 +484,14 @@ rspamd_stat_cache_redis_learn (struct rspamd_task *task,
void
rspamd_stat_cache_redis_close (gpointer c)
{
+ struct rspamd_redis_cache_ctx *ctx = (struct rspamd_redis_cache_ctx *)c;
+ lua_State *L;
+
+ L = ctx->L;
+
+ if (ctx->conf_ref) {
+ luaL_unref (L, LUA_REGISTRYINDEX, ctx->conf_ref);
+ }
+ g_free (ctx);
}
diff --git a/src/libstat/learn_cache/sqlite3_cache.c b/src/libstat/learn_cache/sqlite3_cache.c
index 255c835bb..52921326d 100644
--- a/src/libstat/learn_cache/sqlite3_cache.c
+++ b/src/libstat/learn_cache/sqlite3_cache.c
@@ -221,6 +221,8 @@ rspamd_stat_cache_sqlite3_check (struct rspamd_task *task,
/* We have some existing record in the table */
if (!!flag == !!is_spam) {
/* Already learned */
+ msg_warn_task ("already seen stat hash: %*bs",
+ rspamd_cryptobox_HASHBYTES, out);
return RSPAMD_LEARN_INGORE;
}
else {
diff --git a/src/libstat/stat_api.h b/src/libstat/stat_api.h
index 84db8ee01..9dcd6f8e8 100644
--- a/src/libstat/stat_api.h
+++ b/src/libstat/stat_api.h
@@ -26,16 +26,25 @@
* High level statistics API
*/
-#define RSPAMD_STAT_TOKEN_FLAG_TEXT (1 << 0)
-#define RSPAMD_STAT_TOKEN_FLAG_META (1 << 1)
-#define RSPAMD_STAT_TOKEN_FLAG_LUA_META (1 << 2)
-#define RSPAMD_STAT_TOKEN_FLAG_EXCEPTION (1 << 3)
-#define RSPAMD_STAT_TOKEN_FLAG_SUBJECT (1 << 4)
-#define RSPAMD_STAT_TOKEN_FLAG_UNIGRAM (1 << 5)
+#define RSPAMD_STAT_TOKEN_FLAG_TEXT (1u << 0)
+#define RSPAMD_STAT_TOKEN_FLAG_META (1u << 1)
+#define RSPAMD_STAT_TOKEN_FLAG_LUA_META (1u << 2)
+#define RSPAMD_STAT_TOKEN_FLAG_EXCEPTION (1u << 3)
+#define RSPAMD_STAT_TOKEN_FLAG_HEADER (1u << 4)
+#define RSPAMD_STAT_TOKEN_FLAG_UNIGRAM (1u << 5)
+#define RSPAMD_STAT_TOKEN_FLAG_UTF (1u << 6)
+#define RSPAMD_STAT_TOKEN_FLAG_NORMALISED (1u << 7)
+#define RSPAMD_STAT_TOKEN_FLAG_STEMMED (1u << 8)
+#define RSPAMD_STAT_TOKEN_FLAG_BROKEN_UNICODE (1u << 9)
+#define RSPAMD_STAT_TOKEN_FLAG_STOP_WORD (1u << 9)
+#define RSPAMD_STAT_TOKEN_FLAG_SKIPPED (1u << 10)
+#define RSPAMD_STAT_TOKEN_FLAG_INVISIBLE_SPACES (1u << 11)
typedef struct rspamd_stat_token_s {
- const gchar *begin;
- gsize len;
+ rspamd_ftok_t original; /* utf8 raw */
+ rspamd_ftok_unicode_t unicode; /* array of unicode characters, normalized, lowercased */
+ rspamd_ftok_t normalized; /* normalized and lowercased utf8 */
+ rspamd_ftok_t stemmed; /* stemmed utf8 */
guint flags;
} rspamd_stat_token_t;
diff --git a/src/libstat/stat_config.c b/src/libstat/stat_config.c
index 9d1e57f13..101db4fe6 100644
--- a/src/libstat/stat_config.c
+++ b/src/libstat/stat_config.c
@@ -28,6 +28,7 @@ static struct rspamd_stat_classifier lua_classifier = {
.init_func = lua_classifier_init,
.classify_func = lua_classifier_classify,
.learn_spam_func = lua_classifier_learn_spam,
+ .fin_func = NULL,
};
static struct rspamd_stat_classifier stat_classifiers[] = {
@@ -36,6 +37,7 @@ static struct rspamd_stat_classifier stat_classifiers[] = {
.init_func = bayes_init,
.classify_func = bayes_classify,
.learn_spam_func = bayes_learn_spam,
+ .fin_func = bayes_fin,
}
};
@@ -162,6 +164,67 @@ rspamd_stat_init (struct rspamd_config *cfg, struct event_base *ev_base)
stat_ctx->classifiers = g_ptr_array_new ();
stat_ctx->async_elts = g_queue_new ();
stat_ctx->ev_base = ev_base;
+ stat_ctx->lua_stat_tokens_ref = -1;
+
+ /* Interact with lua_stat */
+ if (luaL_dostring (L, "return require \"lua_stat\"") != 0) {
+ msg_err_config ("cannot require lua_stat: %s",
+ lua_tostring (L, -1));
+ }
+ else {
+ if (lua_type (L, -1) != LUA_TTABLE) {
+ msg_err_config ("lua stat must return "
+ "table and not %s",
+ lua_typename (L, lua_type (L, -1)));
+ }
+ else {
+ lua_pushstring (L, "gen_stat_tokens");
+ lua_gettable (L, -2);
+
+ if (lua_type (L, -1) != LUA_TFUNCTION) {
+ msg_err_config ("gen_stat_tokens must return "
+ "function and not %s",
+ lua_typename (L, lua_type (L, -1)));
+ }
+ else {
+ /* Call this function to obtain closure */
+ gint err_idx, ret;
+ GString *tb;
+ struct rspamd_config **pcfg;
+
+ lua_pushcfunction (L, &rspamd_lua_traceback);
+ err_idx = lua_gettop (L);
+ lua_pushvalue (L, err_idx - 1);
+
+ pcfg = lua_newuserdata (L, sizeof (*pcfg));
+ *pcfg = cfg;
+ rspamd_lua_setclass (L, "rspamd{config}", -1);
+
+ if ((ret = lua_pcall (L, 1, 1, err_idx)) != 0) {
+ tb = lua_touserdata (L, -1);
+ msg_err_config ("call to gen_stat_tokens lua "
+ "script failed (%d): %v", ret, tb);
+
+ if (tb) {
+ g_string_free (tb, TRUE);
+ }
+ }
+ else {
+ if (lua_type (L, -1) != LUA_TFUNCTION) {
+ msg_err_config ("gen_stat_tokens invocation must return "
+ "function and not %s",
+ lua_typename (L, lua_type (L, -1)));
+ }
+ else {
+ stat_ctx->lua_stat_tokens_ref = luaL_ref (L, LUA_REGISTRYINDEX);
+ }
+ }
+ }
+ }
+ }
+
+ /* Cleanup mess */
+ lua_settop (L, 0);
/* Create statfiles from the classifiers */
cur = cfg->classifiers;
@@ -182,7 +245,7 @@ rspamd_stat_init (struct rspamd_config *cfg, struct event_base *ev_base)
continue;
}
- if (!cl->subrs->init_func (cfg->cfg_pool, cl)) {
+ if (!cl->subrs->init_func (cfg, ev_base, cl)) {
g_free (cl);
msg_err_config ("cannot init classifier type %s", clf->name);
cur = g_list_next (cur);
@@ -328,6 +391,11 @@ rspamd_stat_close (void)
}
g_array_free (cl->statfiles_ids, TRUE);
+
+ if (cl->subrs->fin_func) {
+ cl->subrs->fin_func (cl);
+ }
+
g_free (cl);
}
@@ -342,6 +410,12 @@ rspamd_stat_close (void)
g_queue_free (stat_ctx->async_elts);
g_ptr_array_free (st_ctx->statfiles, TRUE);
g_ptr_array_free (st_ctx->classifiers, TRUE);
+
+ if (st_ctx->lua_stat_tokens_ref != -1) {
+ luaL_unref (st_ctx->cfg->lua_state, LUA_REGISTRYINDEX,
+ st_ctx->lua_stat_tokens_ref);
+ }
+
g_free (st_ctx);
/* Set global var to NULL */
@@ -475,11 +549,11 @@ rspamd_stat_ctx_register_async (rspamd_stat_async_handler handler,
g_assert (st_ctx != NULL);
elt = g_malloc0 (sizeof (*elt));
- REF_INIT_RETAIN (elt, rspamd_async_elt_dtor);
elt->handler = handler;
elt->cleanup = cleanup;
elt->ud = d;
elt->timeout = timeout;
+ REF_INIT_RETAIN (elt, rspamd_async_elt_dtor);
/* Enabled by default */
diff --git a/src/libstat/stat_internal.h b/src/libstat/stat_internal.h
index 44f48ae5a..a547ca93a 100644
--- a/src/libstat/stat_internal.h
+++ b/src/libstat/stat_internal.h
@@ -41,6 +41,7 @@ struct rspamd_classifier {
gulong ham_learns;
struct rspamd_classifier_config *cfg;
struct rspamd_stat_classifier *subrs;
+ gpointer specific;
};
struct rspamd_statfile {
@@ -85,6 +86,9 @@ struct rspamd_stat_ctx {
GPtrArray *classifiers; /* struct rspamd_classifier */
GQueue *async_elts; /* struct rspamd_stat_async_elt */
struct rspamd_config *cfg;
+
+ gint lua_stat_tokens_ref;
+
/* Global tokenizer */
struct rspamd_stat_tokenizer *tokenizer;
gpointer tkcf;
diff --git a/src/libstat/stat_process.c b/src/libstat/stat_process.c
index ca51d7b02..d097e12e0 100644
--- a/src/libstat/stat_process.c
+++ b/src/libstat/stat_process.c
@@ -32,273 +32,85 @@
static const gdouble similarity_treshold = 80.0;
static void
-rspamd_stat_tokenize_header (struct rspamd_task *task,
- const gchar *name, const gchar *prefix, GArray *ar)
-{
- struct rspamd_mime_header *cur;
- GPtrArray *hdrs;
- guint i;
- rspamd_stat_token_t str;
-
- hdrs = g_hash_table_lookup (task->raw_headers, name);
- str.flags = RSPAMD_STAT_TOKEN_FLAG_META;
-
- if (hdrs != NULL) {
-
- PTR_ARRAY_FOREACH (hdrs, i, cur) {
- if (cur->name != NULL) {
- str.begin = cur->name;
- str.len = strlen (cur->name);
- g_array_append_val (ar, str);
- }
- if (cur->decoded != NULL) {
- str.begin = cur->decoded;
- str.len = strlen (cur->decoded);
- g_array_append_val (ar, str);
- }
- else if (cur->value != NULL) {
- str.begin = cur->value;
- str.len = strlen (cur->value);
- g_array_append_val (ar, str);
- }
- }
-
- msg_debug_task ("added stat tokens for header '%s'", name);
- }
-}
-
-static void
rspamd_stat_tokenize_parts_metadata (struct rspamd_stat_ctx *st_ctx,
struct rspamd_task *task)
{
- struct rspamd_image *img;
- struct rspamd_mime_part *part;
- struct rspamd_mime_text_part *tp;
- GList *cur;
GArray *ar;
rspamd_stat_token_t elt;
guint i;
- gchar tmpbuf[128];
lua_State *L = task->cfg->lua_state;
- const gchar *headers_hash;
- struct rspamd_mime_header *hdr;
ar = g_array_sized_new (FALSE, FALSE, sizeof (elt), 16);
+ memset (&elt, 0, sizeof (elt));
elt.flags = RSPAMD_STAT_TOKEN_FLAG_META;
- /* Insert images */
- for (i = 0; i < task->parts->len; i ++) {
- part = g_ptr_array_index (task->parts, i);
-
- if ((part->flags & RSPAMD_MIME_PART_IMAGE) && part->specific.img) {
- img = part->specific.img;
-
- /* If an image has a linked HTML part, then we push its details to the stat */
- if (img->html_image) {
- elt.begin = (gchar *)"image";
- elt.len = 5;
- g_array_append_val (ar, elt);
- elt.begin = (gchar *)&img->html_image->height;
- elt.len = sizeof (img->html_image->height);
- g_array_append_val (ar, elt);
- elt.begin = (gchar *)&img->html_image->width;
- elt.len = sizeof (img->html_image->width);
- g_array_append_val (ar, elt);
- elt.begin = (gchar *)&img->type;
- elt.len = sizeof (img->type);
- g_array_append_val (ar, elt);
-
- if (img->filename) {
- elt.begin = (gchar *)img->filename;
- elt.len = strlen (elt.begin);
- g_array_append_val (ar, elt);
- }
+ if (st_ctx->lua_stat_tokens_ref != -1) {
+ gint err_idx, ret;
+ GString *tb;
+ struct rspamd_task **ptask;
- msg_debug_task ("added stat tokens for image '%s'", img->html_image->src);
- }
- }
- else if (part->cd && part->cd->filename.len > 0) {
- elt.begin = (gchar *)part->cd->filename.begin;
- elt.len = part->cd->filename.len;
- g_array_append_val (ar, elt);
- }
- }
+ lua_pushcfunction (L, &rspamd_lua_traceback);
+ err_idx = lua_gettop (L);
+ lua_rawgeti (L, LUA_REGISTRYINDEX, st_ctx->lua_stat_tokens_ref);
- /* Process mime parts */
- for (i = 0; i < task->parts->len; i ++) {
- part = g_ptr_array_index (task->parts, i);
+ ptask = lua_newuserdata (L, sizeof (*ptask));
+ *ptask = task;
+ rspamd_lua_setclass (L, "rspamd{task}", -1);
- if (IS_CT_MULTIPART (part->ct)) {
- elt.begin = (gchar *)part->ct->boundary.begin;
- elt.len = part->ct->boundary.len;
+ if ((ret = lua_pcall (L, 1, 1, err_idx)) != 0) {
+ tb = lua_touserdata (L, -1);
+ msg_err_task ("call to stat_tokens lua "
+ "script failed (%d): %v", ret, tb);
- if (elt.len) {
- msg_debug_task ("added stat tokens for mime boundary '%*s'",
- (gint)elt.len, elt.begin);
- g_array_append_val (ar, elt);
+ if (tb) {
+ g_string_free (tb, TRUE);
}
-
- if (part->parsed_data.len > 1) {
- rspamd_snprintf (tmpbuf, sizeof (tmpbuf), "mime%d:%dlog",
- i, (gint)log2 (part->parsed_data.len));
- elt.begin = rspamd_mempool_strdup (task->task_pool, tmpbuf);
- elt.len = strlen (elt.begin);
- g_array_append_val (ar, elt);
- }
- }
- }
-
- /* Process text parts metadata */
- for (i = 0; i < task->text_parts->len; i ++) {
- tp = g_ptr_array_index (task->text_parts, i);
-
- if (tp->language != NULL && tp->language[0] != '\0') {
- elt.begin = (gchar *)tp->language;
- elt.len = strlen (elt.begin);
- msg_debug_task ("added stat tokens for part language '%s'", elt.begin);
- g_array_append_val (ar, elt);
- }
- if (tp->real_charset != NULL) {
- elt.begin = (gchar *)tp->real_charset;
- elt.len = strlen (elt.begin);
- msg_debug_task ("added stat tokens for part charset '%s'", elt.begin);
- g_array_append_val (ar, elt);
- }
- }
-
- cur = g_list_first (task->cfg->classify_headers);
-
- while (cur) {
- rspamd_stat_tokenize_header (task, cur->data, "UA:", ar);
-
- cur = g_list_next (cur);
- }
-
- /* Use headers order */
- headers_hash = rspamd_mempool_get_variable (task->task_pool,
- RSPAMD_MEMPOOL_HEADERS_HASH);
-
- if (headers_hash) {
- elt.begin = (gchar *)headers_hash;
- elt.len = 16;
- g_array_append_val (ar, elt);
- }
-
- /* Use more precise headers order */
- cur = g_list_first (task->headers_order->head);
- while (cur) {
- hdr = cur->data;
-
- if (hdr->name && hdr->type != RSPAMD_HEADER_RECEIVED) {
- elt.begin = hdr->name;
- elt.len = strlen (hdr->name);
- g_array_append_val (ar, elt);
}
+ else {
+ if (lua_type (L, -1) != LUA_TTABLE) {
+ msg_err_task ("stat_tokens invocation must return "
+ "table and not %s",
+ lua_typename (L, lua_type (L, -1)));
+ }
+ else {
+ guint vlen;
+ rspamd_ftok_t tok;
- cur = g_list_next (cur);
- }
-
- /* Use metatokens plugin from Lua */
- lua_getglobal (L, "rspamd_plugins");
-
- if (lua_type (L, -1) == LUA_TTABLE) {
- lua_pushstring (L, "stat_metatokens");
- lua_gettable (L, -2);
-
- if (lua_type (L, -1) == LUA_TTABLE) {
- gint old_top;
+ vlen = rspamd_lua_table_size (L, -1);
- old_top = lua_gettop (L);
- lua_pushstring (L, "callback");
- lua_gettable (L, -2);
+ for (i = 0; i < vlen; i ++) {
+ lua_rawgeti (L, -1, i + 1);
+ tok.begin = lua_tolstring (L, -1, &tok.len);
- if (lua_type (L, -1) == LUA_TFUNCTION) {
- struct rspamd_task **ptask;
+ if (tok.begin && tok.len > 0) {
+ elt.original.begin =
+ rspamd_mempool_ftokdup (task->task_pool, &tok);
+ elt.original.len = tok.len;
+ elt.stemmed.begin = elt.original.begin;
+ elt.stemmed.len = elt.original.len;
+ elt.normalized.begin = elt.original.begin;
+ elt.normalized.len = elt.original.len;
- ptask = lua_newuserdata (L, sizeof (*ptask));
- rspamd_lua_setclass (L, "rspamd{task}", -1);
- *ptask = task;
+ g_array_append_val (ar, elt);
+ }
- if (lua_pcall (L, 1, LUA_MULTRET, 0) != 0) {
- msg_err_task ("stat_metatokens failed: %s",
- lua_tostring (L, -1));
lua_pop (L, 1);
- } else {
- if (lua_gettop (L) > old_top &&
- lua_istable (L, old_top + 1)) {
- lua_pushvalue (L, old_top + 1);
- /* Iterate over table of tables */
- for (lua_pushnil (L); lua_next (L, -2);
- lua_pop (L, 1)) {
- elt.flags = RSPAMD_STAT_TOKEN_FLAG_META|
- RSPAMD_STAT_TOKEN_FLAG_LUA_META;
-
- if (lua_isnumber (L, -1)) {
- gdouble num = lua_tonumber (L, -1);
- guint8 *pnum = rspamd_mempool_alloc (
- task->task_pool,
- sizeof (num));
-
- msg_debug_task ("got metatoken number: %.2f",
- num);
- memcpy (pnum, &num, sizeof (num));
- elt.begin = (gchar *) pnum;
- elt.len = sizeof (num);
- g_array_append_val (ar, elt);
- } else if (lua_isstring (L, -1)) {
- const gchar *str;
- gsize tlen;
-
- str = lua_tolstring (L, -1, &tlen);
- guint8 *pstr = rspamd_mempool_alloc (
- task->task_pool,
- tlen);
- memcpy (pstr, str, tlen);
-
- msg_debug_task ("got metatoken string: %*s",
- (gint) tlen, str);
- elt.begin = (gchar *) pstr;
- elt.len = tlen;
- g_array_append_val (ar, elt);
- }
- else if (lua_istable (L, -1)) {
- /* Treat that as unigramms */
- for (lua_pushnil (L); lua_next (L, -2);
- lua_pop (L, 1)) {
- if (lua_isstring (L, -1)) {
- const gchar *str;
- gsize tlen;
-
- str = lua_tolstring (L, -1, &tlen);
- guint8 *pstr = rspamd_mempool_alloc (
- task->task_pool,
- tlen);
- memcpy (pstr, str, tlen);
-
- msg_debug_task ("got unigramm "
- "metatoken string: %*s",
- (gint) tlen, str);
- elt.begin = (gchar *) pstr;
- elt.len = tlen;
- elt.flags |= RSPAMD_STAT_TOKEN_FLAG_UNIGRAM;
- g_array_append_val (ar, elt);
- }
- }
- }
- }
- }
}
}
}
+
+ lua_settop (L, 0);
}
- lua_settop (L, 0);
- st_ctx->tokenizer->tokenize_func (st_ctx,
- task->task_pool,
- ar,
- TRUE,
- "META:",
- task->tokens);
+
+ if (ar->len > 0) {
+ st_ctx->tokenizer->tokenize_func (st_ctx,
+ task,
+ ar,
+ TRUE,
+ "M",
+ task->tokens);
+ }
rspamd_mempool_add_destructor (task->task_pool,
rspamd_array_free_hard, ar);
@@ -313,10 +125,7 @@ rspamd_stat_process_tokenize (struct rspamd_stat_ctx *st_ctx,
{
struct rspamd_mime_text_part *part;
rspamd_cryptobox_hash_state_t hst;
- rspamd_stat_token_t *tok;
rspamd_token_t *st_tok;
- GArray *words;
- gchar *sub = NULL;
guint i, reserved_len = 0;
gdouble *pdiff;
guchar hout[rspamd_cryptobox_HASHBYTES];
@@ -347,55 +156,26 @@ rspamd_stat_process_tokenize (struct rspamd_stat_ctx *st_ctx,
part = g_ptr_array_index (task->text_parts, i);
if (!IS_PART_EMPTY (part) && part->utf_words != NULL) {
- st_ctx->tokenizer->tokenize_func (st_ctx, task->task_pool,
+ st_ctx->tokenizer->tokenize_func (st_ctx, task,
part->utf_words, IS_PART_UTF (part),
NULL, task->tokens);
}
if (pdiff != NULL && (1.0 - *pdiff) * 100.0 > similarity_treshold) {
- msg_debug_task ("message has two common parts (%.2f), so skip the last one",
+ msg_debug_bayes ("message has two common parts (%.2f), so skip the last one",
*pdiff);
break;
}
}
- if (task->subject != NULL) {
- sub = task->subject;
- }
-
- if (sub != NULL) {
- UText utxt = UTEXT_INITIALIZER;
- UErrorCode uc_err = U_ZERO_ERROR;
- gsize slen = strlen (sub);
-
- utext_openUTF8 (&utxt,
- sub,
- slen,
- &uc_err);
-
- words = rspamd_tokenize_text (sub, slen, &utxt, RSPAMD_TOKENIZE_UTF,
- NULL, NULL, NULL);
-
- if (words != NULL) {
-
- for (i = 0; i < words->len; i ++) {
- tok = &g_array_index (words, rspamd_stat_token_t, i);
- tok->flags |= RSPAMD_STAT_TOKEN_FLAG_SUBJECT;
- }
-
- st_ctx->tokenizer->tokenize_func (st_ctx,
- task->task_pool,
- words,
- TRUE,
- "SUBJECT",
- task->tokens);
-
- rspamd_mempool_add_destructor (task->task_pool,
- rspamd_array_free_hard, words);
- }
-
- utext_close (&utxt);
+ if (task->meta_words != NULL) {
+ st_ctx->tokenizer->tokenize_func (st_ctx,
+ task,
+ task->meta_words,
+ TRUE,
+ "SUBJECT",
+ task->tokens);
}
rspamd_stat_tokenize_parts_metadata (st_ctx, task);
@@ -445,10 +225,10 @@ rspamd_stat_preprocess (struct rspamd_stat_ctx *st_ctx,
continue;
}
- if (!rspamd_symbols_cache_is_symbol_enabled (task, task->cfg->cache,
+ if (!rspamd_symcache_is_symbol_enabled (task, task->cfg->cache,
st->stcf->symbol)) {
g_ptr_array_index (task->stat_runtimes, i) = NULL;
- msg_debug_task ("symbol %s is disabled, skip classification",
+ msg_debug_bayes ("symbol %s is disabled, skip classification",
st->stcf->symbol);
continue;
}
@@ -550,6 +330,12 @@ rspamd_stat_classifiers_process (struct rspamd_stat_ctx *st_ctx,
return;
}
+ for (i = 0; i < st_ctx->classifiers->len; i++) {
+ cl = g_ptr_array_index (st_ctx->classifiers, i);
+ cl->spam_learns = 0;
+ cl->ham_learns = 0;
+ }
+
for (i = 0; i < st_ctx->statfiles->len; i++) {
st = g_ptr_array_index (st_ctx->statfiles, i);
cl = st->classifier;
@@ -591,7 +377,7 @@ rspamd_stat_classifiers_process (struct rspamd_stat_ctx *st_ctx,
if (bk_run == NULL) {
skip = TRUE;
- msg_debug_task ("disable classifier %s as statfile symbol %s is disabled",
+ msg_debug_bayes ("disable classifier %s as statfile symbol %s is disabled",
cl->cfg->name, st->stcf->symbol);
break;
}
@@ -600,7 +386,7 @@ rspamd_stat_classifiers_process (struct rspamd_stat_ctx *st_ctx,
if (!skip) {
if (cl->cfg->min_tokens > 0 && task->tokens->len < cl->cfg->min_tokens) {
- msg_debug_task (
+ msg_debug_bayes (
"<%s> contains less tokens than required for %s classifier: "
"%ud < %ud",
task->message_id,
@@ -610,7 +396,7 @@ rspamd_stat_classifiers_process (struct rspamd_stat_ctx *st_ctx,
continue;
}
else if (cl->cfg->max_tokens > 0 && task->tokens->len > cl->cfg->max_tokens) {
- msg_debug_task (
+ msg_debug_bayes (
"<%s> contains more tokens than allowed for %s classifier: "
"%ud > %ud",
task->message_id,
@@ -740,7 +526,7 @@ rspamd_stat_classifiers_learn (struct rspamd_stat_ctx *st_ctx,
if ((task->flags & RSPAMD_TASK_FLAG_ALREADY_LEARNED) && err != NULL &&
*err == NULL) {
/* Do not learn twice */
- g_set_error (err, rspamd_stat_quark (), 404, "<%s> has been already "
+ g_set_error (err, rspamd_stat_quark (), 208, "<%s> has been already "
"learned as %s, ignore it", task->message_id,
spam ? "spam" : "ham");
@@ -849,7 +635,7 @@ rspamd_stat_classifiers_learn (struct rspamd_stat_ctx *st_ctx,
if (!learned && err && *err == NULL) {
if (too_large) {
- g_set_error (err, rspamd_stat_quark (), 400,
+ g_set_error (err, rspamd_stat_quark (), 204,
"<%s> contains more tokens than allowed for %s classifier: "
"%d > %d",
task->message_id,
@@ -858,7 +644,7 @@ rspamd_stat_classifiers_learn (struct rspamd_stat_ctx *st_ctx,
cl->cfg->max_tokens);
}
else if (too_small) {
- g_set_error (err, rspamd_stat_quark (), 400,
+ g_set_error (err, rspamd_stat_quark (), 204,
"<%s> contains less tokens than required for %s classifier: "
"%d < %d",
task->message_id,
@@ -867,7 +653,7 @@ rspamd_stat_classifiers_learn (struct rspamd_stat_ctx *st_ctx,
cl->cfg->min_tokens);
}
else if (conditionally_skipped) {
- g_set_error (err, rspamd_stat_quark (), 410,
+ g_set_error (err, rspamd_stat_quark (), 204,
"<%s> is skipped for %s classifier: "
"%s",
task->message_id,
@@ -1107,7 +893,7 @@ rspamd_stat_has_classifier_symbols (struct rspamd_task *task,
if (rspamd_task_find_symbol_result (task, st->stcf->symbol)) {
if (is_spam == !!st->stcf->is_spam) {
- msg_debug_task ("do not autolearn %s as symbol %s is already "
+ msg_debug_bayes ("do not autolearn %s as symbol %s is already "
"added", is_spam ? "spam" : "ham", st->stcf->symbol);
return TRUE;
diff --git a/src/libstat/tokenizers/osb.c b/src/libstat/tokenizers/osb.c
index f6f46c580..0b53f8af9 100644
--- a/src/libstat/tokenizers/osb.c
+++ b/src/libstat/tokenizers/osb.c
@@ -17,8 +17,10 @@
* OSB tokenizer
*/
+
#include "tokenizers.h"
#include "stat_internal.h"
+#include "libmime/lang_detection.h"
/* Size for features pipe */
#define DEFAULT_FEATURE_WINDOW_SIZE 5
@@ -259,11 +261,11 @@ struct token_pipe_entry {
gint
rspamd_tokenizer_osb (struct rspamd_stat_ctx *ctx,
- rspamd_mempool_t *pool,
- GArray *words,
- gboolean is_utf,
- const gchar *prefix,
- GPtrArray *result)
+ struct rspamd_task *task,
+ GArray *words,
+ gboolean is_utf,
+ const gchar *prefix,
+ GPtrArray *result)
{
rspamd_token_t *new_tok = NULL;
rspamd_stat_token_t *token;
@@ -302,23 +304,40 @@ rspamd_tokenizer_osb (struct rspamd_stat_ctx *ctx,
for (w = 0; w < words->len; w ++) {
token = &g_array_index (words, rspamd_stat_token_t, w);
token_flags = token->flags;
+ const gchar *begin;
+ gsize len;
+
+ if (token->flags &
+ (RSPAMD_STAT_TOKEN_FLAG_STOP_WORD|RSPAMD_STAT_TOKEN_FLAG_SKIPPED)) {
+ /* Skip stop/skipped words */
+ continue;
+ }
+
+ if (token->flags & RSPAMD_STAT_TOKEN_FLAG_TEXT) {
+ begin = token->stemmed.begin;
+ len = token->stemmed.len;
+ }
+ else {
+ begin = token->original.begin;
+ len = token->original.len;
+ }
if (osb_cf->ht == RSPAMD_OSB_HASH_COMPAT) {
rspamd_ftok_t ftok;
- ftok.begin = token->begin;
- ftok.len = token->len;
+ ftok.begin = begin;
+ ftok.len = len;
cur = rspamd_fstrhash_lc (&ftok, is_utf);
}
else {
/* We know that the words are normalized */
if (osb_cf->ht == RSPAMD_OSB_HASH_XXHASH) {
cur = rspamd_cryptobox_fast_hash_specific (RSPAMD_CRYPTOBOX_XXHASH64,
- token->begin, token->len, osb_cf->seed);
+ begin, len, osb_cf->seed);
}
else {
- rspamd_cryptobox_siphash ((guchar *)&cur, token->begin,
- token->len, osb_cf->sk);
+ rspamd_cryptobox_siphash ((guchar *)&cur, begin,
+ len, osb_cf->sk);
if (prefix) {
cur ^= seed;
@@ -327,7 +346,7 @@ rspamd_tokenizer_osb (struct rspamd_stat_ctx *ctx,
}
if (token_flags & RSPAMD_STAT_TOKEN_FLAG_UNIGRAM) {
- new_tok = rspamd_mempool_alloc0 (pool, token_size);
+ new_tok = rspamd_mempool_alloc0 (task->task_pool, token_size);
new_tok->flags = token_flags;
new_tok->t1 = token;
new_tok->t2 = token;
@@ -339,7 +358,7 @@ rspamd_tokenizer_osb (struct rspamd_stat_ctx *ctx,
}
#define ADD_TOKEN do {\
- new_tok = rspamd_mempool_alloc0 (pool, token_size); \
+ new_tok = rspamd_mempool_alloc0 (task->task_pool, token_size); \
new_tok->flags = token_flags; \
new_tok->t1 = hashpipe[0].t; \
new_tok->t2 = hashpipe[i].t; \
@@ -354,7 +373,7 @@ rspamd_tokenizer_osb (struct rspamd_stat_ctx *ctx,
else { \
new_tok->data = hashpipe[0].h * primes[0] + hashpipe[i].h * primes[i << 1]; \
} \
- new_tok->window_idx = i + 1; \
+ new_tok->window_idx = i; \
g_ptr_array_add (result, new_tok); \
} while(0)
@@ -375,7 +394,9 @@ rspamd_tokenizer_osb (struct rspamd_stat_ctx *ctx,
processed++;
for (i = 1; i < window_size; i++) {
- ADD_TOKEN;
+ if (!(hashpipe[i].t->flags & RSPAMD_STAT_TOKEN_FLAG_EXCEPTION)) {
+ ADD_TOKEN;
+ }
}
}
}
diff --git a/src/libstat/tokenizers/tokenizers.c b/src/libstat/tokenizers/tokenizers.c
index c8e8e44df..acbbcf2f0 100644
--- a/src/libstat/tokenizers/tokenizers.c
+++ b/src/libstat/tokenizers/tokenizers.c
@@ -20,11 +20,19 @@
#include "rspamd.h"
#include "tokenizers.h"
#include "stat_internal.h"
-#include "../../../contrib/mumhash/mum.h"
+#include "contrib/mumhash/mum.h"
+#include "libmime/lang_detection.h"
+#include "libstemmer.h"
+
#include <unicode/utf8.h>
#include <unicode/uchar.h>
#include <unicode/uiter.h>
#include <unicode/ubrk.h>
+#include <unicode/ucnv.h>
+#if U_ICU_VERSION_MAJOR_NUM >= 44
+#include <unicode/unorm2.h>
+#endif
+
#include <math.h>
typedef gboolean (*token_get_function) (rspamd_stat_token_t * buf, gchar const **pos,
@@ -80,33 +88,33 @@ rspamd_tokenizer_get_word_raw (rspamd_stat_token_t * buf,
ex = (*exceptions)->data;
}
- if (token->begin == NULL || *cur == NULL) {
+ if (token->original.begin == NULL || *cur == NULL) {
if (ex != NULL) {
if (ex->pos == 0) {
- token->begin = buf->begin + ex->len;
- token->len = ex->len;
+ token->original.begin = buf->original.begin + ex->len;
+ token->original.len = ex->len;
token->flags = RSPAMD_STAT_TOKEN_FLAG_EXCEPTION;
}
else {
- token->begin = buf->begin;
- token->len = 0;
+ token->original.begin = buf->original.begin;
+ token->original.len = 0;
}
}
else {
- token->begin = buf->begin;
- token->len = 0;
+ token->original.begin = buf->original.begin;
+ token->original.len = 0;
}
- *cur = token->begin;
+ *cur = token->original.begin;
}
- token->len = 0;
+ token->original.len = 0;
- pos = *cur - buf->begin;
- if (pos >= buf->len) {
+ pos = *cur - buf->original.begin;
+ if (pos >= buf->original.len) {
return FALSE;
}
- remain = buf->len - pos;
+ remain = buf->original.len - pos;
p = *cur;
/* Skip non delimiters symbols */
@@ -122,7 +130,7 @@ rspamd_tokenizer_get_word_raw (rspamd_stat_token_t * buf,
remain--;
} while (remain > 0 && t_delimiters[(guchar)*p]);
- token->begin = p;
+ token->original.begin = p;
while (remain > 0 && !t_delimiters[(guchar)*p]) {
if (ex != NULL && ex->pos == pos) {
@@ -130,7 +138,7 @@ rspamd_tokenizer_get_word_raw (rspamd_stat_token_t * buf,
*cur = p + ex->len;
return TRUE;
}
- token->len++;
+ token->original.len++;
pos++;
remain--;
p++;
@@ -141,7 +149,7 @@ rspamd_tokenizer_get_word_raw (rspamd_stat_token_t * buf,
}
if (rl) {
- *rl = token->len;
+ *rl = token->original.len;
}
token->flags = RSPAMD_STAT_TOKEN_FLAG_TEXT;
@@ -164,12 +172,12 @@ rspamd_tokenize_check_limit (gboolean decay,
static const gdouble avg_word_len = 6.0;
if (!decay) {
- if (token->len >= sizeof (guint64)) {
+ if (token->original.len >= sizeof (guint64)) {
#ifdef _MUM_UNALIGNED_ACCESS
- *hv = mum_hash_step (*hv, *(guint64 *)token->begin);
+ *hv = mum_hash_step (*hv, *(guint64 *)token->original.begin);
#else
guint64 tmp;
- memcpy (&tmp, token->begin, sizeof (tmp));
+ memcpy (&tmp, token->original.begin, sizeof (tmp));
*hv = mum_hash_step (*hv, tmp);
#endif
}
@@ -221,7 +229,7 @@ rspamd_utf_word_valid (const gchar *text, const gchar *end,
U8_NEXT (text, start, finish, c);
- if (u_isalnum (c)) {
+ if (u_isJavaIDPart (c)) {
return TRUE;
}
@@ -237,13 +245,51 @@ rspamd_utf_word_valid (const gchar *text, const gchar *end,
} \
} while(0)
+static inline void
+rspamd_tokenize_exception (struct rspamd_process_exception *ex, GArray *res)
+{
+ rspamd_stat_token_t token;
+
+ memset (&token, 0, sizeof (token));
+
+ if (ex->type == RSPAMD_EXCEPTION_GENERIC) {
+ token.original.begin = "!!EX!!";
+ token.original.len = sizeof ("!!EX!!") - 1;
+ token.flags = RSPAMD_STAT_TOKEN_FLAG_EXCEPTION;
+
+ g_array_append_val (res, token);
+ token.flags = 0;
+ }
+ else if (ex->type == RSPAMD_EXCEPTION_URL) {
+ struct rspamd_url *uri;
+
+ uri = ex->ptr;
+
+ if (uri && uri->tldlen > 0) {
+ token.original.begin = uri->tld;
+ token.original.len = uri->tldlen;
+
+ }
+ else {
+ token.original.begin = "!!EX!!";
+ token.original.len = sizeof ("!!EX!!") - 1;
+ }
+
+ token.flags = RSPAMD_STAT_TOKEN_FLAG_EXCEPTION;
+ g_array_append_val (res, token);
+ token.flags = 0;
+ }
+}
+
+
GArray *
rspamd_tokenize_text (const gchar *text, gsize len,
const UText *utxt,
enum rspamd_tokenize_type how,
struct rspamd_config *cfg,
GList *exceptions,
- guint64 *hash)
+ guint64 *hash,
+ GArray *cur_words)
{
rspamd_stat_token_t token, buf;
const gchar *pos = NULL;
@@ -257,15 +303,14 @@ rspamd_tokenize_text (const gchar *text, gsize len,
static UBreakIterator* bi = NULL;
if (text == NULL) {
- return NULL;
+ return cur_words;
}
- buf.begin = text;
- buf.len = len;
+ buf.original.begin = text;
+ buf.original.len = len;
buf.flags = 0;
- token.begin = NULL;
- token.len = 0;
- token.flags = 0;
+
+ memset (&token, 0, sizeof (token));
if (cfg != NULL) {
min_len = cfg->min_word_len;
@@ -274,31 +319,36 @@ rspamd_tokenize_text (const gchar *text, gsize len,
initial_size = word_decay * 2;
}
- res = g_array_sized_new (FALSE, FALSE, sizeof (rspamd_stat_token_t),
- initial_size);
+ if (!cur_words) {
+ res = g_array_sized_new (FALSE, FALSE, sizeof (rspamd_stat_token_t),
+ initial_size);
+ }
+ else {
+ res = cur_words;
+ }
if (G_UNLIKELY (how == RSPAMD_TOKENIZE_RAW || utxt == NULL)) {
while (rspamd_tokenizer_get_word_raw (&buf, &pos, &token, &cur, &l, FALSE)) {
if (l == 0 || (min_len > 0 && l < min_len) ||
(max_len > 0 && l > max_len)) {
- token.begin = pos;
+ token.original.begin = pos;
continue;
}
- if (token.len > 0 &&
+ if (token.original.len > 0 &&
rspamd_tokenize_check_limit (decay, word_decay, res->len,
&hv, &prob, &token, pos - text, len)) {
if (!decay) {
decay = TRUE;
}
else {
- token.begin = pos;
+ token.original.begin = pos;
continue;
}
}
g_array_append_val (res, token);
- token.begin = pos;
+ token.original.begin = pos;
}
}
else {
@@ -323,7 +373,7 @@ rspamd_tokenize_text (const gchar *text, gsize len,
while (p != UBRK_DONE) {
start_over:
- token.len = 0;
+ token.original.len = 0;
if (p > last) {
if (ex && cur) {
@@ -334,15 +384,7 @@ start_over:
while (cur && ex->pos <= last) {
/* We have an exception at the beginning, skip those */
last += ex->len;
-
- if (ex->type == RSPAMD_EXCEPTION_URL) {
- token.begin = "!!EX!!";
- token.len = sizeof ("!!EX!!") - 1;
- token.flags = RSPAMD_STAT_TOKEN_FLAG_EXCEPTION;
-
- g_array_append_val (res, token);
- token.flags = 0;
- }
+ rspamd_tokenize_exception (ex, res);
if (last > p) {
/* Exception spread over the boundaries */
@@ -363,8 +405,8 @@ start_over:
/* Append the first part */
if (rspamd_utf_word_valid (text, text + len, last,
ex->pos)) {
- token.begin = text + last;
- token.len = ex->pos - last;
+ token.original.begin = text + last;
+ token.original.len = ex->pos - last;
token.flags = 0;
g_array_append_val (res, token);
}
@@ -372,13 +414,7 @@ start_over:
/* Process the current exception */
last += ex->len + (ex->pos - last);
- if (ex->type == RSPAMD_EXCEPTION_URL) {
- token.begin = "!!EX!!";
- token.len = sizeof ("!!EX!!") - 1;
- token.flags = RSPAMD_STAT_TOKEN_FLAG_EXCEPTION;
-
- g_array_append_val (res, token);
- }
+ rspamd_tokenize_exception (ex, res);
if (last > p) {
/* Exception spread over the boundaries */
@@ -394,9 +430,10 @@ start_over:
}
else if (p > last) {
if (rspamd_utf_word_valid (text, text + len, last, p)) {
- token.begin = text + last;
- token.len = p - last;
- token.flags = RSPAMD_STAT_TOKEN_FLAG_TEXT;
+ token.original.begin = text + last;
+ token.original.len = p - last;
+ token.flags = RSPAMD_STAT_TOKEN_FLAG_TEXT |
+ RSPAMD_STAT_TOKEN_FLAG_UTF;
}
}
}
@@ -408,40 +445,43 @@ start_over:
}
if (rspamd_utf_word_valid (text, text + len, last, p)) {
- token.begin = text + last;
- token.len = p - last;
- token.flags = RSPAMD_STAT_TOKEN_FLAG_TEXT;
+ token.original.begin = text + last;
+ token.original.len = p - last;
+ token.flags = RSPAMD_STAT_TOKEN_FLAG_TEXT |
+ RSPAMD_STAT_TOKEN_FLAG_UTF;
}
}
else {
/* No exceptions within boundary */
if (rspamd_utf_word_valid (text, text + len, last, p)) {
- token.begin = text + last;
- token.len = p - last;
- token.flags = RSPAMD_STAT_TOKEN_FLAG_TEXT;
+ token.original.begin = text + last;
+ token.original.len = p - last;
+ token.flags = RSPAMD_STAT_TOKEN_FLAG_TEXT |
+ RSPAMD_STAT_TOKEN_FLAG_UTF;
}
}
}
else {
if (rspamd_utf_word_valid (text, text + len, last, p)) {
- token.begin = text + last;
- token.len = p - last;
- token.flags = RSPAMD_STAT_TOKEN_FLAG_TEXT;
+ token.original.begin = text + last;
+ token.original.len = p - last;
+ token.flags = RSPAMD_STAT_TOKEN_FLAG_TEXT |
+ RSPAMD_STAT_TOKEN_FLAG_UTF;
}
}
- if (token.len > 0 &&
+ if (token.original.len > 0 &&
rspamd_tokenize_check_limit (decay, word_decay, res->len,
&hv, &prob, &token, p, len)) {
if (!decay) {
decay = TRUE;
} else {
- token.len = 0;
+ token.flags |= RSPAMD_STAT_TOKEN_FLAG_SKIPPED;
}
}
}
- if (token.len > 0) {
+ if (token.original.len > 0) {
g_array_append_val (res, token);
}
@@ -463,6 +503,347 @@ start_over:
#undef SHIFT_EX
-/*
- * vi:ts=4
- */
+static void
+rspamd_add_metawords_from_str (const gchar *beg, gsize len,
+ struct rspamd_task *task)
+{
+ UText utxt = UTEXT_INITIALIZER;
+ UErrorCode uc_err = U_ZERO_ERROR;
+ guint i = 0;
+ UChar32 uc;
+ gboolean valid_utf = TRUE;
+
+ while (i < len) {
+ U8_NEXT (beg, i, len, uc);
+
+ if (((gint32) uc) < 0) {
+ valid_utf = FALSE;
+ break;
+ }
+
+#if U_ICU_VERSION_MAJOR_NUM < 50
+ if (u_isalpha (uc)) {
+ gint32 sc = ublock_getCode (uc);
+
+ if (sc == UBLOCK_THAI) {
+ valid_utf = FALSE;
+ msg_info_task ("enable workaround for Thai characters for old libicu");
+ break;
+ }
+ }
+#endif
+ }
+
+ if (valid_utf) {
+ utext_openUTF8 (&utxt,
+ beg,
+ len,
+ &uc_err);
+
+ task->meta_words = rspamd_tokenize_text (beg, len,
+ &utxt, RSPAMD_TOKENIZE_UTF,
+ task->cfg, NULL, NULL, task->meta_words);
+
+ utext_close (&utxt);
+ }
+ else {
+ task->meta_words = rspamd_tokenize_text (beg, len,
+ NULL, RSPAMD_TOKENIZE_RAW,
+ task->cfg, NULL, NULL, task->meta_words);
+ }
+}
+
+void
+rspamd_tokenize_meta_words (struct rspamd_task *task)
+{
+ guint i = 0;
+ rspamd_stat_token_t *tok;
+
+ if (task->subject) {
+ rspamd_add_metawords_from_str (task->subject, strlen (task->subject), task);
+ }
+
+ if (task->from_mime && task->from_mime->len > 0) {
+ struct rspamd_email_address *addr;
+
+ addr = g_ptr_array_index (task->from_mime, 0);
+
+ if (addr->name) {
+ rspamd_add_metawords_from_str (addr->name, strlen (addr->name), task);
+ }
+ }
+
+ if (task->meta_words != NULL) {
+ const gchar *language = NULL;
+
+ if (task->text_parts && task->text_parts->len > 0) {
+ struct rspamd_mime_text_part *tp = g_ptr_array_index (task->text_parts, 0);
+
+ if (tp->language) {
+ language = tp->language;
+ }
+ }
+
+ rspamd_normalize_words (task->meta_words, task->task_pool);
+ rspamd_stem_words (task->meta_words, task->task_pool, language,
+ task->lang_det);
+
+ for (i = 0; i < task->meta_words->len; i++) {
+ tok = &g_array_index (task->meta_words, rspamd_stat_token_t, i);
+ tok->flags |= RSPAMD_STAT_TOKEN_FLAG_HEADER;
+ }
+ }
+}
+
+static inline void
+rspamd_uchars_to_ucs32 (const UChar *src, gsize srclen,
+ rspamd_stat_token_t *tok,
+ rspamd_mempool_t *pool)
+{
+ UChar32 *dest, t, *d;
+ gint32 i = 0;
+
+ dest = rspamd_mempool_alloc (pool, srclen * sizeof (UChar32));
+ d = dest;
+
+ while (i < srclen) {
+ U16_NEXT_UNSAFE (src, i, t);
+
+ if (u_isgraph (t)) {
+ *d++ = u_tolower (t);
+ }
+ else {
+ /* Invisible spaces ! */
+ tok->flags |= RSPAMD_STAT_TOKEN_FLAG_INVISIBLE_SPACES;
+ }
+ }
+
+ tok->unicode.begin = dest;
+ tok->unicode.len = d - dest;
+}
+
+static inline void
+rspamd_ucs32_to_normalised (rspamd_stat_token_t *tok,
+ rspamd_mempool_t *pool)
+{
+ guint i, doff = 0;
+ gsize utflen = 0;
+ gchar *dest;
+ UChar32 t;
+
+ for (i = 0; i < tok->unicode.len; i ++) {
+ utflen += U8_LENGTH (tok->unicode.begin[i]);
+ }
+
+ dest = rspamd_mempool_alloc (pool, utflen + 1);
+
+ for (i = 0; i < tok->unicode.len; i ++) {
+ t = tok->unicode.begin[i];
+ U8_APPEND_UNSAFE (dest, doff, t);
+ }
+
+ g_assert (doff <= utflen);
+ dest[doff] = '\0';
+
+ tok->normalized.len = doff;
+ tok->normalized.begin = dest;
+}
+
+void
+rspamd_normalize_single_word (rspamd_stat_token_t *tok, rspamd_mempool_t *pool)
+{
+ UErrorCode uc_err = U_ZERO_ERROR;
+ UConverter *utf8_converter;
+ UChar tmpbuf[1024]; /* Assume that we have no longer words... */
+ gsize ulen;
+
+ utf8_converter = rspamd_get_utf8_converter ();
+
+ if (tok->flags & RSPAMD_STAT_TOKEN_FLAG_UTF) {
+ ulen = ucnv_toUChars (utf8_converter,
+ tmpbuf,
+ G_N_ELEMENTS (tmpbuf),
+ tok->original.begin,
+ tok->original.len,
+ &uc_err);
+
+ /* Now, we need to understand if we need to normalise the word */
+ if (!U_SUCCESS (uc_err)) {
+ tok->flags |= RSPAMD_STAT_TOKEN_FLAG_BROKEN_UNICODE;
+ tok->unicode.begin = NULL;
+ tok->unicode.len = 0;
+ tok->normalized.begin = NULL;
+ tok->normalized.len = 0;
+ }
+ else {
+#if U_ICU_VERSION_MAJOR_NUM >= 44
+ const UNormalizer2 *norm = rspamd_get_unicode_normalizer ();
+ gint32 end;
+
+ /* We can now check if we need to decompose */
+ end = unorm2_spanQuickCheckYes (norm, tmpbuf, ulen, &uc_err);
+
+ if (!U_SUCCESS (uc_err)) {
+ rspamd_uchars_to_ucs32 (tmpbuf, ulen, tok, pool);
+ tok->normalized.begin = NULL;
+ tok->normalized.len = 0;
+ tok->flags |= RSPAMD_STAT_TOKEN_FLAG_BROKEN_UNICODE;
+ }
+ else {
+ if (end == ulen) {
+ /* Already normalised, just lowercase */
+ rspamd_uchars_to_ucs32 (tmpbuf, ulen, tok, pool);
+ rspamd_ucs32_to_normalised (tok, pool);
+ }
+ else {
+ /* Perform normalization */
+ UChar normbuf[1024];
+
+ g_assert (end < G_N_ELEMENTS (normbuf));
+ /* First part */
+ memcpy (normbuf, tmpbuf, end * sizeof (UChar));
+ /* Second part */
+ ulen = unorm2_normalizeSecondAndAppend (norm,
+ normbuf, end,
+ G_N_ELEMENTS (normbuf),
+ tmpbuf + end,
+ ulen - end,
+ &uc_err);
+
+ if (!U_SUCCESS (uc_err)) {
+ if (uc_err != U_BUFFER_OVERFLOW_ERROR) {
+ msg_warn_pool_check ("cannot normalise text '%*s': %s",
+ (gint)tok->original.len, tok->original.begin,
+ u_errorName (uc_err));
+ rspamd_uchars_to_ucs32 (tmpbuf, ulen, tok, pool);
+ rspamd_ucs32_to_normalised (tok, pool);
+ tok->flags |= RSPAMD_STAT_TOKEN_FLAG_BROKEN_UNICODE;
+ }
+ }
+ else {
+ /* Copy normalised back */
+ rspamd_uchars_to_ucs32 (normbuf, ulen, tok, pool);
+ tok->flags |= RSPAMD_STAT_TOKEN_FLAG_NORMALISED;
+ rspamd_ucs32_to_normalised (tok, pool);
+ }
+ }
+ }
+#else
+ /* Legacy version with no unorm2 interface */
+ rspamd_uchars_to_ucs32 (tmpbuf, ulen, tok, pool);
+ rspamd_ucs32_to_normalised (tok, pool);
+#endif
+ }
+ }
+ else {
+ if (tok->flags & RSPAMD_STAT_TOKEN_FLAG_TEXT) {
+ /* Simple lowercase */
+ gchar *dest;
+
+ dest = rspamd_mempool_alloc (pool, tok->original.len + 1);
+ rspamd_strlcpy (dest, tok->original.begin, tok->original.len + 1);
+ rspamd_str_lc (dest, tok->original.len);
+ tok->normalized.len = tok->original.len;
+ tok->normalized.begin = dest;
+ }
+ }
+}
+
+void
+rspamd_normalize_words (GArray *words, rspamd_mempool_t *pool)
+{
+ rspamd_stat_token_t *tok;
+ guint i;
+
+ for (i = 0; i < words->len; i++) {
+ tok = &g_array_index (words, rspamd_stat_token_t, i);
+ rspamd_normalize_single_word (tok, pool);
+ }
+}
+
+void
+rspamd_stem_words (GArray *words, rspamd_mempool_t *pool,
+ const gchar *language,
+ struct rspamd_lang_detector *d)
+{
+ static GHashTable *stemmers = NULL;
+ struct sb_stemmer *stem = NULL;
+ guint i;
+ rspamd_stat_token_t *tok;
+ gchar *dest;
+ gsize dlen;
+
+ if (!stemmers) {
+ stemmers = g_hash_table_new (rspamd_strcase_hash,
+ rspamd_strcase_equal);
+ }
+
+ if (language && language[0] != '\0') {
+ stem = g_hash_table_lookup (stemmers, language);
+
+ if (stem == NULL) {
+
+ stem = sb_stemmer_new (language, "UTF_8");
+
+ if (stem == NULL) {
+ msg_debug_pool (
+ "<%s> cannot create lemmatizer for %s language",
+ language);
+ g_hash_table_insert (stemmers, g_strdup (language),
+ GINT_TO_POINTER (-1));
+ }
+ else {
+ g_hash_table_insert (stemmers, g_strdup (language),
+ stem);
+ }
+ }
+ else if (stem == GINT_TO_POINTER (-1)) {
+ /* Negative cache */
+ stem = NULL;
+ }
+ }
+ for (i = 0; i < words->len; i++) {
+ tok = &g_array_index (words, rspamd_stat_token_t, i);
+
+ if (tok->flags & RSPAMD_STAT_TOKEN_FLAG_UTF) {
+ if (stem) {
+ const gchar *stemmed = NULL;
+
+ stemmed = sb_stemmer_stem (stem,
+ tok->normalized.begin, tok->normalized.len);
+
+ dlen = stemmed ? strlen (stemmed) : 0;
+
+ if (dlen > 0) {
+ dest = rspamd_mempool_alloc (pool, dlen + 1);
+ memcpy (dest, stemmed, dlen);
+ dest[dlen] = '\0';
+ tok->stemmed.len = dlen;
+ tok->stemmed.begin = dest;
+ tok->flags |= RSPAMD_STAT_TOKEN_FLAG_STEMMED;
+ }
+ else {
+ /* Fallback */
+ tok->stemmed.len = tok->normalized.len;
+ tok->stemmed.begin = tok->normalized.begin;
+ }
+ }
+ else {
+ tok->stemmed.len = tok->normalized.len;
+ tok->stemmed.begin = tok->normalized.begin;
+ }
+
+ if (tok->stemmed.len > 0 && d != NULL &&
+ rspamd_language_detector_is_stop_word (d, tok->stemmed.begin, tok->stemmed.len)) {
+ tok->flags |= RSPAMD_STAT_TOKEN_FLAG_STOP_WORD;
+ }
+ }
+ else {
+ if (tok->flags & RSPAMD_STAT_TOKEN_FLAG_TEXT) {
+ /* Raw text, lowercase */
+ tok->stemmed.len = tok->normalized.len;
+ tok->stemmed.begin = tok->normalized.begin;
+ }
+ }
+ }
+} \ No newline at end of file
diff --git a/src/libstat/tokenizers/tokenizers.h b/src/libstat/tokenizers/tokenizers.h
index 6c538eafc..784426d31 100644
--- a/src/libstat/tokenizers/tokenizers.h
+++ b/src/libstat/tokenizers/tokenizers.h
@@ -18,13 +18,13 @@ struct rspamd_stat_ctx;
struct rspamd_stat_tokenizer {
gchar *name;
gpointer (*get_config) (rspamd_mempool_t *pool,
- struct rspamd_tokenizer_config *cf, gsize *len);
+ struct rspamd_tokenizer_config *cf, gsize *len);
gint (*tokenize_func)(struct rspamd_stat_ctx *ctx,
- rspamd_mempool_t *pool,
- GArray *words,
- gboolean is_utf,
- const gchar *prefix,
- GPtrArray *result);
+ struct rspamd_task *task,
+ GArray *words,
+ gboolean is_utf,
+ const gchar *prefix,
+ GPtrArray *result);
};
enum rspamd_tokenize_type {
@@ -43,20 +43,29 @@ GArray * rspamd_tokenize_text (const gchar *text, gsize len,
enum rspamd_tokenize_type how,
struct rspamd_config *cfg,
GList *exceptions,
- guint64 *hash);
+ guint64 *hash,
+ GArray *cur_words);
/* OSB tokenize function */
gint rspamd_tokenizer_osb (struct rspamd_stat_ctx *ctx,
- rspamd_mempool_t *pool,
- GArray *words,
- gboolean is_utf,
- const gchar *prefix,
- GPtrArray *result);
+ struct rspamd_task *task,
+ GArray *words,
+ gboolean is_utf,
+ const gchar *prefix,
+ GPtrArray *result);
gpointer rspamd_tokenizer_osb_get_config (rspamd_mempool_t *pool,
- struct rspamd_tokenizer_config *cf,
- gsize *len);
+ struct rspamd_tokenizer_config *cf,
+ gsize *len);
+struct rspamd_lang_detector;
+void rspamd_normalize_single_word (rspamd_stat_token_t *tok, rspamd_mempool_t *pool);
+void rspamd_normalize_words (GArray *words, rspamd_mempool_t *pool);
+void rspamd_stem_words (GArray *words, rspamd_mempool_t *pool,
+ const gchar *language,
+ struct rspamd_lang_detector *d);
+
+void rspamd_tokenize_meta_words (struct rspamd_task *task);
#endif
/*
* vi:ts=4