aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/torch/optim/lbfgs.lua
blob: d850fcbb3361a544d667bd77b892562127d13759 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
--[[ An implementation of L-BFGS, heavily inspired by minFunc (Mark Schmidt)

This implementation of L-BFGS relies on a user-provided line
search function (state.lineSearch). If this function is not
provided, then a simple learningRate is used to produce fixed
size steps. Fixed size steps are much less costly than line
searches, and can be useful for stochastic problems.

The learning rate is used even when a line search is provided.
This is also useful for large-scale stochastic problems, where
opfunc is a noisy approximation of f(x). In that case, the learning
rate allows a reduction of confidence in the step size.

ARGS:

- `opfunc` : a function that takes a single input (X), the point of
         evaluation, and returns f(X) and df/dX
- `x` : the initial point
- `state` : a table describing the state of the optimizer; after each
         call the state is modified
- `state.maxIter` : Maximum number of iterations allowed
- `state.maxEval` : Maximum number of function evaluations
- `state.tolFun` : Termination tolerance on the first-order optimality
- `state.tolX` : Termination tol on progress in terms of func/param changes
- `state.lineSearch` : A line search function
- `state.learningRate` : If no line search provided, then a fixed step size is used

RETURN:
- `x*` : the new `x` vector, at the optimal point
- `f`  : a table of all function values:
     `f[1]` is the value of the function before any optimization and
     `f[#f]` is the final fully optimized value, at `x*`

(Clement Farabet, 2012)
]]
function optim.lbfgs(opfunc, x, config, state)
   -- get/update state
   local config = config or {}
   local state = state or config
   local maxIter = tonumber(config.maxIter) or 20
   local maxEval = tonumber(config.maxEval) or maxIter*1.25
   local tolFun = config.tolFun or 1e-5
   local tolX = config.tolX or 1e-9
   local nCorrection = config.nCorrection or 100
   local lineSearch = config.lineSearch
   local lineSearchOpts = config.lineSearchOptions
   local learningRate = config.learningRate or 1
   local isverbose = config.verbose or false

   state.funcEval = state.funcEval or 0
   state.nIter = state.nIter or 0

   -- verbose function
   local verbose
   if isverbose then
      verbose = function(...) print('<optim.lbfgs> ', ...) end
   else
      verbose = function() end
   end

   -- import some functions
   local abs = math.abs
   local min = math.min

   -- evaluate initial f(x) and df/dx
   local f,g = opfunc(x)
   local f_hist = {f}
   local currentFuncEval = 1
   state.funcEval = state.funcEval + 1
   local p = g:size(1)

   -- check optimality of initial point
   state.tmp1 = state.tmp1 or g.new(g:size()):zero(); local tmp1 = state.tmp1
   tmp1:copy(g):abs()
   if tmp1:sum() <= tolFun then
      -- optimality condition below tolFun
      verbose('optimality condition below tolFun')
      return x,f_hist
   end

   if not state.dir_bufs then
      -- reusable buffers for y's and s's, and their histories
      verbose('creating recyclable direction/step/history buffers')
      state.dir_bufs = state.dir_bufs or g.new(nCorrection+1, p):split(1)
      state.stp_bufs = state.stp_bufs or g.new(nCorrection+1, p):split(1)
      for i=1,#state.dir_bufs do
         state.dir_bufs[i] = state.dir_bufs[i]:squeeze(1)
         state.stp_bufs[i] = state.stp_bufs[i]:squeeze(1)
      end
   end

   -- variables cached in state (for tracing)
   local d = state.d
   local t = state.t
   local old_dirs = state.old_dirs
   local old_stps = state.old_stps
   local Hdiag = state.Hdiag
   local g_old = state.g_old
   local f_old = state.f_old

   -- optimize for a max of maxIter iterations
   local nIter = 0
   while nIter < maxIter do
      -- keep track of nb of iterations
      nIter = nIter + 1
      state.nIter = state.nIter + 1

      ------------------------------------------------------------
      -- compute gradient descent direction
      ------------------------------------------------------------
      if state.nIter == 1 then
         d = g:clone():mul(-1) -- -g
         old_dirs = {}
         old_stps = {}
         Hdiag = 1
      else
         -- do lbfgs update (update memory)
         local y = table.remove(state.dir_bufs)  -- pop
         local s = table.remove(state.stp_bufs)
         y:add(g, -1, g_old)  -- g - g_old
         s:mul(d, t)          -- d*t
         local ys = y:dot(s)  -- y*s
         if ys > 1e-10 then
            -- updating memory
            if #old_dirs == nCorrection then
               -- shift history by one (limited-memory)
               local removed1 = table.remove(old_dirs, 1)
               local removed2 = table.remove(old_stps, 1)
               table.insert(state.dir_bufs, removed1)
               table.insert(state.stp_bufs, removed2)
            end

            -- store new direction/step
            table.insert(old_dirs, s)
            table.insert(old_stps, y)

            -- update scale of initial Hessian approximation
            Hdiag = ys / y:dot(y)  -- (y*y)
         else
            -- put y and s back into the buffer pool
            table.insert(state.dir_bufs, y)
            table.insert(state.stp_bufs, s)
         end

         -- compute the approximate (L-BFGS) inverse Hessian
         -- multiplied by the gradient
         local k = #old_dirs

         -- need to be accessed element-by-element, so don't re-type tensor:
         state.ro = state.ro or torch.Tensor(nCorrection); local ro = state.ro
         for i = 1,k do
            ro[i] = 1 / old_stps[i]:dot(old_dirs[i])
         end

         -- iteration in L-BFGS loop collapsed to use just one buffer
         local q = tmp1  -- reuse tmp1 for the q buffer
         -- need to be accessed element-by-element, so don't re-type tensor:
         state.al = state.al or torch.zeros(nCorrection) local al = state.al

         q:mul(g, -1)  -- -g
         for i = k,1,-1 do
            al[i] = old_dirs[i]:dot(q) * ro[i]
            q:add(-al[i], old_stps[i])
         end

         -- multiply by initial Hessian
         r = d  -- share the same buffer, since we don't need the old d
         r:mul(q, Hdiag)  -- q[1] * Hdiag
         for i = 1,k do
            local be_i = old_stps[i]:dot(r) * ro[i]
            r:add(al[i]-be_i, old_dirs[i])
         end
         -- final direction is in r/d (same object)
      end
      g_old = g_old or g:clone()
      g_old:copy(g)
      f_old = f

      ------------------------------------------------------------
      -- compute step length
      ------------------------------------------------------------
      -- directional derivative
      local gtd = g:dot(d)  -- g * d

      -- check that progress can be made along that direction
      if gtd > -tolX then
         break
      end

      -- reset initial guess for step size
      if state.nIter == 1 then
         tmp1:copy(g):abs()
         t = min(1,1/tmp1:sum()) * learningRate
      else
         t = learningRate
      end

      -- optional line search: user function
      local lsFuncEval = 0
      if lineSearch and type(lineSearch) == 'function' then
         -- perform line search, using user function
         f,g,x,t,lsFuncEval = lineSearch(opfunc,x,t,d,f,g,gtd,lineSearchOpts)
         table.insert(f_hist, f)
      else
         -- no line search, simply move with fixed-step
         x:add(t,d)
         if nIter ~= maxIter then
            -- re-evaluate function only if not in last iteration
            -- the reason we do this: in a stochastic setting,
            -- no use to re-evaluate that function here
            f,g = opfunc(x)
            lsFuncEval = 1
            table.insert(f_hist, f)
         end
      end

      -- update func eval
      currentFuncEval = currentFuncEval + lsFuncEval
      state.funcEval = state.funcEval + lsFuncEval

      ------------------------------------------------------------
      -- check conditions
      ------------------------------------------------------------
      if nIter == maxIter then
         -- no use to run tests
         verbose('reached max number of iterations')
         break
      end

      if currentFuncEval >= maxEval then
         -- max nb of function evals
         verbose('max nb of function evals')
         break
      end

      tmp1:copy(g):abs()
      if tmp1:sum() <= tolFun then
         -- check optimality
         verbose('optimality condition below tolFun')
         break
      end

      tmp1:copy(d):mul(t):abs()
      if tmp1:sum() <= tolX then
         -- step size below tolX
         verbose('step size below tolX')
         break
      end

      if abs(f-f_old) < tolX then
         -- function value changing less than tolX
         verbose('function value changing less than tolX')
         break
      end
   end

   -- save state
   state.old_dirs = old_dirs
   state.old_stps = old_stps
   state.Hdiag = Hdiag
   state.g_old = g_old
   state.f_old = f_old
   state.t = t
   state.d = d

   -- return optimal x, and history of f(x)
   return x,f_hist,currentFuncEval
end