aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/lua-torch/optim/lswolfe.lua
blob: 0afbdbe8b29bad9f3aa780db844cbfda026428d2 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
--[[ A Line Search satisfying the Wolfe conditions

ARGS:
- `opfunc` : a function (the objective) that takes a single input (X),
         the point of evaluation, and returns f(X) and df/dX
- `x`          : initial point / starting location
- `t`          : initial step size
- `d`          : descent direction
- `f`          : initial function value
- `g`          : gradient at initial location
- `gtd`        : directional derivative at starting location
- `options.c1` : sufficient decrease parameter
- `options.c2` : curvature parameter
- `options.tolX`    : minimum allowable step length
- `options.maxIter` : maximum nb of iterations

RETURN:
- `f`          : function value at x+t*d
- `g`          : gradient value at x+t*d
- `x`          : the next x (=x+t*d)
- `t`          : the step length
- `lsFuncEval` : the number of function evaluations
]]
function optim.lswolfe(opfunc,x,t,d,f,g,gtd,options)
   -- options
   options = options or {}
   local c1 = options.c1 or 1e-4
   local c2 = options.c2 or 0.9
   local tolX = options.tolX or 1e-9
   local maxIter = options.maxIter or 20
   local isverbose = options.verbose or false

   -- some shortcuts
   local abs = torch.abs
   local min = math.min
   local max = math.max

   -- verbose function
   local function verbose(...)
      if isverbose then print('<optim.lswolfe> ', ...) end
   end

   -- evaluate objective and gradient using initial step
   local x_init = x:clone()
   x:add(t,d)
   local f_new,g_new = opfunc(x)
   local lsFuncEval = 1
   local gtd_new = g_new * d

   -- bracket an interval containing a point satisfying the Wolfe
   -- criteria
   local LSiter,t_prev,done = 0,0,false
   local f_prev,g_prev,gtd_prev = f,g:clone(),gtd
   local bracket,bracketFval,bracketGval
   while LSiter < maxIter do
      -- check conditions:
      if (f_new > (f + c1*t*gtd)) or (LSiter > 1 and f_new >= f_prev) then
         bracket = x.new{t_prev,t}
         bracketFval = x.new{f_prev,f_new}
         bracketGval = x.new(2,g_new:size(1))
         bracketGval[1] = g_prev
         bracketGval[2] = g_new
         break

      elseif abs(gtd_new) <= -c2*gtd then
         bracket = x.new{t}
         bracketFval = x.new{f_new}
         bracketGval = x.new(1,g_new:size(1))
         bracketGval[1] = g_new
         done = true
         break

      elseif gtd_new >= 0 then
         bracket = x.new{t_prev,t}
         bracketFval = x.new{f_prev,f_new}
         bracketGval = x.new(2,g_new:size(1))
         bracketGval[1] = g_prev
         bracketGval[2] = g_new
         break

      end

      -- interpolate:
      local tmp = t_prev
      t_prev = t
      local minStep = t + 0.01*(t-tmp)
      local maxStep = t*10
      t = optim.polyinterp(x.new{{tmp,f_prev,gtd_prev},
                                  {t,f_new,gtd_new}},
                           minStep, maxStep)

      -- next step:
      f_prev = f_new
      g_prev = g_new:clone()
      gtd_prev = gtd_new
      x[{}] = x_init
      x:add(t,d)
      f_new,g_new = opfunc(x)
      lsFuncEval = lsFuncEval + 1
      gtd_new = g_new * d
      LSiter = LSiter + 1
   end

   -- reached max nb of iterations?
   if LSiter == maxIter then
      bracket = x.new{0,t}
      bracketFval = x.new{f,f_new}
      bracketGval = x.new(2,g_new:size(1))
      bracketGval[1] = g
      bracketGval[2] = g_new
   end

   -- zoom phase: we now have a point satisfying the criteria, or
   -- a bracket around it. We refine the bracket until we find the
   -- exact point satisfying the criteria
   local insufProgress = false
   local LOposRemoved = 0
   while not done and LSiter < maxIter do
      -- find high and low points in bracket
      local f_LO,LOpos = bracketFval:min(1)
      LOpos = LOpos[1] f_LO = f_LO[1]
      local HIpos = -LOpos+3

      -- compute new trial value
      t = optim.polyinterp(x.new{{bracket[1],bracketFval[1],bracketGval[1]*d},
                                  {bracket[2],bracketFval[2],bracketGval[2]*d}})

      -- test what we are making sufficient progress
      if min(bracket:max()-t,t-bracket:min())/(bracket:max()-bracket:min()) < 0.1 then
         if insufProgress or t>=bracket:max() or t <= bracket:min() then
            if abs(t-bracket:max()) < abs(t-bracket:min()) then
               t = bracket:max()-0.1*(bracket:max()-bracket:min())
            else
               t = bracket:min()+0.1*(bracket:max()-bracket:min())
            end
            insufProgress = false
         else
            insufProgress = true
         end
      else
         insufProgress = false
      end

      -- Evaluate new point
      x[{}] = x_init
      x:add(t,d)
      f_new,g_new = opfunc(x)
      lsFuncEval = lsFuncEval + 1
      gtd_new = g_new * d
      LSiter = LSiter + 1
      if f_new > f + c1*t*gtd or f_new >= f_LO then
         -- Armijo condition not satisfied or not lower than lowest point
         bracket[HIpos] = t
         bracketFval[HIpos] = f_new
         bracketGval[HIpos] = g_new
      else
         if abs(gtd_new) <= - c2*gtd then
            -- Wolfe conditions satisfied
            done = true
         elseif gtd_new*(bracket[HIpos]-bracket[LOpos]) >= 0 then
            -- Old HI becomes new LO
            bracket[HIpos] = bracket[LOpos]
            bracketFval[HIpos] = bracketFval[LOpos]
            bracketGval[HIpos] = bracketGval[LOpos]
         end
         -- New point becomes new LO
         bracket[LOpos] = t
         bracketFval[LOpos] = f_new
         bracketGval[LOpos] = g_new
      end

      -- done?
      if not done and abs((bracket[1]-bracket[2])*gtd_new) < tolX then
         break
      end
   end

   -- be verbose
   if LSiter == maxIter then
      verbose('reached max number of iterations')
   end

   -- return stuff
   local _,LOpos = bracketFval:min(1)
   LOpos = LOpos[1]
   t = bracket[LOpos]
   f_new = bracketFval[LOpos]
   g_new = bracketGval[LOpos]
   x[{}] = x_init
   x:add(t,d)
	return f_new,g_new,x,t,lsFuncEval
end