From: Vsevolod Stakhov Date: Fri, 10 Jun 2011 13:28:19 +0000 (+0400) Subject: * Add LRU caching structure X-Git-Tag: 0.3.14~6 X-Git-Url: https://source.dussan.org/?a=commitdiff_plain;h=2b5a8d60da266be61d35b285bc14a8c5d71798e7;p=rspamd.git * Add LRU caching structure * Add SPF records cache * Add ability to parse doubles to xmlrpc Several fixes to dns interface. Trie plugin now checks urls as well. --- diff --git a/src/dns.c b/src/dns.c index 71feeee17..da467f664 100644 --- a/src/dns.c +++ b/src/dns.c @@ -907,26 +907,27 @@ dns_parse_rr (guint8 *in, union rspamd_reply_element *elt, guint8 **pos, struct return 0; } -static struct rspamd_dns_reply * -dns_parse_reply (guint8 *in, gint r, struct rspamd_dns_resolver *resolver, struct rspamd_dns_request **req_out) +static gboolean +dns_parse_reply (guint8 *in, gint r, struct rspamd_dns_resolver *resolver, + struct rspamd_dns_request **req_out, struct rspamd_dns_reply **_rep) { struct dns_header *header = (struct dns_header *)in; - struct rspamd_dns_request *req; - struct rspamd_dns_reply *rep; - union rspamd_reply_element *elt; - guint8 *pos; + struct rspamd_dns_request *req; + struct rspamd_dns_reply *rep; + union rspamd_reply_element *elt; + guint8 *pos; gint i, t; /* First check header fields */ if (header->qr == 0) { msg_info ("got request while waiting for reply"); - return NULL; + return FALSE; } /* Now try to find corresponding request */ if ((req = g_hash_table_lookup (resolver->requests, GUINT_TO_POINTER (header->qid))) == NULL) { /* No such requests found */ - return NULL; + return FALSE; } *req_out = req; /* @@ -934,7 +935,7 @@ dns_parse_reply (guint8 *in, gint r, struct rspamd_dns_resolver *resolver, struc * request QR section and reply QR section */ if ((pos = dns_request_reply_cmp (req, in + sizeof (struct dns_header), r - sizeof (struct dns_header))) == NULL) { - return NULL; + return FALSE; } /* * Remove delayed retransmits for this packet @@ -949,24 +950,27 @@ dns_parse_reply (guint8 *in, gint r, struct rspamd_dns_resolver *resolver, struc rep->elements = NULL; rep->code = header->rcode; - r -= pos - in; - /* Extract RR records */ - for (i = 0; i < ntohs (header->ancount); i ++) { - elt = memory_pool_alloc (req->pool, sizeof (union rspamd_reply_element)); - t = dns_parse_rr (in, elt, &pos, rep, &r); - if (t == -1) { - msg_info ("incomplete reply"); - break; + if (rep->code == DNS_RC_NOERROR) { + r -= pos - in; + /* Extract RR records */ + for (i = 0; i < ntohs (header->ancount); i ++) { + elt = memory_pool_alloc (req->pool, sizeof (union rspamd_reply_element)); + t = dns_parse_rr (in, elt, &pos, rep, &r); + if (t == -1) { + msg_info ("incomplete reply"); + break; + } + else if (t == 1) { + rep->elements = g_list_prepend (rep->elements, elt); + } } - else if (t == 1) { - rep->elements = g_list_prepend (rep->elements, elt); + if (rep->elements) { + memory_pool_add_destructor (req->pool, (pool_destruct_func)g_list_free, rep->elements); } } - if (rep->elements) { - memory_pool_add_destructor (req->pool, (pool_destruct_func)g_list_free, rep->elements); - } - return rep; + *_rep = rep; + return TRUE; } static void @@ -994,24 +998,25 @@ dns_check_throttling (struct rspamd_dns_resolver *resolver) static void dns_read_cb (gint fd, short what, void *arg) { - struct rspamd_dns_resolver *resolver = arg; - struct rspamd_dns_request *req = NULL; + struct rspamd_dns_resolver *resolver = arg; + struct rspamd_dns_request *req = NULL; gint r; - struct rspamd_dns_reply *rep; - guint8 in[UDP_PACKET_SIZE]; + struct rspamd_dns_reply *rep; + guint8 in[UDP_PACKET_SIZE]; /* This function is called each time when we have data on one of server's sockets */ /* First read packet from socket */ r = read (fd, in, sizeof (in)); if (r > sizeof (struct dns_header) + sizeof (struct dns_query)) { - if ((rep = dns_parse_reply (in, r, resolver, &req)) != NULL) { + if (dns_parse_reply (in, r, resolver, &req, &rep)) { /* Decrease errors count */ if (rep->request->resolver->errors > 0) { rep->request->resolver->errors --; } upstream_ok (&rep->request->server->up, rep->request->time); rep->request->func (rep, rep->request->arg); + remove_normal_event (req->session, dns_fin_cb, req); } } } diff --git a/src/hash.c b/src/hash.c index a023bcbf4..1d10e1048 100644 --- a/src/hash.c +++ b/src/hash.c @@ -301,6 +301,132 @@ rspamd_hash_foreach (rspamd_hash_t * hash, GHFunc func, gpointer user_data) } } +/** + * LRU hashing + */ + +static void +rspamd_lru_hash_destroy_node (gpointer v) +{ + rspamd_lru_element_t *node = v; + + if (node->hash->value_destroy) { + node->hash->value_destroy (node->data); + } + + g_slice_free1 (sizeof (rspamd_lru_element_t), node); +} + +static rspamd_lru_element_t* +rspamd_lru_create_node (rspamd_lru_hash_t *hash, gpointer key, gpointer value, time_t now) +{ + rspamd_lru_element_t *node; + + node = g_slice_alloc (sizeof (rspamd_lru_element_t)); + node->hash = hash; + node->data = value; + node->key = key; + node->store_time = now; + + return node; +} + +/** + * Create new lru hash + * @param maxsize maximum elements in a hash + * @param maxage maximum age of elemnt + * @param hash_func pointer to hash function + * @param key_equal_func pointer to function for comparing keys + * @return new rspamd_hash object + */ +rspamd_lru_hash_t* +rspamd_lru_hash_new (GHashFunc hash_func, GEqualFunc key_equal_func, gint maxsize, gint maxage, + GDestroyNotify key_destroy, GDestroyNotify value_destroy) +{ + rspamd_lru_hash_t *new; + + new = g_malloc (sizeof (rspamd_lru_hash_t)); + new->storage = g_hash_table_new_full (hash_func, key_equal_func, key_destroy, rspamd_lru_hash_destroy_node); + new->maxage = maxage; + new->maxsize = maxsize; + new->value_destroy = value_destroy; + new->q = g_queue_new (); + + return new; +} +/** + * Lookup item from hash + * @param hash hash object + * @param key key to find + * @return value of key or NULL if key is not found + */ +gpointer +rspamd_lru_hash_lookup (rspamd_lru_hash_t *hash, gpointer key, time_t now) +{ + rspamd_lru_element_t *res; + + if ((res = g_hash_table_lookup (hash->storage, key)) != NULL) { + if (now - res->store_time > hash->maxage) { + /* Expire elements from queue tail */ + res = g_queue_pop_tail (hash->q); + + while (res != NULL && now - res->store_time > hash->maxage) { + g_hash_table_remove (hash->storage, res->key); + res = g_queue_pop_tail (hash->q); + } + /* Restore last element */ + if (res != NULL) { + g_queue_push_tail (hash->q, res); + } + + return NULL; + } + } + + if (res) { + return res->data; + } + + return NULL; +} +/** + * Insert item in hash + * @param hash hash object + * @param key key to insert + * @param value value of key + */ +void +rspamd_lru_hash_insert (rspamd_lru_hash_t *hash, gpointer key, gpointer value, time_t now) +{ + rspamd_lru_element_t *res; + gint removed = 0; + + if (g_hash_table_size (hash->storage) >= hash->maxsize) { + /* Expire some elements */ + res = g_queue_pop_tail (hash->q); + while (res != NULL && now - res->store_time > hash->maxage) { + g_hash_table_remove (hash->storage, res->key); + res = g_queue_pop_tail (hash->q); + removed ++; + } + if (removed != 0 && res != NULL) { + g_queue_push_tail (hash->q, res); + } + } + + res = rspamd_lru_create_node (hash, key, value, now); + g_hash_table_insert (hash->storage, key, res); + g_queue_push_head (hash->q, res); +} + +void +rspamd_lru_hash_destroy (rspamd_lru_hash_t *hash) +{ + g_hash_table_destroy (hash->storage); + g_queue_free (hash->q); + g_free (hash); +} + /* * vi:ts=4 */ diff --git a/src/hash.h b/src/hash.h index 594b6c63b..1625aaba1 100644 --- a/src/hash.h +++ b/src/hash.h @@ -28,6 +28,22 @@ typedef struct rspamd_hash_s { memory_pool_t *pool; } rspamd_hash_t; +typedef struct rspamd_lru_hash_s { + gint maxsize; + gint maxage; + GHashTable *storage; + GDestroyNotify value_destroy; + GQueue *q; +} rspamd_lru_hash_t; + +typedef struct rspamd_lru_element_s { + gpointer data; + gpointer key; + time_t store_time; + rspamd_lru_hash_t *hash; +} rspamd_lru_element_t; + + #define rspamd_hash_size(x) (x)->nnodes /** @@ -79,6 +95,38 @@ gpointer rspamd_hash_lookup (rspamd_hash_t *hash, gpointer key); */ void rspamd_hash_foreach (rspamd_hash_t *hash, GHFunc func, gpointer user_data); +/** + * Create new lru hash + * @param maxsize maximum elements in a hash + * @param maxage maximum age of elemnt + * @param hash_func pointer to hash function + * @param key_equal_func pointer to function for comparing keys + * @return new rspamd_hash object + */ +rspamd_lru_hash_t* rspamd_lru_hash_new (GHashFunc hash_func, GEqualFunc key_equal_func, + gint maxsize, gint maxage, GDestroyNotify key_destroy, GDestroyNotify value_destroy); +/** + * Lookup item from hash + * @param hash hash object + * @param key key to find + * @return value of key or NULL if key is not found + */ +gpointer rspamd_lru_hash_lookup (rspamd_lru_hash_t *hash, gpointer key, time_t now); +/** + * Insert item in hash + * @param hash hash object + * @param key key to insert + * @param value value of key + */ +void rspamd_lru_hash_insert (rspamd_lru_hash_t *hash, gpointer key, gpointer value, time_t now); + +/** + * Remove lru hash + * @param hash hash object + */ + +void rspamd_lru_hash_destroy (rspamd_lru_hash_t *hash); + #endif /* diff --git a/src/lua/lua_http.c b/src/lua/lua_http.c index 25cfa6948..89bd024bf 100644 --- a/src/lua/lua_http.c +++ b/src/lua/lua_http.c @@ -66,6 +66,15 @@ lua_check_task (lua_State * L) return *((struct worker_task **)ud); } +static void +lua_http_fin (void *arg) +{ + struct lua_http_ud *ud = arg; + + rspamd_remove_dispatcher (ud->io_dispatcher); + close (ud->fd); +} + static void lua_http_push_error (gint code, struct lua_http_ud *ud) { @@ -93,6 +102,7 @@ lua_http_push_error (gint code, struct lua_http_ud *ud) } ud->parser_state = 3; + remove_normal_event (ud->task->s, lua_http_fin, ud); ud->task->save.saved--; if (ud->task->save.saved == 0) { @@ -140,6 +150,7 @@ lua_http_push_reply (f_str_t *in, struct lua_http_ud *ud) ud->headers = NULL; } + remove_normal_event (ud->task->s, lua_http_fin, ud); ud->task->save.saved--; if (ud->task->save.saved == 0) { /* Call other filters */ @@ -148,7 +159,6 @@ lua_http_push_reply (f_str_t *in, struct lua_http_ud *ud) } } - /* * Parsing utils */ @@ -246,8 +256,6 @@ lua_http_read_cb (f_str_t * in, void *arg) case 2: /* Get reply */ lua_http_push_reply (in, ud); - rspamd_remove_dispatcher (ud->io_dispatcher); - close (ud->fd); return FALSE; } @@ -264,10 +272,13 @@ lua_http_err_cb (GError * err, void *arg) if (ud->parser_state != 3) { lua_http_push_error (500, ud); } - rspamd_remove_dispatcher (ud->io_dispatcher); - close (ud->fd); + else { + remove_normal_event (ud->task->s, lua_http_fin, ud); + } } + + static void lua_http_dns_callback (struct rspamd_dns_reply *reply, gpointer arg) { @@ -304,6 +315,8 @@ lua_http_dns_callback (struct rspamd_dns_reply *reply, gpointer arg) close (ud->fd); return; } + + register_async_event (ud->task->s, lua_http_fin, ud, FALSE); } /** diff --git a/src/lua/lua_task.c b/src/lua/lua_task.c index 2ab56b29c..0931eb13c 100644 --- a/src/lua/lua_task.c +++ b/src/lua/lua_task.c @@ -57,6 +57,7 @@ LUA_FUNCTION_DEF (task, resolve_dns_txt); LUA_FUNCTION_DEF (task, call_rspamd_function); LUA_FUNCTION_DEF (task, get_recipients); LUA_FUNCTION_DEF (task, get_from); +LUA_FUNCTION_DEF (task, get_user); LUA_FUNCTION_DEF (task, get_recipients_headers); LUA_FUNCTION_DEF (task, get_from_headers); LUA_FUNCTION_DEF (task, get_from_ip); @@ -86,6 +87,7 @@ static const struct luaL_reg tasklib_m[] = { LUA_INTERFACE_DEF (task, call_rspamd_function), LUA_INTERFACE_DEF (task, get_recipients), LUA_INTERFACE_DEF (task, get_from), + LUA_INTERFACE_DEF (task, get_user), LUA_INTERFACE_DEF (task, get_recipients_headers), LUA_INTERFACE_DEF (task, get_from_headers), LUA_INTERFACE_DEF (task, get_from_ip), @@ -833,6 +835,20 @@ lua_task_get_from (lua_State *L) return 1; } +static gint +lua_task_get_user (lua_State *L) +{ + struct worker_task *task = lua_check_task (L); + + if (task && task->user != NULL) { + lua_pushstring (L, task->user); + return 1; + } + + lua_pushnil (L); + return 1; +} + /* * Headers versions */ diff --git a/src/lua/lua_xmlrpc.c b/src/lua/lua_xmlrpc.c index f0ba2d6c0..4589405cc 100644 --- a/src/lua/lua_xmlrpc.c +++ b/src/lua/lua_xmlrpc.c @@ -38,6 +38,7 @@ struct lua_xmlrpc_ud { gint parser_state; gint depth; gint param_count; + gboolean got_text; lua_State *L; }; @@ -126,9 +127,15 @@ xmlrpc_start_element (GMarkupParseContext *context, const gchar *name, const gch } else if (g_ascii_strcasecmp (name, "string") == 0) { ud->parser_state = 11; + ud->got_text = FALSE; } else if (g_ascii_strcasecmp (name, "int") == 0) { ud->parser_state = 12; + ud->got_text = FALSE; + } + else if (g_ascii_strcasecmp (name, "double") == 0) { + ud->parser_state = 13; + ud->got_text = FALSE; } else { /* Error state */ @@ -171,9 +178,15 @@ xmlrpc_start_element (GMarkupParseContext *context, const gchar *name, const gch /* Primitives */ if (g_ascii_strcasecmp (name, "string") == 0) { ud->parser_state = 11; + ud->got_text = FALSE; } else if (g_ascii_strcasecmp (name, "int") == 0) { ud->parser_state = 12; + ud->got_text = FALSE; + } + else if (g_ascii_strcasecmp (name, "double") == 0) { + ud->parser_state = 13; + ud->got_text = FALSE; } /* Structure */ else if (g_ascii_strcasecmp (name, "struct") == 0) { @@ -300,7 +313,15 @@ xmlrpc_end_element (GMarkupParseContext *context, const gchar *name, gpointer us break; case 11: case 12: + case 13: /* Parse any values */ + /* Handle empty tags */ + if (!ud->got_text) { + lua_pushnil (ud->L); + } + else { + ud->got_text = FALSE; + } /* Primitives */ if (g_ascii_strcasecmp (name, "string") == 0) { ud->parser_state = 8; @@ -308,6 +329,9 @@ xmlrpc_end_element (GMarkupParseContext *context, const gchar *name, gpointer us else if (g_ascii_strcasecmp (name, "int") == 0) { ud->parser_state = 8; } + else if (g_ascii_strcasecmp (name, "double") == 0) { + ud->parser_state = 8; + } else { /* Error state */ ud->parser_state = 99; @@ -326,6 +350,7 @@ xmlrpc_text (GMarkupParseContext *context, const gchar *text, gsize text_len, gp { struct lua_xmlrpc_ud *ud = user_data; gint num; + gdouble dnum; /* Strip line */ while (g_ascii_isspace (*text) && text_len > 0) { @@ -352,7 +377,13 @@ xmlrpc_text (GMarkupParseContext *context, const gchar *text, gsize text_len, gp num = strtoul (text, NULL, 10); lua_pushinteger (ud->L, num); break; + case 13: + /* Push integer value */ + dnum = strtod (text, NULL); + lua_pushnumber (ud->L, dnum); + break; } + ud->got_text = TRUE; } } @@ -400,7 +431,8 @@ lua_xmlrpc_parse_reply (lua_State *L) static gint lua_xmlrpc_parse_table (lua_State *L, gint pos, gchar *databuf, gint pr, gsize size) { - gint r = pr; + gint r = pr, num; + double dnum; r += rspamd_snprintf (databuf + r, size - r, ""); lua_pushnil (L); /* first key */ @@ -415,8 +447,18 @@ lua_xmlrpc_parse_table (lua_State *L, gint pos, gchar *databuf, gint pr, gsize s lua_tostring (L, -2)); switch (lua_type (L, -1)) { case LUA_TNUMBER: - r += rspamd_snprintf (databuf + r, size - r, "%d", - lua_tointeger (L, -1)); + num = lua_tointeger (L, -1); + dnum = lua_tonumber (L, -1); + + /* Try to avoid conversion errors */ + if (dnum != (double)num) { + r += rspamd_snprintf (databuf + r, sizeof (databuf) - r, "%f", + dnum); + } + else { + r += rspamd_snprintf (databuf + r, sizeof (databuf) - r, "%d", + num); + } break; case LUA_TBOOLEAN: r += rspamd_snprintf (databuf + r, size - r, "%d", @@ -449,7 +491,8 @@ lua_xmlrpc_make_request (lua_State *L) { gchar databuf[BUFSIZ * 2]; const gchar *func; - gint r, top, i; + gint r, top, i, num; + double dnum; func = luaL_checkstring (L, 1); @@ -465,8 +508,18 @@ lua_xmlrpc_make_request (lua_State *L) r += rspamd_snprintf (databuf + r, sizeof (databuf) - r, ""); switch (lua_type (L, i)) { case LUA_TNUMBER: - r += rspamd_snprintf (databuf + r, sizeof (databuf) - r, "%d", - lua_tointeger (L, i)); + num = lua_tointeger (L, i); + dnum = lua_tonumber (L, i); + + /* Try to avoid conversion errors */ + if (dnum != (double)num) { + r += rspamd_snprintf (databuf + r, sizeof (databuf) - r, "%f", + dnum); + } + else { + r += rspamd_snprintf (databuf + r, sizeof (databuf) - r, "%d", + num); + } break; case LUA_TBOOLEAN: r += rspamd_snprintf (databuf + r, sizeof (databuf) - r, "%d", diff --git a/src/plugins/lua/trie.lua b/src/plugins/lua/trie.lua index 98248f29f..6b1782a00 100644 --- a/src/plugins/lua/trie.lua +++ b/src/plugins/lua/trie.lua @@ -66,6 +66,15 @@ function check_trie(task) if trie['trie']:search_task(task) then task:insert_result(trie['symbol'], 1) end + -- Search inside urls + urls = task:get_urls() + if urls then + for _,url in urls do + if trie['trie']:search_text(url:get_text()) then + task:insert_result(trie['symbol'], 1) + end + end + end end end diff --git a/src/plugins/spf.c b/src/plugins/spf.c index f5cbbe7b3..223bc1241 100644 --- a/src/plugins/spf.c +++ b/src/plugins/spf.c @@ -43,10 +43,13 @@ #include "../map.h" #include "../spf.h" #include "../cfg_xml.h" +#include "../hash.h" #define DEFAULT_SYMBOL_FAIL "R_SPF_FAIL" #define DEFAULT_SYMBOL_SOFTFAIL "R_SPF_SOFTFAIL" #define DEFAULT_SYMBOL_ALLOW "R_SPF_ALLOW" +#define DEFAULT_CACHE_SIZE 2048 +#define DEFAULT_CACHE_MAXAGE 86400 struct spf_ctx { gint (*filter) (struct worker_task * task); @@ -54,13 +57,16 @@ struct spf_ctx { gchar *symbol_softfail; gchar *symbol_allow; - memory_pool_t *spf_pool; - radix_tree_t *whitelist_ip; + memory_pool_t *spf_pool; + radix_tree_t *whitelist_ip; + rspamd_lru_hash_t *spf_hash; }; static struct spf_ctx *spf_module_ctx = NULL; static void spf_symbol_callback (struct worker_task *task, void *unused); +static GList * spf_record_copy (GList *addrs); +static void spf_record_destroy (gpointer list); gint spf_module_init (struct config_file *cfg, struct module_ctx **ctx) @@ -73,6 +79,8 @@ spf_module_init (struct config_file *cfg, struct module_ctx **ctx) register_module_opt ("spf", "symbol_fail", MODULE_OPT_TYPE_STRING); register_module_opt ("spf", "symbol_softfail", MODULE_OPT_TYPE_STRING); register_module_opt ("spf", "symbol_allow", MODULE_OPT_TYPE_STRING); + register_module_opt ("spf", "spf_cache_size", MODULE_OPT_TYPE_UINT); + register_module_opt ("spf", "spf_cache_expire", MODULE_OPT_TYPE_TIME); register_module_opt ("spf", "whitelist", MODULE_OPT_TYPE_MAP); return 0; @@ -82,8 +90,9 @@ spf_module_init (struct config_file *cfg, struct module_ctx **ctx) gint spf_module_config (struct config_file *cfg) { - gchar *value; + gchar *value; gint res = TRUE; + guint cache_size, cache_expire; spf_module_ctx->whitelist_ip = radix_tree_create (); @@ -105,6 +114,18 @@ spf_module_config (struct config_file *cfg) else { spf_module_ctx->symbol_allow = DEFAULT_SYMBOL_ALLOW; } + if ((value = get_module_opt (cfg, "spf", "spf_cache_size")) != NULL) { + cache_size = strtoul (value, NULL, 10); + } + else { + cache_size = DEFAULT_CACHE_SIZE; + } + if ((value = get_module_opt (cfg, "spf", "spf_cache_expire")) != NULL) { + cache_expire = parse_time (value, TIME_SECONDS) / 1000; + } + else { + cache_expire = DEFAULT_CACHE_MAXAGE; + } if ((value = get_module_opt (cfg, "spf", "whitelist")) != NULL) { if (! add_map (value, read_radix_list, fin_radix_list, (void **)&spf_module_ctx->whitelist_ip)) { msg_warn ("cannot load whitelist from %s", value); @@ -115,6 +136,9 @@ spf_module_config (struct config_file *cfg) register_virtual_symbol (&cfg->cache, spf_module_ctx->symbol_softfail, 1); register_virtual_symbol (&cfg->cache, spf_module_ctx->symbol_allow, 1); + spf_module_ctx->spf_hash = rspamd_lru_hash_new (rspamd_strcase_hash, rspamd_strcase_equal, + cache_size, cache_expire, g_free, spf_record_destroy); + return res; } @@ -175,7 +199,6 @@ spf_check_list (GList *list, struct worker_task *task) addr = cur->data; if (addr->is_list) { /* Recursive call */ - addr->data.list = g_list_reverse (addr->data.list); if (spf_check_list (addr->data.list, task)) { return TRUE; } @@ -194,9 +217,15 @@ spf_check_list (GList *list, struct worker_task *task) static void spf_plugin_callback (struct spf_record *record, struct worker_task *task) { + GList *l; if (record) { - record->addrs = g_list_reverse (record->addrs); - spf_check_list (record->addrs, task); + + if ((l = rspamd_lru_hash_lookup (spf_module_ctx->spf_hash, record->sender_domain, task->tv.tv_sec)) == NULL) { + l = spf_record_copy (record->addrs); + rspamd_lru_hash_insert (spf_module_ctx->spf_hash, g_strdup (record->sender_domain), + l, task->tv.tv_sec); + } + spf_check_list (l, task); } if (task->save.saved == 0) { @@ -211,10 +240,20 @@ spf_plugin_callback (struct spf_record *record, struct worker_task *task) static void spf_symbol_callback (struct worker_task *task, void *unused) { + gchar *domain; + GList *l; + if (task->from_addr.s_addr != INADDR_NONE && task->from_addr.s_addr != INADDR_ANY) { if (radix32tree_find (spf_module_ctx->whitelist_ip, ntohl (task->from_addr.s_addr)) == RADIX_NO_VALUE) { - if (!resolve_spf (task, spf_plugin_callback)) { - msg_info ("cannot make spf request for [%s]", task->message_id); + + domain = get_spf_domain (task); + if (domain) { + if ((l = rspamd_lru_hash_lookup (spf_module_ctx->spf_hash, domain, task->tv.tv_sec)) != NULL) { + spf_check_list (l, task); + } + else if (!resolve_spf (task, spf_plugin_callback)) { + msg_info ("cannot make spf request for [%s]", task->message_id); + } } } else { @@ -222,3 +261,54 @@ spf_symbol_callback (struct worker_task *task, void *unused) } } } + +/* + * Make a deep copy of list, note copy is REVERSED + */ +static GList * +spf_record_copy (GList *addrs) +{ + GList *cur, *newl = NULL; + struct spf_addr *addr, *newa; + + cur = addrs; + + while (cur) { + addr = cur->data; + newa = g_malloc (sizeof (struct spf_addr)); + memcpy (newa, addr, sizeof (struct spf_addr)); + if (addr->is_list) { + /* Recursive call */ + newa->data.list = spf_record_copy (addr->data.list); + } + newl = g_list_prepend (newl, newa); + cur = g_list_next (cur); + } + + return newl; +} + +/* + * Destroy allocated spf list + */ + + +static void +spf_record_destroy (gpointer list) +{ + GList *cur = list; + struct spf_addr *addr; + + while (cur) { + addr = cur->data; + if (addr->is_list) { + spf_record_destroy (addr->data.list); + } + else { + g_free (addr); + } + cur = g_list_next (cur); + } + + g_list_free (list); +} diff --git a/src/spf.c b/src/spf.c index 7c0e82ec6..c6498f69b 100644 --- a/src/spf.c +++ b/src/spf.c @@ -1171,13 +1171,47 @@ spf_dns_callback (struct rspamd_dns_reply *reply, gpointer arg) } } +gchar * +get_spf_domain (struct worker_task *task) +{ + gchar *domain, *res = NULL; + GList *domains; + + if (task->from && (domain = strchr (task->from, '@')) != NULL && *domain == '@') { + res = memory_pool_strdup (task->task_pool, domain + 1); + if ((domain = strchr (res, '>')) != NULL) { + *domain = '\0'; + } + } + else { + /* Extract from header */ + domains = message_get_header (task->task_pool, task->message, "From", FALSE); + + if (domains != NULL) { + res = memory_pool_strdup (task->task_pool, domains->data); + + if ((domain = strrchr (res, '@')) == NULL) { + g_list_free (domains); + return NULL; + } + res = memory_pool_strdup (task->task_pool, domain + 1); + g_list_free (domains); + + if ((domain = strchr (res, '>')) != NULL) { + *domain = '\0'; + } + } + } + + return res; +} gboolean resolve_spf (struct worker_task *task, spf_cb_t callback) { - struct spf_record *rec; + struct spf_record *rec; gchar *domain; - GList *domains; + GList *domains; rec = memory_pool_alloc0 (task->task_pool, sizeof (struct spf_record)); rec->task = task; diff --git a/src/spf.h b/src/spf.h index 7a801254e..9ea7d050a 100644 --- a/src/spf.h +++ b/src/spf.h @@ -60,5 +60,7 @@ struct spf_record { gboolean resolve_spf (struct worker_task *task, spf_cb_t callback); +gchar *get_spf_domain (struct worker_task *task); + #endif