aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVsevolod Stakhov <vsevolod@highsecure.ru>2016-01-08 15:16:52 +0000
committerVsevolod Stakhov <vsevolod@highsecure.ru>2016-01-08 15:16:52 +0000
commit3209a937dba3ee45e9317c502056d9035a06f951 (patch)
tree8c653f5215f4457e7e55440cf9f4adf9684097e4
parent8fc2ba3e4e466db8f89f6a17152d66d13660b686 (diff)
downloadrspamd-3209a937dba3ee45e9317c502056d9035a06f951.tar.gz
rspamd-3209a937dba3ee45e9317c502056d9035a06f951.zip
Implement redis learning
-rw-r--r--src/libstat/backends/redis_backend.c164
1 files changed, 153 insertions, 11 deletions
diff --git a/src/libstat/backends/redis_backend.c b/src/libstat/backends/redis_backend.c
index 2eb036841..cf65d91b0 100644
--- a/src/libstat/backends/redis_backend.c
+++ b/src/libstat/backends/redis_backend.c
@@ -278,11 +278,11 @@ rspamd_redis_expand_object (const gchar *pattern,
static rspamd_fstring_t *
rspamd_redis_tokens_to_query (struct rspamd_task *task, GPtrArray *tokens,
- const gchar *arg0, const gchar *arg1)
+ const gchar *arg0, const gchar *arg1, gboolean learn, gint idx)
{
rspamd_fstring_t *out;
rspamd_token_t *tok;
- gchar numbuf[64];
+ gchar n0[64], n1[64];
guint i, l0, l1;
guint64 num;
@@ -291,16 +291,36 @@ rspamd_redis_tokens_to_query (struct rspamd_task *task, GPtrArray *tokens,
l0 = strlen (arg0);
l1 = strlen (arg1);
out = rspamd_fstring_sized_new (1024);
- rspamd_printf_fstring (&out, "*%d\r\n$%d\r\n%s\r\n$%d\r\n%s\r\n",
- tokens->len + 2,
+ rspamd_printf_fstring (&out, "*%d\r\n"
+ "$%d\r\n"
+ "%s\r\n"
+ "$%d\r\n"
+ "%s\r\n",
+ learn ? (tokens->len * 2 + 2) : (tokens->len + 2),
l0, arg0,
l1, arg1);
for (i = 0; i < tokens->len; i ++) {
tok = g_ptr_array_index (tokens, i);
memcpy (&num, tok->data, sizeof (num));
- l0 = rspamd_snprintf (numbuf, sizeof (numbuf), "%uL", num);
- rspamd_printf_fstring (&out, "$%d\r\n%s\r\n", l0, numbuf);
+ l0 = rspamd_snprintf (n0, sizeof (n0), "%uL", num);
+
+ if (learn) {
+ if (tok->values[idx] == (guint64)tok->values[idx]) {
+ l1 = rspamd_snprintf (n1, sizeof (n1), "%uL",
+ (guint64)tok->values[idx]);
+ }
+ else {
+ l1 = rspamd_snprintf (n1, sizeof (n1), "%f",
+ (guint64)tok->values[idx]);
+ }
+
+ rspamd_printf_fstring (&out, "$%d\r\n%s\r\n"
+ "$%d\r\n%s\r\n", l0, n0, l1, n1);
+ }
+ else {
+ rspamd_printf_fstring (&out, "$%d\r\n%s\r\n", l0, n0);
+ }
}
rspamd_mempool_add_destructor (task->task_pool,
@@ -317,7 +337,35 @@ rspamd_redis_fin (gpointer data)
if (rt->conn_state != RSPAMD_REDIS_CONNECTED) {
redisAsyncFree (rt->redis);
+ rt->conn_state = RSPAMD_REDIS_DISCONNECTED;
+ }
+
+ event_del (&rt->timeout_event);
+}
+
+static void
+rspamd_redis_fin_learn (gpointer data)
+{
+ struct redis_stat_runtime *rt = REDIS_RUNTIME (data);
+
+ if (rt->conn_state != RSPAMD_REDIS_CONNECTED) {
+ redisAsyncFree (rt->redis);
+ rt->conn_state = RSPAMD_REDIS_DISCONNECTED;
+ }
+
+ event_del (&rt->timeout_event);
+}
+
+static void
+rspamd_redis_fin_set_learns (gpointer data)
+{
+ struct redis_stat_runtime *rt = REDIS_RUNTIME (data);
+
+ if (rt->conn_state != RSPAMD_REDIS_CONNECTED) {
+ redisAsyncFree (rt->redis);
+ rt->conn_state = RSPAMD_REDIS_DISCONNECTED;
}
+
event_del (&rt->timeout_event);
}
@@ -433,6 +481,27 @@ rspamd_redis_processed (redisAsyncContext *c, gpointer r, gpointer priv)
}
}
+/* Called when we have set tokens during learning */
+static void
+rspamd_redis_learned (redisAsyncContext *c, gpointer r, gpointer priv)
+{
+ struct redis_stat_runtime *rt = REDIS_RUNTIME (priv);
+ struct rspamd_task *task;
+
+ task = rt->task;
+
+ if (c->err == 0) {
+ rspamd_upstream_ok (rt->selected);
+ rspamd_session_remove_event (task->s, rspamd_redis_fin_learn, rt);
+ }
+ else {
+ msg_err_task ("error getting reply from redis server %s: %s",
+ rspamd_upstream_name (rt->selected), c->errstr);
+ rspamd_upstream_fail (rt->selected);
+ rspamd_session_remove_event (task->s, rspamd_redis_fin_learn, rt);
+ }
+}
+
gpointer
rspamd_redis_init (struct rspamd_stat_ctx *ctx,
struct rspamd_config *cfg, struct rspamd_statfile *st)
@@ -520,7 +589,7 @@ rspamd_redis_runtime (struct rspamd_task *task,
g_assert (stcf != NULL);
if (learn && ctx->write_servers == NULL) {
- msg_err ("no write servers defined for %s, cannot learn", stcf->symbol);
+ msg_err_task ("no write servers defined for %s, cannot learn", stcf->symbol);
return NULL;
}
@@ -538,7 +607,7 @@ rspamd_redis_runtime (struct rspamd_task *task,
}
if (up == NULL) {
- msg_err ("no upstreams reachable");
+ msg_err_task ("no upstreams reachable");
return NULL;
}
@@ -598,13 +667,14 @@ rspamd_redis_process_tokens (struct rspamd_task *task,
struct timeval tv;
gint ret;
- if (tokens == NULL || tokens->len == 0 || rt->redis == NULL) {
+ if (tokens == NULL || tokens->len == 0 || rt->redis == NULL ||
+ rt->conn_state != RSPAMD_REDIS_CONNECTED) {
return FALSE;
}
rt->id = id;
query = rspamd_redis_tokens_to_query (task, tokens,
- "HMGET", rt->redis_object_expanded);
+ "HMGET", rt->redis_object_expanded, FALSE, -1);
g_assert (query != NULL);
ret = redisAsyncFormattedCommand (rt->redis, rspamd_redis_processed, rt,
@@ -621,7 +691,6 @@ rspamd_redis_process_tokens (struct rspamd_task *task,
}
else {
msg_err_task ("call to redis failed: %s", rt->redis->errstr);
- g_assert (0);
}
return FALSE;
@@ -632,6 +701,13 @@ rspamd_redis_finalize_process (struct rspamd_task *task, gpointer runtime,
gpointer ctx)
{
struct redis_stat_runtime *rt = REDIS_RUNTIME (runtime);
+
+ if (rt->conn_state == RSPAMD_REDIS_CONNECTED) {
+ event_del (&rt->timeout_event);
+ redisAsyncFree (rt->redis);
+
+ rt->conn_state = RSPAMD_REDIS_DISCONNECTED;
+ }
}
gboolean
@@ -639,6 +715,65 @@ rspamd_redis_learn_tokens (struct rspamd_task *task, GPtrArray *tokens,
gint id, gpointer p)
{
struct redis_stat_runtime *rt = REDIS_RUNTIME (p);
+ struct upstream *up;
+ rspamd_inet_addr_t *addr;
+ struct timeval tv;
+ rspamd_fstring_t *query;
+ gint ret;
+
+ if (rt->conn_state != RSPAMD_REDIS_DISCONNECTED) {
+ /* We are likely in some bad state */
+ msg_err_task ("invalid state for function: %d", rt->conn_state);
+
+ return FALSE;
+ }
+
+ up = rspamd_upstream_get (rt->ctx->write_servers,
+ RSPAMD_UPSTREAM_MASTER_SLAVE,
+ NULL,
+ 0);
+
+ if (up == NULL) {
+ msg_err_task ("no upstreams reachable");
+ return FALSE;
+ }
+
+ rt->selected = up;
+
+ addr = rspamd_upstream_addr (up);
+ g_assert (addr != NULL);
+ rt->redis = redisAsyncConnect (rspamd_inet_address_to_string (addr),
+ rspamd_inet_address_get_port (addr));
+ g_assert (rt->redis != NULL);
+
+ redisLibeventAttach (rt->redis, task->ev_base);
+
+ event_set (&rt->timeout_event, -1, EV_TIMEOUT, rspamd_redis_timeout, rt);
+ event_base_set (task->ev_base, &rt->timeout_event);
+ double_to_tv (rt->ctx->timeout, &tv);
+ event_add (&rt->timeout_event, &tv);
+
+ rt->id = id;
+ query = rspamd_redis_tokens_to_query (task, tokens,
+ "HMSET", rt->redis_object_expanded, TRUE, id);
+ g_assert (query != NULL);
+
+ ret = redisAsyncFormattedCommand (rt->redis, rspamd_redis_learned, rt,
+ query->str, query->len);
+
+ if (ret == REDIS_OK) {
+ rspamd_session_add_event (task->s, rspamd_redis_fin_learn, rt,
+ rspamd_redis_stat_quark ());
+ /* Reset timeout */
+ event_del (&rt->timeout_event);
+ double_to_tv (rt->ctx->timeout, &tv);
+ event_add (&rt->timeout_event, &tv);
+
+ return TRUE;
+ }
+ else {
+ msg_err_task ("call to redis failed: %s", rt->redis->errstr);
+ }
return FALSE;
}
@@ -649,6 +784,13 @@ rspamd_redis_finalize_learn (struct rspamd_task *task, gpointer runtime,
gpointer ctx)
{
struct redis_stat_runtime *rt = REDIS_RUNTIME (runtime);
+
+ if (rt->conn_state == RSPAMD_REDIS_CONNECTED) {
+ event_del (&rt->timeout_event);
+ redisAsyncFree (rt->redis);
+
+ rt->conn_state = RSPAMD_REDIS_DISCONNECTED;
+ }
}
gulong