From 68badebdac453aef0f8bc5af96e9a289aefc02e5 Mon Sep 17 00:00:00 2001 From: Andrew Lewis Date: Mon, 16 Nov 2020 20:13:03 +0200 Subject: [PATCH] [Minor] rspamd_text:byte() metamethod --- contrib/lua-fun/fun.lua | 2 +- lualib/lua_content/pdf.lua | 2 +- lualib/lua_magic/heuristics.lua | 2 +- src/lua/lua_text.c | 40 +++++++++++++++++++++++---------- test/lua/unit/rspamd_text.lua | 30 +++++++++++++++++++++++++ 5 files changed, 61 insertions(+), 15 deletions(-) create mode 100644 test/lua/unit/rspamd_text.lua diff --git a/contrib/lua-fun/fun.lua b/contrib/lua-fun/fun.lua index e16fe07ae..2c712d306 100644 --- a/contrib/lua-fun/fun.lua +++ b/contrib/lua-fun/fun.lua @@ -96,7 +96,7 @@ local text_gen = function(param, state) if state > #param then return nil end - local r = string.char(param:at(state)) + local r = string.char(param:byte(state)) return state, r end diff --git a/lualib/lua_content/pdf.lua b/lualib/lua_content/pdf.lua index 141a07a5f..11d5cab46 100644 --- a/lualib/lua_content/pdf.lua +++ b/lualib/lua_content/pdf.lua @@ -954,7 +954,7 @@ local function attach_pdf_streams(task, input, pdf) end -- Strip the first \n while first < last do - local chr = input:at(first) + local chr = input:byte(first) if chr ~= 13 and chr ~= 10 then break end first = first + 1 end diff --git a/lualib/lua_magic/heuristics.lua b/lualib/lua_magic/heuristics.lua index 6822abe1e..02bc2b4a2 100644 --- a/lualib/lua_magic/heuristics.lua +++ b/lualib/lua_magic/heuristics.lua @@ -181,7 +181,7 @@ local function detect_ole_format(input, log_obj, _, part) end local function process_dir_entry(offset) - local dtype = input:at(offset + 66) + local dtype = input:byte(offset + 66) lua_util.debugm(N, log_obj, "dtype: %s, offset: %s", dtype, offset) if dtype then diff --git a/src/lua/lua_text.c b/src/lua/lua_text.c index 5bbcfb96e..37e1752c1 100644 --- a/src/lua/lua_text.c +++ b/src/lua/lua_text.c @@ -58,6 +58,14 @@ LUA_FUNCTION_DEF (text, randombytes); * @return {rspamd_text} resulting text */ LUA_FUNCTION_DEF (text, fromtable); +/*** + * @method rspamd_text:byte(pos[, pos2]) + * Returns a byte at the position `pos` or bytes from `pos` to `pos2` if specified + * @param {integer} pos index + * @param {integer} pos2 index + * @return {integer} byte at the position `pos` or varargs of bytes + */ +LUA_FUNCTION_DEF (text, byte); /*** * @method rspamd_text:len() * Returns length of a string @@ -226,6 +234,7 @@ static const struct luaL_reg textlib_m[] = { LUA_INTERFACE_DEF (text, split), LUA_INTERFACE_DEF (text, at), LUA_INTERFACE_DEF (text, memchr), + LUA_INTERFACE_DEF (text, byte), LUA_INTERFACE_DEF (text, bytes), LUA_INTERFACE_DEF (text, lower), LUA_INTERFACE_DEF (text, exclude_chars), @@ -946,24 +955,31 @@ lua_text_split (lua_State *L) static gint lua_text_at (lua_State *L) +{ + return lua_text_byte(L); +} + +static gint +lua_text_byte (lua_State *L) { LUA_TRACE_POINT; struct rspamd_lua_text *t = lua_check_text (L, 1); - gint pos = lua_tointeger (L, 2); - - if (t) { - if (pos > 0 && pos <= t->len) { - lua_pushinteger (L, t->start[pos - 1]); - } - else { - lua_pushnil (L); - } - } - else { + if (!t) { return luaL_error (L, "invalid arguments"); } - return 1; + gsize start = relative_pos_start (luaL_optinteger (L, 2, 1), t->len); + gsize end = relative_pos_end (luaL_optinteger (L, 3, start), t->len); + start--; + + if (start >= end) { + return 0; + } + + for (gsize i = start; i < end; i++) { + lua_pushinteger (L, t->start[i]); + } + return end - start; } static gint diff --git a/test/lua/unit/rspamd_text.lua b/test/lua/unit/rspamd_text.lua new file mode 100644 index 000000000..269b49150 --- /dev/null +++ b/test/lua/unit/rspamd_text.lua @@ -0,0 +1,30 @@ +context("Rspamd_text:byte() test", function() + local lua_util = require "lua_util" + local rspamd_text = require "rspamd_text" + + local str = 'OMG' + local txt = rspamd_text.fromstring(str) + local fmt = 'case rspamd_text:byte(%s,%s)' + local cases = { + {'1', 'nil'}, + {'nil', '1'}, + } + + for start = -4, 4 do + for stop = -4, 4 do + table.insert(cases, {tostring(start), tostring(stop)}) + end + end + + for _, case in ipairs(cases) do + local name = string.format(fmt, case[1], case[2]) + test(name, function() + local txt_bytes = {txt:byte(tonumber(case[1]), tonumber(case[2]))} + local str_bytes = {str:byte(tonumber(case[1]), tonumber(case[2]))} + assert_rspamd_table_eq({ + expect = str_bytes, + actual = txt_bytes + }) + end) + end +end) -- 2.39.5