diff options
-rw-r--r-- | src/lua/lua_common.c | 12 | ||||
-rw-r--r-- | src/lua/lua_common.h | 15 | ||||
-rw-r--r-- | src/lua/lua_task.c | 192 | ||||
-rw-r--r-- | src/lua/lua_url.c | 145 | ||||
-rw-r--r-- | src/lua/lua_url.h | 71 |
5 files changed, 250 insertions, 185 deletions
diff --git a/src/lua/lua_common.c b/src/lua/lua_common.c index 2be91140a..9c4a5d8d1 100644 --- a/src/lua/lua_common.c +++ b/src/lua/lua_common.c @@ -198,6 +198,18 @@ rspamd_lua_setclass (lua_State * L, const gchar *classname, gint objidx) } void +rspamd_lua_class_metatable (lua_State *L, const gchar *classname) +{ + khiter_t k; + + k = kh_get (lua_class_set, lua_classes, classname); + + g_assert (k != kh_end (lua_classes)); + lua_rawgetp (L, LUA_REGISTRYINDEX, + RSPAMD_LIGHTUSERDATA_MASK (kh_key (lua_classes, k))); +} + +void rspamd_lua_add_metamethod (lua_State *L, const gchar *classname, luaL_Reg *meth) { diff --git a/src/lua/lua_common.h b/src/lua/lua_common.h index 5edec663b..296b8f326 100644 --- a/src/lua/lua_common.h +++ b/src/lua/lua_common.h @@ -63,11 +63,9 @@ static inline void lua_rawsetp (lua_State *L, int i, const void *p) { #endif /* Interface definitions */ -#define LUA_FUNCTION_DEF(class, name) static int lua_ ## class ## _ ## name ( \ - lua_State * L) -#define LUA_PUBLIC_FUNCTION_DEF(class, name) int lua_ ## class ## _ ## name ( \ - lua_State * L) -#define LUA_INTERFACE_DEF(class, name) { # name, lua_ ## class ## _ ## name } +#define LUA_FUNCTION_DEF(class, name) static int lua_##class##_##name (lua_State * L) +#define LUA_PUBLIC_FUNCTION_DEF(class, name) int lua_##class##_##name (lua_State * L) +#define LUA_INTERFACE_DEF(class, name) { #name, lua_##class##_##name } #ifdef __cplusplus extern "C" { @@ -162,6 +160,13 @@ void rspamd_lua_new_class (lua_State *L, void rspamd_lua_setclass (lua_State *L, const gchar *classname, gint objidx); /** + * Pushes the metatable for specific class on top of the stack + * @param L + * @param classname + */ +void rspamd_lua_class_metatable (lua_State *L, const gchar *classname); + +/** * Adds a new field to the class (metatable) identified by `classname` * @param L * @param classname diff --git a/src/lua/lua_task.c b/src/lua/lua_task.c index b891d7d99..5c7a8b0a4 100644 --- a/src/lua/lua_task.c +++ b/src/lua/lua_task.c @@ -14,6 +14,8 @@ * limitations under the License. */ #include "lua_common.h" +#include "lua_url.h" + #include "message.h" #include "images.h" #include "archives.h" @@ -2245,61 +2247,7 @@ lua_task_append_message (lua_State * L) return 0; } -struct lua_tree_cb_data { - lua_State *L; - int i; - gint mask; - gint need_images; - gdouble skip_prob; - guint64 xoroshiro_state[4]; -}; - -static void -lua_tree_url_callback (gpointer key, gpointer value, gpointer ud) -{ - struct rspamd_lua_url *lua_url; - struct rspamd_url *url = (struct rspamd_url *)value; - struct lua_tree_cb_data *cb = ud; - - if (url->protocol & cb->mask) { - if (!cb->need_images && (url->flags & RSPAMD_URL_FLAG_IMAGE)) { - return; - } - - if (cb->skip_prob > 0) { - gdouble coin = rspamd_random_double_fast_seed (cb->xoroshiro_state); - - if (coin < cb->skip_prob) { - return; - } - } - - lua_url = lua_newuserdata (cb->L, sizeof (struct rspamd_lua_url)); - rspamd_lua_setclass (cb->L, "rspamd{url}", -1); - lua_url->url = url; - lua_rawseti (cb->L, -2, cb->i++); - } -} - -static inline gsize -lua_task_urls_adjust_skip_prob (struct rspamd_task *task, - struct lua_tree_cb_data *cb, gsize sz, gsize max_urls) -{ - if (max_urls > 0 && sz > max_urls) { - cb->skip_prob = 1.0 - ((gdouble)max_urls) / (gdouble)sz; - /* - * Use task dependent probabilistic seed to ensure that - * consequent task:get_urls return the same list of urls - */ - memcpy (&cb->xoroshiro_state[0], &task->task_timestamp, - MIN (sizeof (cb->xoroshiro_state[0]), sizeof (task->task_timestamp))); - memcpy (&cb->xoroshiro_state[1], MESSAGE_FIELD (task, digest), - sizeof (cb->xoroshiro_state[1]) * 3); - sz = max_urls; - } - return sz; -} static gint lua_task_get_urls (lua_State * L) @@ -2307,12 +2255,7 @@ lua_task_get_urls (lua_State * L) LUA_TRACE_POINT; struct rspamd_task *task = lua_check_task (L, 1); struct lua_tree_cb_data cb; - gint protocols_mask = 0; - static const gint default_mask = PROTOCOL_HTTP|PROTOCOL_HTTPS| - PROTOCOL_FILE|PROTOCOL_FTP; - const gchar *cache_name = "emails+urls"; struct rspamd_url *u; - gboolean need_images = FALSE; gsize sz, max_urls = 0; if (task) { @@ -2326,135 +2269,26 @@ lua_task_get_urls (lua_State * L) return 1; } - if (lua_gettop (L) >= 2) { - if (lua_type (L, 2) == LUA_TBOOLEAN) { - protocols_mask = default_mask; - if (lua_toboolean (L, 2)) { - protocols_mask |= PROTOCOL_MAILTO; - } - } - else if (lua_type (L, 2) == LUA_TTABLE) { - for (lua_pushnil (L); lua_next (L, 2); lua_pop (L, 1)) { - int nmask; - const gchar *pname = lua_tostring (L, -1); - - nmask = rspamd_url_protocol_from_string (pname); - - if (nmask != PROTOCOL_UNKNOWN) { - protocols_mask |= nmask; - } - else { - msg_info ("bad url protocol: %s", pname); - } - } - } - else if (lua_type (L, 2) == LUA_TSTRING) { - const gchar *plist = lua_tostring (L, 2); - gchar **strvec; - gchar * const *cvec; - - strvec = g_strsplit_set (plist, ",;", -1); - cvec = strvec; - - while (*cvec) { - int nmask; - - nmask = rspamd_url_protocol_from_string (*cvec); - - if (nmask != PROTOCOL_UNKNOWN) { - protocols_mask |= nmask; - } - else { - msg_info ("bad url protocol: %s", *cvec); - } - - cvec ++; - } - - g_strfreev (strvec); - } - else { - protocols_mask = default_mask; - } - - if (lua_type (L, 3) == LUA_TBOOLEAN) { - need_images = lua_toboolean (L, 3); - } - } - else { - protocols_mask = default_mask; + if (!lua_url_cbdata_fill (L, 2, &cb)) { + return luaL_error (L, "invalid arguments"); } memset (&cb, 0, sizeof (cb)); - cb.i = 1; - cb.L = L; - cb.mask = protocols_mask; - cb.need_images = need_images; - - if (protocols_mask & PROTOCOL_MAILTO) { - if (need_images) { - cache_name = "emails+urls+img"; - } - else { - cache_name = "emails+urls"; - } - sz = kh_size (MESSAGE_FIELD (task, urls)); + sz = kh_size (MESSAGE_FIELD (task, urls)); + sz = lua_url_adjust_skip_prob (task->task_timestamp, + MESSAGE_FIELD (task, digest), &cb, sz, max_urls); - sz = lua_task_urls_adjust_skip_prob (task, &cb, sz, max_urls); + lua_createtable (L, sz, 0); - if (protocols_mask == (default_mask|PROTOCOL_MAILTO)) { - /* Can use cached version */ - if (!lua_task_get_cached (L, task, cache_name)) { - lua_createtable (L, sz, 0); - kh_foreach_key (MESSAGE_FIELD (task, urls), u, { - lua_tree_url_callback (u, u, &cb); - }); - lua_task_set_cached (L, task, cache_name, -1); - } - } - else { - lua_createtable (L, sz, 0); - kh_foreach_key (MESSAGE_FIELD (task, urls), u, { - lua_tree_url_callback (u, u, &cb); - }); - } - - } - else { - if (need_images) { - cache_name = "urls+img"; - } - else { - cache_name = "urls"; - } - - sz = kh_size (MESSAGE_FIELD (task, urls)); - sz = lua_task_urls_adjust_skip_prob (task, &cb, sz, max_urls); + kh_foreach_key (MESSAGE_FIELD (task, urls), u, { + lua_tree_url_callback (u, u, &cb); + }); - if (protocols_mask == (default_mask)) { - if (!lua_task_get_cached (L, task, cache_name)) { - lua_createtable (L, sz, 0); - kh_foreach_key (MESSAGE_FIELD (task, urls), u, { - if (!(u->protocol & PROTOCOL_MAILTO)) { - lua_tree_url_callback (u, u, &cb); - } - }); - lua_task_set_cached (L, task, cache_name, -1); - } - } - else { - lua_createtable (L, sz, 0); - kh_foreach_key (MESSAGE_FIELD (task, urls), u, { - if (!(u->protocol & PROTOCOL_MAILTO)) { - lua_tree_url_callback (u, u, &cb); - } - }); - } - } + lua_url_cbdata_dtor (&cb); } else { - return luaL_error (L, "invalid arguments"); + return luaL_error (L, "invalid arguments, no task"); } return 1; diff --git a/src/lua/lua_url.c b/src/lua/lua_url.c index efd34dc6c..6540919ea 100644 --- a/src/lua/lua_url.c +++ b/src/lua/lua_url.c @@ -14,7 +14,8 @@ * limitations under the License. */ #include "lua_common.h" -#include "contrib/uthash/utlist.h" +#include "lua_url.h" + /*** * @module rspamd_url @@ -903,6 +904,148 @@ lua_url_get_flags (lua_State *L) #undef PUSH_FLAG +void +lua_tree_url_callback (gpointer key, gpointer value, gpointer ud) +{ + struct rspamd_lua_url *lua_url; + struct rspamd_url *url = (struct rspamd_url *)value; + struct lua_tree_cb_data *cb = ud; + + if (url->protocol & cb->mask) { + if (!cb->need_images && (url->flags & RSPAMD_URL_FLAG_IMAGE)) { + return; + } + + if (cb->skip_prob > 0) { + gdouble coin = rspamd_random_double_fast_seed (cb->xoroshiro_state); + + if (coin < cb->skip_prob) { + return; + } + } + + lua_url = lua_newuserdata (cb->L, sizeof (struct rspamd_lua_url)); + lua_pushvalue (cb->L, cb->metatable_pos); + lua_setmetatable (cb->L, -2); + lua_url->url = url; + lua_rawseti (cb->L, -2, cb->i++); + } +} + +gboolean +lua_url_cbdata_fill (lua_State *L, gint pos, struct lua_tree_cb_data *cbd) +{ + gboolean need_images = FALSE; + gint protocols_mask = 0; + static const gint default_mask = PROTOCOL_HTTP|PROTOCOL_HTTPS| + PROTOCOL_FILE|PROTOCOL_FTP; + gint pos_arg_type = lua_type (L, pos); + + if (pos_arg_type == LUA_TBOOLEAN) { + protocols_mask = default_mask; + if (lua_toboolean (L, 2)) { + protocols_mask |= PROTOCOL_MAILTO; + } + } + else if (pos_arg_type == LUA_TTABLE) { + for (lua_pushnil (L); lua_next (L, pos); lua_pop (L, 1)) { + int nmask; + const gchar *pname = lua_tostring (L, -1); + + nmask = rspamd_url_protocol_from_string (pname); + + if (nmask != PROTOCOL_UNKNOWN) { + protocols_mask |= nmask; + } + else { + msg_info ("bad url protocol: %s", pname); + return FALSE; + } + } + } + else if (pos_arg_type == LUA_TSTRING) { + const gchar *plist = lua_tostring (L, pos); + gchar **strvec; + gchar * const *cvec; + + strvec = g_strsplit_set (plist, ",;", -1); + cvec = strvec; + + while (*cvec) { + int nmask; + + nmask = rspamd_url_protocol_from_string (*cvec); + + if (nmask != PROTOCOL_UNKNOWN) { + protocols_mask |= nmask; + } + else { + msg_info ("bad url protocol: %s", *cvec); + return FALSE; + } + + cvec ++; + } + + g_strfreev (strvec); + } + else if (pos_arg_type == LUA_TNONE || pos_arg_type == LUA_TNIL) { + protocols_mask = default_mask; + } + else { + return FALSE; + } + + if (lua_type (L, pos + 1) == LUA_TBOOLEAN) { + need_images = lua_toboolean (L, pos + 1); + } + + memset (cbd, 0, sizeof (*cbd)); + + cbd->i = 1; + cbd->L = L; + cbd->mask = protocols_mask; + cbd->need_images = need_images; + + /* This needs to be removed from the stack */ + rspamd_lua_class_metatable (L, "rspamd{url}"); + cbd->metatable_pos = lua_gettop (L); + (void)lua_checkstack (L, cbd->metatable_pos + 4); + + return TRUE; +} + +void +lua_url_cbdata_dtor (struct lua_tree_cb_data *cbd) +{ + if (cbd->metatable_pos != -1) { + lua_remove (cbd->L, cbd->metatable_pos); + } +} + +gsize +lua_url_adjust_skip_prob (gdouble timestamp, + guchar *digest, + struct lua_tree_cb_data *cb, + gsize sz, + gsize max_urls) +{ + if (max_urls > 0 && sz > max_urls) { + cb->skip_prob = 1.0 - ((gdouble)max_urls) / (gdouble)sz; + /* + * Use task dependent probabilistic seed to ensure that + * consequent task:get_urls return the same list of urls + */ + memcpy (&cb->xoroshiro_state[0], ×tamp, + MIN (sizeof (cb->xoroshiro_state[0]), sizeof (timestamp))); + memcpy (&cb->xoroshiro_state[1], digest, + sizeof (cb->xoroshiro_state[1]) * 3); + sz = max_urls; + } + + return sz; +} + static gint lua_load_url (lua_State * L) { diff --git a/src/lua/lua_url.h b/src/lua/lua_url.h new file mode 100644 index 000000000..57d20f920 --- /dev/null +++ b/src/lua/lua_url.h @@ -0,0 +1,71 @@ +/*- + * Copyright 2020 Vsevolod Stakhov + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef RSPAMD_LUA_URL_H +#define RSPAMD_LUA_URL_H + +#include "lua_common.h" + +#ifdef __cplusplus +extern "C" { +#endif + +struct lua_tree_cb_data { + lua_State *L; + int i; + int metatable_pos; + gint mask; + gint need_images; + gdouble skip_prob; + guint64 xoroshiro_state[4]; +}; + +void lua_tree_url_callback (gpointer key, gpointer value, gpointer ud); + +/** + * Fills a cbdata table based on the parameter at position pos + * @param L + * @param pos + * @param cbd + * @return + */ +gboolean lua_url_cbdata_fill (lua_State *L, gint pos, struct lua_tree_cb_data *cbd); + +/** + * Cleanup url cbdata + * @param cbd + */ +void lua_url_cbdata_dtor (struct lua_tree_cb_data *cbd); + +/** + * Adjust probabilistic skip of the urls + * @param timestamp + * @param digest + * @param cb + * @param sz + * @param max_urls + * @return + */ +gsize lua_url_adjust_skip_prob (gdouble timestamp, + guchar *digest, + struct lua_tree_cb_data *cb, + gsize sz, + gsize max_urls); + +#ifdef __cplusplus +} +#endif + +#endif |