summaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
authorVsevolod Stakhov <vsevolod@highsecure.ru>2016-01-08 17:21:12 +0000
committerVsevolod Stakhov <vsevolod@highsecure.ru>2016-01-08 17:21:12 +0000
commit2d6b01959049355b4a75b9e2d667b6c885a17312 (patch)
tree353db462a9b68279ccbcaa268259a21536a4b82e /src
parent79402fe47a7d2474199f01d20c42d51f8f231336 (diff)
downloadrspamd-2d6b01959049355b4a75b9e2d667b6c885a17312.tar.gz
rspamd-2d6b01959049355b4a75b9e2d667b6c885a17312.zip
Fix setting of number of learns.
Diffstat (limited to 'src')
-rw-r--r--src/libstat/backends/redis_backend.c88
-rw-r--r--src/libstat/stat_process.c2
2 files changed, 64 insertions, 26 deletions
diff --git a/src/libstat/backends/redis_backend.c b/src/libstat/backends/redis_backend.c
index a9fbe3993..4813fb88d 100644
--- a/src/libstat/backends/redis_backend.c
+++ b/src/libstat/backends/redis_backend.c
@@ -345,9 +345,6 @@ rspamd_redis_tokens_to_query (struct rspamd_task *task, GPtrArray *tokens,
}
}
- rspamd_mempool_add_destructor (task->task_pool,
- (rspamd_mempool_destruct_t)rspamd_fstring_free, out);
-
return out;
}
@@ -358,7 +355,6 @@ rspamd_redis_fin (gpointer data)
struct redis_stat_runtime *rt = REDIS_RUNTIME (data);
if (rt->conn_state != RSPAMD_REDIS_CONNECTED) {
- redisAsyncFree (rt->redis);
rt->conn_state = RSPAMD_REDIS_DISCONNECTED;
}
@@ -371,20 +367,6 @@ rspamd_redis_fin_learn (gpointer data)
struct redis_stat_runtime *rt = REDIS_RUNTIME (data);
if (rt->conn_state != RSPAMD_REDIS_CONNECTED) {
- redisAsyncFree (rt->redis);
- rt->conn_state = RSPAMD_REDIS_DISCONNECTED;
- }
-
- event_del (&rt->timeout_event);
-}
-
-static void
-rspamd_redis_fin_set_learns (gpointer data)
-{
- struct redis_stat_runtime *rt = REDIS_RUNTIME (data);
-
- if (rt->conn_state != RSPAMD_REDIS_CONNECTED) {
- redisAsyncFree (rt->redis);
rt->conn_state = RSPAMD_REDIS_DISCONNECTED;
}
@@ -403,7 +385,7 @@ rspamd_redis_timeout (gint fd, short what, gpointer d)
rspamd_upstream_name (rt->selected));
rspamd_upstream_fail (rt->selected);
rt->conn_state = RSPAMD_REDIS_TIMEDOUT;
- rspamd_session_remove_event (task->s, rspamd_redis_fin, d);
+ redisAsyncFree (rt->redis);
}
/* Called when we have connected to the redis server and got stats */
@@ -413,15 +395,25 @@ rspamd_redis_connected (redisAsyncContext *c, gpointer r, gpointer priv)
struct redis_stat_runtime *rt = REDIS_RUNTIME (priv);
redisReply *reply = r;
struct rspamd_task *task;
+ gulong val;
task = rt->task;
if (c->err == 0) {
if (r != NULL) {
- if (reply->type == REDIS_REPLY_INTEGER) {
+ if (G_LIKELY (reply->type == REDIS_REPLY_INTEGER)) {
rt->learned = reply->integer;
}
+ else if (reply->type == REDIS_REPLY_STRING) {
+ rspamd_strtoul (reply->str, reply->len, &val);
+ rt->learned = val;
+ }
else {
+ if (reply->type != REDIS_REPLY_NIL) {
+ msg_err_task ("bad learned type for %s: %d",
+ rt->stcf->symbol, reply->type);
+ }
+
rt->learned = 0;
}
@@ -677,7 +669,7 @@ rspamd_redis_runtime (struct rspamd_task *task,
event_add (&rt->timeout_event, &tv);
redisAsyncCommand (rt->redis, rspamd_redis_connected, rt, "HGET %s %s",
- rt->redis_object_expanded, "learned");
+ rt->redis_object_expanded, "learns");
return rt;
}
@@ -718,6 +710,8 @@ rspamd_redis_process_tokens (struct rspamd_task *task,
"HMGET", rt->redis_object_expanded, FALSE, -1,
rt->stcf->clcf->flags & RSPAMD_FLAG_CLASSIFIER_INTEGER);
g_assert (query != NULL);
+ rspamd_mempool_add_destructor (task->task_pool,
+ (rspamd_mempool_destruct_t)rspamd_fstring_free, query);
ret = redisAsyncFormattedCommand (rt->redis, rspamd_redis_processed, rt,
query->str, query->len);
@@ -762,6 +756,7 @@ rspamd_redis_learn_tokens (struct rspamd_task *task, GPtrArray *tokens,
struct timeval tv;
rspamd_fstring_t *query;
const gchar *redis_cmd;
+ rspamd_token_t *tok;
gint ret;
if (rt->conn_state != RSPAMD_REDIS_DISCONNECTED) {
@@ -809,6 +804,46 @@ rspamd_redis_learn_tokens (struct rspamd_task *task, GPtrArray *tokens,
rt->stcf->clcf->flags & RSPAMD_FLAG_CLASSIFIER_INTEGER);
g_assert (query != NULL);
+ /*
+ * XXX:
+ * Dirty hack: we get a token and check if it's value is -1 or 1, so
+ * we could understand that we are learning or unlearning
+ */
+
+ tok = g_ptr_array_index (task->tokens, 0);
+
+ if (tok->values[id] > 0) {
+ rspamd_printf_fstring (&query, ""
+ "*4\r\n"
+ "$7\r\n"
+ "HINCRBY\r\n"
+ "$%d\r\n"
+ "%s\r\n"
+ "$6\r\n"
+ "learns\r\n"
+ "$1\r\n"
+ "1\r\n",
+ strlen (rt->redis_object_expanded),
+ rt->redis_object_expanded);
+ }
+ else {
+ rspamd_printf_fstring (&query, ""
+ "*4\r\n"
+ "$7\r\n"
+ "HINCRBY\r\n"
+ "$%d\r\n"
+ "%s\r\n"
+ "$6\r\n"
+ "learns\r\n"
+ "$2\r\n"
+ "-1\r\n",
+ strlen (rt->redis_object_expanded),
+ rt->redis_object_expanded);
+ }
+
+ rspamd_mempool_add_destructor (task->task_pool,
+ (rspamd_mempool_destruct_t)rspamd_fstring_free, query);
+
ret = redisAsyncFormattedCommand (rt->redis, rspamd_redis_learned, rt,
query->str, query->len);
@@ -819,6 +854,7 @@ rspamd_redis_learn_tokens (struct rspamd_task *task, GPtrArray *tokens,
event_del (&rt->timeout_event);
double_to_tv (rt->ctx->timeout, &tv);
event_add (&rt->timeout_event, &tv);
+ rt->conn_state = RSPAMD_REDIS_CONNECTED;
return TRUE;
}
@@ -850,7 +886,7 @@ rspamd_redis_total_learns (struct rspamd_task *task, gpointer runtime,
{
struct redis_stat_runtime *rt = REDIS_RUNTIME (runtime);
- return 0;
+ return rt->learned;
}
gulong
@@ -859,7 +895,8 @@ rspamd_redis_inc_learns (struct rspamd_task *task, gpointer runtime,
{
struct redis_stat_runtime *rt = REDIS_RUNTIME (runtime);
- return 0;
+ /* XXX: may cause races */
+ return rt->learned + 1;
}
gulong
@@ -868,7 +905,8 @@ rspamd_redis_dec_learns (struct rspamd_task *task, gpointer runtime,
{
struct redis_stat_runtime *rt = REDIS_RUNTIME (runtime);
- return 0;
+ /* XXX: may cause races */
+ return rt->learned + 1;
}
gulong
@@ -877,7 +915,7 @@ rspamd_redis_learns (struct rspamd_task *task, gpointer runtime,
{
struct redis_stat_runtime *rt = REDIS_RUNTIME (runtime);
- return 0;
+ return rt->learned;
}
ucl_object_t *
diff --git a/src/libstat/stat_process.c b/src/libstat/stat_process.c
index 45a32b57c..81986f552 100644
--- a/src/libstat/stat_process.c
+++ b/src/libstat/stat_process.c
@@ -540,7 +540,7 @@ rspamd_stat_backends_learn (struct rspamd_stat_ctx *st_ctx,
if (!!spam == !!st->stcf->is_spam) {
st->backend->inc_learns (task, bk_run, st_ctx);
}
- else {
+ else if (task->flags & RSPAMD_TASK_FLAG_UNLEARN) {
st->backend->dec_learns (task, bk_run, st_ctx);
}
}