/* * Copyright (c) 2009, Rambler media * All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright * notice, this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * THIS SOFTWARE IS PROVIDED BY Rambler media ''AS IS'' AND ANY * EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL Rambler BE LIABLE FOR ANY * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. */ #include "config.h" #include "util.h" #include "main.h" #include "message.h" #include "protocol.h" #include "upstream.h" #include "cfg_file.h" #include "modules.h" #include "tokenizers/tokenizers.h" #include "classifiers/classifiers.h" #include "binlog.h" #include "statfile_sync.h" #define END "END" CRLF /* 120 seconds for controller's IO */ #define CONTROLLER_IO_TIMEOUT 120 enum command_type { COMMAND_PASSWORD, COMMAND_QUIT, COMMAND_RELOAD, COMMAND_STAT, COMMAND_SHUTDOWN, COMMAND_UPTIME, COMMAND_LEARN, COMMAND_HELP, COMMAND_COUNTERS, COMMAND_SYNC, COMMAND_WEIGHTS }; struct controller_command { char *command; gboolean privilleged; enum command_type type; }; struct custom_controller_command { const char *command; gboolean privilleged; gboolean require_message; controller_func_t handler; }; static struct controller_command commands[] = { {"password", FALSE, COMMAND_PASSWORD}, {"quit", FALSE, COMMAND_QUIT}, {"reload", TRUE, COMMAND_RELOAD}, {"stat", FALSE, COMMAND_STAT}, {"shutdown", TRUE, COMMAND_SHUTDOWN}, {"uptime", FALSE, COMMAND_UPTIME}, {"learn", TRUE, COMMAND_LEARN}, {"weights", FALSE, COMMAND_WEIGHTS}, {"help", FALSE, COMMAND_HELP}, {"counters", FALSE, COMMAND_COUNTERS}, {"sync", FALSE, COMMAND_SYNC} }; static GList *custom_commands = NULL; static GCompletion *comp; static time_t start_time; static char greetingbuf[1024]; extern rspamd_hash_t *counters; static gboolean controller_write_socket (void *arg); #ifndef HAVE_SA_SIGINFO static void sig_handler (int signo) #else static void sig_handler (int signo, siginfo_t *info, void *unused) #endif { switch (signo) { case SIGINT: case SIGTERM: _exit (1); break; } } static void sigusr_handler (int fd, short what, void *arg) { struct rspamd_worker *worker = (struct rspamd_worker *)arg; /* Do not accept new connections, preparing to end worker's process */ struct timeval tv; tv.tv_sec = 2; tv.tv_usec = 0; event_del (&worker->sig_ev); event_del (&worker->bind_ev); msg_info ("controller's shutdown is pending in %d sec", 2); event_loopexit (&tv); return; } static gchar * completion_func (gpointer elem) { struct controller_command *cmd = (struct controller_command *)elem; return cmd->command; } static void free_session (void *ud) { GList *part; struct mime_part *p; struct controller_session *session = ud; msg_debug ("freeing session %p", session); while ((part = g_list_first (session->parts))) { session->parts = g_list_remove_link (session->parts, part); p = (struct mime_part *)part->data; g_byte_array_free (p->content, FALSE); g_list_free_1 (part); } rspamd_remove_dispatcher (session->dispatcher); close (session->sock); memory_pool_delete (session->session_pool); g_free (session); } static int check_auth (struct controller_command *cmd, struct controller_session *session) { char out_buf[128]; int r; if (cmd->privilleged && !session->authorized) { r = snprintf (out_buf, sizeof (out_buf), "not authorized" CRLF); rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE); return 0; } return 1; } static void counter_write_callback (gpointer key, gpointer value, void *data) { struct controller_session *session = data; struct counter_data *cd = value; char *name = key; char out_buf[128]; int r; r = snprintf (out_buf, sizeof (out_buf), "%s: %llu" CRLF, name, (unsigned long long int)cd->value); rspamd_dispatcher_write (session->dispatcher, out_buf, r, TRUE, FALSE); } static gboolean write_whole_statfile (struct controller_session *session, char *symbol, struct classifier_config *ccf) { stat_file_t *statfile; struct statfile *st; char out_buf[BUFSIZ]; int i; uint64_t rev, ti, len, pos; char *out; struct rspamd_binlog_element log_elt; struct stat_file_block *stat_elt; statfile = get_statfile_by_symbol (session->worker->srv->statfile_pool, ccf, symbol, &st, FALSE); if (statfile == NULL) { return FALSE; } /* Begin to copy all blocks into array */ statfile_get_revision (statfile, &rev, (time_t *)&ti); if (ti == 0) { /* Not tracked file */ ti = time (NULL); statfile_set_revision (statfile, rev, ti); } len = statfile->cur_section.length * sizeof (struct rspamd_binlog_element); out = memory_pool_alloc (session->session_pool, len); for (i = 0, pos = 0; i < statfile->cur_section.length; i ++) { stat_elt = (struct stat_file_block *)((u_char *)statfile->map + statfile->seek_pos + i * sizeof (struct stat_file_block)); if (fabs (stat_elt->value) > 0.001) { /* Write only those values which value is not 0 */ log_elt.h1 = stat_elt->hash1; log_elt.h2 = stat_elt->hash2; log_elt.value = stat_elt->value; memcpy (out + pos, &log_elt, sizeof (log_elt)); pos += sizeof (struct rspamd_binlog_element); } } i = rspamd_snprintf (out_buf, sizeof (out_buf), "%uL %uL %uL" CRLF, rev, ti, pos); rspamd_dispatcher_write (session->dispatcher, out_buf, i, TRUE, FALSE); if (!rspamd_dispatcher_write (session->dispatcher, out, pos, TRUE, TRUE)) { return FALSE; } return TRUE; } static gboolean process_sync_command (struct controller_session *session, char **args) { char out_buf[BUFSIZ], *arg, *err_str, *symbol; int r; uint64_t rev, time; struct statfile *st; struct classifier_config *ccf; GList *cur; struct rspamd_binlog *binlog; GByteArray *data = NULL; arg = *args; if (!arg || *arg == '\0') { msg_info ("bad arguments to sync command, need symbol"); return FALSE; } symbol = arg; arg = *(args + 1); if (!arg || *arg == '\0') { msg_info ("bad arguments to sync command, need revision"); return FALSE; } rev = strtoull (arg, &err_str, 10); if (err_str && *err_str != 0) { msg_info ("bad arguments to sync commanc: %s", arg); return FALSE; } arg = *(args + 2); if (!arg || *arg == '\0') { msg_info ("bad arguments to sync command, need time"); return FALSE; } time = strtoull (arg, &err_str, 10); if (err_str && *err_str != 0) { msg_info ("bad arguments to sync commanc: %s", arg); return FALSE; } ccf = g_hash_table_lookup (session->cfg->classifiers_symbols, symbol); if (ccf == NULL) { msg_info ("bad symbol: %s", symbol); return FALSE; } cur = g_list_first (ccf->statfiles); while (cur) { st = cur->data; if (strcmp (symbol, st->symbol) == 0) { break; } st = NULL; cur = g_list_next (cur); } if (st == NULL) { msg_info ("bad symbol: %s", symbol); return FALSE; } binlog = get_binlog_by_statfile (st); if (binlog == NULL) { msg_info ("cannot open binlog: %s", symbol); return FALSE; } while (binlog_sync (binlog, rev, &time, &data)) { r = snprintf (out_buf, sizeof (out_buf), "%lu %lu %lu" CRLF, (long unsigned)rev, (long unsigned)time, (long unsigned)data->len); rspamd_dispatcher_write (session->dispatcher, out_buf, r, TRUE, FALSE); if (!rspamd_dispatcher_write (session->dispatcher, data->data, data->len, TRUE, FALSE)) { if (data != NULL) { g_free (data); } return FALSE; } rev ++; } if (time == 0) { if (data != NULL) { g_free (data); } return write_whole_statfile (session, symbol, ccf); } if (data != NULL) { g_free (data); } return TRUE; } static gboolean process_stat_command (struct controller_session *session) { char out_buf[BUFSIZ], *numbuf; int r; uint64_t used, total, rev; time_t ti; memory_pool_stat_t mem_st; struct classifier_config *ccf; stat_file_t *statfile; struct statfile *st; GList *cur_cl, *cur_st; memory_pool_stat (&mem_st); r = snprintf (out_buf, sizeof (out_buf), "Messages scanned: %u" CRLF, session->worker->srv->stat->messages_scanned); if (session->worker->srv->stat->messages_scanned > 0) { r += snprintf (out_buf + r, sizeof (out_buf) - r, "Messages treated as spam: %u, %.2f%%" CRLF, session->worker->srv->stat->messages_spam, (double)session->worker->srv->stat->messages_spam / (double)session->worker->srv->stat->messages_scanned * 100.); r += snprintf (out_buf + r, sizeof (out_buf) - r, "Messages treated as ham: %u, %.2f%%" CRLF, session->worker->srv->stat->messages_ham, (double)session->worker->srv->stat->messages_ham / (double)session->worker->srv->stat->messages_scanned * 100.); } r += snprintf (out_buf + r, sizeof (out_buf) - r, "Messages learned: %u" CRLF, session->worker->srv->stat->messages_learned); r += snprintf (out_buf + r, sizeof (out_buf) - r, "Connections count: %u" CRLF, session->worker->srv->stat->connections_count); r += snprintf (out_buf + r, sizeof (out_buf) - r, "Control connections count: %u" CRLF, session->worker->srv->stat->control_connections_count); r += snprintf (out_buf + r, sizeof (out_buf) - r, "Pools allocated: %ld" CRLF, (long int)mem_st.pools_allocated); r += snprintf (out_buf + r, sizeof (out_buf) - r, "Pools freed: %ld" CRLF, (long int)mem_st.pools_freed); r += snprintf (out_buf + r, sizeof (out_buf) - r, "Bytes allocated: %ld" CRLF, (long int)mem_st.bytes_allocated); r += snprintf (out_buf + r, sizeof (out_buf) - r, "Memory chunks allocated: %ld" CRLF, (long int)mem_st.chunks_allocated); r += snprintf (out_buf + r, sizeof (out_buf) - r, "Shared chunks allocated: %ld" CRLF, (long int)mem_st.shared_chunks_allocated); r += snprintf (out_buf + r, sizeof (out_buf) - r, "Chunks freed: %ld" CRLF, (long int)mem_st.chunks_freed); r += snprintf (out_buf + r, sizeof (out_buf) - r, "Oversized chunks: %ld" CRLF, (long int)mem_st.oversized_chunks); /* Now write statistics for each statfile */ cur_cl = g_list_first (session->cfg->classifiers); while (cur_cl) { ccf = cur_cl->data; cur_st = g_list_first (ccf->statfiles); while (cur_st) { st = cur_st->data; if ((statfile = statfile_pool_is_open (session->worker->srv->statfile_pool, st->path)) == NULL) { statfile = statfile_pool_open (session->worker->srv->statfile_pool, st->path, st->size, FALSE); } if (statfile) { used = statfile_get_used_blocks (statfile); total = statfile_get_total_blocks (statfile); statfile_get_revision (statfile, &rev, &ti); if (total != (uint64_t)-1 && used != (uint64_t)-1) { numbuf = g_format_size_for_display (st->size); r += snprintf (out_buf + r, sizeof (out_buf) - r, "Statfile: %s (version %lu); length: %s; free blocks: %lu; total blocks: %lu; free: %.2f%%" CRLF, st->symbol, (long unsigned)rev, numbuf, (long unsigned)(total - used), (long unsigned)total, (double)((double)(total - used) / (double)total) * 100.); g_free (numbuf); } } cur_st = g_list_next (cur_st); } cur_cl = g_list_next (cur_cl); } return rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE); } static void process_command (struct controller_command *cmd, char **cmd_args, struct controller_session *session) { char out_buf[BUFSIZ], *arg, *err_str; int r = 0, days, hours, minutes; time_t uptime; unsigned long size = 0; struct classifier_config *cl; char *password = g_hash_table_lookup (session->worker->cf->params, "password"); switch (cmd->type) { case COMMAND_PASSWORD: arg = *cmd_args; if (!arg || *arg == '\0') { msg_debug ("empty password passed"); r = snprintf (out_buf, sizeof (out_buf), "password command requires one argument" CRLF); rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE); return; } if (password == NULL) { r = snprintf (out_buf, sizeof (out_buf), "password command disabled in config, authorized access unallowed" CRLF); rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE); return; } if (strncmp (arg, password, strlen (arg)) == 0) { session->authorized = 1; r = snprintf (out_buf, sizeof (out_buf), "password accepted" CRLF); rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE); } else { session->authorized = 0; r = snprintf (out_buf, sizeof (out_buf), "password NOT accepted" CRLF); rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE); } break; case COMMAND_QUIT: session->state = STATE_QUIT; break; case COMMAND_RELOAD: if (check_auth (cmd, session)) { r = snprintf (out_buf, sizeof (out_buf), "reload request sent" CRLF); rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE); kill (getppid (), SIGHUP); } break; case COMMAND_STAT: if (check_auth (cmd, session)) { (void)process_stat_command (session); } break; case COMMAND_SHUTDOWN: if (check_auth (cmd, session)) { r = snprintf (out_buf, sizeof (out_buf), "shutdown request sent" CRLF); rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE); kill (getppid (), SIGTERM); } break; case COMMAND_UPTIME: if (check_auth (cmd, session)) { uptime = time (NULL) - start_time; /* If uptime more than 2 hours, print as a number of days. */ if (uptime >= 2 * 3600) { days = uptime / 86400; hours = uptime / 3600 - days * 3600; minutes = uptime / 60 - hours * 60 - days * 3600; r = snprintf (out_buf, sizeof (out_buf), "%d day%s %d hour%s %d minute%s" CRLF, days, days > 1 ? "s" : " ", hours, hours > 1 ? "s" : " ", minutes, minutes > 1 ? "s" : " "); } /* If uptime is less than 1 minute print only seconds */ else if (uptime / 60 == 0) { r = snprintf (out_buf, sizeof (out_buf), "%d second%s" CRLF, (int)uptime, (int)uptime > 1 ? "s" : " "); } /* Else print the minutes and seconds. */ else { hours = uptime / 3600; minutes = uptime / 60 - hours * 3600; uptime -= hours * 3600 + minutes * 60; r = snprintf (out_buf, sizeof (out_buf), "%d hour%s %d minite%s %d second%s" CRLF, hours, hours > 1 ? "s" : " ", minutes, minutes > 1 ? "s" : " ", (int)uptime, uptime > 1 ? "s" : " "); } rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE); } break; case COMMAND_LEARN: if (check_auth (cmd, session)) { arg = *cmd_args; if (!arg || *arg == '\0') { msg_debug ("no statfile specified in learn command"); r = snprintf (out_buf, sizeof (out_buf), "learn command requires at least two arguments: stat filename and its size" CRLF); rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE); return; } arg = *(cmd_args + 1); if (arg == NULL || *arg == '\0') { msg_debug ("no statfile size specified in learn command"); r = snprintf (out_buf, sizeof (out_buf), "learn command requires at least two arguments: stat filename and its size" CRLF); rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE); return; } size = strtoul (arg, &err_str, 10); if (err_str && *err_str != '\0') { msg_debug ("message size is invalid: %s", arg); r = snprintf (out_buf, sizeof (out_buf), "learn size is invalid" CRLF); rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE); return; } session->learn_symbol = memory_pool_strdup (session->session_pool, *cmd_args); cl = g_hash_table_lookup (session->cfg->classifiers_symbols, *cmd_args); if (cl == NULL) { r = snprintf (out_buf, sizeof (out_buf), "statfile %s is not defined" CRLF, *cmd_args); rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE); return; } session->learn_classifier = cl; /* By default learn positive */ session->in_class = 1; session->learn_multiplier = 1.; /* Get all arguments */ while (*cmd_args++) { arg = *cmd_args; if (arg && *arg == '-') { switch (*(arg + 1)) { case 'r': arg = *(cmd_args + 1); if (!arg || *arg == '\0') { r = snprintf (out_buf, sizeof (out_buf), "recipient is not defined" CRLF); rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE); return; } session->learn_rcpt = memory_pool_strdup (session->session_pool, arg); break; case 'f': arg = *(cmd_args + 1); if (!arg || *arg == '\0') { r = snprintf (out_buf, sizeof (out_buf), "from is not defined" CRLF); rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE); return; } session->learn_from = memory_pool_strdup (session->session_pool, arg); break; case 'n': session->in_class = 0; break; case 'm': arg = *(cmd_args + 1); if (!arg || *arg == '\0') { r = snprintf (out_buf, sizeof (out_buf), "recipient is not defined" CRLF); rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE); return; } session->learn_multiplier = strtod (arg, NULL); break; default: r = snprintf (out_buf, sizeof (out_buf), "tokenizer is not defined" CRLF); rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE); return; } } } rspamd_set_dispatcher_policy (session->dispatcher, BUFFER_CHARACTER, size); session->state = STATE_LEARN; } break; case COMMAND_WEIGHTS: arg = *cmd_args; if (!arg || *arg == '\0') { msg_debug ("no statfile specified in weights command"); r = snprintf (out_buf, sizeof (out_buf), "weights command requires two arguments: statfile and message size" CRLF); rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE); return; } arg = *(cmd_args + 1); if (arg == NULL || *arg == '\0') { msg_debug ("no message size specified in weights command"); r = snprintf (out_buf, sizeof (out_buf), "weights command requires two arguments: statfile and message size" CRLF); rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE); return; } size = strtoul (arg, &err_str, 10); if (err_str && *err_str != '\0') { msg_debug ("message size is invalid: %s", arg); r = snprintf (out_buf, sizeof (out_buf), "message size is invalid" CRLF); rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE); return; } cl = g_hash_table_lookup (session->cfg->classifiers_symbols, *cmd_args); if (cl == NULL) { r = snprintf (out_buf, sizeof (out_buf), "statfile %s is not defined" CRLF, *cmd_args); rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE); return; } session->learn_classifier = cl; rspamd_set_dispatcher_policy (session->dispatcher, BUFFER_CHARACTER, size); session->state = STATE_WEIGHTS; break; case COMMAND_SYNC: if (!process_sync_command (session, cmd_args)) { r = snprintf (out_buf, sizeof (out_buf), "FAIL" CRLF); rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE); return; } break; case COMMAND_HELP: r = snprintf (out_buf, sizeof (out_buf), "Rspamd CLI commands (* - privilleged command):" CRLF " help - this help message" CRLF "(*) learn [-r recipient] [-m multiplier] [-f from] [-n] - learn message to specified statfile" CRLF " quit - quit CLI session" CRLF "(*) reload - reload rspamd" CRLF "(*) shutdown - shutdown rspamd" CRLF " stat - show different rspamd stat" CRLF " counters - show rspamd counters" CRLF " uptime - rspamd uptime" CRLF); rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE); break; case COMMAND_COUNTERS: rspamd_hash_foreach (counters, counter_write_callback, session); break; } } static gboolean process_custom_command (char *line, char **cmd_args, struct controller_session *session) { GList *cur; struct custom_controller_command *cmd; cur = custom_commands; while (cur) { cmd = cur->data; if (g_ascii_strcasecmp (cmd->command, line) == 0) { /* Call handler */ cmd->handler (cmd_args, session); return TRUE; } cur = g_list_next (cur); } return FALSE; } static gboolean controller_read_socket (f_str_t * in, void *arg) { struct controller_session *session = (struct controller_session *)arg; struct classifier_ctx *cls_ctx; stat_file_t *statfile; struct statfile *st; int len, i, r; char *s, **params, *cmd, out_buf[128]; struct worker_task *task; struct mime_text_part *part; GList *comp_list, *cur = NULL; GTree *tokens = NULL; f_str_t c; double sum; switch (session->state) { case STATE_COMMAND: s = fstrcstr (in, session->session_pool); params = g_strsplit (s, " ", -1); memory_pool_add_destructor (session->session_pool, (pool_destruct_func) g_strfreev, params); len = g_strv_length (params); if (len > 0) { cmd = g_strstrip (params[0]); comp_list = g_completion_complete (comp, cmd, NULL); switch (g_list_length (comp_list)) { case 1: process_command ((struct controller_command *)comp_list->data, ¶ms[1], session); break; case 0: if (!process_custom_command (cmd, ¶ms[1], session)) { msg_debug ("'%s'", cmd); i = snprintf (out_buf, sizeof (out_buf), "Unknown command" CRLF); if (!rspamd_dispatcher_write (session->dispatcher, out_buf, i, FALSE, FALSE)) { return FALSE; } } break; default: msg_debug ("'%s'", cmd); i = snprintf (out_buf, sizeof (out_buf), "Ambigious command" CRLF); if (!rspamd_dispatcher_write (session->dispatcher, out_buf, i, FALSE, FALSE)) { return FALSE; } break; } } if (session->state == STATE_COMMAND) { session->state = STATE_REPLY; } if (session->state != STATE_LEARN && session->state != STATE_WEIGHTS && session->state != STATE_OTHER) { if (!rspamd_dispatcher_write (session->dispatcher, END, sizeof (END) - 1, FALSE, TRUE)) { return FALSE; } } break; case STATE_LEARN: session->learn_buf = in; task = construct_task (session->worker); task->msg = memory_pool_alloc (task->task_pool, sizeof (f_str_t)); task->msg->begin = in->begin; task->msg->len = in->len; r = process_message (task); if (r == -1) { msg_warn ("processing of message failed"); free_task (task, FALSE); session->state = STATE_REPLY; r = snprintf (out_buf, sizeof (out_buf), "cannot process message" CRLF); rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE); return FALSE; } if ((s = g_hash_table_lookup (session->learn_classifier->opts, "header")) != NULL) { cur = message_get_header (task->task_pool, task->message, s); if (cur) { memory_pool_add_destructor (task->task_pool, (pool_destruct_func)g_list_free, cur); } } else { cur = g_list_first (task->text_parts); } while (cur) { if (s != NULL) { c.len = strlen (cur->data); c.begin = cur->data; } else { part = cur->data; if (part->is_empty) { cur = g_list_next (cur); continue; } c.begin = part->content->data; c.len = part->content->len; } if (!session->learn_classifier->tokenizer->tokenize_func (session->learn_classifier->tokenizer, session->session_pool, &c, &tokens)) { i = snprintf (out_buf, sizeof (out_buf), "learn fail, tokenizer error" CRLF); free_task (task, FALSE); if (!rspamd_dispatcher_write (session->dispatcher, out_buf, i, FALSE, FALSE)) { return FALSE; } session->state = STATE_REPLY; return TRUE; } cur = g_list_next (cur); } /* Handle messages without text */ if (tokens == NULL) { i = snprintf (out_buf, sizeof (out_buf), "learn fail, no tokens can be extracted (no text data)" CRLF); free_task (task, FALSE); if (!rspamd_dispatcher_write (session->dispatcher, out_buf, i, FALSE, FALSE)) { return FALSE; } session->state = STATE_REPLY; return TRUE; } /* Get or create statfile */ statfile = get_statfile_by_symbol (session->worker->srv->statfile_pool, session->learn_classifier, session->learn_symbol, &st, TRUE); if (statfile == NULL) { free_task (task, FALSE); i = snprintf (out_buf, sizeof (out_buf), "learn failed" CRLF); if (!rspamd_dispatcher_write (session->dispatcher, out_buf, i, FALSE, FALSE)) { return FALSE; } return TRUE; } /* Init classifier */ cls_ctx = session->learn_classifier->classifier->init_func (session->session_pool, session->learn_classifier); /* XXX: remove this awful legacy */ session->learn_classifier->classifier->learn_func (cls_ctx, session->worker->srv->statfile_pool, statfile, tokens, session->in_class, &sum, session->learn_multiplier); session->worker->srv->stat->messages_learned++; maybe_write_binlog (session->learn_classifier, st, statfile, tokens); if (st->normalizer != NULL) { sum = st->normalizer (sum, st->normalizer_data); } free_task (task, FALSE); i = snprintf (out_buf, sizeof (out_buf), "learn ok, sum weight: %.2f" CRLF, sum); if (!rspamd_dispatcher_write (session->dispatcher, out_buf, i, FALSE, FALSE)) { return FALSE; } session->state = STATE_REPLY; break; case STATE_WEIGHTS: session->learn_buf = in; task = construct_task (session->worker); task->msg = memory_pool_alloc (task->task_pool, sizeof (f_str_t)); task->msg->begin = in->begin; task->msg->len = in->len; r = process_message (task); if (r == -1) { msg_warn ("processing of message failed"); free_task (task, FALSE); session->state = STATE_REPLY; r = snprintf (out_buf, sizeof (out_buf), "cannot process message" CRLF); rspamd_dispatcher_write (session->dispatcher, out_buf, r, FALSE, FALSE); return FALSE; } cur = g_list_first (task->text_parts); while (cur) { part = cur->data; if (part->is_empty) { cur = g_list_next (cur); continue; } c.begin = part->content->data; c.len = part->content->len; if (!session->learn_classifier->tokenizer->tokenize_func (session->learn_classifier->tokenizer, session->session_pool, &c, &tokens)) { i = snprintf (out_buf, sizeof (out_buf), "weights fail, tokenizer error" CRLF); free_task (task, FALSE); if (!rspamd_dispatcher_write (session->dispatcher, out_buf, i, FALSE, FALSE)) { return FALSE; } session->state = STATE_REPLY; return TRUE; } cur = g_list_next (cur); } /* Handle messages without text */ if (tokens == NULL) { i = snprintf (out_buf, sizeof (out_buf), "weights fail, no tokens can be extracted (no text data)" CRLF); free_task (task, FALSE); if (!rspamd_dispatcher_write (session->dispatcher, out_buf, i, FALSE, FALSE)) { return FALSE; } session->state = STATE_REPLY; return TRUE; } /* Init classifier */ cls_ctx = session->learn_classifier->classifier->init_func (session->session_pool, session->learn_classifier); cur = session->learn_classifier->classifier->weights_func (cls_ctx, session->worker->srv->statfile_pool, tokens, task); i = 0; struct classify_weight *w; while (cur) { w = cur->data; i += snprintf (out_buf + i, sizeof (out_buf) - i, "%s: %.2f" CRLF, w->name, w->weight); cur = g_list_next (cur); } if (!rspamd_dispatcher_write (session->dispatcher, out_buf, i, FALSE, FALSE)) { return FALSE; } free_task (task, FALSE); session->state = STATE_REPLY; break; case STATE_OTHER: if (session->other_handler) { session->other_handler (session, in); } session->state = STATE_REPLY; break; case STATE_WAIT: rspamd_dispatcher_pause (session->dispatcher); break; default: msg_debug ("unknown state while reading %d", session->state); break; } if (session->state == STATE_REPLY || session->state == STATE_QUIT) { (void)controller_write_socket (session); } return TRUE; } static gboolean controller_write_socket (void *arg) { struct controller_session *session = (struct controller_session *)arg; if (session->state == STATE_QUIT) { /* Free buffers */ destroy_session (session->s); return FALSE; } else if (session->state == STATE_REPLY) { session->state = STATE_COMMAND; rspamd_set_dispatcher_policy (session->dispatcher, BUFFER_LINE, BUFSIZ); } return TRUE; } static void controller_err_socket (GError * err, void *arg) { struct controller_session *session = (struct controller_session *)arg; if (err->code != EOF) { msg_info ("abnormally closing control connection, error: %s", err->message); } /* Free buffers */ destroy_session (session->s); } static void accept_socket (int fd, short what, void *arg) { struct rspamd_worker *worker = (struct rspamd_worker *)arg; struct sockaddr_storage ss; struct controller_session *new_session; struct timeval *io_tv; socklen_t addrlen = sizeof (ss); int nfd; if ((nfd = accept_from_socket (fd, (struct sockaddr *)&ss, &addrlen)) == -1) { msg_warn ("accept failed: %s", strerror (errno)); return; } new_session = g_malloc (sizeof (struct controller_session)); if (new_session == NULL) { msg_err ("cannot allocate memory for task, %s", strerror (errno)); return; } bzero (new_session, sizeof (struct controller_session)); new_session->worker = worker; new_session->sock = nfd; new_session->cfg = worker->srv->cfg; new_session->state = STATE_COMMAND; new_session->session_pool = memory_pool_new (memory_pool_get_size () - 1); worker->srv->stat->control_connections_count++; /* Set up dispatcher */ io_tv = memory_pool_alloc (new_session->session_pool, sizeof (struct timeval)); io_tv->tv_sec = CONTROLLER_IO_TIMEOUT; io_tv->tv_usec = 0; new_session->s = new_async_session (new_session->session_pool, free_session, new_session); new_session->dispatcher = rspamd_create_dispatcher (nfd, BUFFER_LINE, controller_read_socket, controller_write_socket, controller_err_socket, io_tv, (void *)new_session); rspamd_dispatcher_write (new_session->dispatcher, greetingbuf, strlen (greetingbuf), FALSE, FALSE); } void start_controller (struct rspamd_worker *worker) { struct sigaction signals; int i; GList *comp_list = NULL; char *hostbuf; long int hostmax; worker->srv->pid = getpid (); event_init (); g_mime_init (0); init_signals (&signals, sig_handler); sigprocmask (SIG_UNBLOCK, &signals.sa_mask, NULL); /* SIGUSR2 handler */ signal_set (&worker->sig_ev, SIGUSR2, sigusr_handler, (void *)worker); signal_add (&worker->sig_ev, NULL); start_time = time (NULL); /* Start statfile synchronization */ if (!start_statfile_sync (worker->srv->statfile_pool, worker->srv->cfg)) { msg_info ("cannot start statfile synchronization, statfiles would not be synchronized"); } /* Init command completion */ for (i = 0; i < G_N_ELEMENTS (commands); i++) { comp_list = g_list_prepend (comp_list, &commands[i]); } comp = g_completion_new (completion_func); g_completion_add_items (comp, comp_list); /* Fill hostname buf */ hostmax = sysconf (_SC_HOST_NAME_MAX) + 1; hostbuf = alloca (hostmax); gethostname (hostbuf, hostmax); hostbuf[hostmax - 1] = '\0'; snprintf (greetingbuf, sizeof (greetingbuf), "Rspamd version %s is running on %s" CRLF, RVERSION, hostbuf); /* Accept event */ event_set (&worker->bind_ev, worker->cf->listen_sock, EV_READ | EV_PERSIST, accept_socket, (void *)worker); event_add (&worker->bind_ev, NULL); gperf_profiler_init (worker->srv->cfg, "controller"); event_loop (0); exit (EXIT_SUCCESS); } void register_custom_controller_command (const char *name, controller_func_t handler, gboolean privilleged, gboolean require_message) { struct custom_controller_command *cmd; cmd = g_malloc (sizeof (struct custom_controller_command)); cmd->command = name; cmd->handler = handler; cmd->privilleged = privilleged; cmd->require_message = require_message; custom_commands = g_list_prepend (custom_commands, cmd); } /* * vi:ts=4 */