]> source.dussan.org Git - rspamd.git/commitdiff
* Add LRU caching structure
authorVsevolod Stakhov <vsevolod@rambler-co.ru>
Fri, 10 Jun 2011 13:28:19 +0000 (17:28 +0400)
committerVsevolod Stakhov <vsevolod@rambler-co.ru>
Fri, 10 Jun 2011 13:28:19 +0000 (17:28 +0400)
* Add SPF records cache
* Add ability to parse doubles to xmlrpc
Several fixes to dns interface.
Trie plugin now checks urls as well.

src/dns.c
src/hash.c
src/hash.h
src/lua/lua_http.c
src/lua/lua_task.c
src/lua/lua_xmlrpc.c
src/plugins/lua/trie.lua
src/plugins/spf.c
src/spf.c
src/spf.h

index 71feeee1765ac7031eb625fbcddc6a72831533b7..da467f664b128bf28a5ee6d5d8c1381ad5ffa0e6 100644 (file)
--- a/src/dns.c
+++ b/src/dns.c
@@ -907,26 +907,27 @@ dns_parse_rr (guint8 *in, union rspamd_reply_element *elt, guint8 **pos, struct
        return 0;
 }
 
-static struct rspamd_dns_reply *
-dns_parse_reply (guint8 *in, gint r, struct rspamd_dns_resolver *resolver, struct rspamd_dns_request **req_out)
+static gboolean
+dns_parse_reply (guint8 *in, gint r, struct rspamd_dns_resolver *resolver,
+               struct rspamd_dns_request **req_out, struct rspamd_dns_reply **_rep)
 {
        struct dns_header *header = (struct dns_header *)in;
-       struct rspamd_dns_request *req;
-       struct rspamd_dns_reply *rep;
-       union rspamd_reply_element *elt;
-       guint8 *pos;
+       struct rspamd_dns_request      *req;
+       struct rspamd_dns_reply        *rep;
+       union rspamd_reply_element     *elt;
+       guint8                         *pos;
        gint                            i, t;
        
        /* First check header fields */
        if (header->qr == 0) {
                msg_info ("got request while waiting for reply");
-               return NULL;
+               return FALSE;
        }
 
        /* Now try to find corresponding request */
        if ((req = g_hash_table_lookup (resolver->requests, GUINT_TO_POINTER (header->qid))) == NULL) {
                /* No such requests found */
-               return NULL;
+               return FALSE;
        }
        *req_out = req;
        /* 
@@ -934,7 +935,7 @@ dns_parse_reply (guint8 *in, gint r, struct rspamd_dns_resolver *resolver, struc
         * request QR section and reply QR section
         */
        if ((pos = dns_request_reply_cmp (req, in + sizeof (struct dns_header), r - sizeof (struct dns_header))) == NULL) {
-               return NULL;
+               return FALSE;
        }
        /*
         * Remove delayed retransmits for this packet
@@ -949,24 +950,27 @@ dns_parse_reply (guint8 *in, gint r, struct rspamd_dns_resolver *resolver, struc
        rep->elements = NULL;
        rep->code = header->rcode;
 
-       r -= pos - in;
-       /* Extract RR records */
-       for (i = 0; i < ntohs (header->ancount); i ++) {
-               elt = memory_pool_alloc (req->pool, sizeof (union rspamd_reply_element));
-               t = dns_parse_rr (in, elt, &pos, rep, &r);
-               if (t == -1) {
-                       msg_info ("incomplete reply");
-                       break;
+       if (rep->code == DNS_RC_NOERROR) {
+               r -= pos - in;
+               /* Extract RR records */
+               for (i = 0; i < ntohs (header->ancount); i ++) {
+                       elt = memory_pool_alloc (req->pool, sizeof (union rspamd_reply_element));
+                       t = dns_parse_rr (in, elt, &pos, rep, &r);
+                       if (t == -1) {
+                               msg_info ("incomplete reply");
+                               break;
+                       }
+                       else if (t == 1) {
+                               rep->elements = g_list_prepend (rep->elements, elt);
+                       }
                }
-               else if (t == 1) {
-                       rep->elements = g_list_prepend (rep->elements, elt);
+               if (rep->elements) {
+                       memory_pool_add_destructor (req->pool, (pool_destruct_func)g_list_free, rep->elements);
                }
        }
-       if (rep->elements) {
-               memory_pool_add_destructor (req->pool, (pool_destruct_func)g_list_free, rep->elements);
-       }
        
-       return rep;
+       *_rep = rep;
+       return TRUE;
 }
 
 static void
@@ -994,24 +998,25 @@ dns_check_throttling (struct rspamd_dns_resolver *resolver)
 static void
 dns_read_cb (gint fd, short what, void *arg)
 {
-       struct rspamd_dns_resolver *resolver = arg;
-       struct rspamd_dns_request *req = NULL;
+       struct rspamd_dns_resolver     *resolver = arg;
+       struct rspamd_dns_request      *req = NULL;
        gint                            r;
-       struct rspamd_dns_reply *rep;
-       guint8 in[UDP_PACKET_SIZE];
+       struct rspamd_dns_reply        *rep;
+       guint8                          in[UDP_PACKET_SIZE];
 
        /* This function is called each time when we have data on one of server's sockets */
        
        /* First read packet from socket */
        r = read (fd, in, sizeof (in));
        if (r > sizeof (struct dns_header) + sizeof (struct dns_query)) {
-               if ((rep = dns_parse_reply (in, r, resolver, &req)) != NULL) {
+               if (dns_parse_reply (in, r, resolver, &req, &rep)) {
                        /* Decrease errors count */
                        if (rep->request->resolver->errors > 0) {
                                rep->request->resolver->errors --;
                        }
                        upstream_ok (&rep->request->server->up, rep->request->time);
                        rep->request->func (rep, rep->request->arg);
+                       remove_normal_event (req->session, dns_fin_cb, req);
                }
        }
 }
index a023bcbf48c97833dfe6ff1cde8374ce478dd199..1d10e10482ddc4a384ff9f2fec1c6f819d9f335a 100644 (file)
@@ -301,6 +301,132 @@ rspamd_hash_foreach (rspamd_hash_t * hash, GHFunc func, gpointer user_data)
        }
 }
 
+/**
+ * LRU hashing
+ */
+
+static void
+rspamd_lru_hash_destroy_node (gpointer v)
+{
+       rspamd_lru_element_t           *node = v;
+
+       if (node->hash->value_destroy) {
+               node->hash->value_destroy (node->data);
+       }
+
+       g_slice_free1 (sizeof (rspamd_lru_element_t), node);
+}
+
+static rspamd_lru_element_t*
+rspamd_lru_create_node (rspamd_lru_hash_t *hash, gpointer key, gpointer value, time_t now)
+{
+       rspamd_lru_element_t           *node;
+
+       node = g_slice_alloc (sizeof (rspamd_lru_element_t));
+       node->hash = hash;
+       node->data = value;
+       node->key = key;
+       node->store_time = now;
+
+       return node;
+}
+
+/**
+ * Create new lru hash
+ * @param maxsize maximum elements in a hash
+ * @param maxage maximum age of elemnt
+ * @param hash_func pointer to hash function
+ * @param key_equal_func pointer to function for comparing keys
+ * @return new rspamd_hash object
+ */
+rspamd_lru_hash_t*
+rspamd_lru_hash_new (GHashFunc hash_func, GEqualFunc key_equal_func, gint maxsize, gint maxage,
+               GDestroyNotify key_destroy, GDestroyNotify value_destroy)
+{
+       rspamd_lru_hash_t              *new;
+
+       new = g_malloc (sizeof (rspamd_lru_hash_t));
+       new->storage = g_hash_table_new_full (hash_func, key_equal_func, key_destroy, rspamd_lru_hash_destroy_node);
+       new->maxage = maxage;
+       new->maxsize = maxsize;
+       new->value_destroy = value_destroy;
+       new->q = g_queue_new ();
+
+       return new;
+}
+/**
+ * Lookup item from hash
+ * @param hash hash object
+ * @param key key to find
+ * @return value of key or NULL if key is not found
+ */
+gpointer
+rspamd_lru_hash_lookup (rspamd_lru_hash_t *hash, gpointer key, time_t now)
+{
+       rspamd_lru_element_t           *res;
+
+       if ((res = g_hash_table_lookup (hash->storage, key)) != NULL) {
+               if (now - res->store_time > hash->maxage) {
+                       /* Expire elements from queue tail */
+                       res = g_queue_pop_tail (hash->q);
+
+                       while (res != NULL && now - res->store_time > hash->maxage) {
+                               g_hash_table_remove (hash->storage, res->key);
+                               res = g_queue_pop_tail (hash->q);
+                       }
+                       /* Restore last element */
+                       if (res != NULL) {
+                               g_queue_push_tail (hash->q, res);
+                       }
+
+                       return NULL;
+               }
+       }
+
+       if (res) {
+               return res->data;
+       }
+
+       return NULL;
+}
+/**
+ * Insert item in hash
+ * @param hash hash object
+ * @param key key to insert
+ * @param value value of key
+ */
+void
+rspamd_lru_hash_insert (rspamd_lru_hash_t *hash, gpointer key, gpointer value, time_t now)
+{
+       rspamd_lru_element_t           *res;
+       gint                            removed = 0;
+
+       if (g_hash_table_size (hash->storage) >= hash->maxsize) {
+               /* Expire some elements */
+               res = g_queue_pop_tail (hash->q);
+               while (res != NULL && now - res->store_time > hash->maxage) {
+                       g_hash_table_remove (hash->storage, res->key);
+                       res = g_queue_pop_tail (hash->q);
+                       removed ++;
+               }
+               if (removed != 0 && res != NULL) {
+                       g_queue_push_tail (hash->q, res);
+               }
+       }
+
+       res = rspamd_lru_create_node (hash, key, value, now);
+       g_hash_table_insert (hash->storage, key, res);
+       g_queue_push_head (hash->q, res);
+}
+
+void
+rspamd_lru_hash_destroy (rspamd_lru_hash_t *hash)
+{
+       g_hash_table_destroy (hash->storage);
+       g_queue_free (hash->q);
+       g_free (hash);
+}
+
 /*
  * vi:ts=4
  */
index 594b6c63bb3c7e1ded6a814732da7e7bbe47d282..1625aaba1ff0fe0f013b79ee17dc23468e6982a6 100644 (file)
@@ -28,6 +28,22 @@ typedef struct rspamd_hash_s {
        memory_pool_t            *pool;
 } rspamd_hash_t;
 
+typedef struct rspamd_lru_hash_s {
+       gint                      maxsize;
+       gint                      maxage;
+       GHashTable               *storage;
+       GDestroyNotify            value_destroy;
+       GQueue                   *q;
+} rspamd_lru_hash_t;
+
+typedef struct rspamd_lru_element_s {
+       gpointer                  data;
+       gpointer                  key;
+       time_t                    store_time;
+       rspamd_lru_hash_t        *hash;
+} rspamd_lru_element_t;
+
+
 #define rspamd_hash_size(x) (x)->nnodes
 
 /**
@@ -79,6 +95,38 @@ gpointer rspamd_hash_lookup (rspamd_hash_t *hash, gpointer key);
  */
 void rspamd_hash_foreach (rspamd_hash_t *hash, GHFunc func, gpointer user_data);
 
+/**
+ * Create new lru hash
+ * @param maxsize maximum elements in a hash
+ * @param maxage maximum age of elemnt
+ * @param hash_func pointer to hash function
+ * @param key_equal_func pointer to function for comparing keys
+ * @return new rspamd_hash object
+ */
+rspamd_lru_hash_t* rspamd_lru_hash_new (GHashFunc hash_func, GEqualFunc key_equal_func,
+               gint maxsize, gint maxage, GDestroyNotify key_destroy, GDestroyNotify value_destroy);
+/**
+ * Lookup item from hash
+ * @param hash hash object
+ * @param key key to find
+ * @return value of key or NULL if key is not found
+ */
+gpointer rspamd_lru_hash_lookup (rspamd_lru_hash_t *hash, gpointer key, time_t now);
+/**
+ * Insert item in hash
+ * @param hash hash object
+ * @param key key to insert
+ * @param value value of key
+ */
+void rspamd_lru_hash_insert (rspamd_lru_hash_t *hash, gpointer key, gpointer value, time_t now);
+
+/**
+ * Remove lru hash
+ * @param hash hash object
+ */
+
+void rspamd_lru_hash_destroy (rspamd_lru_hash_t *hash);
+
 #endif
 
 /*
index 25cfa6948184cce9f4039d8ad555e37ae1df9b7c..89bd024bf52bd5cda28bb9222d54d23f2ad090c4 100644 (file)
@@ -66,6 +66,15 @@ lua_check_task (lua_State * L)
        return *((struct worker_task **)ud);
 }
 
+static void
+lua_http_fin (void *arg)
+{
+       struct lua_http_ud             *ud = arg;
+
+       rspamd_remove_dispatcher (ud->io_dispatcher);
+       close (ud->fd);
+}
+
 static void
 lua_http_push_error (gint code, struct lua_http_ud *ud)
 {
@@ -93,6 +102,7 @@ lua_http_push_error (gint code, struct lua_http_ud *ud)
        }
 
        ud->parser_state = 3;
+       remove_normal_event (ud->task->s, lua_http_fin, ud);
 
        ud->task->save.saved--;
        if (ud->task->save.saved == 0) {
@@ -140,6 +150,7 @@ lua_http_push_reply (f_str_t *in, struct lua_http_ud *ud)
                ud->headers = NULL;
        }
 
+       remove_normal_event (ud->task->s, lua_http_fin, ud);
        ud->task->save.saved--;
        if (ud->task->save.saved == 0) {
                /* Call other filters */
@@ -148,7 +159,6 @@ lua_http_push_reply (f_str_t *in, struct lua_http_ud *ud)
        }
 }
 
-
 /*
  * Parsing utils
  */
@@ -246,8 +256,6 @@ lua_http_read_cb (f_str_t * in, void *arg)
        case 2:
                /* Get reply */
                lua_http_push_reply (in, ud);
-               rspamd_remove_dispatcher (ud->io_dispatcher);
-               close (ud->fd);
                return FALSE;
        }
 
@@ -264,10 +272,13 @@ lua_http_err_cb (GError * err, void *arg)
        if (ud->parser_state != 3) {
                lua_http_push_error (500, ud);
        }
-       rspamd_remove_dispatcher (ud->io_dispatcher);
-       close (ud->fd);
+       else {
+               remove_normal_event (ud->task->s, lua_http_fin, ud);
+       }
 }
 
+
+
 static void
 lua_http_dns_callback (struct rspamd_dns_reply *reply, gpointer arg)
 {
@@ -304,6 +315,8 @@ lua_http_dns_callback (struct rspamd_dns_reply *reply, gpointer arg)
                close (ud->fd);
                return;
        }
+
+       register_async_event (ud->task->s, lua_http_fin, ud, FALSE);
 }
 
 /**
index 2ab56b29cc0f3453450e9a4891c163b390512538..0931eb13cbc51d7a532d4ccf17d4cbbc5b98293d 100644 (file)
@@ -57,6 +57,7 @@ LUA_FUNCTION_DEF (task, resolve_dns_txt);
 LUA_FUNCTION_DEF (task, call_rspamd_function);
 LUA_FUNCTION_DEF (task, get_recipients);
 LUA_FUNCTION_DEF (task, get_from);
+LUA_FUNCTION_DEF (task, get_user);
 LUA_FUNCTION_DEF (task, get_recipients_headers);
 LUA_FUNCTION_DEF (task, get_from_headers);
 LUA_FUNCTION_DEF (task, get_from_ip);
@@ -86,6 +87,7 @@ static const struct luaL_reg    tasklib_m[] = {
        LUA_INTERFACE_DEF (task, call_rspamd_function),
        LUA_INTERFACE_DEF (task, get_recipients),
        LUA_INTERFACE_DEF (task, get_from),
+       LUA_INTERFACE_DEF (task, get_user),
        LUA_INTERFACE_DEF (task, get_recipients_headers),
        LUA_INTERFACE_DEF (task, get_from_headers),
        LUA_INTERFACE_DEF (task, get_from_ip),
@@ -833,6 +835,20 @@ lua_task_get_from (lua_State *L)
        return 1;
 }
 
+static gint
+lua_task_get_user (lua_State *L)
+{
+       struct worker_task             *task = lua_check_task (L);
+
+       if (task && task->user != NULL) {
+               lua_pushstring (L, task->user);
+               return 1;
+       }
+
+       lua_pushnil (L);
+       return 1;
+}
+
 /*
  * Headers versions
  */
index f0ba2d6c035b0b3c3c80b65537d47f935f0d8836..4589405ccc6c7872668448a587113923dce4a052 100644 (file)
@@ -38,6 +38,7 @@ struct lua_xmlrpc_ud {
        gint parser_state;
        gint depth;
        gint param_count;
+       gboolean got_text;
        lua_State *L;
 };
 
@@ -126,9 +127,15 @@ xmlrpc_start_element (GMarkupParseContext *context, const gchar *name, const gch
                }
                else if (g_ascii_strcasecmp (name, "string") == 0) {
                        ud->parser_state = 11;
+                       ud->got_text = FALSE;
                }
                else if (g_ascii_strcasecmp (name, "int") == 0) {
                        ud->parser_state = 12;
+                       ud->got_text = FALSE;
+               }
+               else if (g_ascii_strcasecmp (name, "double") == 0) {
+                       ud->parser_state = 13;
+                       ud->got_text = FALSE;
                }
                else {
                        /* Error state */
@@ -171,9 +178,15 @@ xmlrpc_start_element (GMarkupParseContext *context, const gchar *name, const gch
                /* Primitives */
                if (g_ascii_strcasecmp (name, "string") == 0) {
                        ud->parser_state = 11;
+                       ud->got_text = FALSE;
                }
                else if (g_ascii_strcasecmp (name, "int") == 0) {
                        ud->parser_state = 12;
+                       ud->got_text = FALSE;
+               }
+               else if (g_ascii_strcasecmp (name, "double") == 0) {
+                       ud->parser_state = 13;
+                       ud->got_text = FALSE;
                }
                /* Structure */
                else if (g_ascii_strcasecmp (name, "struct") == 0) {
@@ -300,7 +313,15 @@ xmlrpc_end_element (GMarkupParseContext    *context, const gchar *name, gpointer us
                break;
        case 11:
        case 12:
+       case 13:
                /* Parse any values */
+               /* Handle empty tags */
+               if (!ud->got_text) {
+                       lua_pushnil (ud->L);
+               }
+               else {
+                       ud->got_text = FALSE;
+               }
                /* Primitives */
                if (g_ascii_strcasecmp (name, "string") == 0) {
                        ud->parser_state = 8;
@@ -308,6 +329,9 @@ xmlrpc_end_element (GMarkupParseContext     *context, const gchar *name, gpointer us
                else if (g_ascii_strcasecmp (name, "int") == 0) {
                        ud->parser_state = 8;
                }
+               else if (g_ascii_strcasecmp (name, "double") == 0) {
+                       ud->parser_state = 8;
+               }
                else {
                        /* Error state */
                        ud->parser_state = 99;
@@ -326,6 +350,7 @@ xmlrpc_text (GMarkupParseContext *context, const gchar *text, gsize text_len, gp
 {
        struct lua_xmlrpc_ud           *ud = user_data;
        gint                            num;
+       gdouble                         dnum;
 
        /* Strip line */
        while (g_ascii_isspace (*text) && text_len > 0) {
@@ -352,7 +377,13 @@ xmlrpc_text (GMarkupParseContext *context, const gchar *text, gsize text_len, gp
                        num = strtoul (text, NULL, 10);
                        lua_pushinteger (ud->L, num);
                        break;
+               case 13:
+                       /* Push integer value */
+                       dnum = strtod (text, NULL);
+                       lua_pushnumber (ud->L, dnum);
+                       break;
                }
+               ud->got_text = TRUE;
        }
 }
 
@@ -400,7 +431,8 @@ lua_xmlrpc_parse_reply (lua_State *L)
 static gint
 lua_xmlrpc_parse_table (lua_State *L, gint pos, gchar *databuf, gint pr, gsize size)
 {
-       gint                           r = pr;
+       gint                           r = pr, num;
+       double                         dnum;
 
        r += rspamd_snprintf (databuf + r, size - r, "<struct>");
        lua_pushnil (L);  /* first key */
@@ -415,8 +447,18 @@ lua_xmlrpc_parse_table (lua_State *L, gint pos, gchar *databuf, gint pr, gsize s
                                lua_tostring (L, -2));
                switch (lua_type (L, -1)) {
                case LUA_TNUMBER:
-                       r += rspamd_snprintf (databuf + r, size - r, "<int>%d</int>",
-                                       lua_tointeger (L, -1));
+                       num = lua_tointeger (L, -1);
+                       dnum = lua_tonumber (L, -1);
+
+                       /* Try to avoid conversion errors */
+                       if (dnum != (double)num) {
+                               r += rspamd_snprintf (databuf + r, sizeof (databuf) - r, "<double>%f</double>",
+                                               dnum);
+                       }
+                       else {
+                               r += rspamd_snprintf (databuf + r, sizeof (databuf) - r, "<int>%d</int>",
+                                               num);
+                       }
                        break;
                case LUA_TBOOLEAN:
                        r += rspamd_snprintf (databuf + r, size - r, "<boolean>%d</boolean>",
@@ -449,7 +491,8 @@ lua_xmlrpc_make_request (lua_State *L)
 {
        gchar                          databuf[BUFSIZ * 2];
        const gchar                   *func;
-       gint                           r, top, i;
+       gint                           r, top, i, num;
+       double                         dnum;
 
        func = luaL_checkstring (L, 1);
 
@@ -465,8 +508,18 @@ lua_xmlrpc_make_request (lua_State *L)
                        r += rspamd_snprintf (databuf + r, sizeof (databuf) - r, "<param><value>");
                        switch (lua_type (L, i)) {
                        case LUA_TNUMBER:
-                               r += rspamd_snprintf (databuf + r, sizeof (databuf) - r, "<int>%d</int>",
-                                               lua_tointeger (L, i));
+                               num = lua_tointeger (L, i);
+                               dnum = lua_tonumber (L, i);
+
+                               /* Try to avoid conversion errors */
+                               if (dnum != (double)num) {
+                                       r += rspamd_snprintf (databuf + r, sizeof (databuf) - r, "<double>%f</double>",
+                                                                                       dnum);
+                               }
+                               else {
+                                       r += rspamd_snprintf (databuf + r, sizeof (databuf) - r, "<int>%d</int>",
+                                               num);
+                               }
                                break;
                        case LUA_TBOOLEAN:
                                r += rspamd_snprintf (databuf + r, sizeof (databuf) - r, "<boolean>%d</boolean>",
index 98248f29ff8cea0be042c042ee1eb8b04569db98..6b1782a00dcf793bbf82a833b2de83912a3e8562 100644 (file)
@@ -66,6 +66,15 @@ function check_trie(task)
                if trie['trie']:search_task(task) then
                        task:insert_result(trie['symbol'], 1)
                end
+               -- Search inside urls
+               urls = task:get_urls()
+               if urls then
+                       for _,url in urls do
+                               if trie['trie']:search_text(url:get_text()) then
+                                       task:insert_result(trie['symbol'], 1)
+                               end
+                       end
+               end
        end
 end
 
index f5cbbe7b3b6b1ccc8e1f760c1281cc3138a34d07..223bc1241356c87d12334a8c158c5e632828e14b 100644 (file)
 #include "../map.h"
 #include "../spf.h"
 #include "../cfg_xml.h"
+#include "../hash.h"
 
 #define DEFAULT_SYMBOL_FAIL "R_SPF_FAIL"
 #define DEFAULT_SYMBOL_SOFTFAIL "R_SPF_SOFTFAIL"
 #define DEFAULT_SYMBOL_ALLOW "R_SPF_ALLOW"
+#define DEFAULT_CACHE_SIZE 2048
+#define DEFAULT_CACHE_MAXAGE 86400
 
 struct spf_ctx {
        gint                            (*filter) (struct worker_task * task);
@@ -54,13 +57,16 @@ struct spf_ctx {
        gchar                           *symbol_softfail;
        gchar                           *symbol_allow;
 
-       memory_pool_t                  *spf_pool;
-       radix_tree_t                   *whitelist_ip;
+       memory_pool_t                   *spf_pool;
+       radix_tree_t                    *whitelist_ip;
+       rspamd_lru_hash_t               *spf_hash;
 };
 
 static struct spf_ctx        *spf_module_ctx = NULL;
 
 static void                   spf_symbol_callback (struct worker_task *task, void *unused);
+static GList *                spf_record_copy (GList *addrs);
+static void                   spf_record_destroy (gpointer list);
 
 gint
 spf_module_init (struct config_file *cfg, struct module_ctx **ctx)
@@ -73,6 +79,8 @@ spf_module_init (struct config_file *cfg, struct module_ctx **ctx)
        register_module_opt ("spf", "symbol_fail", MODULE_OPT_TYPE_STRING);
        register_module_opt ("spf", "symbol_softfail", MODULE_OPT_TYPE_STRING);
        register_module_opt ("spf", "symbol_allow", MODULE_OPT_TYPE_STRING);
+       register_module_opt ("spf", "spf_cache_size", MODULE_OPT_TYPE_UINT);
+       register_module_opt ("spf", "spf_cache_expire", MODULE_OPT_TYPE_TIME);
        register_module_opt ("spf", "whitelist", MODULE_OPT_TYPE_MAP);
 
        return 0;
@@ -82,8 +90,9 @@ spf_module_init (struct config_file *cfg, struct module_ctx **ctx)
 gint
 spf_module_config (struct config_file *cfg)
 {
-       gchar                           *value;
+       gchar                          *value;
        gint                            res = TRUE;
+       guint                           cache_size, cache_expire;
 
        spf_module_ctx->whitelist_ip = radix_tree_create ();
        
@@ -105,6 +114,18 @@ spf_module_config (struct config_file *cfg)
        else {
                spf_module_ctx->symbol_allow = DEFAULT_SYMBOL_ALLOW;
        }
+       if ((value = get_module_opt (cfg, "spf", "spf_cache_size")) != NULL) {
+               cache_size = strtoul (value, NULL, 10);
+       }
+       else {
+               cache_size = DEFAULT_CACHE_SIZE;
+       }
+       if ((value = get_module_opt (cfg, "spf", "spf_cache_expire")) != NULL) {
+               cache_expire = parse_time (value, TIME_SECONDS) / 1000;
+       }
+       else {
+               cache_expire = DEFAULT_CACHE_MAXAGE;
+       }
        if ((value = get_module_opt (cfg, "spf", "whitelist")) != NULL) {
                if (! add_map (value, read_radix_list, fin_radix_list, (void **)&spf_module_ctx->whitelist_ip)) {
                        msg_warn ("cannot load whitelist from %s", value);
@@ -115,6 +136,9 @@ spf_module_config (struct config_file *cfg)
        register_virtual_symbol (&cfg->cache, spf_module_ctx->symbol_softfail, 1);
        register_virtual_symbol (&cfg->cache, spf_module_ctx->symbol_allow, 1);
 
+       spf_module_ctx->spf_hash = rspamd_lru_hash_new (rspamd_strcase_hash, rspamd_strcase_equal,
+                       cache_size, cache_expire, g_free, spf_record_destroy);
+
        return res;
 }
 
@@ -175,7 +199,6 @@ spf_check_list (GList *list, struct worker_task *task)
                addr = cur->data;
                if (addr->is_list) {
                        /* Recursive call */
-                       addr->data.list = g_list_reverse (addr->data.list);
                        if (spf_check_list (addr->data.list, task)) {
                                return TRUE;
                        }
@@ -194,9 +217,15 @@ spf_check_list (GList *list, struct worker_task *task)
 static void 
 spf_plugin_callback (struct spf_record *record, struct worker_task *task)
 {
+       GList                           *l;
        if (record) {
-               record->addrs = g_list_reverse (record->addrs);
-               spf_check_list (record->addrs, task);
+
+               if ((l = rspamd_lru_hash_lookup (spf_module_ctx->spf_hash, record->sender_domain, task->tv.tv_sec)) == NULL) {
+                       l = spf_record_copy (record->addrs);
+                       rspamd_lru_hash_insert (spf_module_ctx->spf_hash, g_strdup (record->sender_domain),
+                               l, task->tv.tv_sec);
+               }
+               spf_check_list (l, task);
        }
 
        if (task->save.saved == 0) {
@@ -211,10 +240,20 @@ spf_plugin_callback (struct spf_record *record, struct worker_task *task)
 static void 
 spf_symbol_callback (struct worker_task *task, void *unused)
 {
+       gchar                           *domain;
+       GList                           *l;
+
        if (task->from_addr.s_addr != INADDR_NONE && task->from_addr.s_addr != INADDR_ANY) {
                if (radix32tree_find (spf_module_ctx->whitelist_ip, ntohl (task->from_addr.s_addr)) == RADIX_NO_VALUE) {
-                       if (!resolve_spf (task, spf_plugin_callback)) {
-                               msg_info ("cannot make spf request for [%s]", task->message_id);
+
+                       domain = get_spf_domain (task);
+                       if (domain) {
+                               if ((l = rspamd_lru_hash_lookup (spf_module_ctx->spf_hash, domain, task->tv.tv_sec)) != NULL) {
+                                       spf_check_list (l, task);
+                               }
+                               else if (!resolve_spf (task, spf_plugin_callback)) {
+                                       msg_info ("cannot make spf request for [%s]", task->message_id);
+                               }
                        }
                }
                else {
@@ -222,3 +261,54 @@ spf_symbol_callback (struct worker_task *task, void *unused)
                }
        }
 }
+
+/*
+ * Make a deep copy of list, note copy is REVERSED
+ */
+static GList *
+spf_record_copy (GList *addrs)
+{
+       GList                           *cur, *newl = NULL;
+       struct spf_addr                 *addr, *newa;
+
+       cur = addrs;
+
+       while (cur) {
+               addr = cur->data;
+               newa = g_malloc (sizeof (struct spf_addr));
+               memcpy (newa, addr, sizeof (struct spf_addr));
+               if (addr->is_list) {
+                       /* Recursive call */
+                       newa->data.list = spf_record_copy (addr->data.list);
+               }
+               newl = g_list_prepend (newl, newa);
+               cur = g_list_next (cur);
+       }
+
+       return newl;
+}
+
+/*
+ * Destroy allocated spf list
+ */
+
+
+static void
+spf_record_destroy (gpointer list)
+{
+       GList                           *cur = list;
+       struct spf_addr                 *addr;
+
+       while (cur) {
+               addr = cur->data;
+               if (addr->is_list) {
+                       spf_record_destroy (addr->data.list);
+               }
+               else {
+                       g_free (addr);
+               }
+               cur = g_list_next (cur);
+       }
+
+       g_list_free (list);
+}
index 7c0e82ec6e4a2cd606968a12df270fc237b5ea49..c6498f69bb0a6cc2d2260b1ea652cc6b8b01d99f 100644 (file)
--- a/src/spf.c
+++ b/src/spf.c
@@ -1171,13 +1171,47 @@ spf_dns_callback (struct rspamd_dns_reply *reply, gpointer arg)
        }
 }
 
+gchar *
+get_spf_domain (struct worker_task *task)
+{
+       gchar                           *domain, *res = NULL;
+       GList                           *domains;
+
+       if (task->from && (domain = strchr (task->from, '@')) != NULL && *domain == '@') {
+               res = memory_pool_strdup (task->task_pool, domain + 1);
+               if ((domain = strchr (res, '>')) != NULL) {
+                       *domain = '\0';
+               }
+       }
+       else {
+               /* Extract from header */
+               domains = message_get_header (task->task_pool, task->message, "From", FALSE);
+
+               if (domains != NULL) {
+                       res = memory_pool_strdup (task->task_pool, domains->data);
+
+                       if ((domain = strrchr (res, '@')) == NULL) {
+                               g_list_free (domains);
+                               return NULL;
+                       }
+                       res = memory_pool_strdup (task->task_pool, domain + 1);
+                       g_list_free (domains);
+
+                       if ((domain = strchr (res, '>')) != NULL) {
+                               *domain = '\0';
+                       }
+               }
+       }
+
+       return res;
+}
 
 gboolean
 resolve_spf (struct worker_task *task, spf_cb_t callback)
 {
-       struct spf_record *rec;
+       struct spf_record               *rec;
        gchar                           *domain;
-       GList *domains;
+       GList                           *domains;
 
        rec = memory_pool_alloc0 (task->task_pool, sizeof (struct spf_record));
        rec->task = task;
index 7a801254e2d91252e3f46c6378bb373b30b3ce5a..9ea7d050ac2b79c50a87e8f58cfbac28ed17e4a3 100644 (file)
--- a/src/spf.h
+++ b/src/spf.h
@@ -60,5 +60,7 @@ struct spf_record {
 
 gboolean resolve_spf (struct worker_task *task, spf_cb_t callback);
 
+gchar *get_spf_domain (struct worker_task *task);
+
 
 #endif