From f8004be4c94ca214cd399cdb18aa0d085abf0fd7 Mon Sep 17 00:00:00 2001 From: Vsevolod Stakhov Date: Mon, 1 Jul 2019 13:50:21 +0100 Subject: [Test] Add unit test for kann --- test/lua/unit/kann.lua | 43 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) create mode 100644 test/lua/unit/kann.lua (limited to 'test/lua') diff --git a/test/lua/unit/kann.lua b/test/lua/unit/kann.lua new file mode 100644 index 000000000..bb6930203 --- /dev/null +++ b/test/lua/unit/kann.lua @@ -0,0 +1,43 @@ +-- Simple kann test (xor function vs 2 layer MLP) + +context("Kann test", function() + local kann = require "rspamd_kann" + local k + local inputs = { + {0, 0}, + {0, 1}, + {1, 0}, + {1, 1} + } + + local outputs = { + {0}, + {1}, + {1}, + {0} + } + + local t = kann.layer.input(2) + t = kann.transform.relu(t) + t = kann.transform.tanh(kann.layer.dense(t, 2)); + t = kann.layer.cost(t, 1, kann.cost.mse) + k = kann.new.kann(t) + + local iters = 500 + local niter = k:train1(inputs, outputs, { + lr = 0.01, + max_epoch = iters, + mini_size = 80, + }) + + for i,inp in ipairs(inputs) do + test(string.format("Check XOR MLP %s ^ %s == %s", inp[1], inp[2], outputs[i][1]), + function() + local res = math.floor(k:apply1(inp)[1] + 0.5) + assert_equal(outputs[i][1], res, + tostring(outputs[i][1]) .. " but test returned " .. tostring(res)) + end) + end + + +end) \ No newline at end of file -- cgit v1.2.3