diff options
author | Vsevolod Stakhov <vsevolod@highsecure.ru> | 2018-05-23 18:14:15 +0100 |
---|---|---|
committer | Vsevolod Stakhov <vsevolod@highsecure.ru> | 2018-05-23 18:14:15 +0100 |
commit | 714eb56e1760fdfb26afccde92664d3a2f1e8435 (patch) | |
tree | 84d1399acbb92f852b4bd64f9ea5412680b0c6ab /contrib/lua-torch/optim/adadelta.lua | |
parent | 220a51ff68013dd668a45b78c60a7b8bfc10f074 (diff) | |
download | rspamd-714eb56e1760fdfb26afccde92664d3a2f1e8435.tar.gz rspamd-714eb56e1760fdfb26afccde92664d3a2f1e8435.zip |
[Minor] Move lua contrib libraries to lua- prefix
Diffstat (limited to 'contrib/lua-torch/optim/adadelta.lua')
-rw-r--r-- | contrib/lua-torch/optim/adadelta.lua | 55 |
1 files changed, 55 insertions, 0 deletions
diff --git a/contrib/lua-torch/optim/adadelta.lua b/contrib/lua-torch/optim/adadelta.lua new file mode 100644 index 000000000..7cc058d29 --- /dev/null +++ b/contrib/lua-torch/optim/adadelta.lua @@ -0,0 +1,55 @@ +--[[ ADADELTA implementation for SGD http://arxiv.org/abs/1212.5701 + +ARGS: +- `opfunc` : a function that takes a single input (X), the point of + evaluation, and returns f(X) and df/dX +- `x` : the initial point +- `config` : a table of hyper-parameters +- `config.rho` : interpolation parameter +- `config.eps` : for numerical stability +- `config.weightDecay` : weight decay +- `state` : a table describing the state of the optimizer; after each + call the state is modified +- `state.paramVariance` : vector of temporal variances of parameters +- `state.accDelta` : vector of accummulated delta of gradients +RETURN: +- `x` : the new x vector +- `f(x)` : the function, evaluated before the update +]] +function optim.adadelta(opfunc, x, config, state) + -- (0) get/update state + if config == nil and state == nil then + print('no state table, ADADELTA initializing') + end + local config = config or {} + local state = state or config + local rho = config.rho or 0.9 + local eps = config.eps or 1e-6 + local wd = config.weightDecay or 0 + state.evalCounter = state.evalCounter or 0 + -- (1) evaluate f(x) and df/dx + local fx,dfdx = opfunc(x) + + -- (2) weight decay + if wd ~= 0 then + dfdx:add(wd, x) + end + + -- (3) parameter update + if not state.paramVariance then + state.paramVariance = torch.Tensor():typeAs(x):resizeAs(dfdx):zero() + state.paramStd = torch.Tensor():typeAs(x):resizeAs(dfdx):zero() + state.delta = torch.Tensor():typeAs(x):resizeAs(dfdx):zero() + state.accDelta = torch.Tensor():typeAs(x):resizeAs(dfdx):zero() + end + state.paramVariance:mul(rho):addcmul(1-rho,dfdx,dfdx) + state.paramStd:resizeAs(state.paramVariance):copy(state.paramVariance):add(eps):sqrt() + state.delta:resizeAs(state.paramVariance):copy(state.accDelta):add(eps):sqrt():cdiv(state.paramStd):cmul(dfdx) + x:add(-1, state.delta) + state.accDelta:mul(rho):addcmul(1-rho, state.delta, state.delta) + -- (4) update evaluation counter + state.evalCounter = state.evalCounter + 1 + + -- return x*, f(x) before optimization + return x,{fx} +end |