struct rspamd_stat_backend {
const char *name;
gpointer (*init)(struct rspamd_stat_ctx *ctx, struct rspamd_config *cfg);
- gpointer (*runtime)(struct rspamd_statfile_config *stcf, gpointer ctx);
+ gpointer (*runtime)(struct rspamd_statfile_config *stcf, gboolean learn, gpointer ctx);
gboolean (*process_token)(struct token_node_s *tok,
struct rspamd_token_result *res, gpointer ctx);
+ gboolean (*learn_token)(struct token_node_s *tok,
+ struct rspamd_token_result *res, gpointer ctx);
gulong (*total_learns)(struct rspamd_statfile_runtime *runtime, gpointer ctx);
gpointer ctx;
};
gpointer rspamd_mmaped_file_init(struct rspamd_stat_ctx *ctx, struct rspamd_config *cfg);
-gpointer rspamd_mmaped_file_runtime (struct rspamd_statfile_config *stcf, gpointer ctx);
+gpointer rspamd_mmaped_file_runtime (struct rspamd_statfile_config *stcf,
+ gboolean learn, gpointer ctx);
gboolean rspamd_mmaped_file_process_token (struct token_node_s *tok,
struct rspamd_token_result *res,
gpointer ctx);
+gboolean rspamd_mmaped_file_learn_token (struct token_node_s *tok,
+ struct rspamd_token_result *res,
+ gpointer ctx);
gulong rspamd_mmaped_file_total_learns (struct rspamd_statfile_runtime *runtime,
gpointer ctx);
rspamd_mmaped_file_t * file,
guint32 h1,
guint32 h2,
- time_t now,
double value)
{
rspamd_mmaped_file_set_block_common (pool, file, h1, h2, value);
}
gpointer
-rspamd_mmaped_file_runtime (struct rspamd_statfile_config *stcf, gpointer p)
+rspamd_mmaped_file_runtime (struct rspamd_statfile_config *stcf, gboolean learn,
+ gpointer p)
{
rspamd_mmaped_file_ctx *ctx = (rspamd_mmaped_file_ctx *)p;
rspamd_mmaped_file_t *mf;
+ const ucl_object_t *filenameo, *sizeo;
+ const gchar *filename;
+ gsize size;
g_assert (ctx != NULL);
mf = rspamd_mmaped_file_is_open (ctx, stcf);
+ if (mf == NULL && learn) {
+ /* Create file here */
+
+ filenameo = ucl_object_find_key (stcf->opts, "filename");
+ if (filenameo == NULL || ucl_object_type (filenameo) != UCL_STRING) {
+ msg_err ("statfile %s has no filename defined", stcf->symbol);
+ return NULL;
+ }
+
+ filename = ucl_object_tostring (filenameo);
+
+ sizeo = ucl_object_find_key (stcf->opts, "size");
+ if (sizeo == NULL || ucl_object_type (sizeo) != UCL_INT) {
+ msg_err ("statfile %s has no size defined", stcf->symbol);
+ return NULL;
+ }
+
+ size = ucl_object_toint (sizeo);
+ rspamd_mmaped_file_create (ctx, filename, size, stcf);
+
+ mf = rspamd_mmaped_file_open (ctx, filename, size, stcf);
+ }
+
return (gpointer)mf;
}
return FALSE;
}
+gboolean
+rspamd_mmaped_file_learn_token (rspamd_token_t *tok,
+ struct rspamd_token_result *res,
+ gpointer p)
+{
+ rspamd_mmaped_file_ctx *ctx = (rspamd_mmaped_file_ctx *)p;
+ rspamd_mmaped_file_t *mf;
+ guint32 h1, h2;
+
+ g_assert (res != NULL);
+ g_assert (p != NULL);
+ g_assert (res->st_runtime != NULL);
+ g_assert (tok != NULL);
+ g_assert (tok->datalen >= sizeof (guint32) * 2);
+
+ mf = (rspamd_mmaped_file_t *)res->st_runtime->backend_runtime;
+
+ if (mf == NULL) {
+ /* Statfile is does not exist, so all values are zero */
+ res->value = 0.0;
+ return FALSE;
+ }
+
+ memcpy (&h1, tok->data, sizeof (h1));
+ memcpy (&h2, tok->data + sizeof (h1), sizeof (h2));
+ rspamd_mmaped_file_set_block (ctx, mf, h1, h2, res->value);
+
+ if (res->value > 0.0) {
+ return TRUE;
+ }
+
+ return FALSE;
+}
+
gulong
rspamd_mmaped_file_total_learns (struct rspamd_statfile_runtime *runtime,
gpointer ctx)
static GList*
rspamd_stat_preprocess (struct rspamd_stat_ctx *st_ctx,
struct rspamd_task *task, struct rspamd_tokenizer_runtime *tklist,
- lua_State *L, GError **err)
+ lua_State *L, gboolean learn, GError **err)
{
struct rspamd_classifier_config *clcf;
struct rspamd_statfile_config *stcf;
continue;
}
- backend_runtime = bk->runtime (stcf, bk->ctx);
+ backend_runtime = bk->runtime (stcf, learn, bk->ctx);
st_runtime = rspamd_mempool_alloc0 (task->task_pool,
sizeof (*st_runtime));
cbdata.results_count = result_size;
cbdata.classifier_runtimes = cl_runtimes;
cbdata.task = task;
-
- /* Allocate token results */
cbdata.tok = cl_runtime->tok;
g_tree_foreach (cl_runtime->tok->tokens, preprocess_init_stat_token,
&cbdata);
}
/* Initialize classifiers and statfiles runtime */
- if ((cl_runtimes = rspamd_stat_preprocess (st_ctx, task, tklist, L, err))
+ if ((cl_runtimes = rspamd_stat_preprocess (st_ctx, task, tklist, L, FALSE, err))
== NULL) {
return FALSE;
}
return ret;
}
+static gboolean
+rspamd_stat_learn_token (gpointer k, gpointer v, gpointer d)
+{
+ rspamd_token_t *t = (rspamd_token_t *)v;
+ struct preprocess_cb_data *cbdata = (struct preprocess_cb_data *)d;
+ struct rspamd_statfile_runtime *st_runtime;
+ struct rspamd_classifier_runtime *cl_runtime;
+ struct rspamd_token_result *res;
+ GList *cur, *curst;
+ gint i = 0;
+
+ cur = g_list_first (cbdata->classifier_runtimes);
+
+ while (cur) {
+ cl_runtime = (struct rspamd_classifier_runtime *)cur->data;
+
+ if (cl_runtime->clcf->min_tokens > 0 &&
+ (guint32)g_tree_nnodes (cbdata->tok->tokens) < cl_runtime->clcf->min_tokens) {
+ /* Skip this classifier */
+ msg_debug ("<%s> contains less tokens than required for %s classifier: "
+ "%ud < %ud", cbdata->task->message_id, cl_runtime->clcf->name,
+ g_tree_nnodes (cbdata->tok->tokens),
+ cl_runtime->clcf->min_tokens);
+ cur = g_list_next (cur);
+ continue;
+ }
+
+ res = &g_array_index (t->results, struct rspamd_token_result, i);
+
+ curst = res->cl_runtime->st_runtime;
+
+ while (curst) {
+ st_runtime = (struct rspamd_statfile_runtime *)curst->data;
+
+ if (st_runtime->backend->learn_token (t, res,
+ st_runtime->backend->ctx)) {
+ cl_runtime->processed_tokens ++;
+
+ if (cl_runtime->clcf->max_tokens > 0 &&
+ cl_runtime->processed_tokens > cl_runtime->clcf->max_tokens) {
+ msg_debug ("<%s> contains more tokens than allowed for %s classifier: "
+ "%ud > %ud", cbdata->task, cl_runtime->clcf->name,
+ cl_runtime->processed_tokens,
+ cl_runtime->clcf->max_tokens);
+
+ return TRUE;
+ }
+ }
+
+ i ++;
+ curst = g_list_next (curst);
+ }
+ cur = g_list_next (cur);
+ }
+
+
+ return FALSE;
+}
+
gboolean
rspamd_stat_learn (struct rspamd_task *task, gboolean spam, lua_State *L,
GError **err)
struct rspamd_tokenizer_runtime *tklist = NULL, *tok;
struct rspamd_classifier_runtime *cl_run;
struct classifier_ctx *cl_ctx;
+ struct preprocess_cb_data cbdata;
GList *cl_runtimes;
GList *cur;
gboolean ret = FALSE;
}
/* Initialize classifiers and statfiles runtime */
- if ((cl_runtimes = rspamd_stat_preprocess (st_ctx, task, tklist, L, err))
+ if ((cl_runtimes = rspamd_stat_preprocess (st_ctx, task, tklist, L, TRUE, err))
== NULL) {
return FALSE;
}
cl_ctx = cl_run->cl->init_func (task->task_pool, cl_run->clcf);
if (cl_ctx != NULL) {
- ret |= cl_run->cl->learn_spam_func (cl_ctx, cl_run->tok->tokens,
- cl_run, task, spam, err);
+ if (cl_run->cl->learn_spam_func (cl_ctx, cl_run->tok->tokens,
+ cl_run, task, spam, err)) {
+ ret = TRUE;
+
+ cbdata.classifier_runtimes = cur;
+ cbdata.task = task;
+ cbdata.tok = cl_run->tok;
+ g_tree_foreach (cl_run->tok->tokens, rspamd_stat_learn_token,
+ &cbdata);
+
+ }
+ else {
+ return FALSE;
+ }
+
}
}