From 2427ce2a633b2d03851997fce5472e1c3913be72 Mon Sep 17 00:00:00 2001 From: Vsevolod Stakhov Date: Mon, 17 Sep 2012 22:11:47 +0400 Subject: [PATCH] Rewrite controller's logic in librspamdclient and rspamc application. --- lib/client/librspamdclient.c | 1216 +++++++++++++--------------------- lib/client/librspamdclient.h | 124 ++-- src/client/rspamc.c | 205 +++--- 3 files changed, 605 insertions(+), 940 deletions(-) diff --git a/lib/client/librspamdclient.c b/lib/client/librspamdclient.c index 37b6aa847..e1fb0707b 100644 --- a/lib/client/librspamdclient.c +++ b/lib/client/librspamdclient.c @@ -36,12 +36,17 @@ #define CONNECT_TIMEOUT 3 #define G_RSPAMD_ERROR rspamd_error_quark () +#ifndef CRLF +# define CRLF "\r\n" +#endif + struct rspamd_server { struct upstream up; struct in_addr addr; guint16 client_port; guint16 controller_port; gchar *name; + gchar *controller_name; }; struct rspamd_client { @@ -57,9 +62,22 @@ struct rspamd_connection { struct rspamd_client *client; time_t connection_time; gint socket; - struct rspamd_result *result; + union { + struct { + struct rspamd_result *result; + struct rspamd_metric *cur_metric; + } normal; + struct { + struct rspamd_controller_result *result; + enum { + CONTROLLER_READ_REPLY, + CONTROLLER_READ_HEADER, + CONTROLLER_READ_DATA + } state; + } controller; + } res; + gboolean is_controller; GString *in_buf; - struct rspamd_metric *cur_metric; gint version; }; @@ -239,40 +257,19 @@ symbol_free_func (gpointer arg) } static struct rspamd_connection * -rspamd_connect_random_server (struct rspamd_client *client, gboolean is_control, GError **err) +rspamd_connect_specific_server (struct rspamd_client *client, gboolean is_control, GError **err, struct rspamd_server *serv) { - struct rspamd_server *selected = NULL; struct rspamd_connection *new; - time_t now; - if (client->servers_num == 0) { - errno = EINVAL; - if (*err == NULL) { - *err = g_error_new (G_RSPAMD_ERROR, 1, "No servers can be reached"); - } - return NULL; - } - /* Select random server */ - now = time (NULL); - selected = (struct rspamd_server *)get_random_upstream (client->servers, - client->servers_num, sizeof (struct rspamd_server), - now, DEFAULT_UPSTREAM_ERROR_TIME, DEFAULT_UPSTREAM_DEAD_TIME, DEFAULT_UPSTREAM_MAXERRORS); - if (selected == NULL) { - errno = EINVAL; - if (*err == NULL) { - *err = g_error_new (G_RSPAMD_ERROR, 1, "No servers can be reached"); - } - return NULL; - } /* Allocate connection */ - new = g_malloc (sizeof (struct rspamd_connection)); - new->server = selected; - new->connection_time = now; + new = g_malloc0 (sizeof (struct rspamd_connection)); + new->server = serv; + new->connection_time = time (NULL); new->client = client; /* Create socket */ - new->socket = lib_make_tcp_socket (&selected->addr, client->bind_addr, - is_control ? selected->controller_port : selected->client_port, - FALSE, TRUE); + new->socket = lib_make_tcp_socket (&serv->addr, client->bind_addr, + is_control ? serv->controller_port : serv->client_port, + FALSE, TRUE); if (new->socket == -1) { goto err; } @@ -283,19 +280,45 @@ rspamd_connect_random_server (struct rspamd_client *client, gboolean is_control, } new->in_buf = g_string_sized_new (BUFSIZ); - new->cur_metric = NULL; return new; -err: + err: if (*err == NULL) { *err = g_error_new (G_RSPAMD_ERROR, errno, "Could not connect to server %s: %s", - selected->name, strerror (errno)); + serv->name, strerror (errno)); } - upstream_fail (&selected->up, now); + upstream_fail (&serv->up, time (NULL)); g_free (new); return NULL; } +static struct rspamd_connection * +rspamd_connect_random_server (struct rspamd_client *client, gboolean is_control, GError **err) +{ + struct rspamd_server *selected = NULL; + + if (client->servers_num == 0) { + errno = EINVAL; + if (*err == NULL) { + *err = g_error_new (G_RSPAMD_ERROR, 1, "No servers can be reached"); + } + return NULL; + } + /* Select random server */ + selected = (struct rspamd_server *)get_random_upstream (client->servers, + client->servers_num, sizeof (struct rspamd_server), + time (NULL), DEFAULT_UPSTREAM_ERROR_TIME, DEFAULT_UPSTREAM_DEAD_TIME, DEFAULT_UPSTREAM_MAXERRORS); + if (selected == NULL) { + errno = EINVAL; + if (*err == NULL) { + *err = g_error_new (G_RSPAMD_ERROR, 1, "No servers can be reached"); + } + return NULL; + } + + return rspamd_connect_specific_server (client, is_control, err, selected); +} + static struct rspamd_metric * rspamd_create_metric (const gchar *begin, guint len) { @@ -324,6 +347,21 @@ rspamd_create_result (struct rspamd_connection *c) return new; } +static struct rspamd_controller_result * +rspamd_create_controller_result (struct rspamd_connection *c) +{ + struct rspamd_controller_result *new; + + new = g_malloc (sizeof (struct rspamd_controller_result)); + new->conn = c; + new->headers = g_hash_table_new_full (g_str_hash, g_str_equal, g_free, g_free); + new->data = NULL; + new->result = g_string_new (NULL); + new->code = 0; + + return new; +} + /* * Parse line like RSPAMD/{version} {code} {message} */ @@ -336,7 +374,7 @@ parse_rspamd_first_line (struct rspamd_connection *conn, guint len, GError **err p = b; c = p; - while (p - b < (gint)remain) { + while (p - b <= (gint)remain) { switch (state) { case 0: /* Read version */ @@ -357,7 +395,7 @@ parse_rspamd_first_line (struct rspamd_connection *conn, guint len, GError **err state = 99; next_state = 2; if (*c == '0') { - conn->result->is_ok = TRUE; + conn->res.normal.result->is_ok = TRUE; } } else if (!g_ascii_isdigit (*p)) { @@ -367,7 +405,7 @@ parse_rspamd_first_line (struct rspamd_connection *conn, guint len, GError **err break; case 2: /* Read message */ - if (g_ascii_isspace (*p) || p - b == (gint)remain - 1) { + if (g_ascii_isspace (*p) || p - b == (gint)remain) { state = 99; next_state = 3; } @@ -419,7 +457,7 @@ parse_rspamd_metric_line (struct rspamd_connection *conn, guint len, GError **er } c = p; - while (p - b < (gint)remain) { + while (p - b <= (gint)remain) { switch (state) { case 0: /* Read metric's name */ @@ -431,13 +469,13 @@ parse_rspamd_metric_line (struct rspamd_connection *conn, guint len, GError **er else { /* Create new metric */ new = rspamd_create_metric (c, p - c); - if (g_hash_table_lookup (conn->result->metrics, new->name) != NULL) { + if (g_hash_table_lookup (conn->res.normal.result->metrics, new->name) != NULL) { /* Duplicate metric */ metric_free_func (new); goto err; } - g_hash_table_insert (conn->result->metrics, new->name, new); - conn->cur_metric = new; + g_hash_table_insert (conn->res.normal.result->metrics, new->name, new); + conn->res.normal.cur_metric = new; state = 99; next_state = 1; } @@ -483,7 +521,7 @@ parse_rspamd_metric_line (struct rspamd_connection *conn, guint len, GError **er break; case 4: /* Read required score */ - if (g_ascii_isspace (*p) || p - b == (gint)remain - 1) { + if (g_ascii_isspace (*p) || p - b == (gint)remain) { new->required_score = strtod (c, &err_str); if (*err_str != *p && *err_str != *(p + 1)) { /* Invalid score */ @@ -507,7 +545,7 @@ parse_rspamd_metric_line (struct rspamd_connection *conn, guint len, GError **er break; case 6: /* Read reject score */ - if (g_ascii_isspace (*p) || p - b == (gint)remain - 1) { + if (g_ascii_isspace (*p) || p - b == (gint)remain) { new->reject_score = strtod (c, &err_str); if (*err_str != *p && *err_str != *(p + 1)) { /* Invalid score */ @@ -569,7 +607,7 @@ parse_rspamd_symbol_line (struct rspamd_connection *conn, guint len, GError **er goto err; } else { - if (p - b == (gint)remain - 1) { + if (p - b == (gint)remain) { l = p - c + 1; } else { @@ -586,14 +624,14 @@ parse_rspamd_symbol_line (struct rspamd_connection *conn, guint len, GError **er sym[l] = '\0'; memcpy (sym, c, l); - if (g_hash_table_lookup (conn->cur_metric->symbols, sym) != NULL) { + if (g_hash_table_lookup (conn->res.normal.cur_metric->symbols, sym) != NULL) { /* Duplicate symbol */ g_free (sym); goto err; } new = g_malloc0 (sizeof (struct rspamd_symbol)); new->name = sym; - g_hash_table_insert (conn->cur_metric->symbols, sym, new); + g_hash_table_insert (conn->res.normal.cur_metric->symbols, sym, new); state = 99; } } @@ -696,7 +734,7 @@ parse_rspamd_action_line (struct rspamd_connection *conn, guint len, GError **er p = b; c = b; - while (p - b < (gint)remain) { + while (p - b <= (gint)remain) { switch (state) { case 0: /* Read action */ @@ -709,7 +747,7 @@ parse_rspamd_action_line (struct rspamd_connection *conn, guint len, GError **er } break; case 1: - if (p - b == (gint)remain - 1) { + if (p - b == (gint)remain) { if (p - c <= 1) { /* Empty action name */ goto err; @@ -719,8 +757,7 @@ parse_rspamd_action_line (struct rspamd_connection *conn, guint len, GError **er sym = g_malloc (p - c + 2); sym[p - c + 1] = '\0'; memcpy (sym, c, p - c + 1); - - conn->cur_metric->action = sym; + conn->res.normal.cur_metric->action = sym; state = 99; } } @@ -760,11 +797,11 @@ static gboolean parse_rspamd_header_line (struct rspamd_connection *conn, guint len, GError **err) { gchar *b = conn->in_buf->str, *p, *c, *hname = NULL, *hvalue = NULL; - guint remain = len, state = 0, next_state = 0; + guint remain = len, state = 0, next_state = 0, clen; p = b; c = b; - while (p - b < (gint)remain) { + while (p - b <= (gint)remain) { switch (state) { case 0: /* Read header name */ @@ -785,7 +822,7 @@ parse_rspamd_header_line (struct rspamd_connection *conn, guint len, GError **er p ++; break; case 1: - if (p - b == (gint)remain - 1) { + if (p - b == (gint)remain) { if (p - c <= 1) { /* Empty action name */ goto err; @@ -795,7 +832,20 @@ parse_rspamd_header_line (struct rspamd_connection *conn, guint len, GError **er hvalue = g_malloc (p - c + 2); hvalue[p - c + 1] = '\0'; memcpy (hvalue, c, p - c + 1); - g_hash_table_replace (conn->result->headers, hname, hvalue); + if (conn->is_controller) { + if (g_ascii_strcasecmp (hname, "Content-Length") == 0) { + /* Preallocate data buffer */ + errno = 0; + clen = strtoul (hvalue, NULL, 10); + if (errno == 0 && clen > 0) { + conn->res.controller.result->data = g_string_sized_new (clen); + } + } + g_hash_table_replace (conn->res.controller.result->headers, hname, hvalue); + } + else { + g_hash_table_replace (conn->res.normal.result->headers, hname, hvalue); + } state = 99; } } @@ -819,8 +869,9 @@ parse_rspamd_header_line (struct rspamd_connection *conn, guint len, GError **er } return TRUE; - err: +err: if (*err == NULL) { + g_assert (0); *err = g_error_new (G_RSPAMD_ERROR, errno, "Invalid header line: %*s at pos: %d", remain, b, (int)(p - b)); } @@ -834,43 +885,153 @@ parse_rspamd_header_line (struct rspamd_connection *conn, guint len, GError **er return FALSE; } +static gboolean +parse_rspamd_controller_reply (struct rspamd_connection *conn, guint len, GError **err) +{ + gchar *b = conn->in_buf->str, *p, *c; + const gchar http_rep[] = "HTTP/1.0 "; + guint remain = len, state = 0, next_state = 0; + + /* First of all skip "HTTP/1.0 " line */ + if (len < sizeof (http_rep) || memcmp (b, http_rep, sizeof (http_rep) - 1) != 0) { + g_set_error (err, G_RSPAMD_ERROR, -1, "Invalid reply line"); + return FALSE; + } + b += sizeof (http_rep) - 1; + p = b; + c = b; + remain -= sizeof (http_rep) - 1; + + while (p - b <= (gint)remain) { + switch (state) { + case 0: + /* Try to get code */ + if (g_ascii_isdigit (*p)) { + p ++; + } + else if (g_ascii_isspace (*p)) { + conn->res.controller.result->code = atoi (c); + next_state = 1; + state = 99; + } + else { + goto err; + } + break; + case 1: + /* Get reply string */ + if (p - b == (gint)remain) { + if (p - c <= 0) { + /* Empty action name */ + goto err; + } + else { + /* Create header value */ + conn->res.controller.result->result = g_string_sized_new (p - c + 1); + g_string_append_len (conn->res.controller.result->result, c, p - c); + state = 99; + } + } + p ++; + break; + case 99: + /* Skip spaces */ + if (!g_ascii_isspace (*p)) { + state = next_state; + c = p; + } + else { + p ++; + } + break; + } + } + + if (state != 99) { + goto err; + } + + conn->res.controller.state = CONTROLLER_READ_HEADER; + + return TRUE; + +err: + if (*err == NULL) { + *err = g_error_new (G_RSPAMD_ERROR, errno, "Invalid reply line: %*s at pos: %d", + remain, b, (int)(p - b)); + } + return FALSE; +} + static gboolean parse_rspamd_reply_line (struct rspamd_connection *c, guint len, GError **err) { gchar *p = c->in_buf->str; - g_assert (len > 0); - - /* - * In fact we have 3 states of parsing: - * 1) we have current metric and parse symbols - * 2) we have no current metric and skip everything to headers hash - * 3) we have current metric but got not symbol but header -> put it into headers hash - * Line is parsed using specific state machine - */ - if (c->cur_metric == NULL) { - if (len > sizeof ("RSPAMD/") && memcmp (p, "RSPAMD/", sizeof ("RSPAMD/") - 1) == 0) { - return parse_rspamd_first_line (c, len, err); - } - else if (len > sizeof ("Metric:") && memcmp (p, "Metric:", sizeof("Metric:") - 1) == 0) { - return parse_rspamd_metric_line (c, len, err); - } - else { - return parse_rspamd_header_line (c, len, err); + if (c->is_controller) { + switch (c->res.controller.state) { + case CONTROLLER_READ_REPLY: + return parse_rspamd_controller_reply (c, len, err); + break; + case CONTROLLER_READ_HEADER: + if (len == 0) { + /* End of headers */ + c->res.controller.state = CONTROLLER_READ_DATA; + if (c->res.controller.result->data == NULL) { + /* Cannot detect size as controller didn't send Content-Length header, so guess it */ + c->res.controller.result->data = g_string_new (NULL); + } + return TRUE; + } + else { + return parse_rspamd_header_line (c, len, err); + } + break; + case CONTROLLER_READ_DATA: + g_string_append_len (c->res.controller.result->data, p, len); + g_string_append_len (c->res.controller.result->data, CRLF, 2); + return TRUE; + break; } } else { - if (len > sizeof ("Metric:") && memcmp (p, "Metric:", sizeof("Metric:") - 1) == 0) { - return parse_rspamd_metric_line (c, len, err); - } - else if (len > sizeof ("Symbol:") && memcmp (p, "Symbol:", sizeof("Symbol:") - 1) == 0) { - return parse_rspamd_symbol_line (c, len, err); - } - else if (len > sizeof ("Action:") && memcmp (p, "Action:", sizeof("Action:") - 1) == 0) { - return parse_rspamd_action_line (c, len, err); + /* + * In fact we have 3 states of parsing: + * 1) we have current metric and parse symbols + * 2) we have no current metric and skip everything to headers hash + * 3) we have current metric but got not symbol but header -> put it into headers hash + * Line is parsed using specific state machine + */ + if (len > 0) { + if (c->res.normal.cur_metric == NULL) { + if (len > sizeof ("RSPAMD/") && memcmp (p, "RSPAMD/", sizeof ("RSPAMD/") - 1) == 0) { + return parse_rspamd_first_line (c, len, err); + } + else if (len > sizeof ("Metric:") && memcmp (p, "Metric:", sizeof("Metric:") - 1) == 0) { + return parse_rspamd_metric_line (c, len, err); + } + else { + return parse_rspamd_header_line (c, len, err); + } + } + else { + if (len > sizeof ("Metric:") && memcmp (p, "Metric:", sizeof("Metric:") - 1) == 0) { + return parse_rspamd_metric_line (c, len, err); + } + else if (len > sizeof ("Symbol:") && memcmp (p, "Symbol:", sizeof("Symbol:") - 1) == 0) { + return parse_rspamd_symbol_line (c, len, err); + } + else if (len > sizeof ("Action:") && memcmp (p, "Action:", sizeof("Action:") - 1) == 0) { + return parse_rspamd_action_line (c, len, err); + } + else { + return parse_rspamd_header_line (c, len, err); + } + } } else { - return parse_rspamd_header_line (c, len, err); + /* TODO: here we should parse commands that contains data, like PROCESS */ + return TRUE; } } @@ -889,14 +1050,10 @@ read_rspamd_reply_line (struct rspamd_connection *c, GError **err) len = 0; while (len < (gint)c->in_buf->len) { p = c->in_buf->str[len]; - if (p == '\r' || p == '\n') { - if (parse_rspamd_reply_line (c, len, err)) { - /* Strip '\r\n' */ - while (len < (gint)c->in_buf->len && (p == '\r' || p == '\n')) { - p = c->in_buf->str[++len]; - } + if (p == '\n') { + if (parse_rspamd_reply_line (c, len - 1, err)) { /* Move remaining buffer to the begin of string */ - c->in_buf = g_string_erase (c->in_buf, 0, len); + c->in_buf = g_string_erase (c->in_buf, 0, len + 1); len = 0; } else { @@ -990,40 +1147,6 @@ err: return FALSE; } -#ifndef GLIB_HASH_COMPAT -static gboolean -rspamd_send_normal_command (struct rspamd_connection *c, const gchar *command, - gsize clen, GHashTable *headers, GError **err) -{ - gchar outbuf[16384]; - GHashTableIter it; - gpointer key, value; - gint r; - - /* Write command */ - r = snprintf (outbuf, sizeof (outbuf), "%s RSPAMC/1.3\r\n", command); - r += snprintf (outbuf + r, sizeof (outbuf) - r, "Content-Length: %lu\r\n", (unsigned long)clen); - /* Iterate through headers */ - if (headers != NULL) { - g_hash_table_iter_init (&it, headers); - while (g_hash_table_iter_next (&it, &key, &value)) { - r += snprintf (outbuf + r, sizeof (outbuf) - r, "%s: %s\r\n", (const gchar *)key, (const gchar *)value); - } - } - r += snprintf (outbuf + r, sizeof (outbuf) - r, "\r\n"); - - if ((r = write (c->socket, outbuf, r)) == -1) { - if (*err == NULL) { - *err = g_error_new (G_RSPAMD_ERROR, errno, "Write error: %s", - strerror (errno)); - } - return FALSE; - } - - return TRUE; -} -#else -/* Compatible version */ struct hash_iter_cb { gchar *buf; gsize size; @@ -1070,7 +1193,6 @@ rspamd_send_normal_command (struct rspamd_connection *c, const gchar *command, return TRUE; } -#endif static void rspamd_free_connection (struct rspamd_connection *c) @@ -1082,181 +1204,117 @@ rspamd_free_connection (struct rspamd_connection *c) g_free (c); } -/* - * Send a single command to controller and get reply - */ -static GString * -rspamd_send_controller_command (struct rspamd_connection *c, const gchar *line, gsize len, gint fd, GError **err) -{ - GString *res = NULL; - gchar tmpbuf[BUFSIZ], *p; - gint r = 0; - static const gchar end_marker[] = "\r\nEND\r\n"; - /* Set blocking for writing */ - make_socket_blocking (c->socket); - if (write (c->socket, line, len) == -1) { - if (*err == NULL) { - *err = g_error_new (G_RSPAMD_ERROR, errno, "Write error: %s", - strerror (errno)); + +static gboolean +rspamd_send_controller_command (struct rspamd_connection *c, const gchar *command, const gchar *password, GHashTable *in_headers, gint fd, GByteArray *data, GError **err) +{ + struct iovec iov[2]; + gchar outbuf[BUFSIZ]; + gint r; + struct stat st; + struct hash_iter_cb cbdata; + + /* Form a request */ + r = rspamd_snprintf (outbuf, sizeof (outbuf), "GET / HTTP/1.0" CRLF "Command: %s" CRLF, command); + /* Content length */ + if (fd != -1) { + if (fstat (fd, &st) == -1) { + g_set_error (err, G_RSPAMD_ERROR, errno, "Stat error: %s", strerror (errno)); + goto err; } - return NULL; + r += rspamd_snprintf (outbuf + r, sizeof (outbuf) - r, "Content-Length: %z" CRLF, st.st_size); + } + else if (data && data->len > 0) { + r += rspamd_snprintf (outbuf + r, sizeof (outbuf) - r, "Content-Length: %z" CRLF, data->len); + } + /* Password */ + if (password != NULL) { + r += rspamd_snprintf (outbuf + r, sizeof (outbuf) - r, "Password: %s" CRLF, password); + } + /* Other headers */ + if (in_headers != NULL) { + cbdata.size = sizeof (outbuf); + cbdata.pos = r; + cbdata.buf = outbuf; + g_hash_table_foreach (in_headers, rspamd_hash_iter_cb, &cbdata); + r = cbdata.pos; } + r += rspamd_snprintf (outbuf + r, sizeof (outbuf) - r, CRLF); + + + /* Assume that a socket is in blocking mode */ if (fd != -1) { +#ifdef LINUX + if (send (c->socket, outbuf, r, MSG_MORE) == -1) { + g_set_error (err, G_RSPAMD_ERROR, errno, "Send error: %s", strerror (errno)); + goto err; + } +#else + if (send (c->socket, outbuf, r, 0) == -1) { + g_set_error (err, G_RSPAMD_ERROR, errno, "Send error: %s", strerror (errno)); + goto err; + } +#endif if (!rspamd_sendfile (c->socket, fd, err)) { - return NULL; + goto err; } } - /* Now set non-blocking mode and read buffer till END\r\n marker */ - make_socket_nonblocking (c->socket); - /* Poll socket */ - do { - if ((r = poll_sync_socket (c->socket, c->client->read_timeout, POLL_IN)) <= 0) { - if (*err == NULL) { - if (r == 0) { - errno = ETIMEDOUT; - } - *err = g_error_new (G_RSPAMD_ERROR, errno, "Cannot read reply from controller %s: %s", - c->server->name, strerror (errno)); - } - upstream_fail (&c->server->up, c->connection_time); - return NULL; + else if (data && data->len > 0) { + /* Use iovec */ + iov[0].iov_base = outbuf; + iov[0].iov_len = r; + iov[1].iov_base = data->data; + iov[1].iov_len = data->len; + + if (writev (c->socket, iov, G_N_ELEMENTS (iov)) == -1) { + g_set_error (err, G_RSPAMD_ERROR, errno, "Writev error: %s", strerror (errno)); + goto err; } - if ((r = read (c->socket, tmpbuf, sizeof (tmpbuf) - 1)) > 0) { - /* Check the end of the buffer for END marker */ - tmpbuf[r] = '\0'; - /* Store data inside res */ - if (res == NULL) { - res = g_string_new_len (tmpbuf, r); - } - else { - /* Append data to string */ - res = g_string_append_len (res, tmpbuf, r); - } - /* Check for END marker */ - if (res->len > sizeof (end_marker) - 1 && (p = strstr (res->str, end_marker)) != NULL) { - *p = '\0'; - res->len = p - res->str; - return res; - } + } + else { + /* Just write request */ + if (send (c->socket, outbuf, r, 0) == -1) { + g_set_error (err, G_RSPAMD_ERROR, errno, "Send error: %s", strerror (errno)); + goto err; } - } while (r > 0); - - /* Incomplete reply, so store error */ - if (*err == NULL) { - *err = g_error_new (G_RSPAMD_ERROR, errno, "Cannot read reply from controller %s: %s", - c->server->name, strerror (errno)); } - upstream_fail (&c->server->up, c->connection_time); - return NULL; + + return TRUE; +err: + return FALSE; } +/** Public API **/ + /* - * Authenticate on the controller + * Init rspamd client library */ -static gboolean -rspamd_controller_auth (struct rspamd_connection *c, const gchar *password, GError **err) +struct rspamd_client* +rspamd_client_init_binded (const struct in_addr *addr) { - gchar outbuf[BUFSIZ]; - static const gchar success_str[] = "password accepted"; - gint r; - GString *in; + struct rspamd_client *client; - r = snprintf (outbuf, sizeof (outbuf), "password %s\r\n", password); - in = rspamd_send_controller_command (c, outbuf, r, -1, err); + client = g_malloc0 (sizeof (struct rspamd_client)); + client->read_timeout = DEFAULT_READ_TIMEOUT; + client->connect_timeout = DEFAULT_CONNECT_TIMEOUT; - if (in == NULL) { - return FALSE; + if (addr != NULL) { + client->bind_addr = g_malloc (sizeof (struct in_addr)); + memcpy (client->bind_addr, addr, sizeof (struct in_addr)); } - if (in->len >= sizeof (success_str) - 1 && - memcmp (in->str, success_str, sizeof (success_str) - 1) == 0) { - g_string_free (in, TRUE); - return TRUE; - } + return client; +} - g_string_free (in, TRUE); - return FALSE; +struct rspamd_client* +rspamd_client_init (void) +{ + return rspamd_client_init_binded (NULL); } /* - * Read greeting from the controller - */ -static gboolean -rspamd_read_controller_greeting (struct rspamd_connection *c, GError **err) -{ - gchar inbuf[BUFSIZ], *pos; - gint r, got_greeting = FALSE; - static const gchar greeting_str[] = "Rspamd"; - - pos = inbuf; - - while (pos - inbuf < (gint)sizeof (inbuf)) { - if ((r = poll_sync_socket (c->socket, c->client->read_timeout, POLL_IN)) <= 0) { - if (*err == NULL) { - if (r == 0) { - errno = ETIMEDOUT; - } - *err = g_error_new (G_RSPAMD_ERROR, errno, "Cannot read reply from controller %s: %s", - c->server->name, strerror (errno)); - } - upstream_fail (&c->server->up, c->connection_time); - return FALSE; - } - if ((r = read (c->socket, pos, sizeof (inbuf) - (pos - inbuf))) > 0) { - if (r >= (gint)sizeof (greeting_str) - 1 && - memcmp (inbuf, greeting_str, sizeof (greeting_str) - 1) == 0) { - got_greeting = TRUE; - } - } - else { - if (*err == NULL) { - *err = g_error_new (G_RSPAMD_ERROR, errno, "Cannot read reply from controller %s: %s", - c->server->name, strerror (errno)); - } - upstream_fail (&c->server->up, c->connection_time); - return FALSE; - } - pos += r; - if (got_greeting && *(pos - 1) == '\n') { - /* Got the complete greeting */ - return TRUE; - } - } - - return FALSE; -} - -/** Public API **/ - -/* - * Init rspamd client library - */ -struct rspamd_client* -rspamd_client_init_binded (const struct in_addr *addr) -{ - struct rspamd_client *client; - - client = g_malloc0 (sizeof (struct rspamd_client)); - client->read_timeout = DEFAULT_READ_TIMEOUT; - client->connect_timeout = DEFAULT_CONNECT_TIMEOUT; - - if (addr != NULL) { - client->bind_addr = g_malloc (sizeof (struct in_addr)); - memcpy (client->bind_addr, addr, sizeof (struct in_addr)); - } - - return client; -} - -struct rspamd_client* -rspamd_client_init (void) -{ - return rspamd_client_init_binded (NULL); -} - -/* - * Add rspamd server + * Add rspamd server */ gboolean rspamd_add_server (struct rspamd_client *client, const gchar *host, guint16 port, @@ -1264,6 +1322,7 @@ rspamd_add_server (struct rspamd_client *client, const gchar *host, guint16 port { struct rspamd_server *new; struct hostent *hent; + gint nlen; g_assert (client != NULL); if (client->servers_num >= MAX_RSPAMD_SERVERS) { @@ -1289,7 +1348,11 @@ rspamd_add_server (struct rspamd_client *client, const gchar *host, guint16 port } new->client_port = port; new->controller_port = controller_port; - new->name = g_strdup (host); + nlen = strlen (host) + sizeof ("65535") + 1; + new->name = g_malloc (nlen); + new->controller_name = g_malloc (nlen); + rspamd_snprintf (new->name, nlen, "%s:%d", host, (gint)port); + rspamd_snprintf (new->controller_name, nlen, "%s:%d", host, (gint)controller_port); client->servers_num ++; return TRUE; @@ -1348,7 +1411,8 @@ rspamd_scan_memory (struct rspamd_client *client, const guchar *message, gsize l /* Create result structure */ res = rspamd_create_result (c); - c->result = res; + c->res.normal.result = res; + c->is_controller = FALSE; /* Restore non-blocking mode for reading operations */ make_socket_nonblocking (c->socket); @@ -1427,7 +1491,8 @@ rspamd_scan_fd (struct rspamd_client *client, int fd, GHashTable *headers, GErro /* Create result structure */ res = rspamd_create_result (c); - c->result = res; + c->is_controller = FALSE; + c->res.normal.result = res; /* Restore non-blocking mode for reading operations */ make_socket_nonblocking (c->socket); @@ -1439,398 +1504,139 @@ rspamd_scan_fd (struct rspamd_client *client, int fd, GHashTable *headers, GErro } /* - * Learn message from memory + * Send a common controller command to all servers */ -gboolean -rspamd_learn_memory (struct rspamd_client *client, const guchar *message, gsize length, const gchar *symbol, const gchar *password, GError **err) +static void +rspamd_controller_command_single (struct rspamd_client* client, const gchar *command, const gchar *password, + GHashTable *in_headers, GByteArray *mem, gint fd, GError **err, + struct rspamd_controller_result *res, struct rspamd_server *serv) { struct rspamd_connection *c; - GString *in; - gchar *outbuf; - guint r; - static const gchar ok_str[] = "learn ok"; - - g_assert (client != NULL); - g_assert (length > 0); /* Connect to server */ - c = rspamd_connect_random_server (client, TRUE, err); + c = rspamd_connect_specific_server (client, FALSE, err, serv); if (c == NULL) { - return FALSE; - } - - /* Read greeting */ - if (! rspamd_read_controller_greeting (c, err)) { - if (*err == NULL) { - *err = g_error_new (G_RSPAMD_ERROR, errno, "Invalid greeting"); - } - return FALSE; - } - if (password != NULL) { - /* Perform auth */ - if (! rspamd_controller_auth (c, password, err)) { - if (*err == NULL) { - *err = g_error_new (G_RSPAMD_ERROR, errno, "Authentication error"); - } - return FALSE; - } - } - - r = length + sizeof ("learn %s %uz\r\n") + strlen (symbol) + sizeof ("4294967296"); - outbuf = g_malloc (r); - r = snprintf (outbuf, r, "learn %s %lu\r\n%s", symbol, (unsigned long)length, message); - in = rspamd_send_controller_command (c, outbuf, r, -1, err); - g_free (outbuf); - if (in == NULL) { - return FALSE; - } - - /* Search for string learn ok */ - if (in->len > sizeof (ok_str) - 1 && memcmp (in->str, ok_str, sizeof (ok_str) - 1) == 0) { - upstream_ok (&c->server->up, c->connection_time); - return TRUE; - } - else { - if (*err == NULL) { - *err = g_error_new (G_RSPAMD_ERROR, errno, "Bad reply: %s", in->str); - } - } - return FALSE; -} - -/* - * Learn message from file - */ -gboolean -rspamd_learn_file (struct rspamd_client *client, const guchar *filename, const gchar *symbol, const gchar *password, GError **err) -{ - gint fd; - g_assert (client != NULL); - - /* Open file */ - if ((fd = open (filename, O_RDONLY)) == -1) { - if (*err == NULL) { - *err = g_error_new (G_RSPAMD_ERROR, errno, "Open error for file %s: %s", - filename, strerror (errno)); - } - return FALSE; + return; } - return rspamd_learn_fd (client, fd, symbol, password, err); -} - -/* - * Learn message from fd - */ -gboolean -rspamd_learn_fd (struct rspamd_client *client, int fd, const gchar *symbol, const gchar *password, GError **err) -{ - struct rspamd_connection *c; - GString *in; - gchar *outbuf; - guint r; - struct stat st; - static const gchar ok_str[] = "learn ok"; - - g_assert (client != NULL); - - /* Connect to server */ - c = rspamd_connect_random_server (client, TRUE, err); - - if (c == NULL) { - return FALSE; - } + res->conn = c; + /* Set socket blocking for writing */ + make_socket_blocking (c->socket); - /* Read greeting */ - if (! rspamd_read_controller_greeting (c, err)) { - if (*err == NULL) { - *err = g_error_new (G_RSPAMD_ERROR, errno, "Invalid greeting"); - } - return FALSE; - } - if (password != NULL) { - /* Perform auth */ - if (! rspamd_controller_auth (c, password, err)) { - if (*err == NULL) { - *err = g_error_new (G_RSPAMD_ERROR, errno, "Authentication error"); - } - return FALSE; - } + /* Send command */ + if (!rspamd_send_controller_command (c, command, password, in_headers, fd, mem, err)) { + res->result = g_string_new (*err != NULL ? (*err)->message : "unknown error"); + res->code = 500; + return; } + c->is_controller = TRUE; + c->res.controller.result = res; - /* Get length */ - if (fstat (fd, &st) == -1) { - if (*err == NULL) { - *err = g_error_new (G_RSPAMD_ERROR, errno, "Stat error: %s", - strerror (errno)); - } - return FALSE; - } - if (st.st_size == 0) { - if (*err == NULL) { - *err = g_error_new (G_RSPAMD_ERROR, -1, "File has zero length"); - } - return FALSE; - } - r = sizeof ("learn %s %uz\r\n") + strlen (symbol) + sizeof ("4294967296"); - outbuf = g_malloc (r); - r = snprintf (outbuf, r, "learn %s %lu\r\n", symbol, (unsigned long)st.st_size); - in = rspamd_send_controller_command (c, outbuf, r, fd, err); - g_free (outbuf); - if (in == NULL) { - return FALSE; - } + /* Restore non-blocking mode for reading operations */ + make_socket_nonblocking (c->socket); - /* Search for string learn ok */ - if (in->len > sizeof (ok_str) - 1 && memcmp (in->str, ok_str, sizeof (ok_str) - 1) == 0) { - upstream_ok (&c->server->up, c->connection_time); - return TRUE; - } - else { - if (*err == NULL) { - *err = g_error_new (G_RSPAMD_ERROR, errno, "Bad reply: %s", in->str); - } - } + /* Read result cycle */ + while (read_rspamd_reply_line (c, err)); - return FALSE; + upstream_ok (&c->server->up, c->connection_time); } /* - * Learn message from memory + * Send a common controller command to all servers */ -gboolean -rspamd_learn_spam_memory (struct rspamd_client *client, const guchar *message, gsize length, const gchar *classifier, gboolean is_spam, const gchar *password, GError **err) +static GList* +rspamd_controller_command_common (struct rspamd_client* client, const gchar *command, const gchar *password, + GHashTable *in_headers, GByteArray *mem, gint fd, GError **err) { - struct rspamd_connection *c; - GString *in; - gchar *outbuf; - guint r; - static const gchar ok_str[] = "learn ok"; + struct rspamd_server *serv; + struct rspamd_controller_result *res; + GList *res_list = NULL; + guint i; g_assert (client != NULL); - g_assert (length > 0); - - /* Connect to server */ - c = rspamd_connect_random_server (client, TRUE, err); - if (c == NULL) { - return FALSE; + for (i = 0; i < client->servers_num; i ++) { + serv = &client->servers[i]; + res = rspamd_create_controller_result (NULL); + res->server_name = serv->controller_name; + res_list = g_list_prepend (res_list, res); + /* Fill result */ + rspamd_controller_command_single (client, command, password, in_headers, mem, fd, err, res, serv); } - /* Read greeting */ - if (! rspamd_read_controller_greeting (c, err)) { - if (*err == NULL) { - *err = g_error_new (G_RSPAMD_ERROR, errno, "Invalid greeting"); - } - return FALSE; - } - if (password != NULL) { - /* Perform auth */ - if (! rspamd_controller_auth (c, password, err)) { - if (*err == NULL) { - *err = g_error_new (G_RSPAMD_ERROR, errno, "Authentication error"); - } - return FALSE; - } - } - - r = length + sizeof ("learn_spam %s %uz\r\n") + strlen (classifier) + sizeof ("4294967296"); - outbuf = g_malloc (r); - r = snprintf (outbuf, r, "learn_%s %s %lu\r\n%s", is_spam ? "spam" : "ham", - classifier, (unsigned long)length, message); - in = rspamd_send_controller_command (c, outbuf, r, -1, err); - g_free (outbuf); - if (in == NULL) { - return FALSE; - } - - /* Search for string learn ok */ - if (in->len > sizeof (ok_str) - 1 && memcmp (in->str, ok_str, sizeof (ok_str) - 1) == 0) { - upstream_ok (&c->server->up, c->connection_time); - return TRUE; - } - else { - if (*err == NULL) { - *err = g_error_new (G_RSPAMD_ERROR, errno, "Bad reply: %s", in->str); - } - } - return FALSE; + return res_list; } -/* - * Learn message from file +/** + * Perform a simple controller command on all rspamd servers + * @param client rspamd client + * @param command command to send + * @param password password (NULL if no password required) + * @param in_headers custom in headers, specific for this command (or NULL) + * @param err error object (should be pointer to NULL object) + * @return list of rspamd_controller_result structures for each server */ -gboolean -rspamd_learn_spam_file (struct rspamd_client *client, const guchar *filename, const gchar *classifier, gboolean is_spam, const gchar *password, GError **err) +GList* +rspamd_controller_command_simple (struct rspamd_client* client, const gchar *command, const gchar *password, + GHashTable *in_headers, GError **err) { - gint fd; - g_assert (client != NULL); - - /* Open file */ - if ((fd = open (filename, O_RDONLY)) == -1) { - if (*err == NULL) { - *err = g_error_new (G_RSPAMD_ERROR, errno, "Open error for file %s: %s", - filename, strerror (errno)); - } - return FALSE; - } - - return rspamd_learn_spam_fd (client, fd, classifier, is_spam, password, err); + return rspamd_controller_command_common (client, command, password, in_headers, NULL, -1, err); } -/* - * Learn message from fd +/** + * Perform a controller command on all rspamd servers with in memory argument + * @param client rspamd client + * @param command command to send + * @param password password (NULL if no password required) + * @param in_headers custom in headers, specific for this command (or NULL) + * @param message data to pass to the controller + * @param length its length + * @param err error object (should be pointer to NULL object) + * @return list of rspamd_controller_result structures for each server */ -gboolean -rspamd_learn_spam_fd (struct rspamd_client *client, int fd, const gchar *classifier, gboolean is_spam, const gchar *password, GError **err) +GList* +rspamd_controller_command_memory (struct rspamd_client* client, const gchar *command, const gchar *password, + GHashTable *in_headers, const guchar *message, gsize length, GError **err) { - struct rspamd_connection *c; - GString *in; - gchar *outbuf; - guint r; - struct stat st; - static const gchar ok_str[] = "learn ok"; - - g_assert (client != NULL); - - /* Connect to server */ - c = rspamd_connect_random_server (client, TRUE, err); - - if (c == NULL) { - return FALSE; - } - - /* Read greeting */ - if (! rspamd_read_controller_greeting (c, err)) { - if (*err == NULL) { - *err = g_error_new (G_RSPAMD_ERROR, errno, "Invalid greeting"); - } - return FALSE; - } - if (password != NULL) { - /* Perform auth */ - if (! rspamd_controller_auth (c, password, err)) { - if (*err == NULL) { - *err = g_error_new (G_RSPAMD_ERROR, errno, "Authentication error"); - } - return FALSE; - } - } - - /* Get length */ - if (fstat (fd, &st) == -1) { - if (*err == NULL) { - *err = g_error_new (G_RSPAMD_ERROR, errno, "Stat error: %s", - strerror (errno)); - } - return FALSE; - } - if (st.st_size == 0) { - if (*err == NULL) { - *err = g_error_new (G_RSPAMD_ERROR, -1, "File has zero length"); - } - return FALSE; - } - r = sizeof ("learn_spam %s %uz\r\n") + strlen (classifier) + sizeof ("4294967296"); - outbuf = g_malloc (r); - r = snprintf (outbuf, r, "learn_%s %s %lu\r\n", is_spam ? "spam" : "ham", - classifier, (unsigned long)st.st_size); - in = rspamd_send_controller_command (c, outbuf, r, fd, err); - g_free (outbuf); - if (in == NULL) { - return FALSE; - } - - /* Search for string learn ok */ - if (in->len > sizeof (ok_str) - 1 && memcmp (in->str, ok_str, sizeof (ok_str) - 1) == 0) { - upstream_ok (&c->server->up, c->connection_time); - return TRUE; - } - else { - if (*err == NULL) { - *err = g_error_new (G_RSPAMD_ERROR, errno, "Bad reply: %s", in->str); - } - } - - return FALSE; + GByteArray ba; + ba.data = (guint8 *)message; + ba.len = length; + return rspamd_controller_command_common (client, command, password, in_headers, &ba, -1, err); } - -/* - * Learn message fuzzy from memory +/** + * Perform a controller command on all rspamd servers with descriptor argument + * @param client rspamd client + * @param command command to send + * @param password password (NULL if no password required) + * @param in_headers custom in headers, specific for this command (or NULL) + * @param fd file descriptor of data + * @param err error object (should be pointer to NULL object) + * @return list of rspamd_controller_result structures for each server */ -gboolean -rspamd_fuzzy_memory (struct rspamd_client *client, const guchar *message, gsize length, const gchar *password, gint weight, gint flag, gboolean delete, GError **err) +GList* +rspamd_controller_command_fd (struct rspamd_client* client, const gchar *command, const gchar *password, + GHashTable *in_headers, gint fd, GError **err) { - struct rspamd_connection *c; - GString *in; - gchar *outbuf; - guint r; - static const gchar ok_str[] = "OK"; - - g_assert (client != NULL); - g_assert (length > 0); - - /* Connect to server */ - c = rspamd_connect_random_server (client, TRUE, err); - - if (c == NULL) { - return FALSE; - } - - /* Read greeting */ - if (! rspamd_read_controller_greeting (c, err)) { - if (*err == NULL) { - *err = g_error_new (G_RSPAMD_ERROR, errno, "Invalid greeting"); - } - return FALSE; - } - if (password != NULL) { - /* Perform auth */ - if (! rspamd_controller_auth (c, password, err)) { - if (*err == NULL) { - *err = g_error_new (G_RSPAMD_ERROR, errno, "Authentication error"); - } - return FALSE; - } - } - - r = length + sizeof ("fuzzy_add %uz %d %d\r\n") + sizeof ("4294967296") * 3; - outbuf = g_malloc (r); - if (delete) { - r = snprintf (outbuf, r, "fuzzy_del %lu %d %d\r\n%s", (unsigned long)length, weight, flag, message); - } - else { - r = snprintf (outbuf, r, "fuzzy_add %lu %d %d\r\n%s", (unsigned long)length, weight, flag, message); - } - in = rspamd_send_controller_command (c, outbuf, r, -1, err); - g_free (outbuf); - if (in == NULL) { - return FALSE; - } - - /* Search for string "OK" */ - if (in->len > sizeof (ok_str) - 1 && memcmp (in->str, ok_str, sizeof (ok_str) - 1) == 0) { - upstream_ok (&c->server->up, c->connection_time); - return TRUE; - } - else { - if (*err == NULL) { - *err = g_error_new (G_RSPAMD_ERROR, errno, "Bad reply: %s", in->str); - } - } - return FALSE; + return rspamd_controller_command_common (client, command, password, in_headers, NULL, fd, err); } -/* - * Learn message fuzzy from file +/** + * Perform a controller command on all rspamd servers with descriptor argument + * @param client rspamd client + * @param command command to send + * @param password password (NULL if no password required) + * @param in_headers custom in headers, specific for this command (or NULL) + * @param filename filename of data + * @param err error object (should be pointer to NULL object) + * @return list of rspamd_controller_result structures for each server */ -gboolean -rspamd_fuzzy_file (struct rspamd_client *client, const guchar *filename, const gchar *password, gint weight, gint flag, gboolean delete, GError **err) +GList* +rspamd_controller_command_file (struct rspamd_client* client, const gchar *command, const gchar *password, + GHashTable *in_headers, const gchar *filename, GError **err) { - gint fd; - g_assert (client != NULL); + gint fd; /* Open file */ if ((fd = open (filename, O_RDONLY)) == -1) { @@ -1838,133 +1644,11 @@ rspamd_fuzzy_file (struct rspamd_client *client, const guchar *filename, const g *err = g_error_new (G_RSPAMD_ERROR, errno, "Open error for file %s: %s", filename, strerror (errno)); } - return FALSE; - } - - return rspamd_fuzzy_fd (client, fd, password, weight, flag, delete, err); -} - -/* - * Learn message fuzzy from fd - */ -gboolean -rspamd_fuzzy_fd (struct rspamd_client *client, int fd, const gchar *password, gint weight, gint flag, gboolean delete, GError **err) -{ - struct rspamd_connection *c; - GString *in; - gchar *outbuf; - guint r; - struct stat st; - static const gchar ok_str[] = "OK"; - - g_assert (client != NULL); - - /* Connect to server */ - c = rspamd_connect_random_server (client, TRUE, err); - - if (c == NULL) { - return FALSE; - } - - /* Read greeting */ - if (! rspamd_read_controller_greeting (c, err)) { - if (*err == NULL) { - *err = g_error_new (G_RSPAMD_ERROR, errno, "Invalid greeting"); - } - return FALSE; - } - if (password != NULL) { - /* Perform auth */ - if (! rspamd_controller_auth (c, password, err)) { - if (*err == NULL) { - *err = g_error_new (G_RSPAMD_ERROR, errno, "Authentication error"); - } - return FALSE; - } - } - /* Get length */ - if (fstat (fd, &st) == -1) { - if (*err == NULL) { - *err = g_error_new (G_RSPAMD_ERROR, errno, "Stat error: %s", - strerror (errno)); - } - return FALSE; - } - if (st.st_size == 0) { - if (*err == NULL) { - *err = g_error_new (G_RSPAMD_ERROR, -1, "File has zero length"); - } - return FALSE; - } - r = sizeof ("fuzzy_add %uz %d %d\r\n") + sizeof ("4294967296") * 3; - outbuf = g_malloc (r); - if (delete) { - r = snprintf (outbuf, r, "fuzzy_del %lu %d %d\r\n", (unsigned long)st.st_size, weight, flag); - } - else { - r = snprintf (outbuf, r, "fuzzy_add %lu %d %d\r\n", (unsigned long)st.st_size, weight, flag); - } - in = rspamd_send_controller_command (c, outbuf, r, fd, err); - - g_free (outbuf); - if (in == NULL) { - return FALSE; - } - - /* Search for string "OK" */ - if (in->len > sizeof (ok_str) - 1 && memcmp (in->str, ok_str, sizeof (ok_str) - 1) == 0) { - upstream_ok (&c->server->up, c->connection_time); - return TRUE; - } - else { - if (*err == NULL) { - *err = g_error_new (G_RSPAMD_ERROR, errno, "Bad reply: %s", in->str); - } - } - return FALSE; -} - -GString * -rspamd_get_stat (struct rspamd_client *client, GError **err) -{ - struct rspamd_connection *c; - GString *res; - static const gchar outcmd[] = "stat\r\n"; - - g_assert (client != NULL); - - /* Connect to server */ - c = rspamd_connect_random_server (client, TRUE, err); - - if (c == NULL) { return NULL; } - - res = rspamd_send_controller_command (c, outcmd, strlen (outcmd), -1, err); - - return res; + return rspamd_controller_command_common (client, command, password, in_headers, NULL, fd, err); } -GString * -rspamd_get_uptime (struct rspamd_client *client, GError **err) -{ - struct rspamd_connection *c; - GString *res; - static const gchar outcmd[] = "uptime\r\n"; - - g_assert (client != NULL); - - /* Connect to server */ - c = rspamd_connect_random_server (client, TRUE, err); - - if (c == NULL) { - return NULL; - } - - res = rspamd_send_controller_command (c, outcmd, strlen (outcmd), -1, err); - - return res; -} /* * Free results @@ -1979,14 +1663,38 @@ rspamd_free_result (struct rspamd_result *result) rspamd_free_connection (result->conn); } +void +rspamd_free_controller_result (struct rspamd_controller_result *result) +{ + g_assert (result != NULL); + + g_hash_table_destroy (result->headers); + g_string_free (result->result, TRUE); + if (result->data) { + g_string_free (result->data, TRUE); + } + rspamd_free_connection (result->conn); +} + /* * Close library and free associated resources */ void rspamd_client_close (struct rspamd_client *client) { + struct rspamd_server *serv; + guint i; + if (client->bind_addr) { g_free (client->bind_addr); } + + /* Cleanup servers */ + for (i = 0; i < client->servers_num; i ++) { + serv = &client->servers[i]; + g_free (serv->name); + g_free (serv->controller_name); + } + g_free (client); } diff --git a/lib/client/librspamdclient.h b/lib/client/librspamdclient.h index 6eea8627e..828bcd422 100644 --- a/lib/client/librspamdclient.h +++ b/lib/client/librspamdclient.h @@ -40,101 +40,115 @@ struct rspamd_result { GHashTable *headers; }; -/* +/** + * Result of controller command + */ +struct rspamd_controller_result { + struct rspamd_connection *conn; + const gchar *server_name; + gint code; + GString *result; + GHashTable *headers; + GString *data; +}; + +/** * Init rspamd client library */ struct rspamd_client* rspamd_client_init (void); -/* +/** * Init rspamd client library and bind it */ struct rspamd_client* rspamd_client_init_binded (const struct in_addr *local_addr); -/* +/** * Add rspamd server */ gboolean rspamd_add_server (struct rspamd_client* client, const gchar *host, guint16 port, guint16 controller_port, GError **err); -/* +/** * Set timeouts (values in milliseconds) */ void rspamd_set_timeout (struct rspamd_client* client, guint connect_timeout, guint read_timeout); -/* +/** * Scan message from memory */ struct rspamd_result * rspamd_scan_memory (struct rspamd_client* client, const guchar *message, gsize length, GHashTable *headers, GError **err); -/* +/** * Scan message from file */ struct rspamd_result * rspamd_scan_file (struct rspamd_client* client, const guchar *filename, GHashTable *headers, GError **err); -/* +/** * Scan message from fd */ struct rspamd_result * rspamd_scan_fd (struct rspamd_client* client, int fd, GHashTable *headers, GError **err); -/* - * Learn message from memory - */ -gboolean rspamd_learn_spam_memory (struct rspamd_client* client, const guchar *message, gsize length, const gchar *classifier, gboolean is_spam, const gchar *password, GError **err); - -/* - * Learn message from file - */ -gboolean rspamd_learn_spam_file (struct rspamd_client* client, const guchar *filename, const gchar *classifier, gboolean is_spam, const gchar *password, GError **err); - -/* - * Learn message from fd - */ -gboolean rspamd_learn_spam_fd (struct rspamd_client* client, int fd, const gchar *classifier, gboolean is_spam, const gchar *password, GError **err); - -/* - * Learn message from memory - */ -gboolean rspamd_learn_memory (struct rspamd_client* client, const guchar *message, gsize length, const gchar *symbol, const gchar *password, GError **err); - -/* - * Learn message from file - */ -gboolean rspamd_learn_file (struct rspamd_client* client, const guchar *filename, const gchar *symbol, const gchar *password, GError **err); - -/* - * Learn message from fd - */ -gboolean rspamd_learn_fd (struct rspamd_client* client, int fd, const gchar *symbol, const gchar *password, GError **err); - -/* - * Learn message fuzzy from memory +/** + * Perform a simple controller command on all rspamd servers + * @param client rspamd client + * @param command command to send + * @param password password (NULL if no password required) + * @param in_headers custom in headers, specific for this command (or NULL) + * @param err error object (should be pointer to NULL object) + * @return list of rspamd_controller_result structures for each server */ -gboolean rspamd_fuzzy_memory (struct rspamd_client* client, const guchar *message, gsize length, const gchar *password, gint weight, gint flag, gboolean delete, GError **err); +GList* rspamd_controller_command_simple (struct rspamd_client* client, const gchar *command, const gchar *password, + GHashTable *in_headers, GError **err); -/* - * Learn message fuzzy from file - */ -gboolean rspamd_fuzzy_file (struct rspamd_client* client, const guchar *filename, const gchar *password, gint weight, gint flag, gboolean delete, GError **err); +/** + * Perform a controller command on all rspamd servers with in memory argument + * @param client rspamd client + * @param command command to send + * @param password password (NULL if no password required) + * @param in_headers custom in headers, specific for this command (or NULL) + * @param message data to pass to the controller + * @param length its length + * @param err error object (should be pointer to NULL object) + * @return list of rspamd_controller_result structures for each server + */ +GList* rspamd_controller_command_memory (struct rspamd_client* client, const gchar *command, const gchar *password, + GHashTable *in_headers, const guchar *message, gsize length, GError **err); -/* - * Learn message fuzzy from fd - */ -gboolean rspamd_fuzzy_fd (struct rspamd_client* client, int fd, const gchar *password, gint weight, gint flag, gboolean delete, GError **err); +/** + * Perform a controller command on all rspamd servers with descriptor argument + * @param client rspamd client + * @param command command to send + * @param password password (NULL if no password required) + * @param in_headers custom in headers, specific for this command (or NULL) + * @param fd file descriptor of data + * @param err error object (should be pointer to NULL object) + * @return list of rspamd_controller_result structures for each server + */ +GList* rspamd_controller_command_fd (struct rspamd_client* client, const gchar *command, const gchar *password, + GHashTable *in_headers, gint fd, GError **err); -/* - * Get statistic from server +/** + * Perform a controller command on all rspamd servers with descriptor argument + * @param client rspamd client + * @param command command to send + * @param password password (NULL if no password required) + * @param in_headers custom in headers, specific for this command (or NULL) + * @param filename filename of data + * @param err error object (should be pointer to NULL object) + * @return list of rspamd_controller_result structures for each server */ -GString *rspamd_get_stat (struct rspamd_client* client, GError **err); +GList* rspamd_controller_command_file (struct rspamd_client* client, const gchar *command, const gchar *password, + GHashTable *in_headers, const gchar *filename, GError **err); /* - * Get uptime from server + * Free results */ -GString *rspamd_get_uptime (struct rspamd_client* client, GError **err); +void rspamd_free_result (struct rspamd_result *result); /* - * Free results + * Free controller results */ -void rspamd_free_result (struct rspamd_result *result); +void rspamd_free_controller_result (struct rspamd_controller_result *result); /* * Close library and free associated resources diff --git a/src/client/rspamc.c b/src/client/rspamc.c index 91674f779..7d88ea65c 100644 --- a/src/client/rspamc.c +++ b/src/client/rspamc.c @@ -52,7 +52,6 @@ static GOptionEntry entries[] = { { "connect", 'h', 0, G_OPTION_ARG_STRING, &connect_str, "Specify host and port", NULL }, { "password", 'P', 0, G_OPTION_ARG_STRING, &password, "Specify control password", NULL }, - { "statfile", 's', 0, G_OPTION_ARG_STRING, &statfile, "Statfile to learn (symbol name)", NULL }, { "classifier", 'c', 0, G_OPTION_ARG_STRING, &classifier, "Classifier to learn spam or ham", NULL }, { "weight", 'w', 0, G_OPTION_ARG_INT, &weight, "Weight for fuzzy operations", NULL }, { "flag", 'f', 0, G_OPTION_ARG_INT, &flag, "Flag for fuzzy operations", NULL }, @@ -71,7 +70,6 @@ static GOptionEntry entries[] = enum rspamc_command { RSPAMC_COMMAND_UNKNOWN = 0, RSPAMC_COMMAND_SYMBOLS, - RSPAMC_COMMAND_LEARN, RSPAMC_COMMAND_LEARN_SPAM, RSPAMC_COMMAND_LEARN_HAM, RSPAMC_COMMAND_FUZZY_ADD, @@ -115,9 +113,6 @@ check_rspamc_command (const gchar *cmd) /* These all are symbols, don't use other commands */ return RSPAMC_COMMAND_SYMBOLS; } - else if (g_ascii_strcasecmp (cmd, "LEARN") == 0) { - return RSPAMC_COMMAND_LEARN; - } else if (g_ascii_strcasecmp (cmd, "LEARN_SPAM") == 0) { return RSPAMC_COMMAND_LEARN_SPAM; } @@ -396,8 +391,11 @@ learn_rspamd_stdin (gboolean is_spam) gchar *in_buf; gint r = 0, len; GError *err = NULL; + GHashTable *params; + GList *results, *cur; + struct rspamd_controller_result *res; - if ((statfile == NULL && classifier == NULL)) { + if (classifier == NULL) { fprintf (stderr, "cannot learn message without password and symbol/classifier name\n"); exit (EXIT_FAILURE); } @@ -417,45 +415,36 @@ learn_rspamd_stdin (gboolean is_spam) in_buf = g_realloc (in_buf, len); } } - if (statfile != NULL) { - if (!rspamd_learn_memory (client, in_buf, r, statfile, password, &err)) { - if (err != NULL) { - fprintf (stderr, "cannot learn message: %s\n", err->message); - } - else { - fprintf (stderr, "cannot learn message\n"); - } - exit (EXIT_FAILURE); + + params = g_hash_table_new (g_str_hash, g_str_equal); + g_hash_table_insert (params, "Classifier", classifier); + + results = rspamd_controller_command_memory (client, is_spam ? "learn_spam" : "learn_ham", password, params, in_buf, r, &err); + g_hash_table_destroy (params); + if (results == NULL) { + if (err != NULL) { + fprintf (stderr, "cannot learn message: %s\n", err->message); } else { - if (tty) { - printf ("\033[1m"); - } - PRINT_FUNC ("Results for host: %s: learn ok\n", connect_str); - if (tty) { - printf ("\033[0m"); - } + fprintf (stderr, "cannot learn message\n"); } + exit (EXIT_FAILURE); } - else if (classifier != NULL) { - if (!rspamd_learn_spam_memory (client, in_buf, r, classifier, is_spam, password, &err)) { - if (err != NULL) { - fprintf (stderr, "cannot learn message: %s\n", err->message); - } - else { - fprintf (stderr, "cannot learn message\n"); - } - exit (EXIT_FAILURE); - } - else { + else { + cur = results; + while (cur) { + res = cur->data; if (tty) { printf ("\033[1m"); } - PRINT_FUNC ("Results for host: %s: learn ok\n", connect_str); + PRINT_FUNC ("Results for host: %s: %d, %s\n", res->server_name, res->code, res->result->str); if (tty) { printf ("\033[0m"); } + rspamd_free_controller_result (res); + cur = g_list_next (cur); } + g_list_free (results); } } @@ -463,49 +452,44 @@ static void learn_rspamd_file (gboolean is_spam, const gchar *file) { GError *err = NULL; + GHashTable *params; + GList *results, *cur; + struct rspamd_controller_result *res; - if ((statfile == NULL && classifier == NULL)) { + if (classifier == NULL) { fprintf (stderr, "cannot learn message without password and symbol/classifier name\n"); exit (EXIT_FAILURE); } - if (statfile != NULL) { - if (!rspamd_learn_file (client, file, statfile, password, &err)) { - if (err != NULL) { - fprintf (stderr, "cannot learn message: %s\n", err->message); - } - else { - fprintf (stderr, "cannot learn message\n"); - } + params = g_hash_table_new (g_str_hash, g_str_equal); + g_hash_table_insert (params, "Classifier", classifier); + + results = rspamd_controller_command_file (client, is_spam ? "learn_spam" : "learn_ham", password, params, file, &err); + g_hash_table_destroy (params); + if (results == NULL) { + if (err != NULL) { + fprintf (stderr, "cannot learn message: %s\n", err->message); } else { - if (tty) { - printf ("\033[1m"); - } - PRINT_FUNC ("learn ok\n"); - if (tty) { - printf ("\033[0m"); - } + fprintf (stderr, "cannot learn message\n"); } + exit (EXIT_FAILURE); } - else if (classifier != NULL) { - if (!rspamd_learn_spam_file (client, file, classifier, is_spam, password, &err)) { - if (err != NULL) { - fprintf (stderr, "cannot learn message: %s\n", err->message); - } - else { - fprintf (stderr, "cannot learn message\n"); - } - } - else { + else { + cur = results; + while (cur) { + res = cur->data; if (tty) { printf ("\033[1m"); } - PRINT_FUNC ("learn ok\n"); + PRINT_FUNC ("Results for host: %s: %d, %s\n", res->server_name, res->code, res->result->str); if (tty) { printf ("\033[0m"); } + rspamd_free_controller_result (res); + cur = g_list_next (cur); } + g_list_free (results); } } @@ -532,31 +516,15 @@ fuzzy_rspamd_stdin (gboolean delete) in_buf = g_realloc (in_buf, len); } } - if (!rspamd_fuzzy_memory (client, in_buf, r, password, weight, flag, delete, &err)) { - if (err != NULL) { - fprintf (stderr, "cannot learn message: %s\n", err->message); - } - else { - fprintf (stderr, "cannot learn message\n"); - } - exit (EXIT_FAILURE); - } - else { - if (tty) { - printf ("\033[1m"); - } - PRINT_FUNC ("Results for host: %s: learn ok\n", connect_str); - if (tty) { - printf ("\033[0m"); - } - } + /* TODO: write this function */ } static void fuzzy_rspamd_file (const gchar *file, gboolean delete) { GError *err = NULL; - + /* TODO: write this function */ +#if 0 if (!rspamd_fuzzy_file (client, file, password, weight, flag, delete, &err)) { if (err != NULL) { fprintf (stderr, "cannot learn message: %s\n", err->message); @@ -574,68 +542,49 @@ fuzzy_rspamd_file (const gchar *file, gboolean delete) printf ("\033[0m"); } } +#endif } static void -rspamd_do_stat (void) +rspamd_do_controller_simple_command (gchar *command) { GError *err = NULL; - GString *res; + GList *results, *cur; + struct rspamd_controller_result *res; /* Add server */ add_rspamd_server (TRUE); - res = rspamd_get_stat (client, &err); - if (res == NULL) { + results = rspamd_controller_command_simple (client, command, password, NULL, &err); + if (results == NULL) { if (err != NULL) { - fprintf (stderr, "cannot stat: %s\n", err->message); + fprintf (stderr, "cannot perform command: %s\n", err->message); } else { - fprintf (stderr, "cannot stat\n"); + fprintf (stderr, "cannot perform command:\n"); } exit (EXIT_FAILURE); } - if (tty) { - printf ("\033[1m"); - } - PRINT_FUNC ("Results for host: %s\n\n", connect_str); - if (tty) { - printf ("\033[0m"); - } - res = g_string_append_c (res, '\0'); - printf ("%s\n", res->str); -} - -static void -rspamd_do_uptime (void) -{ - GError *err = NULL; - GString *res; - - /* Add server */ - add_rspamd_server (TRUE); - - res = rspamd_get_uptime (client, &err); - if (res == NULL) { - if (err != NULL) { - fprintf (stderr, "cannot uptime: %s\n", err->message); - } - else { - fprintf (stderr, "cannot uptime\n"); + else { + cur = results; + while (cur) { + res = cur->data; + if (tty) { + printf ("\033[1m"); + } + PRINT_FUNC ("Results for host: %s: %d, %s\n", res->server_name, res->code, res->result->str); + if (tty) { + printf ("\033[0m"); + } + PRINT_FUNC ("%*s\n", (gint)res->data->len, res->data->str); + rspamd_free_controller_result (res); + cur = g_list_next (cur); } - exit (EXIT_FAILURE); - } - if (tty) { - printf ("\033[1m"); + g_list_free (results); } - PRINT_FUNC ("Results for host: %s\n\n", connect_str); - if (tty) { - printf ("\033[0m"); - } - res = g_string_append_c (res, '\0'); - printf ("%s\n", res->str); } + gint main (gint argc, gchar **argv, gchar **env) { @@ -674,9 +623,6 @@ main (gint argc, gchar **argv, gchar **env) case RSPAMC_COMMAND_SYMBOLS: scan_rspamd_stdin (); break; - case RSPAMC_COMMAND_LEARN: - learn_rspamd_stdin (TRUE); - break; case RSPAMC_COMMAND_LEARN_SPAM: if (classifier != NULL) { learn_rspamd_stdin (TRUE); @@ -702,10 +648,10 @@ main (gint argc, gchar **argv, gchar **env) fuzzy_rspamd_stdin (TRUE); break; case RSPAMC_COMMAND_STAT: - rspamd_do_stat (); + rspamd_do_controller_simple_command ("stat"); break; case RSPAMC_COMMAND_UPTIME: - rspamd_do_uptime (); + rspamd_do_controller_simple_command ("uptime"); break; default: fprintf (stderr, "invalid arguments\n"); @@ -741,9 +687,6 @@ main (gint argc, gchar **argv, gchar **env) case RSPAMC_COMMAND_SYMBOLS: scan_rspamd_file (argv[i]); break; - case RSPAMC_COMMAND_LEARN: - learn_rspamd_file (TRUE, argv[i]); - break; case RSPAMC_COMMAND_LEARN_SPAM: if (classifier != NULL) { learn_rspamd_file (TRUE, argv[i]); -- 2.39.5