diff options
author | Vsevolod Stakhov <vsevolod@highsecure.ru> | 2019-07-01 13:50:21 +0100 |
---|---|---|
committer | Vsevolod Stakhov <vsevolod@highsecure.ru> | 2019-07-01 13:50:21 +0100 |
commit | f8004be4c94ca214cd399cdb18aa0d085abf0fd7 (patch) | |
tree | ca5efdb84dc5676f6600a8cda91720c0e5cedcaa /test/lua | |
parent | c5ef059e0d0ea41b8a490c2f838a819e1363d0dd (diff) | |
download | rspamd-f8004be4c94ca214cd399cdb18aa0d085abf0fd7.tar.gz rspamd-f8004be4c94ca214cd399cdb18aa0d085abf0fd7.zip |
[Test] Add unit test for kann
Diffstat (limited to 'test/lua')
-rw-r--r-- | test/lua/unit/kann.lua | 43 |
1 files changed, 43 insertions, 0 deletions
diff --git a/test/lua/unit/kann.lua b/test/lua/unit/kann.lua new file mode 100644 index 000000000..bb6930203 --- /dev/null +++ b/test/lua/unit/kann.lua @@ -0,0 +1,43 @@ +-- Simple kann test (xor function vs 2 layer MLP) + +context("Kann test", function() + local kann = require "rspamd_kann" + local k + local inputs = { + {0, 0}, + {0, 1}, + {1, 0}, + {1, 1} + } + + local outputs = { + {0}, + {1}, + {1}, + {0} + } + + local t = kann.layer.input(2) + t = kann.transform.relu(t) + t = kann.transform.tanh(kann.layer.dense(t, 2)); + t = kann.layer.cost(t, 1, kann.cost.mse) + k = kann.new.kann(t) + + local iters = 500 + local niter = k:train1(inputs, outputs, { + lr = 0.01, + max_epoch = iters, + mini_size = 80, + }) + + for i,inp in ipairs(inputs) do + test(string.format("Check XOR MLP %s ^ %s == %s", inp[1], inp[2], outputs[i][1]), + function() + local res = math.floor(k:apply1(inp)[1] + 0.5) + assert_equal(outputs[i][1], res, + tostring(outputs[i][1]) .. " but test returned " .. tostring(res)) + end) + end + + +end)
\ No newline at end of file |