]> source.dussan.org Git - rspamd.git/commitdiff
Fix setting of number of learns.
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Fri, 8 Jan 2016 17:21:12 +0000 (17:21 +0000)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Fri, 8 Jan 2016 17:21:12 +0000 (17:21 +0000)
src/libstat/backends/redis_backend.c
src/libstat/stat_process.c

index a9fbe3993c9d14aabb8cc91e3ac2e11866d5bb9f..4813fb88d51502c96657d91df6b9df871c3bad7d 100644 (file)
@@ -345,9 +345,6 @@ rspamd_redis_tokens_to_query (struct rspamd_task *task, GPtrArray *tokens,
                }
        }
 
-       rspamd_mempool_add_destructor (task->task_pool,
-                       (rspamd_mempool_destruct_t)rspamd_fstring_free, out);
-
        return out;
 }
 
@@ -358,7 +355,6 @@ rspamd_redis_fin (gpointer data)
        struct redis_stat_runtime *rt = REDIS_RUNTIME (data);
 
        if (rt->conn_state != RSPAMD_REDIS_CONNECTED) {
-               redisAsyncFree (rt->redis);
                rt->conn_state = RSPAMD_REDIS_DISCONNECTED;
        }
 
@@ -371,20 +367,6 @@ rspamd_redis_fin_learn (gpointer data)
        struct redis_stat_runtime *rt = REDIS_RUNTIME (data);
 
        if (rt->conn_state != RSPAMD_REDIS_CONNECTED) {
-               redisAsyncFree (rt->redis);
-               rt->conn_state = RSPAMD_REDIS_DISCONNECTED;
-       }
-
-       event_del (&rt->timeout_event);
-}
-
-static void
-rspamd_redis_fin_set_learns (gpointer data)
-{
-       struct redis_stat_runtime *rt = REDIS_RUNTIME (data);
-
-       if (rt->conn_state != RSPAMD_REDIS_CONNECTED) {
-               redisAsyncFree (rt->redis);
                rt->conn_state = RSPAMD_REDIS_DISCONNECTED;
        }
 
@@ -403,7 +385,7 @@ rspamd_redis_timeout (gint fd, short what, gpointer d)
                        rspamd_upstream_name (rt->selected));
        rspamd_upstream_fail (rt->selected);
        rt->conn_state = RSPAMD_REDIS_TIMEDOUT;
-       rspamd_session_remove_event (task->s, rspamd_redis_fin, d);
+       redisAsyncFree (rt->redis);
 }
 
 /* Called when we have connected to the redis server and got stats */
@@ -413,15 +395,25 @@ rspamd_redis_connected (redisAsyncContext *c, gpointer r, gpointer priv)
        struct redis_stat_runtime *rt = REDIS_RUNTIME (priv);
        redisReply *reply = r;
        struct rspamd_task *task;
+       gulong val;
 
        task = rt->task;
 
        if (c->err == 0) {
                if (r != NULL) {
-                       if (reply->type == REDIS_REPLY_INTEGER) {
+                       if (G_LIKELY (reply->type == REDIS_REPLY_INTEGER)) {
                                rt->learned = reply->integer;
                        }
+                       else if (reply->type == REDIS_REPLY_STRING) {
+                               rspamd_strtoul (reply->str, reply->len, &val);
+                               rt->learned = val;
+                       }
                        else {
+                               if (reply->type != REDIS_REPLY_NIL) {
+                                       msg_err_task ("bad learned type for %s: %d",
+                                               rt->stcf->symbol, reply->type);
+                               }
+
                                rt->learned = 0;
                        }
 
@@ -677,7 +669,7 @@ rspamd_redis_runtime (struct rspamd_task *task,
        event_add (&rt->timeout_event, &tv);
 
        redisAsyncCommand (rt->redis, rspamd_redis_connected, rt, "HGET %s %s",
-                       rt->redis_object_expanded, "learned");
+                       rt->redis_object_expanded, "learns");
 
        return rt;
 }
@@ -718,6 +710,8 @@ rspamd_redis_process_tokens (struct rspamd_task *task,
                        "HMGET", rt->redis_object_expanded, FALSE, -1,
                        rt->stcf->clcf->flags & RSPAMD_FLAG_CLASSIFIER_INTEGER);
        g_assert (query != NULL);
+       rspamd_mempool_add_destructor (task->task_pool,
+                               (rspamd_mempool_destruct_t)rspamd_fstring_free, query);
 
        ret = redisAsyncFormattedCommand (rt->redis, rspamd_redis_processed, rt,
                        query->str, query->len);
@@ -762,6 +756,7 @@ rspamd_redis_learn_tokens (struct rspamd_task *task, GPtrArray *tokens,
        struct timeval tv;
        rspamd_fstring_t *query;
        const gchar *redis_cmd;
+       rspamd_token_t *tok;
        gint ret;
 
        if (rt->conn_state != RSPAMD_REDIS_DISCONNECTED) {
@@ -809,6 +804,46 @@ rspamd_redis_learn_tokens (struct rspamd_task *task, GPtrArray *tokens,
                        rt->stcf->clcf->flags & RSPAMD_FLAG_CLASSIFIER_INTEGER);
        g_assert (query != NULL);
 
+       /*
+        * XXX:
+        * Dirty hack: we get a token and check if it's value is -1 or 1, so
+        * we could understand that we are learning or unlearning
+        */
+
+       tok = g_ptr_array_index (task->tokens, 0);
+
+       if (tok->values[id] > 0) {
+               rspamd_printf_fstring (&query, ""
+                               "*4\r\n"
+                               "$7\r\n"
+                               "HINCRBY\r\n"
+                               "$%d\r\n"
+                               "%s\r\n"
+                               "$6\r\n"
+                               "learns\r\n"
+                               "$1\r\n"
+                               "1\r\n",
+                               strlen (rt->redis_object_expanded),
+                               rt->redis_object_expanded);
+       }
+       else {
+               rspamd_printf_fstring (&query, ""
+                               "*4\r\n"
+                               "$7\r\n"
+                               "HINCRBY\r\n"
+                               "$%d\r\n"
+                               "%s\r\n"
+                               "$6\r\n"
+                               "learns\r\n"
+                               "$2\r\n"
+                               "-1\r\n",
+                               strlen (rt->redis_object_expanded),
+                               rt->redis_object_expanded);
+       }
+
+       rspamd_mempool_add_destructor (task->task_pool,
+                               (rspamd_mempool_destruct_t)rspamd_fstring_free, query);
+
        ret = redisAsyncFormattedCommand (rt->redis, rspamd_redis_learned, rt,
                        query->str, query->len);
 
@@ -819,6 +854,7 @@ rspamd_redis_learn_tokens (struct rspamd_task *task, GPtrArray *tokens,
                event_del (&rt->timeout_event);
                double_to_tv (rt->ctx->timeout, &tv);
                event_add (&rt->timeout_event, &tv);
+               rt->conn_state = RSPAMD_REDIS_CONNECTED;
 
                return TRUE;
        }
@@ -850,7 +886,7 @@ rspamd_redis_total_learns (struct rspamd_task *task, gpointer runtime,
 {
        struct redis_stat_runtime *rt = REDIS_RUNTIME (runtime);
 
-       return 0;
+       return rt->learned;
 }
 
 gulong
@@ -859,7 +895,8 @@ rspamd_redis_inc_learns (struct rspamd_task *task, gpointer runtime,
 {
        struct redis_stat_runtime *rt = REDIS_RUNTIME (runtime);
 
-       return 0;
+       /* XXX: may cause races */
+       return rt->learned + 1;
 }
 
 gulong
@@ -868,7 +905,8 @@ rspamd_redis_dec_learns (struct rspamd_task *task, gpointer runtime,
 {
        struct redis_stat_runtime *rt = REDIS_RUNTIME (runtime);
 
-       return 0;
+       /* XXX: may cause races */
+       return rt->learned + 1;
 }
 
 gulong
@@ -877,7 +915,7 @@ rspamd_redis_learns (struct rspamd_task *task, gpointer runtime,
 {
        struct redis_stat_runtime *rt = REDIS_RUNTIME (runtime);
 
-       return 0;
+       return rt->learned;
 }
 
 ucl_object_t *
index 45a32b57c78f39eacb7c11237dccb995230acb45..81986f5520d8ec4d83693f45de2b943e603de738 100644 (file)
@@ -540,7 +540,7 @@ rspamd_stat_backends_learn (struct rspamd_stat_ctx *st_ctx,
                                if (!!spam == !!st->stcf->is_spam) {
                                        st->backend->inc_learns (task, bk_run, st_ctx);
                                }
-                               else {
+                               else if (task->flags & RSPAMD_TASK_FLAG_UNLEARN) {
                                        st->backend->dec_learns (task, bk_run, st_ctx);
                                }
                        }