diff options
author | Vsevolod Stakhov <vsevolod@rambler-co.ru> | 2011-07-14 17:48:51 +0400 |
---|---|---|
committer | Vsevolod Stakhov <vsevolod@rambler-co.ru> | 2011-07-14 17:48:51 +0400 |
commit | de262993190d03e40143e83067fb348a33ff9dbd (patch) | |
tree | b8ab7237bb1d7d695ef92eaf17ae90f306fb215e | |
parent | 8f5fb08857e989193ea30828175f61599512add5 (diff) | |
download | rspamd-de262993190d03e40143e83067fb348a33ff9dbd.tar.gz rspamd-de262993190d03e40143e83067fb348a33ff9dbd.zip |
* Add learn_spam/learn_ham interface to librspamdclient and to rspamc
* Improve logic of io dispatcher restoration
Remove correction factor from bayes as it leads to classify errors.
-rw-r--r-- | lib/librspamdclient.c | 165 | ||||
-rw-r--r-- | lib/librspamdclient.h | 15 | ||||
-rw-r--r-- | src/buffer.c | 9 | ||||
-rw-r--r-- | src/buffer.h | 1 | ||||
-rw-r--r-- | src/classifiers/bayes.c | 2 | ||||
-rw-r--r-- | src/client/rspamc.c | 161 | ||||
-rw-r--r-- | src/controller.c | 21 | ||||
-rw-r--r-- | src/filter.c | 4 | ||||
-rw-r--r-- | src/main.h | 1 |
9 files changed, 332 insertions, 47 deletions
diff --git a/lib/librspamdclient.c b/lib/librspamdclient.c index 39adf0f18..759b822e7 100644 --- a/lib/librspamdclient.c +++ b/lib/librspamdclient.c @@ -1383,7 +1383,7 @@ rspamd_learn_memory (const guchar *message, gsize length, const gchar *symbol, c } /* Read greeting */ - if (! rspamd_read_controller_greeting(c, err)) { + if (! rspamd_read_controller_greeting (c, err)) { if (*err == NULL) { *err = g_error_new (G_RSPAMD_ERROR, errno, "Invalid greeting"); } @@ -1463,7 +1463,7 @@ rspamd_learn_fd (int fd, const gchar *symbol, const gchar *password, GError **er } /* Read greeting */ - if (! rspamd_read_controller_greeting(c, err)) { + if (! rspamd_read_controller_greeting (c, err)) { if (*err == NULL) { *err = g_error_new (G_RSPAMD_ERROR, errno, "Invalid greeting"); } @@ -1515,6 +1515,163 @@ rspamd_learn_fd (int fd, const gchar *symbol, const gchar *password, GError **er } /* + * Learn message from memory + */ +gboolean +rspamd_learn_spam_memory (const guchar *message, gsize length, const gchar *classifier, gboolean is_spam, const gchar *password, GError **err) +{ + struct rspamd_connection *c; + GString *in; + gchar *outbuf; + guint r; + static const gchar ok_str[] = "learn ok"; + + g_assert (client != NULL); + g_assert (length > 0); + + /* Connect to server */ + c = rspamd_connect_random_server (TRUE, err); + + if (c == NULL) { + return FALSE; + } + + /* Read greeting */ + if (! rspamd_read_controller_greeting (c, err)) { + if (*err == NULL) { + *err = g_error_new (G_RSPAMD_ERROR, errno, "Invalid greeting"); + } + return FALSE; + } + /* Perform auth */ + if (! rspamd_controller_auth (c, password, err)) { + if (*err == NULL) { + *err = g_error_new (G_RSPAMD_ERROR, errno, "Authentication error"); + } + return FALSE; + } + + r = length + sizeof ("learn_spam %s %uz\r\n") + strlen (classifier) + sizeof ("4294967296"); + outbuf = g_malloc (r); + r = snprintf (outbuf, r, "learn_%s %s %lu\r\n%s", is_spam ? "spam" : "ham", + classifier, (unsigned long)length, message); + in = rspamd_send_controller_command (c, outbuf, r, -1, err); + g_free (outbuf); + if (in == NULL) { + return FALSE; + } + + /* Search for string learn ok */ + if (in->len > sizeof (ok_str) - 1 && memcmp (in->str, ok_str, sizeof (ok_str) - 1) == 0) { + upstream_ok (&c->server->up, c->connection_time); + return TRUE; + } + else { + if (*err == NULL) { + *err = g_error_new (G_RSPAMD_ERROR, errno, "Bad reply: %s", in->str); + } + } + return FALSE; +} + +/* + * Learn message from file + */ +gboolean +rspamd_learn_spam_file (const guchar *filename, const gchar *classifier, gboolean is_spam, const gchar *password, GError **err) +{ + gint fd; + g_assert (client != NULL); + + /* Open file */ + if ((fd = open (filename, O_RDONLY)) == -1) { + if (*err == NULL) { + *err = g_error_new (G_RSPAMD_ERROR, errno, "Open error for file %s: %s", + filename, strerror (errno)); + } + return FALSE; + } + + return rspamd_learn_spam_fd (fd, classifier, is_spam, password, err); +} + +/* + * Learn message from fd + */ +gboolean +rspamd_learn_spam_fd (int fd, const gchar *classifier, gboolean is_spam, const gchar *password, GError **err) +{ + struct rspamd_connection *c; + GString *in; + gchar *outbuf; + guint r; + struct stat st; + static const gchar ok_str[] = "learn ok"; + + g_assert (client != NULL); + + /* Connect to server */ + c = rspamd_connect_random_server (TRUE, err); + + if (c == NULL) { + return FALSE; + } + + /* Read greeting */ + if (! rspamd_read_controller_greeting (c, err)) { + if (*err == NULL) { + *err = g_error_new (G_RSPAMD_ERROR, errno, "Invalid greeting"); + } + return FALSE; + } + /* Perform auth */ + if (! rspamd_controller_auth (c, password, err)) { + if (*err == NULL) { + *err = g_error_new (G_RSPAMD_ERROR, errno, "Authentication error"); + } + return FALSE; + } + + /* Get length */ + if (fstat (fd, &st) == -1) { + if (*err == NULL) { + *err = g_error_new (G_RSPAMD_ERROR, errno, "Stat error: %s", + strerror (errno)); + } + return FALSE; + } + if (st.st_size == 0) { + if (*err == NULL) { + *err = g_error_new (G_RSPAMD_ERROR, -1, "File has zero length"); + } + return FALSE; + } + r = sizeof ("learn_spam %s %uz\r\n") + strlen (classifier) + sizeof ("4294967296"); + outbuf = g_malloc (r); + r = snprintf (outbuf, r, "learn_%s %s %lu\r\n", is_spam ? "spam" : "ham", + classifier, (unsigned long)st.st_size); + in = rspamd_send_controller_command (c, outbuf, r, fd, err); + g_free (outbuf); + if (in == NULL) { + return FALSE; + } + + /* Search for string learn ok */ + if (in->len > sizeof (ok_str) - 1 && memcmp (in->str, ok_str, sizeof (ok_str) - 1) == 0) { + upstream_ok (&c->server->up, c->connection_time); + return TRUE; + } + else { + if (*err == NULL) { + *err = g_error_new (G_RSPAMD_ERROR, errno, "Bad reply: %s", in->str); + } + } + + return FALSE; +} + + +/* * Learn message fuzzy from memory */ gboolean @@ -1537,7 +1694,7 @@ rspamd_fuzzy_memory (const guchar *message, gsize length, const gchar *password, } /* Read greeting */ - if (! rspamd_read_controller_greeting(c, err)) { + if (! rspamd_read_controller_greeting (c, err)) { if (*err == NULL) { *err = g_error_new (G_RSPAMD_ERROR, errno, "Invalid greeting"); } @@ -1622,7 +1779,7 @@ rspamd_fuzzy_fd (int fd, const gchar *password, gint weight, gint flag, gboolean } /* Read greeting */ - if (! rspamd_read_controller_greeting(c, err)) { + if (! rspamd_read_controller_greeting (c, err)) { if (*err == NULL) { *err = g_error_new (G_RSPAMD_ERROR, errno, "Invalid greeting"); } diff --git a/lib/librspamdclient.h b/lib/librspamdclient.h index 72b7311a7..180566fe5 100644 --- a/lib/librspamdclient.h +++ b/lib/librspamdclient.h @@ -70,6 +70,21 @@ struct rspamd_result * rspamd_scan_fd (int fd, GHashTable *headers, GError **err /* * Learn message from memory */ +gboolean rspamd_learn_spam_memory (const guchar *message, gsize length, const gchar *classifier, gboolean is_spam, const gchar *password, GError **err); + +/* + * Learn message from file + */ +gboolean rspamd_learn_spam_file (const guchar *filename, const gchar *classifier, gboolean is_spam, const gchar *password, GError **err); + +/* + * Learn message from fd + */ +gboolean rspamd_learn_spam_fd (int fd, const gchar *classifier, gboolean is_spam, const gchar *password, GError **err); + +/* + * Learn message from memory + */ gboolean rspamd_learn_memory (const guchar *message, gsize length, const gchar *symbol, const gchar *password, GError **err); /* diff --git a/src/buffer.c b/src/buffer.c index ed39d9205..7280c9162 100644 --- a/src/buffer.c +++ b/src/buffer.c @@ -463,6 +463,13 @@ dispatcher_cb (gint fd, short what, void *arg) event_del (d->ev); event_set (d->ev, fd, EV_READ | EV_PERSIST, dispatcher_cb, (void *)d); event_add (d->ev, d->tv); + if (d->is_restored && d->write_callback) { + if (!d->write_callback (d->user_data)) { + debug_ip ("callback set wanna_die flag, terminating"); + return; + } + d->is_restored = FALSE; + } } else { /* Delayed write */ @@ -637,7 +644,9 @@ rspamd_dispatcher_pause (rspamd_io_dispatcher_t * d) void rspamd_dispatcher_restore (rspamd_io_dispatcher_t * d) { + event_set (d->ev, d->fd, EV_READ | EV_WRITE, dispatcher_cb, d); event_add (d->ev, d->tv); + d->is_restored = TRUE; } #undef debug_ip diff --git a/src/buffer.h b/src/buffer.h index fc92511b0..16394ae7d 100644 --- a/src/buffer.h +++ b/src/buffer.h @@ -51,6 +51,7 @@ typedef struct rspamd_io_dispatcher_s { gint sendfile_fd; gboolean in_sendfile; /**< whether buffer is in sendfile mode */ gboolean strip_eol; /**< strip or not line ends in BUFFER_LINE policy */ + gboolean is_restored; /**< call a callback when dispatcher is restored */ #ifndef HAVE_SENDFILE void *map; #endif diff --git a/src/classifiers/bayes.c b/src/classifiers/bayes.c index 44e9323a2..5998d0fdb 100644 --- a/src/classifiers/bayes.c +++ b/src/classifiers/bayes.c @@ -111,7 +111,7 @@ bayes_classify_callback (gpointer key, gpointer value, gpointer data) for (i = 0; i < cd->statfiles_num; i ++) { cur = &cd->statfiles[i]; - cur->value = statfile_pool_get_block (cd->pool, cur->file, node->h1, node->h2, cd->now) * cur->corr; + cur->value = statfile_pool_get_block (cd->pool, cur->file, node->h1, node->h2, cd->now); if (cur->value > 0) { cur->total_hits ++; cur->hits = cur->value; diff --git a/src/client/rspamc.c b/src/client/rspamc.c index 80798da05..a1d79dc96 100644 --- a/src/client/rspamc.c +++ b/src/client/rspamc.c @@ -31,13 +31,14 @@ #define DEFAULT_CONTROL_PORT 11334 static gchar *connect_str = "localhost"; -static gchar *password; -static gchar *statfile; -static gchar *ip; -static gchar *from; -static gchar *deliver_to; -static gchar *rcpt; -static gchar *user; +static gchar *password = NULL; +static gchar *statfile = NULL; +static gchar *ip = NULL; +static gchar *from = NULL; +static gchar *deliver_to = NULL; +static gchar *rcpt = NULL; +static gchar *user = NULL; +static gchar *classifier = NULL; static gint weight = 1; static gint flag; static gint timeout = 5; @@ -50,6 +51,7 @@ static GOptionEntry entries[] = { "connect", 'h', 0, G_OPTION_ARG_STRING, &connect_str, "Specify host and port", NULL }, { "password", 'P', 0, G_OPTION_ARG_STRING, &password, "Specify control password", NULL }, { "statfile", 's', 0, G_OPTION_ARG_STRING, &statfile, "Statfile to learn (symbol name)", NULL }, + { "classifier", 'c', 0, G_OPTION_ARG_STRING, &classifier, "Classifier to learn spam or ham", NULL }, { "weight", 'w', 0, G_OPTION_ARG_INT, &weight, "Weight for fuzzy operations", NULL }, { "flag", 'f', 0, G_OPTION_ARG_INT, &flag, "Flag for fuzzy operations", NULL }, { "pass", 'p', 0, G_OPTION_ARG_NONE, &pass_all, "Pass all filters", NULL }, @@ -67,6 +69,8 @@ enum rspamc_command { RSPAMC_COMMAND_UNKNOWN = 0, RSPAMC_COMMAND_SYMBOLS, RSPAMC_COMMAND_LEARN, + RSPAMC_COMMAND_LEARN_SPAM, + RSPAMC_COMMAND_LEARN_HAM, RSPAMC_COMMAND_FUZZY_ADD, RSPAMC_COMMAND_FUZZY_DEL, RSPAMC_COMMAND_STAT, @@ -111,6 +115,12 @@ check_rspamc_command (const gchar *cmd) else if (g_ascii_strcasecmp (cmd, "LEARN") == 0) { return RSPAMC_COMMAND_LEARN; } + else if (g_ascii_strcasecmp (cmd, "LEARN_SPAM") == 0) { + return RSPAMC_COMMAND_LEARN_SPAM; + } + else if (g_ascii_strcasecmp (cmd, "LEARN_HAM") == 0) { + return RSPAMC_COMMAND_LEARN_HAM; + } else if (g_ascii_strcasecmp (cmd, "FUZZY_ADD") == 0) { return RSPAMC_COMMAND_FUZZY_ADD; } @@ -379,14 +389,14 @@ scan_rspamd_file (const gchar *file) } static void -learn_rspamd_stdin () +learn_rspamd_stdin (gboolean is_spam) { gchar *in_buf; gint r = 0, len; GError *err = NULL; - if (password == NULL || statfile == NULL) { - fprintf (stderr, "cannot learn message without password and symbol name\n"); + if (password == NULL || (statfile == NULL && classifier == NULL)) { + fprintf (stderr, "cannot learn message without password and symbol/classifier name\n"); exit (EXIT_FAILURE); } /* Add server */ @@ -405,51 +415,94 @@ learn_rspamd_stdin () in_buf = g_realloc (in_buf, len); } } - if (!rspamd_learn_memory (in_buf, r, statfile, password, &err)) { - if (err != NULL) { - fprintf (stderr, "cannot learn message: %s\n", err->message); + if (statfile != NULL) { + if (!rspamd_learn_memory (in_buf, r, statfile, password, &err)) { + if (err != NULL) { + fprintf (stderr, "cannot learn message: %s\n", err->message); + } + else { + fprintf (stderr, "cannot learn message\n"); + } + exit (EXIT_FAILURE); } else { - fprintf (stderr, "cannot learn message\n"); + if (tty) { + printf ("\033[1m"); + } + PRINT_FUNC ("Results for host: %s: learn ok\n", connect_str); + if (tty) { + printf ("\033[0m"); + } } - exit (EXIT_FAILURE); } - else { - if (tty) { - printf ("\033[1m"); + else if (classifier != NULL) { + if (!rspamd_learn_spam_memory (in_buf, r, classifier, is_spam, password, &err)) { + if (err != NULL) { + fprintf (stderr, "cannot learn message: %s\n", err->message); + } + else { + fprintf (stderr, "cannot learn message\n"); + } + exit (EXIT_FAILURE); } - PRINT_FUNC ("Results for host: %s: learn ok\n", connect_str); - if (tty) { - printf ("\033[0m"); + else { + if (tty) { + printf ("\033[1m"); + } + PRINT_FUNC ("Results for host: %s: learn ok\n", connect_str); + if (tty) { + printf ("\033[0m"); + } } } } static void -learn_rspamd_file (const gchar *file) +learn_rspamd_file (gboolean is_spam, const gchar *file) { GError *err = NULL; - if (password == NULL || statfile == NULL) { - fprintf (stderr, "cannot learn message without password and symbol name\n"); + if (password == NULL || (statfile == NULL && classifier == NULL)) { + fprintf (stderr, "cannot learn message without password and symbol/classifier name\n"); exit (EXIT_FAILURE); } - if (!rspamd_learn_file (file, statfile, password, &err)) { - if (err != NULL) { - fprintf (stderr, "cannot learn message: %s\n", err->message); + if (statfile != NULL) { + if (!rspamd_learn_file (file, statfile, password, &err)) { + if (err != NULL) { + fprintf (stderr, "cannot learn message: %s\n", err->message); + } + else { + fprintf (stderr, "cannot learn message\n"); + } } else { - fprintf (stderr, "cannot learn message\n"); + if (tty) { + printf ("\033[1m"); + } + PRINT_FUNC ("learn ok\n"); + if (tty) { + printf ("\033[0m"); + } } } - else { - if (tty) { - printf ("\033[1m"); + else if (classifier != NULL) { + if (!rspamd_learn_spam_file (file, classifier, is_spam, password, &err)) { + if (err != NULL) { + fprintf (stderr, "cannot learn message: %s\n", err->message); + } + else { + fprintf (stderr, "cannot learn message\n"); + } } - PRINT_FUNC ("learn ok\n"); - if (tty) { - printf ("\033[0m"); + else { + if (tty) { + printf ("\033[1m"); + } + PRINT_FUNC ("learn ok\n"); + if (tty) { + printf ("\033[0m"); + } } } } @@ -615,7 +668,25 @@ main (gint argc, gchar **argv, gchar **env) scan_rspamd_stdin (); break; case RSPAMC_COMMAND_LEARN: - learn_rspamd_stdin (); + learn_rspamd_stdin (TRUE); + break; + case RSPAMC_COMMAND_LEARN_SPAM: + if (classifier != NULL) { + learn_rspamd_stdin (TRUE); + } + else { + fprintf (stderr, "no classifier specified\n"); + exit (EXIT_FAILURE); + } + break; + case RSPAMC_COMMAND_LEARN_HAM: + if (classifier != NULL) { + learn_rspamd_stdin (FALSE); + } + else { + fprintf (stderr, "no classifier specified\n"); + exit (EXIT_FAILURE); + } break; case RSPAMC_COMMAND_FUZZY_ADD: fuzzy_rspamd_stdin (FALSE); @@ -664,7 +735,25 @@ main (gint argc, gchar **argv, gchar **env) scan_rspamd_file (argv[i]); break; case RSPAMC_COMMAND_LEARN: - learn_rspamd_file (argv[i]); + learn_rspamd_file (TRUE, argv[i]); + break; + case RSPAMC_COMMAND_LEARN_SPAM: + if (classifier != NULL) { + learn_rspamd_file (TRUE, argv[i]); + } + else { + fprintf (stderr, "no classifier specified\n"); + exit (EXIT_FAILURE); + } + break; + case RSPAMC_COMMAND_LEARN_HAM: + if (classifier != NULL) { + learn_rspamd_file (FALSE, argv[i]); + } + else { + fprintf (stderr, "no classifier specified\n"); + exit (EXIT_FAILURE); + } break; case RSPAMC_COMMAND_FUZZY_ADD: fuzzy_rspamd_file (argv[i], FALSE); diff --git a/src/controller.c b/src/controller.c index a59bc3d32..8c9da3bb9 100644 --- a/src/controller.c +++ b/src/controller.c @@ -31,6 +31,8 @@ #include "cfg_file.h" #include "cfg_xml.h" #include "modules.h" +#include "map.h" +#include "dns.h" #include "tokenizers/tokenizers.h" #include "classifiers/classifiers.h" #include "binlog.h" @@ -74,6 +76,7 @@ struct custom_controller_command { struct rspamd_controller_ctx { char *password; guint32 timeout; + struct rspamd_dns_resolver *resolver; }; static struct controller_command commands[] = { @@ -558,7 +561,7 @@ process_command (struct controller_command *cmd, gchar **cmd_args, struct contro /* By default learn positive */ session->in_class = TRUE; rspamd_set_dispatcher_policy (session->dispatcher, BUFFER_CHARACTER, size); - session->state = STATE_LEARN_SPAM; + session->state = STATE_LEARN_SPAM_PRE; } break; case COMMAND_LEARN_HAM: @@ -604,7 +607,7 @@ process_command (struct controller_command *cmd, gchar **cmd_args, struct contro /* By default learn positive */ session->in_class = FALSE; rspamd_set_dispatcher_policy (session->dispatcher, BUFFER_CHARACTER, size); - session->state = STATE_LEARN_SPAM; + session->state = STATE_LEARN_SPAM_PRE; } break; case COMMAND_LEARN: @@ -864,7 +867,8 @@ controller_read_socket (f_str_t * in, void *arg) if (session->state == STATE_COMMAND) { session->state = STATE_REPLY; } - if (session->state != STATE_LEARN && session->state != STATE_WEIGHTS && session->state != STATE_OTHER) { + if (session->state != STATE_LEARN && session->state != STATE_LEARN_SPAM_PRE + && session->state != STATE_WEIGHTS && session->state != STATE_OTHER) { if (!rspamd_dispatcher_write (session->dispatcher, END, sizeof (END) - 1, FALSE, TRUE)) { return FALSE; } @@ -922,6 +926,7 @@ controller_read_socket (f_str_t * in, void *arg) task->msg->begin = in->begin; task->msg->len = in->len; + task->resolver = session->resolver; r = process_message (task); if (r == -1) { @@ -1105,7 +1110,6 @@ controller_write_socket (void *arg) else { i = rspamd_snprintf (out_buf, sizeof (out_buf), "learn ok" CRLF END); } - learn_task_spam (session->learn_classifier, session->learn_task, session->in_class, &err); session->learn_task->dispatcher = NULL; free_task (session->learn_task, FALSE); session->state = STATE_REPLY; @@ -1162,6 +1166,7 @@ accept_socket (gint fd, short what, void *arg) new_session->cfg = worker->srv->cfg; new_session->state = STATE_COMMAND; new_session->session_pool = memory_pool_new (memory_pool_get_size () - 1); + new_session->resolver = ctx->resolver; worker->srv->stat->control_connections_count++; /* Set up dispatcher */ @@ -1198,8 +1203,11 @@ start_controller (struct rspamd_worker *worker) struct sigaction signals; gchar *hostbuf; gsize hostmax; + struct rspamd_controller_ctx *ctx; worker->srv->pid = getpid (); + ctx = worker->ctx; + event_init (); g_mime_init (0); @@ -1228,10 +1236,15 @@ start_controller (struct rspamd_worker *worker) event_set (&worker->bind_ev, worker->cf->listen_sock, EV_READ | EV_PERSIST, accept_socket, (void *)worker); event_add (&worker->bind_ev, NULL); + start_map_watch (); + ctx->resolver = dns_resolver_init (worker->srv->cfg); + gperf_profiler_init (worker->srv->cfg, "controller"); event_loop (0); + close_log (worker->srv->logger); + exit (EXIT_SUCCESS); } diff --git a/src/filter.c b/src/filter.c index 8321e6d21..0ad82f94b 100644 --- a/src/filter.c +++ b/src/filter.c @@ -290,12 +290,12 @@ end: /* Call post filters */ lua_call_post_filters (task); task->state = WRITE_REPLY; - /* XXX: ugly direct call */ + if (task->fin_callback) { task->fin_callback (task->fin_arg); } else { - task->dispatcher->write_callback (task); + rspamd_dispatcher_restore (task->dispatcher); } return 1; } diff --git a/src/main.h b/src/main.h index d8761617f..6ac24e7ad 100644 --- a/src/main.h +++ b/src/main.h @@ -166,6 +166,7 @@ struct controller_session { void *other_data; /**< and its data */ struct rspamd_async_session* s; /**< async session object */ struct worker_task *learn_task; + struct rspamd_dns_resolver *resolver; /**< DNS resolver */ }; typedef void (*controller_func_t)(gchar **args, struct controller_session *session); |