Ver código fonte

[Rework] Rework learn and add classify condition

tags/3.1
Vsevolod Stakhov 2 anos atrás
pai
commit
718238fd33

+ 1
- 0
src/libserver/cfg_file.h Ver arquivo

@@ -192,6 +192,7 @@ struct rspamd_classifier_config {
const gchar *backend; /**< name of statfile's backend */
ucl_object_t *opts; /**< other options */
GList *learn_conditions; /**< list of learn condition callbacks */
GList *classify_conditions; /**< list of classify condition callbacks */
gchar *name; /**< unique name of classifier */
guint32 min_tokens; /**< minimal number of tokens to process classifier */
guint32 max_tokens; /**< maximum number of tokens */

+ 28
- 2
src/libserver/cfg_rcl.c Ver arquivo

@@ -1299,7 +1299,7 @@ rspamd_rcl_classifier_handler (rspamd_mempool_t *pool,
ccf->tokenizer = tkcf;

/* Handle lua conditions */
val = ucl_object_lookup_any (obj, "condition", "learn_condition", NULL);
val = ucl_object_lookup_any (obj, "learn_condition", NULL);

if (val) {
LL_FOREACH (val, cur) {
@@ -1310,7 +1310,7 @@ rspamd_rcl_classifier_handler (rspamd_mempool_t *pool,

lua_script = ucl_object_tolstring(cur, &slen);
ref_idx = rspamd_lua_function_ref_from_str(L,
lua_script, slen, err);
lua_script, slen, "learn_condition", err);

if (ref_idx == LUA_NOREF) {
return FALSE;
@@ -1325,6 +1325,32 @@ rspamd_rcl_classifier_handler (rspamd_mempool_t *pool,
}
}

val = ucl_object_lookup_any (obj, "classify_condition", NULL);

if (val) {
LL_FOREACH (val, cur) {
if (ucl_object_type(cur) == UCL_STRING) {
const gchar *lua_script;
gsize slen;
gint ref_idx;

lua_script = ucl_object_tolstring(cur, &slen);
ref_idx = rspamd_lua_function_ref_from_str(L,
lua_script, slen, "classify_condition", err);

if (ref_idx == LUA_NOREF) {
return FALSE;
}

rspamd_lua_add_ref_dtor (L, cfg->cfg_pool, ref_idx);
ccf->classify_conditions = rspamd_mempool_glist_append(
cfg->cfg_pool,
ccf->classify_conditions,
GINT_TO_POINTER (ref_idx));
}
}
}

ccf->opts = (ucl_object_t *)obj;
cfg->classifiers = g_list_prepend (cfg->classifiers, ccf);


+ 100
- 80
src/libstat/stat_process.c Ver arquivo

@@ -190,9 +190,75 @@ rspamd_stat_process_tokenize (struct rspamd_stat_ctx *st_ctx,
b32_hout, g_free);
}

static gboolean
rspamd_stat_classifier_is_skipped (struct rspamd_task *task,
struct rspamd_classifier *cl, gboolean is_learn, gboolean is_spam)
{
GList *cur = is_learn ? cl->cfg->learn_conditions : cl->cfg->classify_conditions;
lua_State *L = task->cfg->lua_state;
gboolean ret = FALSE;

while (cur) {
gint cb_ref = GPOINTER_TO_INT (cur->data);
gint old_top = lua_gettop (L);

lua_rawgeti (L, LUA_REGISTRYINDEX, cb_ref);
/* Push task and two booleans: is_spam and is_unlearn */
struct rspamd_task **ptask = lua_newuserdata (L, sizeof (*ptask));
*ptask = task;
rspamd_lua_setclass (L, "rspamd{task}", -1);

if (is_learn) {
lua_pushboolean(L, is_spam);
lua_pushboolean(L,
task->flags & RSPAMD_TASK_FLAG_UNLEARN ? true : false);
}

if (lua_pcall (L, 3, LUA_MULTRET, 0) != 0) {
msg_err_task ("call to %s failed: %s",
"condition callback",
lua_tostring (L, -1));
}
else {
if (lua_isboolean (L, 1)) {
if (!lua_toboolean (L, 1)) {
ret = TRUE;
}
}

if (lua_isstring (L, 2)) {
if (ret) {
msg_notice_task ("%s condition for classifier %s returned: %s; skip classifier",
is_learn ? "learn" : "classify", cl->cfg->name,
lua_tostring(L, 2));
}
else {
msg_info_task ("%s condition for classifier %s returned: %s",
is_learn ? "learn" : "classify", cl->cfg->name,
lua_tostring(L, 2));
}
}
else if (ret) {
msg_notice_task("%s condition for classifier %s returned false; skip classifier",
is_learn ? "learn" : "classify", cl->cfg->name);
}

if (ret) {
lua_settop (L, old_top);
break;
}
}

lua_settop (L, old_top);
cur = g_list_next (cur);
}

return ret;
}

static void
rspamd_stat_preprocess (struct rspamd_stat_ctx *st_ctx,
struct rspamd_task *task, gboolean learn)
struct rspamd_task *task, gboolean is_learn, gboolean is_spam)
{
guint i;
struct rspamd_statfile *st;
@@ -207,12 +273,39 @@ rspamd_stat_preprocess (struct rspamd_stat_ctx *st_ctx,
rspamd_mempool_add_destructor (task->task_pool,
rspamd_ptr_array_free_hard, task->stat_runtimes);

/* Temporary set all stat_runtimes to some max size to distinguish from NULL */
for (i = 0; i < st_ctx->statfiles->len; i ++) {
g_ptr_array_index (task->stat_runtimes, i) = GSIZE_TO_POINTER(G_MAXSIZE);
}

for (i = 0; i < st_ctx->classifiers->len; i++) {
struct rspamd_classifier *cl = g_ptr_array_index (st_ctx->classifiers, i);
gboolean skip_classifier = FALSE;

if (cl->cfg->flags & RSPAMD_FLAG_CLASSIFIER_NO_BACKEND) {
skip_classifier = TRUE;
}
else {
if (rspamd_stat_classifier_is_skipped (task, cl, is_learn , is_spam)) {
skip_classifier = TRUE;
}
}

if (skip_classifier) {
/* Set NULL for all statfiles indexed by id */
for (int j = 0; j < cl->statfiles_ids->len; j++) {
int id = g_array_index (cl->statfiles_ids, gint, j);
g_ptr_array_index (task->stat_runtimes, id) = NULL;
}
}
}

for (i = 0; i < st_ctx->statfiles->len; i ++) {
st = g_ptr_array_index (st_ctx->statfiles, i);
g_assert (st != NULL);

if (st->classifier->cfg->flags & RSPAMD_FLAG_CLASSIFIER_NO_BACKEND) {
g_ptr_array_index (task->stat_runtimes, i) = NULL;
if (g_ptr_array_index (task->stat_runtimes, i) == NULL) {
/* The whole classifier is skipped */
continue;
}

@@ -224,7 +317,7 @@ rspamd_stat_preprocess (struct rspamd_stat_ctx *st_ctx,
continue;
}

bk_run = st->backend->runtime (task, st->stcf, learn, st->bkcf);
bk_run = st->backend->runtime (task, st->stcf, is_learn, st->bkcf);

if (bk_run == NULL) {
msg_err_task ("cannot init backend %s for statfile %s",
@@ -249,11 +342,6 @@ rspamd_stat_backends_process (struct rspamd_stat_ctx *st_ctx,
for (i = 0; i < st_ctx->statfiles->len; i++) {
st = g_ptr_array_index (st_ctx->statfiles, i);
cl = st->classifier;

if (cl->cfg->flags & RSPAMD_FLAG_CLASSIFIER_NO_BACKEND) {
continue;
}

bk_run = g_ptr_array_index (task->stat_runtimes, i);

if (bk_run != NULL) {
@@ -302,10 +390,6 @@ rspamd_stat_classifiers_process (struct rspamd_stat_ctx *st_ctx,
st = g_ptr_array_index (st_ctx->statfiles, i);
cl = st->classifier;

if (cl->cfg->flags & RSPAMD_FLAG_CLASSIFIER_NO_BACKEND) {
continue;
}

bk_run = g_ptr_array_index (task->stat_runtimes, i);
g_assert (st != NULL);

@@ -332,10 +416,6 @@ rspamd_stat_classifiers_process (struct rspamd_stat_ctx *st_ctx,

/* Do not process classifiers on backend failures */
for (j = 0; j < cl->statfiles_ids->len; j++) {
if (cl->cfg->flags & RSPAMD_FLAG_CLASSIFIER_NO_BACKEND) {
continue;
}

id = g_array_index (cl->statfiles_ids, gint, j);
bk_run = g_ptr_array_index (task->stat_runtimes, id);
st = g_ptr_array_index (st_ctx->statfiles, id);
@@ -406,7 +486,7 @@ rspamd_stat_classify (struct rspamd_task *task, lua_State *L, guint stage,

if (stage == RSPAMD_TASK_STAGE_CLASSIFIERS_PRE) {
/* Preprocess tokens */
rspamd_stat_preprocess (st_ctx, task, FALSE);
rspamd_stat_preprocess (st_ctx, task, FALSE, FALSE);
}
else if (stage == RSPAMD_TASK_STAGE_CLASSIFIERS) {
/* Process backends */
@@ -490,13 +570,7 @@ rspamd_stat_classifiers_learn (struct rspamd_stat_ctx *st_ctx,
{
struct rspamd_classifier *cl, *sel = NULL;
guint i;
gboolean learned = FALSE, too_small = FALSE, too_large = FALSE,
conditionally_skipped = FALSE;
lua_State *L;
struct rspamd_task **ptask;
GList *cur;
gint cb_ref;
gchar *cond_str = NULL;
gboolean learned = FALSE, too_small = FALSE, too_large = FALSE;

if ((task->flags & RSPAMD_TASK_FLAG_ALREADY_LEARNED) && err != NULL &&
*err == NULL) {
@@ -544,52 +618,6 @@ rspamd_stat_classifiers_learn (struct rspamd_stat_ctx *st_ctx,
continue;
}

/* Check all conditions for this classifier */
cur = cl->cfg->learn_conditions;
L = task->cfg->lua_state;

while (cur) {
cb_ref = GPOINTER_TO_INT (cur->data);

gint old_top = lua_gettop (L);
lua_rawgeti (L, LUA_REGISTRYINDEX, cb_ref);
/* Push task and two booleans: is_spam and is_unlearn */
ptask = lua_newuserdata (L, sizeof (*ptask));
*ptask = task;
rspamd_lua_setclass (L, "rspamd{task}", -1);
lua_pushboolean (L, spam);
lua_pushboolean (L,
task->flags & RSPAMD_TASK_FLAG_UNLEARN ? true : false);

if (lua_pcall (L, 3, LUA_MULTRET, 0) != 0) {
msg_err_task ("call to %s failed: %s",
"condition callback",
lua_tostring (L, -1));
}
else {
if (lua_isboolean (L, 1)) {
if (!lua_toboolean (L, 1)) {
conditionally_skipped = TRUE;
/* Also check for error string if needed */
if (lua_isstring (L, 2)) {
cond_str = rspamd_mempool_strdup (task->task_pool,
lua_tostring (L, 2));
}

lua_settop (L, old_top);
break;
}
}
}

lua_settop (L, old_top);
cur = g_list_next (cur);
}

if (conditionally_skipped) {
break;
}

if (cl->subrs->learn_spam_func (cl, task->tokens, task, spam,
task->flags & RSPAMD_TASK_FLAG_UNLEARN, err)) {
learned = TRUE;
@@ -627,14 +655,6 @@ rspamd_stat_classifiers_learn (struct rspamd_stat_ctx *st_ctx,
task->tokens->len,
sel->cfg->min_tokens);
}
else if (conditionally_skipped) {
g_set_error (err, rspamd_stat_quark (), 204,
"<%s> is skipped for %s classifier: "
"%s",
MESSAGE_FIELD (task, message_id),
sel->cfg->name,
cond_str ? cond_str : "unknown reason");
}
}

return learned;
@@ -828,7 +848,7 @@ rspamd_stat_learn (struct rspamd_task *task,

if (stage == RSPAMD_TASK_STAGE_LEARN_PRE) {
/* Process classifiers */
rspamd_stat_preprocess (st_ctx, task, TRUE);
rspamd_stat_preprocess (st_ctx, task, TRUE, spam);

if (!rspamd_stat_cache_check (st_ctx, task, classifier, spam, err)) {
return RSPAMD_STAT_PROCESS_ERROR;

+ 10
- 6
src/lua/lua_common.c Ver arquivo

@@ -2294,7 +2294,7 @@ rspamd_lua_require_function (lua_State *L, const gchar *modname,

gint
rspamd_lua_function_ref_from_str (lua_State *L, const gchar *str, gsize slen,
GError **err)
const gchar *modname, GError **err)
{
gint err_idx, ref_idx;

@@ -2302,11 +2302,12 @@ rspamd_lua_function_ref_from_str (lua_State *L, const gchar *str, gsize slen,
err_idx = lua_gettop (L);

/* Load file */
if (luaL_loadbuffer (L, str, slen, "lua_embedded_str") != 0) {
if (luaL_loadbuffer (L, str, slen, modname) != 0) {
g_set_error (err,
lua_error_quark(),
EINVAL,
"cannot load lua script: %s",
"%s: cannot load lua script: %s",
modname,
lua_tostring (L, -1));
lua_settop (L, err_idx - 1); /* Error function */

@@ -2318,7 +2319,8 @@ rspamd_lua_function_ref_from_str (lua_State *L, const gchar *str, gsize slen,
g_set_error (err,
lua_error_quark(),
EINVAL,
"cannot init lua script: %s",
"%s: cannot init lua script: %s",
modname,
lua_tostring (L, -1));
lua_settop (L, err_idx - 1);

@@ -2329,8 +2331,10 @@ rspamd_lua_function_ref_from_str (lua_State *L, const gchar *str, gsize slen,
g_set_error (err,
lua_error_quark(),
EINVAL,
"cannot init lua script: "
"must return function");
"%s: cannot init lua script: "
"must return function not %s",
modname,
lua_typename (L, lua_type (L, -1)));
lua_settop (L, err_idx - 1);

return LUA_NOREF;

+ 1
- 1
src/lua/lua_common.h Ver arquivo

@@ -572,7 +572,7 @@ void rspamd_lua_add_ref_dtor (lua_State *L, rspamd_mempool_t *pool,
* @return
*/
gint rspamd_lua_function_ref_from_str (lua_State *L, const gchar *str, gsize slen,
GError **err);
const gchar *modname, GError **err);

/**
* Tries to load some module using `require` and get some method from it

Carregando…
Cancelar
Salvar