Allow learning from lua_task.

This commit is contained in:
Vsevolod Stakhov 2014-08-14 13:18:31 +01:00
parent 506b3b37ed
commit 11aaf8a93a

View File

@ -91,7 +91,7 @@ LUA_FUNCTION_DEF (task, get_message_id);
LUA_FUNCTION_DEF (task, get_timeval);
LUA_FUNCTION_DEF (task, get_metric_score);
LUA_FUNCTION_DEF (task, get_metric_action);
LUA_FUNCTION_DEF (task, learn_statfile);
LUA_FUNCTION_DEF (task, learn);
static const struct luaL_reg tasklib_f[] = {
LUA_INTERFACE_DEF (task, create_empty),
@ -142,7 +142,7 @@ static const struct luaL_reg tasklib_m[] = {
LUA_INTERFACE_DEF (task, get_timeval),
LUA_INTERFACE_DEF (task, get_metric_score),
LUA_INTERFACE_DEF (task, get_metric_action),
LUA_INTERFACE_DEF (task, learn_statfile),
LUA_INTERFACE_DEF (task, learn),
{"__tostring", lua_class_tostring},
{NULL, NULL}
};
@ -1295,57 +1295,45 @@ lua_task_get_timeval (lua_State *L)
static gint
lua_task_learn_statfile (lua_State *L)
lua_task_learn (lua_State *L)
{
struct rspamd_task *task = lua_check_task (L);
const gchar *symbol;
gboolean is_spam = FALSE;
const gchar *clname;
struct rspamd_classifier_config *cl;
GTree *tokens;
struct rspamd_statfile_config *st;
stat_file_t *statfile;
struct classifier_ctx *ctx;
GError *err = NULL;
int ret = 1;
symbol = luaL_checkstring (L, 2);
if (task && symbol) {
cl = g_hash_table_lookup (task->cfg->classifiers_symbols, symbol);
if (cl == NULL) {
msg_warn ("classifier for symbol %s is not found", symbol);
lua_pushboolean (L, FALSE);
return 1;
}
ctx = cl->classifier->init_func (task->task_pool, cl);
if ((tokens =
g_hash_table_lookup (task->tokens, cl->tokenizer)) == NULL) {
msg_warn ("no tokens found learn failed!");
lua_pushboolean (L, FALSE);
return 1;
}
statfile = get_statfile_by_symbol (task->worker->srv->statfile_pool,
ctx->cfg,
symbol,
&st,
TRUE);
if (statfile == NULL) {
msg_warn ("opening statfile failed!");
lua_pushboolean (L, FALSE);
return 1;
}
cl->classifier->learn_func (ctx,
task->worker->srv->statfile_pool,
symbol,
tokens,
TRUE,
NULL,
1.,
NULL);
maybe_write_binlog (ctx->cfg, st, statfile, tokens);
lua_pushboolean (L, TRUE);
is_spam = lua_toboolean(L, 2);
if (lua_gettop (L) > 2) {
clname = luaL_checkstring (L, 3);
}
else {
clname = "bayes";
}
return 1;
cl = rspamd_config_find_classifier (task->cfg, clname);
if (cl == NULL) {
msg_warn ("classifier %s is not found", clname);
lua_pushboolean (L, FALSE);
lua_pushstring (L, "classifier not found");
ret = 2;
}
else {
if (!learn_task_spam (cl, task, is_spam, &err)) {
lua_pushboolean (L, FALSE);
if (err != NULL) {
lua_pushstring (L, err->message);
ret = 2;
}
}
else {
lua_pushboolean (L, TRUE);
}
}
return ret;
}
static gint