aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVsevolod Stakhov <vsevolod@highsecure.ru>2020-11-16 19:12:59 +0000
committerGitHub <noreply@github.com>2020-11-16 19:12:59 +0000
commitae73d813970522f5f3de5399ddde0f1f367c9479 (patch)
tree44e688977cc8bccd39d8aa6a56009889a5342548
parent0daf76ab882d2e2bda63a507e075ca05b59ee7f4 (diff)
parent68badebdac453aef0f8bc5af96e9a289aefc02e5 (diff)
downloadrspamd-ae73d813970522f5f3de5399ddde0f1f367c9479.tar.gz
rspamd-ae73d813970522f5f3de5399ddde0f1f367c9479.zip
Merge pull request #3553 from fatalbanana/byte
[Minor] rspamd_text:byte() metamethod
-rw-r--r--contrib/lua-fun/fun.lua2
-rw-r--r--lualib/lua_content/pdf.lua2
-rw-r--r--lualib/lua_magic/heuristics.lua2
-rw-r--r--src/lua/lua_text.c40
-rw-r--r--test/lua/unit/rspamd_text.lua30
5 files changed, 61 insertions, 15 deletions
diff --git a/contrib/lua-fun/fun.lua b/contrib/lua-fun/fun.lua
index e16fe07ae..2c712d306 100644
--- a/contrib/lua-fun/fun.lua
+++ b/contrib/lua-fun/fun.lua
@@ -96,7 +96,7 @@ local text_gen = function(param, state)
if state > #param then
return nil
end
- local r = string.char(param:at(state))
+ local r = string.char(param:byte(state))
return state, r
end
diff --git a/lualib/lua_content/pdf.lua b/lualib/lua_content/pdf.lua
index 141a07a5f..11d5cab46 100644
--- a/lualib/lua_content/pdf.lua
+++ b/lualib/lua_content/pdf.lua
@@ -954,7 +954,7 @@ local function attach_pdf_streams(task, input, pdf)
end
-- Strip the first \n
while first < last do
- local chr = input:at(first)
+ local chr = input:byte(first)
if chr ~= 13 and chr ~= 10 then break end
first = first + 1
end
diff --git a/lualib/lua_magic/heuristics.lua b/lualib/lua_magic/heuristics.lua
index 6822abe1e..02bc2b4a2 100644
--- a/lualib/lua_magic/heuristics.lua
+++ b/lualib/lua_magic/heuristics.lua
@@ -181,7 +181,7 @@ local function detect_ole_format(input, log_obj, _, part)
end
local function process_dir_entry(offset)
- local dtype = input:at(offset + 66)
+ local dtype = input:byte(offset + 66)
lua_util.debugm(N, log_obj, "dtype: %s, offset: %s", dtype, offset)
if dtype then
diff --git a/src/lua/lua_text.c b/src/lua/lua_text.c
index feef9111f..99b8b8151 100644
--- a/src/lua/lua_text.c
+++ b/src/lua/lua_text.c
@@ -59,6 +59,14 @@ LUA_FUNCTION_DEF (text, randombytes);
*/
LUA_FUNCTION_DEF (text, fromtable);
/***
+ * @method rspamd_text:byte(pos[, pos2])
+ * Returns a byte at the position `pos` or bytes from `pos` to `pos2` if specified
+ * @param {integer} pos index
+ * @param {integer} pos2 index
+ * @return {integer} byte at the position `pos` or varargs of bytes
+ */
+LUA_FUNCTION_DEF (text, byte);
+/***
* @method rspamd_text:len()
* Returns length of a string
* @return {number} length of string in **bytes**
@@ -240,6 +248,7 @@ static const struct luaL_reg textlib_m[] = {
LUA_INTERFACE_DEF (text, split),
LUA_INTERFACE_DEF (text, at),
LUA_INTERFACE_DEF (text, memchr),
+ LUA_INTERFACE_DEF (text, byte),
LUA_INTERFACE_DEF (text, bytes),
LUA_INTERFACE_DEF (text, lower),
LUA_INTERFACE_DEF (text, exclude_chars),
@@ -962,23 +971,30 @@ lua_text_split (lua_State *L)
static gint
lua_text_at (lua_State *L)
{
+ return lua_text_byte(L);
+}
+
+static gint
+lua_text_byte (lua_State *L)
+{
LUA_TRACE_POINT;
struct rspamd_lua_text *t = lua_check_text (L, 1);
- gint pos = lua_tointeger (L, 2);
-
- if (t) {
- if (pos > 0 && pos <= t->len) {
- lua_pushinteger (L, t->start[pos - 1]);
- }
- else {
- lua_pushnil (L);
- }
- }
- else {
+ if (!t) {
return luaL_error (L, "invalid arguments");
}
- return 1;
+ gsize start = relative_pos_start (luaL_optinteger (L, 2, 1), t->len);
+ gsize end = relative_pos_end (luaL_optinteger (L, 3, start), t->len);
+ start--;
+
+ if (start >= end) {
+ return 0;
+ }
+
+ for (gsize i = start; i < end; i++) {
+ lua_pushinteger (L, t->start[i]);
+ }
+ return end - start;
}
static gint
diff --git a/test/lua/unit/rspamd_text.lua b/test/lua/unit/rspamd_text.lua
new file mode 100644
index 000000000..269b49150
--- /dev/null
+++ b/test/lua/unit/rspamd_text.lua
@@ -0,0 +1,30 @@
+context("Rspamd_text:byte() test", function()
+ local lua_util = require "lua_util"
+ local rspamd_text = require "rspamd_text"
+
+ local str = 'OMG'
+ local txt = rspamd_text.fromstring(str)
+ local fmt = 'case rspamd_text:byte(%s,%s)'
+ local cases = {
+ {'1', 'nil'},
+ {'nil', '1'},
+ }
+
+ for start = -4, 4 do
+ for stop = -4, 4 do
+ table.insert(cases, {tostring(start), tostring(stop)})
+ end
+ end
+
+ for _, case in ipairs(cases) do
+ local name = string.format(fmt, case[1], case[2])
+ test(name, function()
+ local txt_bytes = {txt:byte(tonumber(case[1]), tonumber(case[2]))}
+ local str_bytes = {str:byte(tonumber(case[1]), tonumber(case[2]))}
+ assert_rspamd_table_eq({
+ expect = str_bytes,
+ actual = txt_bytes
+ })
+ end)
+ end
+end)