]> source.dussan.org Git - rspamd.git/commitdiff
Implement redis learning
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Fri, 8 Jan 2016 15:16:52 +0000 (15:16 +0000)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Fri, 8 Jan 2016 15:16:52 +0000 (15:16 +0000)
src/libstat/backends/redis_backend.c

index 2eb0368410d6a4212655e5b81d004373df3ea523..cf65d91b04b78241463eea052805c95c6228fc79 100644 (file)
@@ -278,11 +278,11 @@ rspamd_redis_expand_object (const gchar *pattern,
 
 static rspamd_fstring_t *
 rspamd_redis_tokens_to_query (struct rspamd_task *task, GPtrArray *tokens,
-               const gchar *arg0, const gchar *arg1)
+               const gchar *arg0, const gchar *arg1, gboolean learn, gint idx)
 {
        rspamd_fstring_t *out;
        rspamd_token_t *tok;
-       gchar numbuf[64];
+       gchar n0[64], n1[64];
        guint i, l0, l1;
        guint64 num;
 
@@ -291,16 +291,36 @@ rspamd_redis_tokens_to_query (struct rspamd_task *task, GPtrArray *tokens,
        l0 = strlen (arg0);
        l1 = strlen (arg1);
        out = rspamd_fstring_sized_new (1024);
-       rspamd_printf_fstring (&out, "*%d\r\n$%d\r\n%s\r\n$%d\r\n%s\r\n",
-                       tokens->len + 2,
+       rspamd_printf_fstring (&out, "*%d\r\n"
+                       "$%d\r\n"
+                       "%s\r\n"
+                       "$%d\r\n"
+                       "%s\r\n",
+                       learn ? (tokens->len * 2 + 2) : (tokens->len + 2),
                        l0, arg0,
                        l1, arg1);
 
        for (i = 0; i < tokens->len; i ++) {
                tok = g_ptr_array_index (tokens, i);
                memcpy (&num, tok->data, sizeof (num));
-               l0 = rspamd_snprintf (numbuf, sizeof (numbuf), "%uL", num);
-               rspamd_printf_fstring (&out, "$%d\r\n%s\r\n", l0, numbuf);
+               l0 = rspamd_snprintf (n0, sizeof (n0), "%uL", num);
+
+               if (learn) {
+                       if (tok->values[idx] == (guint64)tok->values[idx]) {
+                               l1 = rspamd_snprintf (n1, sizeof (n1), "%uL",
+                                               (guint64)tok->values[idx]);
+                       }
+                       else {
+                               l1 = rspamd_snprintf (n1, sizeof (n1), "%f",
+                                               (guint64)tok->values[idx]);
+                       }
+
+                       rspamd_printf_fstring (&out, "$%d\r\n%s\r\n"
+                                       "$%d\r\n%s\r\n", l0, n0, l1, n1);
+               }
+               else {
+                       rspamd_printf_fstring (&out, "$%d\r\n%s\r\n", l0, n0);
+               }
        }
 
        rspamd_mempool_add_destructor (task->task_pool,
@@ -317,7 +337,35 @@ rspamd_redis_fin (gpointer data)
 
        if (rt->conn_state != RSPAMD_REDIS_CONNECTED) {
                redisAsyncFree (rt->redis);
+               rt->conn_state = RSPAMD_REDIS_DISCONNECTED;
+       }
+
+       event_del (&rt->timeout_event);
+}
+
+static void
+rspamd_redis_fin_learn (gpointer data)
+{
+       struct redis_stat_runtime *rt = REDIS_RUNTIME (data);
+
+       if (rt->conn_state != RSPAMD_REDIS_CONNECTED) {
+               redisAsyncFree (rt->redis);
+               rt->conn_state = RSPAMD_REDIS_DISCONNECTED;
+       }
+
+       event_del (&rt->timeout_event);
+}
+
+static void
+rspamd_redis_fin_set_learns (gpointer data)
+{
+       struct redis_stat_runtime *rt = REDIS_RUNTIME (data);
+
+       if (rt->conn_state != RSPAMD_REDIS_CONNECTED) {
+               redisAsyncFree (rt->redis);
+               rt->conn_state = RSPAMD_REDIS_DISCONNECTED;
        }
+
        event_del (&rt->timeout_event);
 }
 
@@ -433,6 +481,27 @@ rspamd_redis_processed (redisAsyncContext *c, gpointer r, gpointer priv)
        }
 }
 
+/* Called when we have set tokens during learning */
+static void
+rspamd_redis_learned (redisAsyncContext *c, gpointer r, gpointer priv)
+{
+       struct redis_stat_runtime *rt = REDIS_RUNTIME (priv);
+       struct rspamd_task *task;
+
+       task = rt->task;
+
+       if (c->err == 0) {
+               rspamd_upstream_ok (rt->selected);
+               rspamd_session_remove_event (task->s, rspamd_redis_fin_learn, rt);
+       }
+       else {
+               msg_err_task ("error getting reply from redis server %s: %s",
+                               rspamd_upstream_name (rt->selected), c->errstr);
+               rspamd_upstream_fail (rt->selected);
+               rspamd_session_remove_event (task->s, rspamd_redis_fin_learn, rt);
+       }
+}
+
 gpointer
 rspamd_redis_init (struct rspamd_stat_ctx *ctx,
                struct rspamd_config *cfg, struct rspamd_statfile *st)
@@ -520,7 +589,7 @@ rspamd_redis_runtime (struct rspamd_task *task,
        g_assert (stcf != NULL);
 
        if (learn && ctx->write_servers == NULL) {
-               msg_err ("no write servers defined for %s, cannot learn", stcf->symbol);
+               msg_err_task ("no write servers defined for %s, cannot learn", stcf->symbol);
                return NULL;
        }
 
@@ -538,7 +607,7 @@ rspamd_redis_runtime (struct rspamd_task *task,
        }
 
        if (up == NULL) {
-               msg_err ("no upstreams reachable");
+               msg_err_task ("no upstreams reachable");
                return NULL;
        }
 
@@ -598,13 +667,14 @@ rspamd_redis_process_tokens (struct rspamd_task *task,
        struct timeval tv;
        gint ret;
 
-       if (tokens == NULL || tokens->len == 0 || rt->redis == NULL) {
+       if (tokens == NULL || tokens->len == 0 || rt->redis == NULL ||
+                       rt->conn_state != RSPAMD_REDIS_CONNECTED) {
                return FALSE;
        }
 
        rt->id = id;
        query = rspamd_redis_tokens_to_query (task, tokens,
-                       "HMGET", rt->redis_object_expanded);
+                       "HMGET", rt->redis_object_expanded, FALSE, -1);
        g_assert (query != NULL);
 
        ret = redisAsyncFormattedCommand (rt->redis, rspamd_redis_processed, rt,
@@ -621,7 +691,6 @@ rspamd_redis_process_tokens (struct rspamd_task *task,
        }
        else {
                msg_err_task ("call to redis failed: %s", rt->redis->errstr);
-               g_assert (0);
        }
 
        return FALSE;
@@ -632,6 +701,13 @@ rspamd_redis_finalize_process (struct rspamd_task *task, gpointer runtime,
                gpointer ctx)
 {
        struct redis_stat_runtime *rt = REDIS_RUNTIME (runtime);
+
+       if (rt->conn_state == RSPAMD_REDIS_CONNECTED) {
+               event_del (&rt->timeout_event);
+               redisAsyncFree (rt->redis);
+
+               rt->conn_state = RSPAMD_REDIS_DISCONNECTED;
+       }
 }
 
 gboolean
@@ -639,6 +715,65 @@ rspamd_redis_learn_tokens (struct rspamd_task *task, GPtrArray *tokens,
                gint id, gpointer p)
 {
        struct redis_stat_runtime *rt = REDIS_RUNTIME (p);
+       struct upstream *up;
+       rspamd_inet_addr_t *addr;
+       struct timeval tv;
+       rspamd_fstring_t *query;
+       gint ret;
+
+       if (rt->conn_state != RSPAMD_REDIS_DISCONNECTED) {
+               /* We are likely in some bad state */
+               msg_err_task ("invalid state for function: %d", rt->conn_state);
+
+               return FALSE;
+       }
+
+       up = rspamd_upstream_get (rt->ctx->write_servers,
+                       RSPAMD_UPSTREAM_MASTER_SLAVE,
+                       NULL,
+                       0);
+
+       if (up == NULL) {
+               msg_err_task ("no upstreams reachable");
+               return FALSE;
+       }
+
+       rt->selected = up;
+
+       addr = rspamd_upstream_addr (up);
+       g_assert (addr != NULL);
+       rt->redis = redisAsyncConnect (rspamd_inet_address_to_string (addr),
+                       rspamd_inet_address_get_port (addr));
+       g_assert (rt->redis != NULL);
+
+       redisLibeventAttach (rt->redis, task->ev_base);
+
+       event_set (&rt->timeout_event, -1, EV_TIMEOUT, rspamd_redis_timeout, rt);
+       event_base_set (task->ev_base, &rt->timeout_event);
+       double_to_tv (rt->ctx->timeout, &tv);
+       event_add (&rt->timeout_event, &tv);
+
+       rt->id = id;
+       query = rspamd_redis_tokens_to_query (task, tokens,
+                       "HMSET", rt->redis_object_expanded, TRUE, id);
+       g_assert (query != NULL);
+
+       ret = redisAsyncFormattedCommand (rt->redis, rspamd_redis_learned, rt,
+                       query->str, query->len);
+
+       if (ret == REDIS_OK) {
+               rspamd_session_add_event (task->s, rspamd_redis_fin_learn, rt,
+                               rspamd_redis_stat_quark ());
+               /* Reset timeout */
+               event_del (&rt->timeout_event);
+               double_to_tv (rt->ctx->timeout, &tv);
+               event_add (&rt->timeout_event, &tv);
+
+               return TRUE;
+       }
+       else {
+               msg_err_task ("call to redis failed: %s", rt->redis->errstr);
+       }
 
        return FALSE;
 }
@@ -649,6 +784,13 @@ rspamd_redis_finalize_learn (struct rspamd_task *task, gpointer runtime,
                gpointer ctx)
 {
        struct redis_stat_runtime *rt = REDIS_RUNTIME (runtime);
+
+       if (rt->conn_state == RSPAMD_REDIS_CONNECTED) {
+               event_del (&rt->timeout_event);
+               redisAsyncFree (rt->redis);
+
+               rt->conn_state = RSPAMD_REDIS_DISCONNECTED;
+       }
 }
 
 gulong