]> source.dussan.org Git - rspamd.git/commitdiff
[Rework] Rework urls extraction
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Tue, 21 Apr 2020 12:09:16 +0000 (13:09 +0100)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Tue, 21 Apr 2020 12:09:16 +0000 (13:09 +0100)
src/lua/lua_common.c
src/lua/lua_common.h
src/lua/lua_task.c
src/lua/lua_url.c
src/lua/lua_url.h [new file with mode: 0644]

index 2be91140acc89790a87ffef57e6c135dd7472ece..9c4a5d8d12f8e38cfee6859e81f7bae7f0531cf2 100644 (file)
@@ -197,6 +197,18 @@ rspamd_lua_setclass (lua_State * L, const gchar *classname, gint objidx)
        lua_setmetatable (L, objidx);
 }
 
+void
+rspamd_lua_class_metatable (lua_State *L, const gchar *classname)
+{
+       khiter_t k;
+
+       k = kh_get (lua_class_set, lua_classes, classname);
+
+       g_assert (k != kh_end (lua_classes));
+       lua_rawgetp (L, LUA_REGISTRYINDEX,
+                       RSPAMD_LIGHTUSERDATA_MASK (kh_key (lua_classes, k)));
+}
+
 void
 rspamd_lua_add_metamethod (lua_State *L, const gchar *classname,
                                                                luaL_Reg *meth)
index 5edec663b91c6fc6971eaa848347cfe6cd546b54..296b8f3265d104c3c016873e0dcbd3ff6f48114e 100644 (file)
@@ -63,11 +63,9 @@ static inline void lua_rawsetp (lua_State *L, int i, const void *p) {
 #endif
 
 /* Interface definitions */
-#define LUA_FUNCTION_DEF(class, name) static int lua_ ## class ## _ ## name ( \
-               lua_State * L)
-#define LUA_PUBLIC_FUNCTION_DEF(class, name) int lua_ ## class ## _ ## name ( \
-               lua_State * L)
-#define LUA_INTERFACE_DEF(class, name) { # name, lua_ ## class ## _ ## name }
+#define LUA_FUNCTION_DEF(class, name) static int lua_##class##_##name (lua_State * L)
+#define LUA_PUBLIC_FUNCTION_DEF(class, name) int lua_##class##_##name (lua_State * L)
+#define LUA_INTERFACE_DEF(class, name) { #name, lua_##class##_##name }
 
 #ifdef  __cplusplus
 extern "C" {
@@ -161,6 +159,13 @@ void rspamd_lua_new_class (lua_State *L,
  */
 void rspamd_lua_setclass (lua_State *L, const gchar *classname, gint objidx);
 
+/**
+ * Pushes the metatable for specific class on top of the stack
+ * @param L
+ * @param classname
+ */
+void rspamd_lua_class_metatable (lua_State *L, const gchar *classname);
+
 /**
  * Adds a new field to the class (metatable) identified by `classname`
  * @param L
index b891d7d99a1ed65203ca1eea6e6d208869e9f0eb..5c7a8b0a427613ef01049ebda3be1d71a7e481b2 100644 (file)
@@ -14,6 +14,8 @@
  * limitations under the License.
  */
 #include "lua_common.h"
+#include "lua_url.h"
+
 #include "message.h"
 #include "images.h"
 #include "archives.h"
@@ -2245,61 +2247,7 @@ lua_task_append_message (lua_State * L)
        return 0;
 }
 
-struct lua_tree_cb_data {
-       lua_State *L;
-       int i;
-       gint mask;
-       gint need_images;
-       gdouble skip_prob;
-       guint64 xoroshiro_state[4];
-};
-
-static void
-lua_tree_url_callback (gpointer key, gpointer value, gpointer ud)
-{
-       struct rspamd_lua_url *lua_url;
-       struct rspamd_url *url = (struct rspamd_url *)value;
-       struct lua_tree_cb_data *cb = ud;
-
-       if (url->protocol & cb->mask) {
-               if (!cb->need_images && (url->flags & RSPAMD_URL_FLAG_IMAGE)) {
-                       return;
-               }
-
-               if (cb->skip_prob > 0) {
-                       gdouble coin = rspamd_random_double_fast_seed (cb->xoroshiro_state);
-
-                       if (coin < cb->skip_prob) {
-                               return;
-                       }
-               }
-
-               lua_url = lua_newuserdata (cb->L, sizeof (struct rspamd_lua_url));
-               rspamd_lua_setclass (cb->L, "rspamd{url}", -1);
-               lua_url->url = url;
-               lua_rawseti (cb->L, -2, cb->i++);
-       }
-}
-
-static inline gsize
-lua_task_urls_adjust_skip_prob (struct rspamd_task *task,
-               struct lua_tree_cb_data *cb, gsize sz, gsize max_urls)
-{
-       if (max_urls > 0 && sz > max_urls) {
-               cb->skip_prob = 1.0 - ((gdouble)max_urls) / (gdouble)sz;
-               /*
-                * Use task dependent probabilistic seed to ensure that
-                * consequent task:get_urls return the same list of urls
-                */
-               memcpy (&cb->xoroshiro_state[0], &task->task_timestamp,
-                               MIN (sizeof (cb->xoroshiro_state[0]), sizeof (task->task_timestamp)));
-               memcpy (&cb->xoroshiro_state[1], MESSAGE_FIELD (task, digest),
-                               sizeof (cb->xoroshiro_state[1]) * 3);
-               sz = max_urls;
-       }
 
-       return sz;
-}
 
 static gint
 lua_task_get_urls (lua_State * L)
@@ -2307,12 +2255,7 @@ lua_task_get_urls (lua_State * L)
        LUA_TRACE_POINT;
        struct rspamd_task *task = lua_check_task (L, 1);
        struct lua_tree_cb_data cb;
-       gint protocols_mask = 0;
-       static const gint default_mask = PROTOCOL_HTTP|PROTOCOL_HTTPS|
-                       PROTOCOL_FILE|PROTOCOL_FTP;
-       const gchar *cache_name = "emails+urls";
        struct rspamd_url *u;
-       gboolean need_images = FALSE;
        gsize sz, max_urls = 0;
 
        if (task) {
@@ -2326,135 +2269,26 @@ lua_task_get_urls (lua_State * L)
                        return 1;
                }
 
-               if (lua_gettop (L) >= 2) {
-                       if (lua_type (L, 2) == LUA_TBOOLEAN) {
-                               protocols_mask = default_mask;
-                               if (lua_toboolean (L, 2)) {
-                                       protocols_mask |= PROTOCOL_MAILTO;
-                               }
-                       }
-                       else if (lua_type (L, 2) == LUA_TTABLE) {
-                               for (lua_pushnil (L); lua_next (L, 2); lua_pop (L, 1)) {
-                                       int nmask;
-                                       const gchar *pname = lua_tostring (L, -1);
-
-                                       nmask = rspamd_url_protocol_from_string (pname);
-
-                                       if (nmask != PROTOCOL_UNKNOWN) {
-                                               protocols_mask |= nmask;
-                                       }
-                                       else {
-                                               msg_info ("bad url protocol: %s", pname);
-                                       }
-                               }
-                       }
-                       else if (lua_type (L, 2) == LUA_TSTRING) {
-                               const gchar *plist = lua_tostring (L, 2);
-                               gchar **strvec;
-                               gchar * const *cvec;
-
-                               strvec = g_strsplit_set (plist, ",;", -1);
-                               cvec = strvec;
-
-                               while (*cvec) {
-                                       int nmask;
-
-                                       nmask = rspamd_url_protocol_from_string (*cvec);
-
-                                       if (nmask != PROTOCOL_UNKNOWN) {
-                                               protocols_mask |= nmask;
-                                       }
-                                       else {
-                                               msg_info ("bad url protocol: %s", *cvec);
-                                       }
-
-                                       cvec ++;
-                               }
-
-                               g_strfreev (strvec);
-                       }
-                       else {
-                               protocols_mask = default_mask;
-                       }
-
-                       if (lua_type (L, 3) == LUA_TBOOLEAN) {
-                               need_images = lua_toboolean (L, 3);
-                       }
-               }
-               else {
-                       protocols_mask = default_mask;
+               if (!lua_url_cbdata_fill (L, 2, &cb)) {
+                       return luaL_error (L, "invalid arguments");
                }
 
                memset (&cb, 0, sizeof (cb));
-               cb.i = 1;
-               cb.L = L;
-               cb.mask = protocols_mask;
-               cb.need_images = need_images;
-
-               if (protocols_mask & PROTOCOL_MAILTO) {
-                       if (need_images) {
-                               cache_name = "emails+urls+img";
-                       }
-                       else {
-                               cache_name = "emails+urls";
-                       }
 
-                       sz = kh_size (MESSAGE_FIELD (task, urls));
+               sz = kh_size (MESSAGE_FIELD (task, urls));
+               sz = lua_url_adjust_skip_prob (task->task_timestamp,
+                               MESSAGE_FIELD (task, digest), &cb, sz, max_urls);
 
-                       sz = lua_task_urls_adjust_skip_prob (task, &cb, sz, max_urls);
+               lua_createtable (L, sz, 0);
 
-                       if (protocols_mask == (default_mask|PROTOCOL_MAILTO)) {
-                               /* Can use cached version */
-                               if (!lua_task_get_cached (L, task, cache_name)) {
-                                       lua_createtable (L, sz, 0);
-                                       kh_foreach_key (MESSAGE_FIELD (task, urls), u, {
-                                               lua_tree_url_callback (u, u, &cb);
-                                       });
-                                       lua_task_set_cached (L, task, cache_name, -1);
-                               }
-                       }
-                       else {
-                               lua_createtable (L, sz, 0);
-                               kh_foreach_key (MESSAGE_FIELD (task, urls), u, {
-                                       lua_tree_url_callback (u, u, &cb);
-                               });
-                       }
-
-               }
-               else {
-                       if (need_images) {
-                               cache_name = "urls+img";
-                       }
-                       else {
-                               cache_name = "urls";
-                       }
-
-                       sz = kh_size (MESSAGE_FIELD (task, urls));
-                       sz = lua_task_urls_adjust_skip_prob (task, &cb, sz, max_urls);
+               kh_foreach_key (MESSAGE_FIELD (task, urls), u, {
+                       lua_tree_url_callback (u, u, &cb);
+               });
 
-                       if (protocols_mask == (default_mask)) {
-                               if (!lua_task_get_cached (L, task, cache_name)) {
-                                       lua_createtable (L, sz, 0);
-                                       kh_foreach_key (MESSAGE_FIELD (task, urls), u, {
-                                               if (!(u->protocol & PROTOCOL_MAILTO)) {
-                                                       lua_tree_url_callback (u, u, &cb);
-                                               }
-                                       });
-                                       lua_task_set_cached (L, task, cache_name, -1);
-                               }
-                       }
-                       else {
-                               lua_createtable (L, sz, 0);
-                               kh_foreach_key (MESSAGE_FIELD (task, urls), u, {
-                                       if (!(u->protocol & PROTOCOL_MAILTO)) {
-                                               lua_tree_url_callback (u, u, &cb);
-                                       }
-                               });
-                       }
-               }
+               lua_url_cbdata_dtor (&cb);
        }
        else {
-               return luaL_error (L, "invalid arguments");
+               return luaL_error (L, "invalid arguments, no task");
        }
 
        return 1;
index efd34dc6c74231f73e40b6ba76ea2d55b1eb7a21..6540919eaf6d23dff2f105eff52d4a29e15659b3 100644 (file)
@@ -14,7 +14,8 @@
  * limitations under the License.
  */
 #include "lua_common.h"
-#include "contrib/uthash/utlist.h"
+#include "lua_url.h"
+
 
 /***
  * @module rspamd_url
@@ -903,6 +904,148 @@ lua_url_get_flags (lua_State *L)
 
 #undef PUSH_FLAG
 
+void
+lua_tree_url_callback (gpointer key, gpointer value, gpointer ud)
+{
+       struct rspamd_lua_url *lua_url;
+       struct rspamd_url *url = (struct rspamd_url *)value;
+       struct lua_tree_cb_data *cb = ud;
+
+       if (url->protocol & cb->mask) {
+               if (!cb->need_images && (url->flags & RSPAMD_URL_FLAG_IMAGE)) {
+                       return;
+               }
+
+               if (cb->skip_prob > 0) {
+                       gdouble coin = rspamd_random_double_fast_seed (cb->xoroshiro_state);
+
+                       if (coin < cb->skip_prob) {
+                               return;
+                       }
+               }
+
+               lua_url = lua_newuserdata (cb->L, sizeof (struct rspamd_lua_url));
+               lua_pushvalue (cb->L, cb->metatable_pos);
+               lua_setmetatable (cb->L, -2);
+               lua_url->url = url;
+               lua_rawseti (cb->L, -2, cb->i++);
+       }
+}
+
+gboolean
+lua_url_cbdata_fill (lua_State *L, gint pos, struct lua_tree_cb_data *cbd)
+{
+       gboolean need_images = FALSE;
+       gint protocols_mask = 0;
+       static const gint default_mask = PROTOCOL_HTTP|PROTOCOL_HTTPS|
+                                                                        PROTOCOL_FILE|PROTOCOL_FTP;
+       gint pos_arg_type = lua_type (L, pos);
+
+       if (pos_arg_type == LUA_TBOOLEAN) {
+               protocols_mask = default_mask;
+               if (lua_toboolean (L, 2)) {
+                       protocols_mask |= PROTOCOL_MAILTO;
+               }
+       }
+       else if (pos_arg_type == LUA_TTABLE) {
+               for (lua_pushnil (L); lua_next (L, pos); lua_pop (L, 1)) {
+                       int nmask;
+                       const gchar *pname = lua_tostring (L, -1);
+
+                       nmask = rspamd_url_protocol_from_string (pname);
+
+                       if (nmask != PROTOCOL_UNKNOWN) {
+                               protocols_mask |= nmask;
+                       }
+                       else {
+                               msg_info ("bad url protocol: %s", pname);
+                               return FALSE;
+                       }
+               }
+       }
+       else if (pos_arg_type == LUA_TSTRING) {
+               const gchar *plist = lua_tostring (L, pos);
+               gchar **strvec;
+               gchar * const *cvec;
+
+               strvec = g_strsplit_set (plist, ",;", -1);
+               cvec = strvec;
+
+               while (*cvec) {
+                       int nmask;
+
+                       nmask = rspamd_url_protocol_from_string (*cvec);
+
+                       if (nmask != PROTOCOL_UNKNOWN) {
+                               protocols_mask |= nmask;
+                       }
+                       else {
+                               msg_info ("bad url protocol: %s", *cvec);
+                               return FALSE;
+                       }
+
+                       cvec ++;
+               }
+
+               g_strfreev (strvec);
+       }
+       else if (pos_arg_type == LUA_TNONE || pos_arg_type == LUA_TNIL) {
+               protocols_mask = default_mask;
+       }
+       else {
+               return FALSE;
+       }
+
+       if (lua_type (L, pos + 1) == LUA_TBOOLEAN) {
+               need_images = lua_toboolean (L, pos + 1);
+       }
+
+       memset (cbd, 0, sizeof (*cbd));
+
+       cbd->i = 1;
+       cbd->L = L;
+       cbd->mask = protocols_mask;
+       cbd->need_images = need_images;
+
+       /* This needs to be removed from the stack */
+       rspamd_lua_class_metatable (L, "rspamd{url}");
+       cbd->metatable_pos = lua_gettop (L);
+       (void)lua_checkstack (L, cbd->metatable_pos + 4);
+
+       return TRUE;
+}
+
+void
+lua_url_cbdata_dtor (struct lua_tree_cb_data *cbd)
+{
+       if (cbd->metatable_pos != -1) {
+               lua_remove (cbd->L, cbd->metatable_pos);
+       }
+}
+
+gsize
+lua_url_adjust_skip_prob (gdouble timestamp,
+                                                 guchar *digest,
+                                                 struct lua_tree_cb_data *cb,
+                                                 gsize sz,
+                                                 gsize max_urls)
+{
+       if (max_urls > 0 && sz > max_urls) {
+               cb->skip_prob = 1.0 - ((gdouble)max_urls) / (gdouble)sz;
+               /*
+                * Use task dependent probabilistic seed to ensure that
+                * consequent task:get_urls return the same list of urls
+                */
+               memcpy (&cb->xoroshiro_state[0], &timestamp,
+                               MIN (sizeof (cb->xoroshiro_state[0]), sizeof (timestamp)));
+               memcpy (&cb->xoroshiro_state[1], digest,
+                               sizeof (cb->xoroshiro_state[1]) * 3);
+               sz = max_urls;
+       }
+
+       return sz;
+}
+
 static gint
 lua_load_url (lua_State * L)
 {
diff --git a/src/lua/lua_url.h b/src/lua/lua_url.h
new file mode 100644 (file)
index 0000000..57d20f9
--- /dev/null
@@ -0,0 +1,71 @@
+/*-
+ * Copyright 2020 Vsevolod Stakhov
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef RSPAMD_LUA_URL_H
+#define RSPAMD_LUA_URL_H
+
+#include "lua_common.h"
+
+#ifdef  __cplusplus
+extern "C" {
+#endif
+
+struct lua_tree_cb_data {
+       lua_State *L;
+       int i;
+       int metatable_pos;
+       gint mask;
+       gint need_images;
+       gdouble skip_prob;
+       guint64 xoroshiro_state[4];
+};
+
+void lua_tree_url_callback (gpointer key, gpointer value, gpointer ud);
+
+/**
+ * Fills a cbdata table based on the parameter at position pos
+ * @param L
+ * @param pos
+ * @param cbd
+ * @return
+ */
+gboolean lua_url_cbdata_fill (lua_State *L, gint pos, struct lua_tree_cb_data *cbd);
+
+/**
+ * Cleanup url cbdata
+ * @param cbd
+ */
+void lua_url_cbdata_dtor (struct lua_tree_cb_data *cbd);
+
+/**
+ * Adjust probabilistic skip of the urls
+ * @param timestamp
+ * @param digest
+ * @param cb
+ * @param sz
+ * @param max_urls
+ * @return
+ */
+gsize lua_url_adjust_skip_prob (gdouble timestamp,
+                                                               guchar *digest,
+                                                               struct lua_tree_cb_data *cb,
+                                                               gsize sz,
+                                                               gsize max_urls);
+
+#ifdef  __cplusplus
+}
+#endif
+
+#endif