]> source.dussan.org Git - rspamd.git/commitdiff
[CritFix] Fix classifier learning with Redis backend
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Tue, 4 Apr 2017 16:28:30 +0000 (17:28 +0100)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Tue, 4 Apr 2017 16:28:30 +0000 (17:28 +0100)
src/libstat/backends/redis_backend.c

index 5c66c00a76ec154af6d4448e5c20bfc5af6f5047..25007d8832c3b6363eb55e537dc701584c1b1edb 100644 (file)
@@ -335,14 +335,20 @@ rspamd_redis_maybe_auth (struct redis_stat_ctx *ctx, redisAsyncContext *redis)
 }
 
 static rspamd_fstring_t *
-rspamd_redis_tokens_to_query (struct rspamd_task *task, GPtrArray *tokens,
-               const gchar *arg0, const gchar *arg1, gboolean learn, gint idx,
+rspamd_redis_tokens_to_query (struct rspamd_task *task,
+               struct redis_stat_runtime *rt,
+               GPtrArray *tokens,
+               const gchar *arg0,
+               const gchar *arg1,
+               gboolean learn,
+               gint idx,
                gboolean intvals)
 {
        rspamd_fstring_t *out;
        rspamd_token_t *tok;
        gchar n0[64], n1[64];
        guint i, l0, l1, larg0, larg1;
+       gint ret;
 
        g_assert (tokens != NULL);
 
@@ -350,13 +356,28 @@ rspamd_redis_tokens_to_query (struct rspamd_task *task, GPtrArray *tokens,
        larg1 = strlen (arg1);
        out = rspamd_fstring_sized_new (1024);
 
-       if (!learn) {
+       if (learn) {
+               rspamd_printf_fstring (&out, "*1\r\n$5\r\nMULTI\r\n");
+
+               ret = redisAsyncFormattedCommand (rt->redis, NULL, NULL,
+                               out->str, out->len);
+
+               if (ret != REDIS_OK) {
+                       msg_err_task ("call to redis failed: %s", rt->redis->errstr);
+                       rspamd_fstring_free (out);
+
+                       return NULL;
+               }
+
+               out->len = 0;
+       }
+       else {
                rspamd_printf_fstring (&out, ""
-                               "*%d\r\n"
-                               "$%d\r\n"
-                               "%s\r\n"
-                               "$%d\r\n"
-                               "%s\r\n",
+                                               "*%d\r\n"
+                                               "$%d\r\n"
+                                               "%s\r\n"
+                                               "$%d\r\n"
+                                               "%s\r\n",
                                (tokens->len + 2),
                                larg0, arg0,
                                larg1, arg1);
@@ -391,6 +412,18 @@ rspamd_redis_tokens_to_query (struct rspamd_task *task, GPtrArray *tokens,
                                        "%s\r\n"
                                        "$%d\r\n"
                                        "%s\r\n", l0, n0, l1, n1);
+
+                       ret = redisAsyncFormattedCommand (rt->redis, NULL, NULL,
+                                       out->str, out->len);
+
+                       if (ret != REDIS_OK) {
+                               msg_err_task ("call to redis failed: %s", rt->redis->errstr);
+                               rspamd_fstring_free (out);
+
+                               return NULL;
+                       }
+
+                       out->len = 0;
                }
                else {
                        l0 = rspamd_snprintf (n0, sizeof (n0), "%uL", tok->data);
@@ -1214,7 +1247,7 @@ rspamd_redis_process_tokens (struct rspamd_task *task,
                double_to_tv (rt->ctx->timeout, &tv);
                event_add (&rt->timeout_event, &tv);
 
-               query = rspamd_redis_tokens_to_query (task, tokens,
+               query = rspamd_redis_tokens_to_query (task, rt, tokens,
                                "HMGET", rt->redis_object_expanded, FALSE, -1,
                                rt->stcf->clcf->flags & RSPAMD_FLAG_CLASSIFIER_INTEGER);
                g_assert (query != NULL);
@@ -1265,6 +1298,7 @@ rspamd_redis_learn_tokens (struct rspamd_task *task, GPtrArray *tokens,
        const gchar *redis_cmd;
        rspamd_token_t *tok;
        gint ret;
+       goffset off;
 
        up = rspamd_upstream_get (rt->ctx->write_servers,
                        RSPAMD_UPSTREAM_MASTER_SLAVE,
@@ -1307,10 +1341,11 @@ rspamd_redis_learn_tokens (struct rspamd_task *task, GPtrArray *tokens,
        }
 
        rt->id = id;
-       query = rspamd_redis_tokens_to_query (task, tokens,
+       query = rspamd_redis_tokens_to_query (task, rt, tokens,
                        redis_cmd, rt->redis_object_expanded, TRUE, id,
                        rt->stcf->clcf->flags & RSPAMD_FLAG_CLASSIFIER_INTEGER);
        g_assert (query != NULL);
+       query->len = 0;
 
        /*
         * XXX:
@@ -1349,12 +1384,23 @@ rspamd_redis_learn_tokens (struct rspamd_task *task, GPtrArray *tokens,
                                rt->redis_object_expanded);
        }
 
-       rspamd_mempool_add_destructor (task->task_pool,
-                               (rspamd_mempool_destruct_t)rspamd_fstring_free, query);
+       ret = redisAsyncFormattedCommand (rt->redis, NULL, NULL,
+                       query->str, query->len);
 
+       if (ret != REDIS_OK) {
+               msg_err_task ("call to redis failed: %s", rt->redis->errstr);
+               rspamd_fstring_free (query);
+
+               return FALSE;
+       }
+
+       off = query->len;
+       ret = rspamd_printf_fstring (&query, "*1\r\n$4\r\nEXEC\r\n");
        ret = redisAsyncFormattedCommand (rt->redis, rspamd_redis_learned, rt,
-                       query->str, query->len);
+                       query->str + off, ret);
 
+       rspamd_mempool_add_destructor (task->task_pool,
+                       (rspamd_mempool_destruct_t)rspamd_fstring_free, query);
        if (ret == REDIS_OK) {
                rspamd_session_add_event (task->s, rspamd_redis_fin_learn, rt,
                                rspamd_redis_stat_quark ());