aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/lua-torch/optim/lbfgs.lua
diff options
context:
space:
mode:
Diffstat (limited to 'contrib/lua-torch/optim/lbfgs.lua')
-rw-r--r--contrib/lua-torch/optim/lbfgs.lua268
1 files changed, 268 insertions, 0 deletions
diff --git a/contrib/lua-torch/optim/lbfgs.lua b/contrib/lua-torch/optim/lbfgs.lua
new file mode 100644
index 000000000..d850fcbb3
--- /dev/null
+++ b/contrib/lua-torch/optim/lbfgs.lua
@@ -0,0 +1,268 @@
+--[[ An implementation of L-BFGS, heavily inspired by minFunc (Mark Schmidt)
+
+This implementation of L-BFGS relies on a user-provided line
+search function (state.lineSearch). If this function is not
+provided, then a simple learningRate is used to produce fixed
+size steps. Fixed size steps are much less costly than line
+searches, and can be useful for stochastic problems.
+
+The learning rate is used even when a line search is provided.
+This is also useful for large-scale stochastic problems, where
+opfunc is a noisy approximation of f(x). In that case, the learning
+rate allows a reduction of confidence in the step size.
+
+ARGS:
+
+- `opfunc` : a function that takes a single input (X), the point of
+ evaluation, and returns f(X) and df/dX
+- `x` : the initial point
+- `state` : a table describing the state of the optimizer; after each
+ call the state is modified
+- `state.maxIter` : Maximum number of iterations allowed
+- `state.maxEval` : Maximum number of function evaluations
+- `state.tolFun` : Termination tolerance on the first-order optimality
+- `state.tolX` : Termination tol on progress in terms of func/param changes
+- `state.lineSearch` : A line search function
+- `state.learningRate` : If no line search provided, then a fixed step size is used
+
+RETURN:
+- `x*` : the new `x` vector, at the optimal point
+- `f` : a table of all function values:
+ `f[1]` is the value of the function before any optimization and
+ `f[#f]` is the final fully optimized value, at `x*`
+
+(Clement Farabet, 2012)
+]]
+function optim.lbfgs(opfunc, x, config, state)
+ -- get/update state
+ local config = config or {}
+ local state = state or config
+ local maxIter = tonumber(config.maxIter) or 20
+ local maxEval = tonumber(config.maxEval) or maxIter*1.25
+ local tolFun = config.tolFun or 1e-5
+ local tolX = config.tolX or 1e-9
+ local nCorrection = config.nCorrection or 100
+ local lineSearch = config.lineSearch
+ local lineSearchOpts = config.lineSearchOptions
+ local learningRate = config.learningRate or 1
+ local isverbose = config.verbose or false
+
+ state.funcEval = state.funcEval or 0
+ state.nIter = state.nIter or 0
+
+ -- verbose function
+ local verbose
+ if isverbose then
+ verbose = function(...) print('<optim.lbfgs> ', ...) end
+ else
+ verbose = function() end
+ end
+
+ -- import some functions
+ local abs = math.abs
+ local min = math.min
+
+ -- evaluate initial f(x) and df/dx
+ local f,g = opfunc(x)
+ local f_hist = {f}
+ local currentFuncEval = 1
+ state.funcEval = state.funcEval + 1
+ local p = g:size(1)
+
+ -- check optimality of initial point
+ state.tmp1 = state.tmp1 or g.new(g:size()):zero(); local tmp1 = state.tmp1
+ tmp1:copy(g):abs()
+ if tmp1:sum() <= tolFun then
+ -- optimality condition below tolFun
+ verbose('optimality condition below tolFun')
+ return x,f_hist
+ end
+
+ if not state.dir_bufs then
+ -- reusable buffers for y's and s's, and their histories
+ verbose('creating recyclable direction/step/history buffers')
+ state.dir_bufs = state.dir_bufs or g.new(nCorrection+1, p):split(1)
+ state.stp_bufs = state.stp_bufs or g.new(nCorrection+1, p):split(1)
+ for i=1,#state.dir_bufs do
+ state.dir_bufs[i] = state.dir_bufs[i]:squeeze(1)
+ state.stp_bufs[i] = state.stp_bufs[i]:squeeze(1)
+ end
+ end
+
+ -- variables cached in state (for tracing)
+ local d = state.d
+ local t = state.t
+ local old_dirs = state.old_dirs
+ local old_stps = state.old_stps
+ local Hdiag = state.Hdiag
+ local g_old = state.g_old
+ local f_old = state.f_old
+
+ -- optimize for a max of maxIter iterations
+ local nIter = 0
+ while nIter < maxIter do
+ -- keep track of nb of iterations
+ nIter = nIter + 1
+ state.nIter = state.nIter + 1
+
+ ------------------------------------------------------------
+ -- compute gradient descent direction
+ ------------------------------------------------------------
+ if state.nIter == 1 then
+ d = g:clone():mul(-1) -- -g
+ old_dirs = {}
+ old_stps = {}
+ Hdiag = 1
+ else
+ -- do lbfgs update (update memory)
+ local y = table.remove(state.dir_bufs) -- pop
+ local s = table.remove(state.stp_bufs)
+ y:add(g, -1, g_old) -- g - g_old
+ s:mul(d, t) -- d*t
+ local ys = y:dot(s) -- y*s
+ if ys > 1e-10 then
+ -- updating memory
+ if #old_dirs == nCorrection then
+ -- shift history by one (limited-memory)
+ local removed1 = table.remove(old_dirs, 1)
+ local removed2 = table.remove(old_stps, 1)
+ table.insert(state.dir_bufs, removed1)
+ table.insert(state.stp_bufs, removed2)
+ end
+
+ -- store new direction/step
+ table.insert(old_dirs, s)
+ table.insert(old_stps, y)
+
+ -- update scale of initial Hessian approximation
+ Hdiag = ys / y:dot(y) -- (y*y)
+ else
+ -- put y and s back into the buffer pool
+ table.insert(state.dir_bufs, y)
+ table.insert(state.stp_bufs, s)
+ end
+
+ -- compute the approximate (L-BFGS) inverse Hessian
+ -- multiplied by the gradient
+ local k = #old_dirs
+
+ -- need to be accessed element-by-element, so don't re-type tensor:
+ state.ro = state.ro or torch.Tensor(nCorrection); local ro = state.ro
+ for i = 1,k do
+ ro[i] = 1 / old_stps[i]:dot(old_dirs[i])
+ end
+
+ -- iteration in L-BFGS loop collapsed to use just one buffer
+ local q = tmp1 -- reuse tmp1 for the q buffer
+ -- need to be accessed element-by-element, so don't re-type tensor:
+ state.al = state.al or torch.zeros(nCorrection) local al = state.al
+
+ q:mul(g, -1) -- -g
+ for i = k,1,-1 do
+ al[i] = old_dirs[i]:dot(q) * ro[i]
+ q:add(-al[i], old_stps[i])
+ end
+
+ -- multiply by initial Hessian
+ r = d -- share the same buffer, since we don't need the old d
+ r:mul(q, Hdiag) -- q[1] * Hdiag
+ for i = 1,k do
+ local be_i = old_stps[i]:dot(r) * ro[i]
+ r:add(al[i]-be_i, old_dirs[i])
+ end
+ -- final direction is in r/d (same object)
+ end
+ g_old = g_old or g:clone()
+ g_old:copy(g)
+ f_old = f
+
+ ------------------------------------------------------------
+ -- compute step length
+ ------------------------------------------------------------
+ -- directional derivative
+ local gtd = g:dot(d) -- g * d
+
+ -- check that progress can be made along that direction
+ if gtd > -tolX then
+ break
+ end
+
+ -- reset initial guess for step size
+ if state.nIter == 1 then
+ tmp1:copy(g):abs()
+ t = min(1,1/tmp1:sum()) * learningRate
+ else
+ t = learningRate
+ end
+
+ -- optional line search: user function
+ local lsFuncEval = 0
+ if lineSearch and type(lineSearch) == 'function' then
+ -- perform line search, using user function
+ f,g,x,t,lsFuncEval = lineSearch(opfunc,x,t,d,f,g,gtd,lineSearchOpts)
+ table.insert(f_hist, f)
+ else
+ -- no line search, simply move with fixed-step
+ x:add(t,d)
+ if nIter ~= maxIter then
+ -- re-evaluate function only if not in last iteration
+ -- the reason we do this: in a stochastic setting,
+ -- no use to re-evaluate that function here
+ f,g = opfunc(x)
+ lsFuncEval = 1
+ table.insert(f_hist, f)
+ end
+ end
+
+ -- update func eval
+ currentFuncEval = currentFuncEval + lsFuncEval
+ state.funcEval = state.funcEval + lsFuncEval
+
+ ------------------------------------------------------------
+ -- check conditions
+ ------------------------------------------------------------
+ if nIter == maxIter then
+ -- no use to run tests
+ verbose('reached max number of iterations')
+ break
+ end
+
+ if currentFuncEval >= maxEval then
+ -- max nb of function evals
+ verbose('max nb of function evals')
+ break
+ end
+
+ tmp1:copy(g):abs()
+ if tmp1:sum() <= tolFun then
+ -- check optimality
+ verbose('optimality condition below tolFun')
+ break
+ end
+
+ tmp1:copy(d):mul(t):abs()
+ if tmp1:sum() <= tolX then
+ -- step size below tolX
+ verbose('step size below tolX')
+ break
+ end
+
+ if abs(f-f_old) < tolX then
+ -- function value changing less than tolX
+ verbose('function value changing less than tolX')
+ break
+ end
+ end
+
+ -- save state
+ state.old_dirs = old_dirs
+ state.old_stps = old_stps
+ state.Hdiag = Hdiag
+ state.g_old = g_old
+ state.f_old = f_old
+ state.t = t
+ state.d = d
+
+ -- return optimal x, and history of f(x)
+ return x,f_hist,currentFuncEval
+end