/*
 * This program is RSPAMD agent for use with
 * exim (http://www.exim.org) MTA by its local_scan feature.
 *
 * To enable exim local scan please copy this file to exim source tree
 * Local/local_scan.c, edit Local/Makefile to add
 *
 * LOCAL_SCAN_SOURCE=Local/local_scan.c
 * LOCAL_SCAN_HAS_OPTIONS=yes
 *
 * and compile exim.
 *
 * For exim compilation with local scan feature details please visit
 * http://www.exim.org/exim-html-current/doc/html/spec_html/ch42.html
 *
 * For RSPAMD details please visit
 * https://bitbucket.org/vstakhov/rspamd/
 *
 * Example configuration:
 * **********************
 *
 * local_scan_timeout = 50s
 *
 * begin local_scan
 *       rspam_ip = 127.0.0.1
 *       rspam_port = 11333
 *       rspam_skip_sasl_authenticated = true
 *       # don't reject message if on of recipients from this list
 *       rspam_skip_rcpt = postmaster@example.com : some_user@example.com
 *       rspam_message = "Spam rejected; If this is not spam, please contact <postmaster@example.com>"
 *
 *
 * $Id: local_scan.c 646 2010-08-11 11:49:36Z ayuzhaninov $
 */

#include <sys/types.h>
#include <sys/socket.h>
#include <sys/stat.h>
#include <sys/uio.h>

#include <netinet/in.h>
#include <arpa/inet.h>

#include <errno.h>
#include <math.h>
#include <stdio.h>
#include <string.h>
#include <unistd.h>

#include "local_scan.h"

#define REQUEST_LINES	64
#define REPLY_BUF_SIZE	16384
#define HEADER_STATUS	"X-Rspam-Status"
#define HEADER_METRIC	"X-Rspam-Metric"
#define HEADER_SCORE	"X-Rspam-Score"

/* configuration options */
static uschar *daemon_ip = US"127.0.0.1";
static int max_scan_size = 4 * 1024 * 1024;
static uschar *reject_message = US"Spam message rejected";
static int daemon_port = 11333;
static BOOL skip_authenticated = TRUE;
static uschar *want_spam_rcpt_list = US"";

/* the entries must appear in alphabetical order */
optionlist local_scan_options[] = {
	{ "rspam_ip", opt_stringptr, &daemon_ip },
	{ "rspam_max_scan_size", opt_mkint, &max_scan_size },
	{ "rspam_message", opt_stringptr, &reject_message },
	{ "rspam_port", opt_int, &daemon_port },
	{ "rspam_skip_rcpt", opt_stringptr, &want_spam_rcpt_list },
	{ "rspam_skip_sasl_authenticated", opt_bool, &skip_authenticated },
};

int local_scan_options_count = sizeof(local_scan_options) / sizeof(optionlist);

/* push formatted line into vector */
int push_line(struct iovec *iov, int i, const char *fmt, ...);

int
local_scan(int fd, uschar **return_text)
{
	struct stat sb;
	struct sockaddr_in server_in;
	int s, i, r, request_p = 0, headers_count = 0, is_spam = 0, is_reject = 0;
	off_t message_size;
	struct iovec request_v[REQUEST_LINES], *headers_v;
#if "@CMAKE_SYSTEM_NAME@" == "FreeBSD"
	struct sf_hdtr headers_sf;
#endif
	uschar *helo, *log_buf;
	header_line *header_p;
	char reply_buf[REPLY_BUF_SIZE], io_buf[BUFSIZ];
	ssize_t size;
	char *tok_ptr, *str;
	char mteric[128], result[8];
	float score, required_score;

	*return_text = reject_message;

	/*
	 * one msaage can be send via exim+rspamd twice
	 * remove header from previous pass
	 */
	header_remove(0, US HEADER_STATUS);
	header_remove(0, US HEADER_METRIC);

	/* check message size */
	fstat(fd,&sb); /* XXX shuld check error */
	message_size = sb.st_size - SPOOL_DATA_START_OFFSET;
	if (message_size > max_scan_size) {
		header_add(' ', HEADER_STATUS ": skip_big\n");
		log_write (0, LOG_MAIN, "rspam: message larger than rspam_max_scan_size, accept");
		return LOCAL_SCAN_ACCEPT;
	}

	/* don't scan mail from authenticated hosts */
	if (skip_authenticated && sender_host_authenticated != NULL) {
		header_add(' ', HEADER_STATUS ": skip_authenticated\n");
		log_write(0, LOG_MAIN, "rspam: from=<%s> ip=%s authenticated (%s), skip check\n",
				sender_address,
				sender_host_address == NULL ? US"localhost" : sender_host_address,
				sender_host_authenticated);
		return LOCAL_SCAN_ACCEPT;
	}

	/*
	 * add status header, which mean, that message was not scanned
	 * if message will be scanned, this header will be replaced
	 */
	header_add(' ', HEADER_STATUS ": check_error\n");

	/* create socket */
	memset(&server_in, 0, sizeof(server_in));
	server_in.sin_family = AF_INET;
	server_in.sin_port = htons(daemon_port);
	server_in.sin_addr.s_addr = inet_addr(daemon_ip);
	if ((s = socket(PF_INET, SOCK_STREAM, 0)) < 0) {
		log_write(0, LOG_MAIN, "rspam: socket (%d: %s)", errno, strerror(errno));
		return LOCAL_SCAN_ACCEPT;
	}
	if (connect(s, (struct sockaddr *) &server_in, sizeof(server_in)) < 0) {
		close(s);
		log_write(0, LOG_MAIN, "rspam: can't connect to %s:%d (%d: %s)", daemon_ip, daemon_port, errno, strerror(errno));
		return LOCAL_SCAN_ACCEPT;
	}

	/* count message headers */
	for (header_p = header_list; header_p != NULL; header_p = header_p->next) {
		/* header type '*' is used for replaced or deleted header */
		if (header_p->type == '*')
			continue;
		headers_count++;
	}

	/* write message headers to vector */
#if "@CMAKE_SYSTEM_NAME@" == "FreeBSD"
	memset(&headers_sf, 0, sizeof(headers_sf));
	if (headers_count > 0) {
		headers_v = store_get((headers_count + 1)* sizeof(*headers_v));
		i = 0;
		for (header_p = header_list; header_p != NULL; header_p = header_p->next) {
			if (header_p->type == '*')
				continue;
			headers_v[i].iov_base = header_p->text;
			headers_v[i].iov_len = header_p->slen;
			i++;
			message_size += header_p->slen;
		}
		headers_v[i].iov_base = "\n";
		headers_v[i].iov_len = strlen("\n");
		message_size += strlen("\n");

		headers_sf.headers = headers_v;
		headers_sf.hdr_cnt = headers_count + 1;
	}
#else
	if (headers_count > 0) {
		headers_v = store_get((headers_count + 1)* sizeof(*headers_v));
		i = 0;
		for (header_p = header_list; header_p != NULL; header_p = header_p->next) {
			if (header_p->type == '*')
				continue;
			headers_v[i].iov_base = header_p->text;
			headers_v[i].iov_len = header_p->slen;
			i++;
			message_size += header_p->slen;
		}
		headers_v[i].iov_base = "\n";
		headers_v[i].iov_len = strlen("\n");
		message_size += strlen("\n");
#endif

	/* write request to vector */
	r = 0;
	r += push_line(request_v, request_p++, "SYMBOLS RSPAMC/1.1\r\n");
	r += push_line(request_v, request_p++, "Content-length: " OFF_T_FMT "\r\n", message_size);
	r += push_line(request_v, request_p++, "Queue-Id: %s\r\n", message_id);
	r += push_line(request_v, request_p++, "From: %s\r\n", sender_address);
	r += push_line(request_v, request_p++, "Recipient-Number: %d\r\n", recipients_count);
	for (i = 0; i < recipients_count; i ++)
		r += push_line(request_v, request_p++, "Rcpt: %s\r\n", recipients_list[i].address);
	if ((helo = expand_string(US"$sender_helo_name")) != NULL && *helo != '\0')
		r += push_line(request_v, request_p++, "Helo: %s\r\n", helo);
	if (sender_host_address != NULL)
		r += push_line(request_v, request_p++, "IP: %s\r\n", sender_host_address);
	r += push_line(request_v, request_p++, "\r\n");

	if (r < 0) {
		close(s);
		return LOCAL_SCAN_ACCEPT;
	}

	/* send request */
	if (writev(s, request_v, request_p) < 0) {
		close(s);
		log_write(0, LOG_MAIN, "rspam: can't send request to %s:%d (%d: %s)", daemon_ip, daemon_port, errno, strerror(errno));
		return LOCAL_SCAN_ACCEPT;
	}

#if "@CMAKE_SYSTEM_NAME@" == "FreeBSD"
	/* send headers (from iovec) and message body (from file) */
	if (sendfile(fd, s, SPOOL_DATA_START_OFFSET, 0, &headers_sf, NULL, 0) < 0) {
		close(s);
		log_write(0, LOG_MAIN, "rspam: can't send message to %s:%d (%d: %s)", daemon_ip, daemon_port, errno, strerror(errno));
		return LOCAL_SCAN_ACCEPT;
	}
#else
	/* send headers */
	if (writev(s, headers_v, headers_count) < 0) {
		close(s);
		log_write(0, LOG_MAIN, "rspam: can't send headers to %s:%d (%d: %s)", daemon_ip, daemon_port, errno, strerror(errno));
		return LOCAL_SCAN_ACCEPT;
	}

	/* Send message */
	while ((r = read (fd, io_buf, sizeof (io_buf))) > 0) {
		if (write (s, io_buf, r) < 0) {
			close(s);
			log_write(0, LOG_MAIN, "rspam: can't send message to %s:%d (%d: %s)", daemon_ip, daemon_port, errno, strerror(errno));
			return LOCAL_SCAN_ACCEPT;
		}	
	}
#endif

	/* read reply from rspamd */
	reply_buf[0] = '\0';
	size = 0;
	while ((r = read(s, reply_buf + size, sizeof(reply_buf) - size - 1)) > 0 && size < sizeof(reply_buf) - 1) {
		size += r;
	}

	if (r < 0) {
		close(s);
		log_write(0, LOG_MAIN, "rspam: can't read from %s:%d (%d: %s)", daemon_ip, daemon_port, errno, strerror(errno));
		return LOCAL_SCAN_ACCEPT;
	}
	reply_buf[size] = '\0';
	close(s);

	if (size >= REPLY_BUF_SIZE - 1) {
		log_write(0, LOG_MAIN, "rspam: buffer is full, reply may be truncated");
	}

	/* parse reply */
	tok_ptr = reply_buf;

	/*
	 * rspamd can use several metrics, logic implemented here:
	 * if any metric more than reject_score - will reject
	 * if any metric true - message will be marked as spam
	 */

	/* First line is: <PROTOCOL>/<VERSION> <ERROR_CODE> <ERROR_REPLY> */
	str = strsep(&tok_ptr, "\r\n");
	if (str != NULL && sscanf(str, "%*s %d %*s", &i) == 1) {
		if (i != 0) {
			log_write(0, LOG_MAIN, "rspam: server error: %s", str);
			return LOCAL_SCAN_ACCEPT;
		}
	} else {
		log_write(0, LOG_MAIN, "rspam: bad reply from server: %s", str);
		return LOCAL_SCAN_ACCEPT;
	}

	while ((str = strsep(&tok_ptr, "\r\n")) != NULL) {
		/* skip empty tockens */
		if (*str == '\0')
			continue;
		if (strncmp(str, "Metric:", strlen("Metric:")) == 0) {
			/*
			 * parse line like
			 * Metric: default; False; 27.00 / 30.00
			 */
			if (sscanf(str, "Metric: %s %s %f / %f",
				mteric, result, &score, &required_score) == 4) {
				log_write(0, LOG_MAIN, "rspam: metric %s %s %.2f / %.2f",
					mteric, result, score, required_score);
				header_add(' ', HEADER_METRIC ": %s %s %.2f / %.2f\n",
					mteric, result, score, required_score);
				/* integers score for use in sieve ascii-numeric comparator */
				if (strcmp(mteric, "default;") == 0)
					 header_add(' ', HEADER_SCORE ": %d\n",
						(int)round(score));
			} else {
				log_write(0, LOG_MAIN, "rspam: can't parse: %s", str);
				return LOCAL_SCAN_ACCEPT;
			}
		} else if (strncmp(str, "Action:", strlen("Action:")) == 0) {
			/* line like Action: add header */
			str += strlen("Action: ");
			if (strncmp(str, "reject", strlen("reject")) == 0) {
				is_reject = 1;
				is_spam = 1;
			} else if (strncmp(str, "add header", strlen("add header")) == 0) {
				is_spam = 1;
			}
		}
	}

	/* XXX many allocs by string_sprintf()
	 * better to sprintf() to single buffer allocated by store_get()
	 */
	log_buf = string_sprintf("message to");
	for (i = 0; i < recipients_count; i ++) {
		log_buf = string_sprintf("%s %s", log_buf, recipients_list[i].address);
		if (is_reject && lss_match_address(recipients_list[i].address, want_spam_rcpt_list, TRUE) == OK) {
			is_reject = 0;
			log_write(0, LOG_MAIN, "rspam: %s want spam, don't reject this message", recipients_list[i].address);
		}
	}

	if (is_reject) {
		log_write(0, LOG_MAIN, "rspam: reject %s", log_buf);
		return LOCAL_SCAN_REJECT;
	}

	header_remove(0, US HEADER_STATUS);
	if (is_spam) {
		header_add(' ', HEADER_STATUS ": spam\n");
		log_write(0, LOG_MAIN, "rspam: message marked as spam");
	} else {
		header_add(' ', HEADER_STATUS ": ham\n");
		log_write(0, LOG_MAIN, "rspam: message marked as ham");
	}

	return LOCAL_SCAN_ACCEPT;
}

int
push_line(struct iovec *iov, const int i, const char *fmt, ...)
{
	va_list ap;
	size_t len;
	char buf[512];

	if (i >= REQUEST_LINES) {
		log_write(0, LOG_MAIN, "rspam: %s: index out of bounds", __FUNCTION__);
		return (-1);
	}

	va_start(ap, fmt);
	len = vsnprintf(buf, sizeof(buf), fmt, ap);
	va_end(ap);

	iov[i].iov_base = string_copy(US buf);
	iov[i].iov_len = len;

	if (len >= sizeof(buf)) {
		log_write(0, LOG_MAIN, "rspam: %s: error, string was longer than %d", __FUNCTION__, sizeof(buf));
		return (-1);
	}

	return 0;
}