mirror of
https://github.com/rspamd/rspamd.git
synced 2024-09-13 23:56:50 +02:00
Allow learning from lua_task.
This commit is contained in:
parent
506b3b37ed
commit
11aaf8a93a
@ -91,7 +91,7 @@ LUA_FUNCTION_DEF (task, get_message_id);
|
||||
LUA_FUNCTION_DEF (task, get_timeval);
|
||||
LUA_FUNCTION_DEF (task, get_metric_score);
|
||||
LUA_FUNCTION_DEF (task, get_metric_action);
|
||||
LUA_FUNCTION_DEF (task, learn_statfile);
|
||||
LUA_FUNCTION_DEF (task, learn);
|
||||
|
||||
static const struct luaL_reg tasklib_f[] = {
|
||||
LUA_INTERFACE_DEF (task, create_empty),
|
||||
@ -142,7 +142,7 @@ static const struct luaL_reg tasklib_m[] = {
|
||||
LUA_INTERFACE_DEF (task, get_timeval),
|
||||
LUA_INTERFACE_DEF (task, get_metric_score),
|
||||
LUA_INTERFACE_DEF (task, get_metric_action),
|
||||
LUA_INTERFACE_DEF (task, learn_statfile),
|
||||
LUA_INTERFACE_DEF (task, learn),
|
||||
{"__tostring", lua_class_tostring},
|
||||
{NULL, NULL}
|
||||
};
|
||||
@ -1295,57 +1295,45 @@ lua_task_get_timeval (lua_State *L)
|
||||
|
||||
|
||||
static gint
|
||||
lua_task_learn_statfile (lua_State *L)
|
||||
lua_task_learn (lua_State *L)
|
||||
{
|
||||
struct rspamd_task *task = lua_check_task (L);
|
||||
const gchar *symbol;
|
||||
gboolean is_spam = FALSE;
|
||||
const gchar *clname;
|
||||
struct rspamd_classifier_config *cl;
|
||||
GTree *tokens;
|
||||
struct rspamd_statfile_config *st;
|
||||
stat_file_t *statfile;
|
||||
struct classifier_ctx *ctx;
|
||||
GError *err = NULL;
|
||||
int ret = 1;
|
||||
|
||||
symbol = luaL_checkstring (L, 2);
|
||||
|
||||
if (task && symbol) {
|
||||
cl = g_hash_table_lookup (task->cfg->classifiers_symbols, symbol);
|
||||
if (cl == NULL) {
|
||||
msg_warn ("classifier for symbol %s is not found", symbol);
|
||||
lua_pushboolean (L, FALSE);
|
||||
return 1;
|
||||
}
|
||||
ctx = cl->classifier->init_func (task->task_pool, cl);
|
||||
if ((tokens =
|
||||
g_hash_table_lookup (task->tokens, cl->tokenizer)) == NULL) {
|
||||
msg_warn ("no tokens found learn failed!");
|
||||
lua_pushboolean (L, FALSE);
|
||||
return 1;
|
||||
}
|
||||
statfile = get_statfile_by_symbol (task->worker->srv->statfile_pool,
|
||||
ctx->cfg,
|
||||
symbol,
|
||||
&st,
|
||||
TRUE);
|
||||
|
||||
if (statfile == NULL) {
|
||||
msg_warn ("opening statfile failed!");
|
||||
lua_pushboolean (L, FALSE);
|
||||
return 1;
|
||||
}
|
||||
|
||||
cl->classifier->learn_func (ctx,
|
||||
task->worker->srv->statfile_pool,
|
||||
symbol,
|
||||
tokens,
|
||||
TRUE,
|
||||
NULL,
|
||||
1.,
|
||||
NULL);
|
||||
maybe_write_binlog (ctx->cfg, st, statfile, tokens);
|
||||
lua_pushboolean (L, TRUE);
|
||||
is_spam = lua_toboolean(L, 2);
|
||||
if (lua_gettop (L) > 2) {
|
||||
clname = luaL_checkstring (L, 3);
|
||||
}
|
||||
else {
|
||||
clname = "bayes";
|
||||
}
|
||||
|
||||
return 1;
|
||||
cl = rspamd_config_find_classifier (task->cfg, clname);
|
||||
|
||||
if (cl == NULL) {
|
||||
msg_warn ("classifier %s is not found", clname);
|
||||
lua_pushboolean (L, FALSE);
|
||||
lua_pushstring (L, "classifier not found");
|
||||
ret = 2;
|
||||
}
|
||||
else {
|
||||
if (!learn_task_spam (cl, task, is_spam, &err)) {
|
||||
lua_pushboolean (L, FALSE);
|
||||
if (err != NULL) {
|
||||
lua_pushstring (L, err->message);
|
||||
ret = 2;
|
||||
}
|
||||
}
|
||||
else {
|
||||
lua_pushboolean (L, TRUE);
|
||||
}
|
||||
}
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
static gint
|
||||
|
Loading…
Reference in New Issue
Block a user