diff options
-rw-r--r-- | lualib/lua_bayes_redis.lua | 16 | ||||
-rw-r--r-- | lualib/redis_scripts/bayes_learn.lua | 20 |
2 files changed, 31 insertions, 5 deletions
diff --git a/lualib/lua_bayes_redis.lua b/lualib/lua_bayes_redis.lua index 96a505518..5ad5c3514 100644 --- a/lualib/lua_bayes_redis.lua +++ b/lualib/lua_bayes_redis.lua @@ -42,7 +42,7 @@ local function gen_classify_functor(redis_params, classify_script_id) end local function gen_learn_functor(redis_params, learn_script_id) - return function(task, expanded_key, id, is_spam, symbol, is_unlearn, stat_tokens, callback) + return function(task, expanded_key, id, is_spam, symbol, is_unlearn, stat_tokens, callback, maybe_text_tokens) local function learn_redis_cb(err, data) lua_util.debugm(N, task, 'learn redis cb: %s, %s', err, data) if err then @@ -52,9 +52,17 @@ local function gen_learn_functor(redis_params, learn_script_id) end end - lua_redis.exec_redis_script(learn_script_id, - { task = task, is_write = false, key = expanded_key }, - learn_redis_cb, { expanded_key, tostring(is_spam), symbol, tostring(is_unlearn), stat_tokens }) + if maybe_text_tokens then + lua_redis.exec_redis_script(learn_script_id, + { task = task, is_write = false, key = expanded_key }, + learn_redis_cb, + { expanded_key, tostring(is_spam), symbol, tostring(is_unlearn), stat_tokens, maybe_text_tokens }) + else + lua_redis.exec_redis_script(learn_script_id, + { task = task, is_write = false, key = expanded_key }, + learn_redis_cb, { expanded_key, tostring(is_spam), symbol, tostring(is_unlearn), stat_tokens }) + end + end end diff --git a/lualib/redis_scripts/bayes_learn.lua b/lualib/redis_scripts/bayes_learn.lua index 244be43f6..b0ed1dd4b 100644 --- a/lualib/redis_scripts/bayes_learn.lua +++ b/lualib/redis_scripts/bayes_learn.lua @@ -5,12 +5,18 @@ -- key3 - string symbol -- key4 - boolean is_unlearn -- key5 - set of tokens encoded in messagepack array of strings +-- key6 - set of text tokens (if any) encoded in messagepack array of strings (size must be twice of `KEYS[5]`) local prefix = KEYS[1] local is_spam = KEYS[2] == 'true' and true or false local symbol = KEYS[3] local is_unlearn = KEYS[4] == 'true' and true or false local input_tokens = cmsgpack.unpack(KEYS[5]) +local text_tokens + +if KEYS[6] then + text_tokens = cmsgpack.unpack(KEYS[6]) +end local hash_key = is_spam and 'S' or 'H' local learned_key = is_spam and 'learns_spam' or 'learns_ham' @@ -19,6 +25,18 @@ redis.call('SADD', symbol .. '_keys', prefix) redis.call('HSET', prefix, 'version', '2') -- new schema redis.call('HINCRBY', prefix, learned_key, is_unlearn and -1 or 1) -- increase or decrease learned count -for _, token in ipairs(input_tokens) do +for i, token in ipairs(input_tokens) do redis.call('HINCRBY', token, hash_key, 1) + if text_tokens then + local tok1 = text_tokens[i * 2 - 1] + local tok2 = text_tokens[i * 2] + + if tok2 then + redis.call('HSET', token, 'tokens', string.format('%s:%s', tok1, tok2)) + else + redis.call('HSET', token, 'tokens', tok1) + end + + redis.call('ZINCRBY', prefix .. '_z', token, is_unlearn and -1 or 1) + end end
\ No newline at end of file |