/*
 * Copyright (c) 2009-2012, Vsevolod Stakhov
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met:
 *     * Redistributions of source code must retain the above copyright
 *       notice, this list of conditions and the following disclaimer.
 *     * Redistributions in binary form must reproduce the above copyright
 *       notice, this list of conditions and the following disclaimer in the
 *       documentation and/or other materials provided with the distribution.
 *
 * THIS SOFTWARE IS PROVIDED BY AUTHOR ''AS IS'' AND ANY
 * EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 * DISCLAIMED. IN NO EVENT SHALL AUTHOR BE LIABLE FOR ANY
 * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 */

/*
 * Winnow classifier
 */

#include "classifiers.h"
#include "tokenizers/tokenizers.h"
#include "main.h"
#include "filter.h"
#include "cfg_file.h"
#include "lua/lua_common.h"

#define WINNOW_PROMOTION 1.23
#define WINNOW_DEMOTION 0.83

#define MEDIAN_WINDOW_SIZE 5

#define MAX_WEIGHT G_MAXDOUBLE / 2.



#define MAX_LEARN_ITERATIONS 100

static inline GQuark
winnow_error_quark (void)
{
	return g_quark_from_static_string ("winnow-error");
}

struct winnow_callback_data {
	statfile_pool_t                *pool;
	struct classifier_ctx          *ctx;
	stat_file_t                    *file;
	stat_file_t                    *learn_file;
	long double                     sum;
	long double 					start;
	double                          multiplier;
	guint32                         count;
	guint32                         new_blocks;
	gboolean                        in_class;
	gboolean                        do_demote;
	gboolean                        fresh_run;
	time_t                          now;
};

static const double max_common_weight = MAX_WEIGHT * WINNOW_DEMOTION;



static                          gboolean
winnow_classify_callback (gpointer key, gpointer value, gpointer data)
{
	token_node_t                   *node = key;
	struct winnow_callback_data    *cd = data;
	double                           v;

	/* Consider that not found blocks have value 1 */
	v = statfile_pool_get_block (cd->pool, cd->file, node->h1, node->h2, cd->now);
	if (fabs (v) > ALPHA) {
		cd->sum += v;
	}
	else {
		cd->sum += 1.0;
		cd->new_blocks ++;
	}

	cd->count++;

	return FALSE;
}

static                          gboolean
winnow_learn_callback (gpointer key, gpointer value, gpointer data)
{
	token_node_t                   *node = key;
	struct winnow_callback_data    *cd = data;
	double                          v, c;
	
	c = (cd->in_class) ? WINNOW_PROMOTION * cd->multiplier : WINNOW_DEMOTION / cd->multiplier;

	/* Consider that not found blocks have value 1 */
	v = statfile_pool_get_block (cd->pool, cd->file, node->h1, node->h2, cd->now);
	if (fabs (v) < ALPHA) {
		/* Block not found, insert new */
		cd->start += 1;
		if (cd->file == cd->learn_file) {
			statfile_pool_set_block (cd->pool, cd->file, node->h1, node->h2, cd->now, c);
			node->value = c;
			cd->new_blocks ++;
		}
	}
	else {
		cd->start += v;
		/* Here we just increase the extra value of block */
		if (cd->fresh_run) {
			node->extra = 0;
		}
		else {
			node->extra ++;
		}
		node->value = v;
		
		if (node->extra > 1) {
			/* 
			 * Assume that this node is common for several statfiles, so
			 * decrease its weight proportianally
			 */
			if (node->value > max_common_weight) {
				/* Static fluctuation */
				statfile_pool_set_block (cd->pool, cd->file, node->h1, node->h2, cd->now, 0.);
				node->value = 0.;
			}
			else if (node->value > WINNOW_PROMOTION * cd->multiplier) {
				/* Try to decrease its value */
				/* XXX: it is more intelligent to add some adaptive filter here */
				if (cd->file == cd->learn_file) {
					if (node->value > max_common_weight / 2.) {
						node->value *= c;
					}
					else {
						/* 
						 * Too high token value that exists also in other
						 * statfiles, may be statistic error, so decrease it
						 * slightly
						 */
						node->value *= WINNOW_DEMOTION;
					}
				}
				else {
					node->value = WINNOW_DEMOTION / cd->multiplier;
				}
				statfile_pool_set_block (cd->pool, cd->file, node->h1, node->h2, cd->now, node->value);
			} 
		}
		else if (cd->file == cd->learn_file) {
			/* New block or block that is in only one statfile */
			/* Set some limit on growing */
			if (v > MAX_WEIGHT) {
				node->value = v;
			}
			else {
				node->value *= c;
			}
			statfile_pool_set_block (cd->pool, cd->file, node->h1, node->h2, cd->now, node->value);
		}
		else if (cd->do_demote) {
			/* Demote blocks in file */
			node->value *= WINNOW_DEMOTION / cd->multiplier;
			statfile_pool_set_block (cd->pool, cd->file, node->h1, node->h2, cd->now, node->value);
		}
	}


	cd->sum += node->value;

	cd->count++;

	return FALSE;
}

struct classifier_ctx          *
winnow_init (memory_pool_t * pool, struct classifier_config *cfg)
{
	struct classifier_ctx          *ctx = memory_pool_alloc (pool, sizeof (struct classifier_ctx));

	ctx->pool = pool;
	ctx->cfg = cfg;

	return ctx;
}

gboolean
winnow_classify (struct classifier_ctx *ctx, statfile_pool_t * pool, GTree * input, struct worker_task *task, lua_State *L)
{
	struct winnow_callback_data     data;
	char                           *sumbuf, *value;
	long double                     res = 0., max = 0.;
	GList                          *cur;
	struct statfile                *st, *sel = NULL;
	int                             nodes, minnodes;

	g_assert (pool != NULL);
	g_assert (ctx != NULL);

	data.pool = pool;
	data.now = time (NULL);
	data.ctx = ctx;
    
	if (ctx->cfg->opts && (value = g_hash_table_lookup (ctx->cfg->opts, "min_tokens")) != NULL) {
		minnodes = strtol (value, NULL, 10);
		nodes = g_tree_nnodes (input);
		if (nodes > FEATURE_WINDOW_SIZE) {
			nodes = nodes / FEATURE_WINDOW_SIZE + FEATURE_WINDOW_SIZE;
		}
		if (nodes < minnodes) {
			msg_info ("do not classify message as it has too few tokens: %d, while %d min", nodes, minnodes);
			return FALSE;
		}
	}

	cur = call_classifier_pre_callbacks (ctx->cfg, task, FALSE, FALSE, L);
	if (cur) {
		memory_pool_add_destructor (task->task_pool, (pool_destruct_func)g_list_free, cur);
	}
	else {
		cur = ctx->cfg->statfiles;
	}

	while (cur) {
		st = cur->data;
		data.sum = 0;
		data.count = 0;
		data.new_blocks = 0;
		if ((data.file = statfile_pool_is_open (pool, st->path)) == NULL) {
			if ((data.file = statfile_pool_open (pool, st->path, st->size, FALSE)) == NULL) {
				msg_warn ("cannot open %s, skip it", st->path);
				cur = g_list_next (cur);
				continue;
			}
		}

		if (data.file != NULL) {
			g_tree_foreach (input, winnow_classify_callback, &data);
		}

		if (data.count != 0) {
			res = data.sum / (double)data.count;
		}
		else {
			res = 0;
		}
		if (res > max) {
			max = res;
			sel = st;
		}
		cur = g_list_next (cur);
	}

	if (sel != NULL) {
#ifdef WITH_LUA
        max = call_classifier_post_callbacks (ctx->cfg, task, max, L);
#endif
#ifdef HAVE_TANHL
        max = tanhl (max);
#else
        /*
         * As some implementations of libm does not support tanhl, try to use
         * tanh
         */
        max = tanh ((double) max);
#endif
		sumbuf = memory_pool_alloc (task->task_pool, 32);
		rspamd_snprintf (sumbuf, 32, "%.2F", max);
		cur = g_list_prepend (NULL, sumbuf);
		insert_result (task, sel->symbol, max, cur);
	}

	return TRUE;
}

GList *
winnow_weights (struct classifier_ctx *ctx, statfile_pool_t * pool, GTree * input, struct worker_task *task)
{
	struct winnow_callback_data     data;
	long double                     res = 0.;
	GList                          *cur, *resl = NULL;
	struct statfile                *st;
	struct classify_weight         *w;
	char                           *value;
	int                             nodes, minnodes;

	g_assert (pool != NULL);
	g_assert (ctx != NULL);

	data.pool = pool;
	data.now = time (NULL);
	data.ctx = ctx;

	if (ctx->cfg->opts && (value = g_hash_table_lookup (ctx->cfg->opts, "min_tokens")) != NULL) {
		minnodes = strtol (value, NULL, 10);
		nodes = g_tree_nnodes (input);
		if (nodes > FEATURE_WINDOW_SIZE) {
			nodes = nodes / FEATURE_WINDOW_SIZE + FEATURE_WINDOW_SIZE;
		}
		if (nodes < minnodes) {
			msg_info ("do not classify message as it has too few tokens: %d, while %d min", nodes, minnodes);
			return NULL;
		}
	}
    
	cur = ctx->cfg->statfiles;
	while (cur) {
		st = cur->data;
		data.sum = 0;
		data.count = 0;
		if ((data.file = statfile_pool_is_open (pool, st->path)) == NULL) {
			if ((data.file = statfile_pool_open (pool, st->path, st->size, FALSE)) == NULL) {
				msg_warn ("cannot open %s, skip it", st->path);
				cur = g_list_next (cur);
				continue;
			}
		}

		if (data.file != NULL) {
			g_tree_foreach (input, winnow_classify_callback, &data);
		}

		w = memory_pool_alloc0 (task->task_pool, sizeof (struct classify_weight));
		if (data.count != 0) {
			res = data.sum / (double)data.count;
		}
		else {
			res = 0;
		}
		w->name = st->symbol;
		w->weight = res;
		resl = g_list_prepend (resl, w);
		cur = g_list_next (cur);
	}
	
	if (resl != NULL) {
		memory_pool_add_destructor (task->task_pool, (pool_destruct_func)g_list_free, resl);
	}

	return resl;

}


gboolean
winnow_learn (struct classifier_ctx *ctx, statfile_pool_t *pool, const char *symbol,
		GTree * input, int in_class, double *sum, double multiplier, GError **err)
{
	struct winnow_callback_data     data = {
		.file = NULL,
		.multiplier = multiplier
	};
	char                           *value;
	int                             nodes, minnodes, iterations = 0;
	struct statfile                *st, *sel_st = NULL;
	stat_file_t                    *sel = NULL, *to_learn;
	long double                     res = 0., max = 0., start_value = 0., end_value = 0.;
	double                          learn_threshold = 0.0;
	GList                          *cur, *to_demote = NULL;
	gboolean                        force_learn = FALSE;

	g_assert (pool != NULL);
	g_assert (ctx != NULL);

	data.pool = pool;
	data.in_class = in_class;
	data.now = time (NULL);
	data.ctx = ctx;


	if (ctx->cfg->opts && (value = g_hash_table_lookup (ctx->cfg->opts, "min_tokens")) != NULL) {
		minnodes = strtol (value, NULL, 10);
		nodes = g_tree_nnodes (input);
		if (nodes > FEATURE_WINDOW_SIZE) {
			nodes = nodes / FEATURE_WINDOW_SIZE + FEATURE_WINDOW_SIZE;
		}
		if (nodes < minnodes) {
			msg_info ("do not learn message as it has too few tokens: %d, while %d min", nodes, minnodes);
			if (sum != NULL) {
				*sum = 0;
			}
			g_set_error (err,
	                   winnow_error_quark(),		/* error domain */
	                   1,            				/* error code */
	                   "message contains too few tokens: %d, while min is %d",
	                   nodes, minnodes);
			return FALSE;
		}
	}
	if (ctx->cfg->opts && (value = g_hash_table_lookup (ctx->cfg->opts, "learn_threshold")) != NULL) {
		learn_threshold = strtod (value, NULL);
	}
	
	if (learn_threshold <= 1.0 && learn_threshold >= 0) {
		/* Classify message and check target statfile score */
		cur = ctx->cfg->statfiles;
		while (cur) {
			/* Open or create all statfiles inside classifier */
			st = cur->data;
			if (statfile_pool_is_open (pool, st->path) == NULL) {
				if (statfile_pool_open (pool, st->path, st->size, FALSE) == NULL) {
					msg_warn ("cannot open %s", st->path);
					if (statfile_pool_create (pool, st->path, st->size) == -1) {
						msg_err ("cannot create statfile %s", st->path);
						g_set_error (err,
								winnow_error_quark(),		/* error domain */
								1,            				/* error code */
								"cannot create statfile: %s",
								st->path);
						return FALSE;
					}
					if (statfile_pool_open (pool, st->path, st->size, FALSE) == NULL) {
						g_set_error (err,
								winnow_error_quark(),		/* error domain */
								1,            				/* error code */
								"open statfile %s after creation",
								st->path);
						msg_err ("cannot open statfile %s after creation", st->path);
						return FALSE;
					}
				}
			}
			if (strcmp (st->symbol, symbol) == 0) {
				sel_st = st;

			}
			cur = g_list_next (cur);
		}

		if (sel_st == NULL) {
			g_set_error (err,
					winnow_error_quark(),		/* error domain */
					1,            				/* error code */
					"cannot find statfile for symbol %s",
					symbol);
			msg_err ("cannot find statfile for symbol %s", symbol);
			return FALSE;
		}

		to_learn = statfile_pool_is_open (pool, sel_st->path);
		if (to_learn == NULL) {
			g_set_error (err,
					winnow_error_quark(),		/* error domain */
					1,            				/* error code */
					"statfile %s is not opened this maybe if your statfile pool is too small to handle all statfiles",
					sel_st->path);
			return FALSE;
		}
		/* Check target statfile */
		data.file = to_learn;
		data.sum = 0;
		data.count = 0;
		data.new_blocks = 0;
		g_tree_foreach (input, winnow_classify_callback, &data);
		if (data.count > 0) {
			max = data.sum / (double)data.count;
		}
		else {
			max = 0;
		}
		/* If most of blocks are not presented in targeted statfile do forced learn */
		if (max < 1 + learn_threshold) {
			force_learn = TRUE;
		}
		/* Check other statfiles */
		while (cur) {
			st = cur->data;
			data.sum = 0;
			data.count = 0;
			if ((data.file = statfile_pool_is_open (pool, st->path)) == NULL) {
				g_set_error (err,
						winnow_error_quark(),		/* error domain */
						1,            				/* error code */
						"statfile %s is not opened this maybe if your statfile pool is too small to handle all statfiles",
						st->path);
				return FALSE;
			}
			g_tree_foreach (input, winnow_classify_callback, &data);
			if (data.count != 0) {
				res = data.sum / data.count;
			}
			else {
				res = 0;
			}
			if (to_learn != data.file && res - max > 1 - learn_threshold) {
				/* Demote tokens in this statfile */
				to_demote = g_list_prepend (to_demote, data.file);
			}
			cur = g_list_next (cur);
		}
	}
	else {
		msg_err ("learn threshold is more than 1 or less than 0, so cannot do learn, please check your configuration");
		g_set_error (err,
				winnow_error_quark(),		/* error domain */
				1,            				/* error code */
				"bad learn_threshold setting: %.2f",
				learn_threshold);
		return FALSE;
	}
	/* If to_demote list is empty this message is already classified correctly */
	if (max > WINNOW_PROMOTION && to_demote == NULL && !force_learn) {
		msg_info ("this message is already of class %s with threshold %.2f and weight %.2F",
				sel_st->symbol, learn_threshold, max);
		goto end;
	}
	data.learn_file = to_learn;
	end_value = max;
	do {
		cur = ctx->cfg->statfiles;
		data.fresh_run = TRUE;
		while (cur) {
			st = cur->data;
			data.sum = 0;
			data.count = 0;
			data.new_blocks = 0;
			data.start = 0;
			if ((data.file = statfile_pool_is_open (pool, st->path)) == NULL) {
				return FALSE;
			}
			if (to_demote != NULL && g_list_find (to_demote, data.file) != NULL) {
				data.do_demote = TRUE;
			}
			else {
				data.do_demote = FALSE;
			}

			statfile_pool_lock_file (pool, data.file);
			g_tree_foreach (input, winnow_learn_callback, &data);
			statfile_pool_unlock_file (pool, data.file);
			if (data.count != 0) {
				res = data.sum / data.count;
			}
			else {
				res = 0;
			}
			if (res > max) {
				max = res;
				sel = data.file;
			}
			if (data.file == to_learn) {
				if (data.count > 0) {
					start_value = data.start / data.count;
				}
				end_value = res;
			}
			cur = g_list_next (cur);
			data.fresh_run = FALSE;
		}

		data.multiplier *= WINNOW_PROMOTION;
		msg_info ("learn iteration %d for statfile %s: %G -> %G, multiplier: %.2f", iterations + 1, symbol,
				start_value, end_value, data.multiplier);
	} while ((in_class ? sel != to_learn : sel == to_learn)  && iterations ++ < MAX_LEARN_ITERATIONS);
	
	if (iterations >= MAX_LEARN_ITERATIONS) {
		msg_warn ("learning statfile %s  was not fully successfull: iterations count is limited to %d, final sum is %G", 
				sel_st->symbol, MAX_LEARN_ITERATIONS, max);
		g_set_error (err,
				winnow_error_quark(),		/* error domain */
				1,            				/* error code */
				"learning statfile %s  was not fully successfull: iterations count is limited to %d",
				sel_st->symbol, MAX_LEARN_ITERATIONS);
		return FALSE;
	}
	else {
		msg_info ("learned statfile %s successfully with %d iterations and sum %G", sel_st->symbol, iterations + 1, max);
	}


end:
	if (sum) {
#ifdef HAVE_TANHL
        *sum = (double)tanhl (max);
#else
        /*
         * As some implementations of libm does not support tanhl, try to use
         * tanh
         */
        *sum = tanh ((double) max);
#endif
	}
	return TRUE;
}

gboolean
winnow_learn_spam (struct classifier_ctx* ctx, statfile_pool_t *pool,
		GTree *input, struct worker_task *task, gboolean is_spam, lua_State *L, GError **err)
{
	g_set_error (err,
					winnow_error_quark(),		/* error domain */
					1,            				/* error code */
					"learn spam is not supported for winnow"
					);
	return FALSE;
}