diff options
author | Vsevolod Stakhov <vsevolod@highsecure.ru> | 2016-01-08 17:21:12 +0000 |
---|---|---|
committer | Vsevolod Stakhov <vsevolod@highsecure.ru> | 2016-01-08 17:21:12 +0000 |
commit | 2d6b01959049355b4a75b9e2d667b6c885a17312 (patch) | |
tree | 353db462a9b68279ccbcaa268259a21536a4b82e /src | |
parent | 79402fe47a7d2474199f01d20c42d51f8f231336 (diff) | |
download | rspamd-2d6b01959049355b4a75b9e2d667b6c885a17312.tar.gz rspamd-2d6b01959049355b4a75b9e2d667b6c885a17312.zip |
Fix setting of number of learns.
Diffstat (limited to 'src')
-rw-r--r-- | src/libstat/backends/redis_backend.c | 88 | ||||
-rw-r--r-- | src/libstat/stat_process.c | 2 |
2 files changed, 64 insertions, 26 deletions
diff --git a/src/libstat/backends/redis_backend.c b/src/libstat/backends/redis_backend.c index a9fbe3993..4813fb88d 100644 --- a/src/libstat/backends/redis_backend.c +++ b/src/libstat/backends/redis_backend.c @@ -345,9 +345,6 @@ rspamd_redis_tokens_to_query (struct rspamd_task *task, GPtrArray *tokens, } } - rspamd_mempool_add_destructor (task->task_pool, - (rspamd_mempool_destruct_t)rspamd_fstring_free, out); - return out; } @@ -358,7 +355,6 @@ rspamd_redis_fin (gpointer data) struct redis_stat_runtime *rt = REDIS_RUNTIME (data); if (rt->conn_state != RSPAMD_REDIS_CONNECTED) { - redisAsyncFree (rt->redis); rt->conn_state = RSPAMD_REDIS_DISCONNECTED; } @@ -371,20 +367,6 @@ rspamd_redis_fin_learn (gpointer data) struct redis_stat_runtime *rt = REDIS_RUNTIME (data); if (rt->conn_state != RSPAMD_REDIS_CONNECTED) { - redisAsyncFree (rt->redis); - rt->conn_state = RSPAMD_REDIS_DISCONNECTED; - } - - event_del (&rt->timeout_event); -} - -static void -rspamd_redis_fin_set_learns (gpointer data) -{ - struct redis_stat_runtime *rt = REDIS_RUNTIME (data); - - if (rt->conn_state != RSPAMD_REDIS_CONNECTED) { - redisAsyncFree (rt->redis); rt->conn_state = RSPAMD_REDIS_DISCONNECTED; } @@ -403,7 +385,7 @@ rspamd_redis_timeout (gint fd, short what, gpointer d) rspamd_upstream_name (rt->selected)); rspamd_upstream_fail (rt->selected); rt->conn_state = RSPAMD_REDIS_TIMEDOUT; - rspamd_session_remove_event (task->s, rspamd_redis_fin, d); + redisAsyncFree (rt->redis); } /* Called when we have connected to the redis server and got stats */ @@ -413,15 +395,25 @@ rspamd_redis_connected (redisAsyncContext *c, gpointer r, gpointer priv) struct redis_stat_runtime *rt = REDIS_RUNTIME (priv); redisReply *reply = r; struct rspamd_task *task; + gulong val; task = rt->task; if (c->err == 0) { if (r != NULL) { - if (reply->type == REDIS_REPLY_INTEGER) { + if (G_LIKELY (reply->type == REDIS_REPLY_INTEGER)) { rt->learned = reply->integer; } + else if (reply->type == REDIS_REPLY_STRING) { + rspamd_strtoul (reply->str, reply->len, &val); + rt->learned = val; + } else { + if (reply->type != REDIS_REPLY_NIL) { + msg_err_task ("bad learned type for %s: %d", + rt->stcf->symbol, reply->type); + } + rt->learned = 0; } @@ -677,7 +669,7 @@ rspamd_redis_runtime (struct rspamd_task *task, event_add (&rt->timeout_event, &tv); redisAsyncCommand (rt->redis, rspamd_redis_connected, rt, "HGET %s %s", - rt->redis_object_expanded, "learned"); + rt->redis_object_expanded, "learns"); return rt; } @@ -718,6 +710,8 @@ rspamd_redis_process_tokens (struct rspamd_task *task, "HMGET", rt->redis_object_expanded, FALSE, -1, rt->stcf->clcf->flags & RSPAMD_FLAG_CLASSIFIER_INTEGER); g_assert (query != NULL); + rspamd_mempool_add_destructor (task->task_pool, + (rspamd_mempool_destruct_t)rspamd_fstring_free, query); ret = redisAsyncFormattedCommand (rt->redis, rspamd_redis_processed, rt, query->str, query->len); @@ -762,6 +756,7 @@ rspamd_redis_learn_tokens (struct rspamd_task *task, GPtrArray *tokens, struct timeval tv; rspamd_fstring_t *query; const gchar *redis_cmd; + rspamd_token_t *tok; gint ret; if (rt->conn_state != RSPAMD_REDIS_DISCONNECTED) { @@ -809,6 +804,46 @@ rspamd_redis_learn_tokens (struct rspamd_task *task, GPtrArray *tokens, rt->stcf->clcf->flags & RSPAMD_FLAG_CLASSIFIER_INTEGER); g_assert (query != NULL); + /* + * XXX: + * Dirty hack: we get a token and check if it's value is -1 or 1, so + * we could understand that we are learning or unlearning + */ + + tok = g_ptr_array_index (task->tokens, 0); + + if (tok->values[id] > 0) { + rspamd_printf_fstring (&query, "" + "*4\r\n" + "$7\r\n" + "HINCRBY\r\n" + "$%d\r\n" + "%s\r\n" + "$6\r\n" + "learns\r\n" + "$1\r\n" + "1\r\n", + strlen (rt->redis_object_expanded), + rt->redis_object_expanded); + } + else { + rspamd_printf_fstring (&query, "" + "*4\r\n" + "$7\r\n" + "HINCRBY\r\n" + "$%d\r\n" + "%s\r\n" + "$6\r\n" + "learns\r\n" + "$2\r\n" + "-1\r\n", + strlen (rt->redis_object_expanded), + rt->redis_object_expanded); + } + + rspamd_mempool_add_destructor (task->task_pool, + (rspamd_mempool_destruct_t)rspamd_fstring_free, query); + ret = redisAsyncFormattedCommand (rt->redis, rspamd_redis_learned, rt, query->str, query->len); @@ -819,6 +854,7 @@ rspamd_redis_learn_tokens (struct rspamd_task *task, GPtrArray *tokens, event_del (&rt->timeout_event); double_to_tv (rt->ctx->timeout, &tv); event_add (&rt->timeout_event, &tv); + rt->conn_state = RSPAMD_REDIS_CONNECTED; return TRUE; } @@ -850,7 +886,7 @@ rspamd_redis_total_learns (struct rspamd_task *task, gpointer runtime, { struct redis_stat_runtime *rt = REDIS_RUNTIME (runtime); - return 0; + return rt->learned; } gulong @@ -859,7 +895,8 @@ rspamd_redis_inc_learns (struct rspamd_task *task, gpointer runtime, { struct redis_stat_runtime *rt = REDIS_RUNTIME (runtime); - return 0; + /* XXX: may cause races */ + return rt->learned + 1; } gulong @@ -868,7 +905,8 @@ rspamd_redis_dec_learns (struct rspamd_task *task, gpointer runtime, { struct redis_stat_runtime *rt = REDIS_RUNTIME (runtime); - return 0; + /* XXX: may cause races */ + return rt->learned + 1; } gulong @@ -877,7 +915,7 @@ rspamd_redis_learns (struct rspamd_task *task, gpointer runtime, { struct redis_stat_runtime *rt = REDIS_RUNTIME (runtime); - return 0; + return rt->learned; } ucl_object_t * diff --git a/src/libstat/stat_process.c b/src/libstat/stat_process.c index 45a32b57c..81986f552 100644 --- a/src/libstat/stat_process.c +++ b/src/libstat/stat_process.c @@ -540,7 +540,7 @@ rspamd_stat_backends_learn (struct rspamd_stat_ctx *st_ctx, if (!!spam == !!st->stcf->is_spam) { st->backend->inc_learns (task, bk_run, st_ctx); } - else { + else if (task->flags & RSPAMD_TASK_FLAG_UNLEARN) { st->backend->dec_learns (task, bk_run, st_ctx); } } |