diff options
-rw-r--r-- | test/lua/unit/kann.lua | 43 |
1 files changed, 43 insertions, 0 deletions
diff --git a/test/lua/unit/kann.lua b/test/lua/unit/kann.lua new file mode 100644 index 000000000..bb6930203 --- /dev/null +++ b/test/lua/unit/kann.lua @@ -0,0 +1,43 @@ +-- Simple kann test (xor function vs 2 layer MLP) + +context("Kann test", function() + local kann = require "rspamd_kann" + local k + local inputs = { + {0, 0}, + {0, 1}, + {1, 0}, + {1, 1} + } + + local outputs = { + {0}, + {1}, + {1}, + {0} + } + + local t = kann.layer.input(2) + t = kann.transform.relu(t) + t = kann.transform.tanh(kann.layer.dense(t, 2)); + t = kann.layer.cost(t, 1, kann.cost.mse) + k = kann.new.kann(t) + + local iters = 500 + local niter = k:train1(inputs, outputs, { + lr = 0.01, + max_epoch = iters, + mini_size = 80, + }) + + for i,inp in ipairs(inputs) do + test(string.format("Check XOR MLP %s ^ %s == %s", inp[1], inp[2], outputs[i][1]), + function() + local res = math.floor(k:apply1(inp)[1] + 0.5) + assert_equal(outputs[i][1], res, + tostring(outputs[i][1]) .. " but test returned " .. tostring(res)) + end) + end + + +end)
\ No newline at end of file |