diff options
Diffstat (limited to 'src/libserver')
-rw-r--r-- | src/libserver/cfg_file.h | 28 | ||||
-rw-r--r-- | src/libserver/cfg_rcl.cxx | 193 | ||||
-rw-r--r-- | src/libserver/cfg_utils.cxx | 186 | ||||
-rw-r--r-- | src/libserver/task.c | 43 | ||||
-rw-r--r-- | src/libserver/task.h | 6 |
5 files changed, 426 insertions, 30 deletions
diff --git a/src/libserver/cfg_file.h b/src/libserver/cfg_file.h index 36941da7a..355046cac 100644 --- a/src/libserver/cfg_file.h +++ b/src/libserver/cfg_file.h @@ -139,7 +139,10 @@ struct rspamd_statfile_config { char *symbol; /**< symbol of statfile */ char *label; /**< label of this statfile */ ucl_object_t *opts; /**< other options */ - gboolean is_spam; /**< spam flag */ + char *class_name; /**< class name for multi-class classification */ + unsigned int class_index; /**< class index for O(1) lookup during classification */ + gboolean is_spam; /**< DEPRECATED: spam flag - use class_name instead */ + gboolean is_spam_converted; /**< TRUE if class_name was converted from is_spam flag */ struct rspamd_classifier_config *clcf; /**< parent pointer of classifier configuration */ gpointer data; /**< opaque data */ }; @@ -182,6 +185,8 @@ struct rspamd_classifier_config { double min_prob_strength; /**< use only tokens with probability in [0.5 - MPS, 0.5 + MPS] */ unsigned int min_learns; /**< minimum number of learns for each statfile */ unsigned int flags; + GHashTable *class_labels; /**< class_name -> backend_symbol mapping for multi-class */ + GPtrArray *class_names; /**< ordered list of class names */ }; struct rspamd_worker_bind_conf { @@ -621,12 +626,25 @@ void rspamd_config_insert_classify_symbols(struct rspamd_config *cfg); */ gboolean rspamd_config_check_statfiles(struct rspamd_classifier_config *cf); -/* - * Find classifier config by name +/** + * Multi-class configuration helpers + */ +gboolean rspamd_config_parse_class_labels(const ucl_object_t *obj, + GHashTable **class_labels); + +gboolean rspamd_config_migrate_binary_config(struct rspamd_statfile_config *stcf); + +gboolean rspamd_config_validate_class_config(struct rspamd_classifier_config *ccf, + GError **err); + +const char *rspamd_config_get_class_label(struct rspamd_classifier_config *ccf, + const char *class_name); + +/** + * Find classifier by name */ struct rspamd_classifier_config *rspamd_config_find_classifier( - struct rspamd_config *cfg, - const char *name); + struct rspamd_config *cfg, const char *name); void rspamd_ucl_add_conf_macros(struct ucl_parser *parser, struct rspamd_config *cfg); diff --git a/src/libserver/cfg_rcl.cxx b/src/libserver/cfg_rcl.cxx index 0a48e8a4f..da5845917 100644 --- a/src/libserver/cfg_rcl.cxx +++ b/src/libserver/cfg_rcl.cxx @@ -1197,31 +1197,73 @@ rspamd_rcl_statfile_handler(rspamd_mempool_t *pool, const ucl_object_t *obj, st->opts = (ucl_object_t *) obj; st->clcf = ccf; - const auto *val = ucl_object_lookup(obj, "spam"); - if (val == nullptr) { + /* Handle migration from old 'spam' field to new 'class' field */ + const auto *class_val = ucl_object_lookup(obj, "class"); + const auto *spam_val = ucl_object_lookup(obj, "spam"); + + if (class_val != nullptr && spam_val != nullptr) { + msg_warn_config("statfile %s has both 'class' and 'spam' fields, using 'class' field", + st->symbol); + } + + if (class_val == nullptr && spam_val == nullptr) { + /* Neither field present, try to guess by symbol name */ msg_info_config( - "statfile %s has no explicit 'spam' setting, trying to guess by symbol", + "statfile %s has no explicit 'class' or 'spam' setting, trying to guess by symbol", st->symbol); if (rspamd_substring_search_caseless(st->symbol, strlen(st->symbol), "spam", 4) != -1) { st->is_spam = TRUE; + st->class_name = rspamd_mempool_strdup(pool, "spam"); + st->is_spam_converted = TRUE; } else if (rspamd_substring_search_caseless(st->symbol, strlen(st->symbol), "ham", 3) != -1) { st->is_spam = FALSE; + st->class_name = rspamd_mempool_strdup(pool, "ham"); + st->is_spam_converted = TRUE; } else { g_set_error(err, CFG_RCL_ERROR, EINVAL, - "cannot guess spam setting from %s", + "cannot guess class setting from %s, please specify 'class' field", st->symbol); return FALSE; } - msg_info_config("guessed that statfile with symbol %s is %s", - st->symbol, - st->is_spam ? "spam" : "ham"); + msg_info_config("guessed that statfile with symbol %s has class '%s'", + st->symbol, st->class_name); + } + else if (class_val == nullptr && spam_val != nullptr) { + /* Only spam field present - migrate to class */ + msg_warn_config("statfile %s uses deprecated 'spam' field, please use 'class' instead", + st->symbol); + if (st->is_spam) { + st->class_name = rspamd_mempool_strdup(pool, "spam"); + } + else { + st->class_name = rspamd_mempool_strdup(pool, "ham"); + } + st->is_spam_converted = TRUE; } + else if (class_val != nullptr && spam_val == nullptr) { + /* Only class field present - set is_spam for backward compatibility */ + if (st->class_name != nullptr) { + if (strcmp(st->class_name, "spam") == 0) { + st->is_spam = TRUE; + } + else if (strcmp(st->class_name, "ham") == 0) { + st->is_spam = FALSE; + } + else { + /* For non-binary classes, default to not spam */ + st->is_spam = FALSE; + } + msg_debug_config("statfile %s with class '%s' set is_spam=%s for compatibility", + st->symbol, st->class_name, st->is_spam ? "true" : "false"); + } + } + /* If both fields are present, class takes precedence and was already parsed by the default parser */ return TRUE; } @@ -1229,6 +1271,31 @@ rspamd_rcl_statfile_handler(rspamd_mempool_t *pool, const ucl_object_t *obj, } static gboolean +rspamd_rcl_class_labels_handler(rspamd_mempool_t *pool, + const ucl_object_t *obj, + const char *key, + gpointer ud, + struct rspamd_rcl_section *section, + GError **err) +{ + auto *ccf = static_cast<rspamd_classifier_config *>(ud); + + if (obj->type != UCL_OBJECT) { + g_set_error(err, CFG_RCL_ERROR, EINVAL, + "class_labels must be an object"); + return FALSE; + } + + if (!rspamd_config_parse_class_labels(obj, &ccf->class_labels)) { + g_set_error(err, CFG_RCL_ERROR, EINVAL, + "invalid class_labels configuration"); + return FALSE; + } + + return TRUE; +} + +static gboolean rspamd_rcl_classifier_handler(rspamd_mempool_t *pool, const ucl_object_t *obj, const char *key, @@ -1301,6 +1368,22 @@ rspamd_rcl_classifier_handler(rspamd_mempool_t *pool, } } } + else if (g_ascii_strcasecmp(st_key, "class_labels") == 0) { + /* Parse class_labels configuration directly */ + if (ucl_object_type(val) != UCL_OBJECT) { + g_set_error(err, CFG_RCL_ERROR, EINVAL, + "class_labels must be an object"); + ucl_object_iterate_free(it); + return FALSE; + } + + if (!rspamd_config_parse_class_labels(val, &ccf->class_labels)) { + g_set_error(err, CFG_RCL_ERROR, EINVAL, + "invalid class_labels configuration"); + ucl_object_iterate_free(it); + return FALSE; + } + } } } @@ -1375,8 +1458,80 @@ rspamd_rcl_classifier_handler(rspamd_mempool_t *pool, } ccf->opts = (ucl_object_t *) obj; + + /* Validate multi-class configuration */ + GError *validation_err = nullptr; + if (!rspamd_config_validate_class_config(ccf, &validation_err)) { + if (validation_err) { + g_propagate_error(err, validation_err); + } + else { + g_set_error(err, CFG_RCL_ERROR, EINVAL, + "multi-class configuration validation failed for classifier '%s'", + ccf->name ? ccf->name : "unknown"); + } + return FALSE; + } + cfg->classifiers = g_list_prepend(cfg->classifiers, ccf); + /* Populate class_names array from statfiles - only for explicit multiclass configs */ + if (ccf->statfiles) { + GList *cur = ccf->statfiles; + gboolean has_explicit_classes = FALSE; + + /* Check if any statfile uses explicit class declaration (not converted from is_spam) */ + cur = ccf->statfiles; + while (cur) { + struct rspamd_statfile_config *stcf = (struct rspamd_statfile_config *) cur->data; + msg_debug("checking statfile %s: class_name=%s, is_spam_converted=%s", + stcf->symbol, stcf->class_name ? stcf->class_name : "NULL", + stcf->is_spam_converted ? "true" : "false"); + if (stcf->class_name && !stcf->is_spam_converted) { + has_explicit_classes = TRUE; + break; + } + cur = g_list_next(cur); + } + + msg_debug("has_explicit_classes = %s", has_explicit_classes ? "true" : "false"); + + /* Only populate class_names for explicit multiclass configurations */ + if (has_explicit_classes) { + msg_debug("populating class_names for multiclass configuration"); + } + else { + msg_debug("skipping class_names population for binary configuration"); + } + + if (has_explicit_classes) { + ccf->class_names = g_ptr_array_new(); + + cur = ccf->statfiles; + while (cur) { + struct rspamd_statfile_config *stcf = (struct rspamd_statfile_config *) cur->data; + if (stcf->class_name) { + /* Check if class already exists */ + bool found = false; + for (unsigned int i = 0; i < ccf->class_names->len; i++) { + if (strcmp((char *) g_ptr_array_index(ccf->class_names, i), stcf->class_name) == 0) { + stcf->class_index = i; /* Store the index for O(1) lookup */ + found = true; + break; + } + } + + if (!found) { + /* Add new class */ + stcf->class_index = ccf->class_names->len; + g_ptr_array_add(ccf->class_names, g_strdup(stcf->class_name)); + } + } + cur = g_list_next(cur); + } + } + } + return TRUE; } @@ -2457,7 +2612,7 @@ rspamd_rcl_config_init(struct rspamd_config *cfg, GHashTable *skip_sections) FALSE, TRUE, cfg->doc_strings, - "CLassifier options"); + "Classifier options"); /* Default classifier is 'bayes' for now */ sub->default_key = "bayes"; @@ -2476,7 +2631,7 @@ rspamd_rcl_config_init(struct rspamd_config *cfg, GHashTable *skip_sections) rspamd_rcl_add_default_handler(sub, "min_prob_strength", rspamd_rcl_parse_struct_double, - G_STRUCT_OFFSET(struct rspamd_classifier_config, min_token_hits), + G_STRUCT_OFFSET(struct rspamd_classifier_config, min_prob_strength), 0, "Use only tokens with probability in [0.5 - MPS, 0.5 + MPS]"); rspamd_rcl_add_default_handler(sub, @@ -2505,6 +2660,18 @@ rspamd_rcl_config_init(struct rspamd_config *cfg, GHashTable *skip_sections) "Name of classifier"); /* + * Multi-class configuration + */ + rspamd_rcl_add_section_doc(&top, sub, + "class_labels", nullptr, + rspamd_rcl_class_labels_handler, + UCL_OBJECT, + FALSE, + TRUE, + sub->doc_ref, + "Class to backend label mapping for multi-class classification"); + + /* * Statfile defaults */ auto *ssub = rspamd_rcl_add_section_doc(&top, sub, @@ -2522,11 +2689,17 @@ rspamd_rcl_config_init(struct rspamd_config *cfg, GHashTable *skip_sections) 0, "Statfile unique label"); rspamd_rcl_add_default_handler(ssub, + "class", + rspamd_rcl_parse_struct_string, + G_STRUCT_OFFSET(struct rspamd_statfile_config, class_name), + 0, + "Class name for multi-class classification"); + rspamd_rcl_add_default_handler(ssub, "spam", rspamd_rcl_parse_struct_boolean, G_STRUCT_OFFSET(struct rspamd_statfile_config, is_spam), 0, - "Sets if this statfile contains spam samples"); + "DEPRECATED: Sets if this statfile contains spam samples (use 'class' instead)"); } if (!(skip_sections && g_hash_table_lookup(skip_sections, "composite"))) { diff --git a/src/libserver/cfg_utils.cxx b/src/libserver/cfg_utils.cxx index c7bb20210..c22a9b877 100644 --- a/src/libserver/cfg_utils.cxx +++ b/src/libserver/cfg_utils.cxx @@ -3042,3 +3042,189 @@ rspamd_ip_is_local_cfg(struct rspamd_config *cfg, return FALSE; } + +gboolean +rspamd_config_parse_class_labels(const ucl_object_t *obj, GHashTable **class_labels) +{ + const ucl_object_t *cur; + ucl_object_iter_t it = nullptr; + + if (!obj || ucl_object_type(obj) != UCL_OBJECT) { + return FALSE; + } + + if (*class_labels == nullptr) { + *class_labels = g_hash_table_new_full(g_str_hash, g_str_equal, + g_free, g_free); + } + + while ((cur = ucl_object_iterate(obj, &it, true)) != nullptr) { + const char *class_name = ucl_object_key(cur); + const char *label = ucl_object_tostring(cur); + + if (class_name && label) { + /* Validate class name: alphanumeric + underscore, max 32 chars */ + if (strlen(class_name) > 32) { + msg_err("class name '%s' is too long (max 32 characters)", class_name); + g_hash_table_destroy(*class_labels); + *class_labels = nullptr; + return FALSE; + } + + for (const char *p = class_name; *p; p++) { + if (!g_ascii_isalnum(*p) && *p != '_') { + msg_err("class name '%s' contains invalid character '%c'", class_name, *p); + g_hash_table_destroy(*class_labels); + *class_labels = nullptr; + return FALSE; + } + } + + /* Validate label uniqueness */ + if (g_hash_table_lookup(*class_labels, label)) { + msg_err("backend label '%s' is used by multiple classes", label); + g_hash_table_destroy(*class_labels); + *class_labels = nullptr; + return FALSE; + } + } + + g_hash_table_insert(*class_labels, g_strdup(class_name), g_strdup(label)); + } + + return g_hash_table_size(*class_labels) > 0; +} + +gboolean +rspamd_config_migrate_binary_config(struct rspamd_statfile_config *stcf) +{ + if (stcf->class_name != nullptr) { + /* Already migrated or using new format */ + return TRUE; + } + + if (stcf->is_spam) { + stcf->class_name = g_strdup("spam"); + msg_info("migrated statfile '%s' from is_spam=true to class='spam'", + stcf->symbol ? stcf->symbol : "unknown"); + } + else { + stcf->class_name = g_strdup("ham"); + msg_info("migrated statfile '%s' from is_spam=false to class='ham'", + stcf->symbol ? stcf->symbol : "unknown"); + } + + return TRUE; +} + +gboolean +rspamd_config_validate_class_config(struct rspamd_classifier_config *ccf, GError **err) +{ + GList *cur; + GHashTable *seen_classes = nullptr; + struct rspamd_statfile_config *stcf; + unsigned int class_count = 0; + + if (!ccf || !ccf->statfiles) { + g_set_error(err, g_quark_from_static_string("config"), 1, + "classifier has no statfiles defined"); + return FALSE; + } + + seen_classes = g_hash_table_new_full(g_str_hash, g_str_equal, g_free, nullptr); + + /* Iterate through statfiles and collect classes */ + cur = ccf->statfiles; + while (cur) { + stcf = (struct rspamd_statfile_config *) cur->data; + + /* Migrate binary config if needed */ + if (!rspamd_config_migrate_binary_config(stcf)) { + g_set_error(err, g_quark_from_static_string("config"), 1, + "failed to migrate binary config for statfile '%s'", + stcf->symbol ? stcf->symbol : "unknown"); + g_hash_table_destroy(seen_classes); + return FALSE; + } + + /* Check class name */ + if (!stcf->class_name || strlen(stcf->class_name) == 0) { + g_set_error(err, g_quark_from_static_string("config"), 1, + "statfile '%s' has no class defined", + stcf->symbol ? stcf->symbol : "unknown"); + g_hash_table_destroy(seen_classes); + return FALSE; + } + + /* Track unique classes */ + if (!g_hash_table_contains(seen_classes, stcf->class_name)) { + g_hash_table_insert(seen_classes, g_strdup(stcf->class_name), GINT_TO_POINTER(1)); + class_count++; + } + + cur = g_list_next(cur); + } + + /* Validate class count */ + if (class_count < 2) { + g_set_error(err, g_quark_from_static_string("config"), 1, + "classifier must have at least 2 classes, found %ud", class_count); + g_hash_table_destroy(seen_classes); + return FALSE; + } + + if (class_count > 20) { + msg_warn("classifier has %ud classes, performance may be degraded above 20 classes", + class_count); + } + + /* Initialize classifier class tracking - only for explicit multiclass configurations */ + gboolean has_explicit_classes = FALSE; + + /* Check if any statfile uses explicit class declaration (not converted from is_spam) */ + cur = ccf->statfiles; + while (cur) { + stcf = (struct rspamd_statfile_config *) cur->data; + if (stcf->class_name && !stcf->is_spam_converted) { + has_explicit_classes = TRUE; + break; + } + cur = g_list_next(cur); + } + + /* Only populate class_names for explicit multiclass configurations */ + if (has_explicit_classes) { + if (ccf->class_names) { + g_ptr_array_unref(ccf->class_names); + } + ccf->class_names = g_ptr_array_new_with_free_func(g_free); + + /* Populate class names array */ + GHashTableIter iter; + gpointer key, value; + g_hash_table_iter_init(&iter, seen_classes); + while (g_hash_table_iter_next(&iter, &key, &value)) { + g_ptr_array_add(ccf->class_names, g_strdup((const char *) key)); + } + } + else { + /* Binary configuration - ensure class_names is NULL */ + if (ccf->class_names) { + g_ptr_array_unref(ccf->class_names); + ccf->class_names = nullptr; + } + } + + g_hash_table_destroy(seen_classes); + return TRUE; +} + +const char * +rspamd_config_get_class_label(struct rspamd_classifier_config *ccf, const char *class_name) +{ + if (!ccf || !ccf->class_labels || !class_name) { + return nullptr; + } + + return (const char *) g_hash_table_lookup(ccf->class_labels, class_name); +} diff --git a/src/libserver/task.c b/src/libserver/task.c index 9f5b1f00a..f655ab11b 100644 --- a/src/libserver/task.c +++ b/src/libserver/task.c @@ -730,7 +730,7 @@ rspamd_task_process(struct rspamd_task *task, unsigned int stages) if (all_done && (task->flags & RSPAMD_TASK_FLAG_LEARN_AUTO) && !RSPAMD_TASK_IS_EMPTY(task) && - !(task->flags & (RSPAMD_TASK_FLAG_LEARN_SPAM | RSPAMD_TASK_FLAG_LEARN_HAM))) { + !(task->flags & (RSPAMD_TASK_FLAG_LEARN_SPAM | RSPAMD_TASK_FLAG_LEARN_HAM | RSPAMD_TASK_FLAG_LEARN_CLASS))) { rspamd_stat_check_autolearn(task); } break; @@ -738,12 +738,32 @@ rspamd_task_process(struct rspamd_task *task, unsigned int stages) case RSPAMD_TASK_STAGE_LEARN: case RSPAMD_TASK_STAGE_LEARN_PRE: case RSPAMD_TASK_STAGE_LEARN_POST: - if (task->flags & (RSPAMD_TASK_FLAG_LEARN_SPAM | RSPAMD_TASK_FLAG_LEARN_HAM)) { + if (task->flags & (RSPAMD_TASK_FLAG_LEARN_SPAM | RSPAMD_TASK_FLAG_LEARN_HAM | RSPAMD_TASK_FLAG_LEARN_CLASS)) { if (task->err == NULL) { - if (!rspamd_stat_learn(task, - task->flags & RSPAMD_TASK_FLAG_LEARN_SPAM, - task->cfg->lua_state, task->classifier, - st, &stat_error)) { + gboolean learn_result = FALSE; + + if (task->flags & RSPAMD_TASK_FLAG_LEARN_CLASS) { + /* Multi-class learning */ + const char *autolearn_class = rspamd_task_get_autolearn_class(task); + if (autolearn_class) { + learn_result = rspamd_stat_learn_class(task, autolearn_class, + task->cfg->lua_state, task->classifier, + st, &stat_error); + } + else { + g_set_error(&stat_error, g_quark_from_static_string("stat"), 500, + "No autolearn class specified for multi-class learning"); + } + } + else { + /* Legacy binary learning */ + learn_result = rspamd_stat_learn(task, + task->flags & RSPAMD_TASK_FLAG_LEARN_SPAM, + task->cfg->lua_state, task->classifier, + st, &stat_error); + } + + if (!learn_result) { if (stat_error == NULL) { g_set_error(&stat_error, @@ -922,15 +942,14 @@ rspamd_learn_task_spam(struct rspamd_task *task, const char *classifier, GError **err) { + /* Use unified class-based approach internally */ + const char *class_name = is_spam ? "spam" : "ham"; + /* Disable learn auto flag to avoid bad learn codes */ task->flags &= ~RSPAMD_TASK_FLAG_LEARN_AUTO; - if (is_spam) { - task->flags |= RSPAMD_TASK_FLAG_LEARN_SPAM; - } - else { - task->flags |= RSPAMD_TASK_FLAG_LEARN_HAM; - } + /* Use the unified class-based learning approach */ + rspamd_task_set_autolearn_class(task, class_name); task->classifier = classifier; diff --git a/src/libserver/task.h b/src/libserver/task.h index 1c1778fee..a1742e160 100644 --- a/src/libserver/task.h +++ b/src/libserver/task.h @@ -104,9 +104,9 @@ enum rspamd_task_stage { #define RSPAMD_TASK_FLAG_LEARN_SPAM (1u << 12u) #define RSPAMD_TASK_FLAG_LEARN_HAM (1u << 13u) #define RSPAMD_TASK_FLAG_LEARN_AUTO (1u << 14u) +#define RSPAMD_TASK_FLAG_LEARN_CLASS (1u << 25u) #define RSPAMD_TASK_FLAG_BROKEN_HEADERS (1u << 15u) -#define RSPAMD_TASK_FLAG_HAS_SPAM_TOKENS (1u << 16u) -#define RSPAMD_TASK_FLAG_HAS_HAM_TOKENS (1u << 17u) +/* Removed RSPAMD_TASK_FLAG_HAS_SPAM_TOKENS and RSPAMD_TASK_FLAG_HAS_HAM_TOKENS - not needed in multi-class */ #define RSPAMD_TASK_FLAG_EMPTY (1u << 18u) #define RSPAMD_TASK_FLAG_PROFILE (1u << 19u) #define RSPAMD_TASK_FLAG_GREYLISTED (1u << 20u) @@ -114,7 +114,7 @@ enum rspamd_task_stage { #define RSPAMD_TASK_FLAG_SSL (1u << 22u) #define RSPAMD_TASK_FLAG_BAD_UNICODE (1u << 23u) #define RSPAMD_TASK_FLAG_MESSAGE_REWRITE (1u << 24u) -#define RSPAMD_TASK_FLAG_MAX_SHIFT (24u) +#define RSPAMD_TASK_FLAG_MAX_SHIFT (25u) /* Request has been done by a local client */ #define RSPAMD_TASK_PROTOCOL_FLAG_LOCAL_CLIENT (1u << 1u) |