diff options
author | Vsevolod Stakhov <vsevolod@highsecure.ru> | 2016-01-08 15:16:52 +0000 |
---|---|---|
committer | Vsevolod Stakhov <vsevolod@highsecure.ru> | 2016-01-08 15:16:52 +0000 |
commit | 3209a937dba3ee45e9317c502056d9035a06f951 (patch) | |
tree | 8c653f5215f4457e7e55440cf9f4adf9684097e4 /src | |
parent | 8fc2ba3e4e466db8f89f6a17152d66d13660b686 (diff) | |
download | rspamd-3209a937dba3ee45e9317c502056d9035a06f951.tar.gz rspamd-3209a937dba3ee45e9317c502056d9035a06f951.zip |
Implement redis learning
Diffstat (limited to 'src')
-rw-r--r-- | src/libstat/backends/redis_backend.c | 164 |
1 files changed, 153 insertions, 11 deletions
diff --git a/src/libstat/backends/redis_backend.c b/src/libstat/backends/redis_backend.c index 2eb036841..cf65d91b0 100644 --- a/src/libstat/backends/redis_backend.c +++ b/src/libstat/backends/redis_backend.c @@ -278,11 +278,11 @@ rspamd_redis_expand_object (const gchar *pattern, static rspamd_fstring_t * rspamd_redis_tokens_to_query (struct rspamd_task *task, GPtrArray *tokens, - const gchar *arg0, const gchar *arg1) + const gchar *arg0, const gchar *arg1, gboolean learn, gint idx) { rspamd_fstring_t *out; rspamd_token_t *tok; - gchar numbuf[64]; + gchar n0[64], n1[64]; guint i, l0, l1; guint64 num; @@ -291,16 +291,36 @@ rspamd_redis_tokens_to_query (struct rspamd_task *task, GPtrArray *tokens, l0 = strlen (arg0); l1 = strlen (arg1); out = rspamd_fstring_sized_new (1024); - rspamd_printf_fstring (&out, "*%d\r\n$%d\r\n%s\r\n$%d\r\n%s\r\n", - tokens->len + 2, + rspamd_printf_fstring (&out, "*%d\r\n" + "$%d\r\n" + "%s\r\n" + "$%d\r\n" + "%s\r\n", + learn ? (tokens->len * 2 + 2) : (tokens->len + 2), l0, arg0, l1, arg1); for (i = 0; i < tokens->len; i ++) { tok = g_ptr_array_index (tokens, i); memcpy (&num, tok->data, sizeof (num)); - l0 = rspamd_snprintf (numbuf, sizeof (numbuf), "%uL", num); - rspamd_printf_fstring (&out, "$%d\r\n%s\r\n", l0, numbuf); + l0 = rspamd_snprintf (n0, sizeof (n0), "%uL", num); + + if (learn) { + if (tok->values[idx] == (guint64)tok->values[idx]) { + l1 = rspamd_snprintf (n1, sizeof (n1), "%uL", + (guint64)tok->values[idx]); + } + else { + l1 = rspamd_snprintf (n1, sizeof (n1), "%f", + (guint64)tok->values[idx]); + } + + rspamd_printf_fstring (&out, "$%d\r\n%s\r\n" + "$%d\r\n%s\r\n", l0, n0, l1, n1); + } + else { + rspamd_printf_fstring (&out, "$%d\r\n%s\r\n", l0, n0); + } } rspamd_mempool_add_destructor (task->task_pool, @@ -317,7 +337,35 @@ rspamd_redis_fin (gpointer data) if (rt->conn_state != RSPAMD_REDIS_CONNECTED) { redisAsyncFree (rt->redis); + rt->conn_state = RSPAMD_REDIS_DISCONNECTED; + } + + event_del (&rt->timeout_event); +} + +static void +rspamd_redis_fin_learn (gpointer data) +{ + struct redis_stat_runtime *rt = REDIS_RUNTIME (data); + + if (rt->conn_state != RSPAMD_REDIS_CONNECTED) { + redisAsyncFree (rt->redis); + rt->conn_state = RSPAMD_REDIS_DISCONNECTED; + } + + event_del (&rt->timeout_event); +} + +static void +rspamd_redis_fin_set_learns (gpointer data) +{ + struct redis_stat_runtime *rt = REDIS_RUNTIME (data); + + if (rt->conn_state != RSPAMD_REDIS_CONNECTED) { + redisAsyncFree (rt->redis); + rt->conn_state = RSPAMD_REDIS_DISCONNECTED; } + event_del (&rt->timeout_event); } @@ -433,6 +481,27 @@ rspamd_redis_processed (redisAsyncContext *c, gpointer r, gpointer priv) } } +/* Called when we have set tokens during learning */ +static void +rspamd_redis_learned (redisAsyncContext *c, gpointer r, gpointer priv) +{ + struct redis_stat_runtime *rt = REDIS_RUNTIME (priv); + struct rspamd_task *task; + + task = rt->task; + + if (c->err == 0) { + rspamd_upstream_ok (rt->selected); + rspamd_session_remove_event (task->s, rspamd_redis_fin_learn, rt); + } + else { + msg_err_task ("error getting reply from redis server %s: %s", + rspamd_upstream_name (rt->selected), c->errstr); + rspamd_upstream_fail (rt->selected); + rspamd_session_remove_event (task->s, rspamd_redis_fin_learn, rt); + } +} + gpointer rspamd_redis_init (struct rspamd_stat_ctx *ctx, struct rspamd_config *cfg, struct rspamd_statfile *st) @@ -520,7 +589,7 @@ rspamd_redis_runtime (struct rspamd_task *task, g_assert (stcf != NULL); if (learn && ctx->write_servers == NULL) { - msg_err ("no write servers defined for %s, cannot learn", stcf->symbol); + msg_err_task ("no write servers defined for %s, cannot learn", stcf->symbol); return NULL; } @@ -538,7 +607,7 @@ rspamd_redis_runtime (struct rspamd_task *task, } if (up == NULL) { - msg_err ("no upstreams reachable"); + msg_err_task ("no upstreams reachable"); return NULL; } @@ -598,13 +667,14 @@ rspamd_redis_process_tokens (struct rspamd_task *task, struct timeval tv; gint ret; - if (tokens == NULL || tokens->len == 0 || rt->redis == NULL) { + if (tokens == NULL || tokens->len == 0 || rt->redis == NULL || + rt->conn_state != RSPAMD_REDIS_CONNECTED) { return FALSE; } rt->id = id; query = rspamd_redis_tokens_to_query (task, tokens, - "HMGET", rt->redis_object_expanded); + "HMGET", rt->redis_object_expanded, FALSE, -1); g_assert (query != NULL); ret = redisAsyncFormattedCommand (rt->redis, rspamd_redis_processed, rt, @@ -621,7 +691,6 @@ rspamd_redis_process_tokens (struct rspamd_task *task, } else { msg_err_task ("call to redis failed: %s", rt->redis->errstr); - g_assert (0); } return FALSE; @@ -632,6 +701,13 @@ rspamd_redis_finalize_process (struct rspamd_task *task, gpointer runtime, gpointer ctx) { struct redis_stat_runtime *rt = REDIS_RUNTIME (runtime); + + if (rt->conn_state == RSPAMD_REDIS_CONNECTED) { + event_del (&rt->timeout_event); + redisAsyncFree (rt->redis); + + rt->conn_state = RSPAMD_REDIS_DISCONNECTED; + } } gboolean @@ -639,6 +715,65 @@ rspamd_redis_learn_tokens (struct rspamd_task *task, GPtrArray *tokens, gint id, gpointer p) { struct redis_stat_runtime *rt = REDIS_RUNTIME (p); + struct upstream *up; + rspamd_inet_addr_t *addr; + struct timeval tv; + rspamd_fstring_t *query; + gint ret; + + if (rt->conn_state != RSPAMD_REDIS_DISCONNECTED) { + /* We are likely in some bad state */ + msg_err_task ("invalid state for function: %d", rt->conn_state); + + return FALSE; + } + + up = rspamd_upstream_get (rt->ctx->write_servers, + RSPAMD_UPSTREAM_MASTER_SLAVE, + NULL, + 0); + + if (up == NULL) { + msg_err_task ("no upstreams reachable"); + return FALSE; + } + + rt->selected = up; + + addr = rspamd_upstream_addr (up); + g_assert (addr != NULL); + rt->redis = redisAsyncConnect (rspamd_inet_address_to_string (addr), + rspamd_inet_address_get_port (addr)); + g_assert (rt->redis != NULL); + + redisLibeventAttach (rt->redis, task->ev_base); + + event_set (&rt->timeout_event, -1, EV_TIMEOUT, rspamd_redis_timeout, rt); + event_base_set (task->ev_base, &rt->timeout_event); + double_to_tv (rt->ctx->timeout, &tv); + event_add (&rt->timeout_event, &tv); + + rt->id = id; + query = rspamd_redis_tokens_to_query (task, tokens, + "HMSET", rt->redis_object_expanded, TRUE, id); + g_assert (query != NULL); + + ret = redisAsyncFormattedCommand (rt->redis, rspamd_redis_learned, rt, + query->str, query->len); + + if (ret == REDIS_OK) { + rspamd_session_add_event (task->s, rspamd_redis_fin_learn, rt, + rspamd_redis_stat_quark ()); + /* Reset timeout */ + event_del (&rt->timeout_event); + double_to_tv (rt->ctx->timeout, &tv); + event_add (&rt->timeout_event, &tv); + + return TRUE; + } + else { + msg_err_task ("call to redis failed: %s", rt->redis->errstr); + } return FALSE; } @@ -649,6 +784,13 @@ rspamd_redis_finalize_learn (struct rspamd_task *task, gpointer runtime, gpointer ctx) { struct redis_stat_runtime *rt = REDIS_RUNTIME (runtime); + + if (rt->conn_state == RSPAMD_REDIS_CONNECTED) { + event_del (&rt->timeout_event); + redisAsyncFree (rt->redis); + + rt->conn_state = RSPAMD_REDIS_DISCONNECTED; + } } gulong |