aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/lua-torch/optim/cg.lua
diff options
context:
space:
mode:
Diffstat (limited to 'contrib/lua-torch/optim/cg.lua')
-rw-r--r--contrib/lua-torch/optim/cg.lua208
1 files changed, 208 insertions, 0 deletions
diff --git a/contrib/lua-torch/optim/cg.lua b/contrib/lua-torch/optim/cg.lua
new file mode 100644
index 000000000..842a7d569
--- /dev/null
+++ b/contrib/lua-torch/optim/cg.lua
@@ -0,0 +1,208 @@
+--[[
+
+This cg implementation is a rewrite of minimize.m written by Carl
+E. Rasmussen. It is supposed to produce exactly same results (give
+or take numerical accuracy due to some changed order of
+operations). You can compare the result on rosenbrock with minimize.m.
+http://www.gatsby.ucl.ac.uk/~edward/code/minimize/example.html
+
+ [x fx c] = minimize([0 0]', 'rosenbrock', -25)
+
+Note that we limit the number of function evaluations only, it seems much
+more important in practical use.
+
+ARGS:
+
+- `opfunc` : a function that takes a single input, the point of evaluation.
+- `x` : the initial point
+- `state` : a table of parameters and temporary allocations.
+- `state.maxEval` : max number of function evaluations
+- `state.maxIter` : max number of iterations
+- `state.df[0,1,2,3]` : if you pass torch.Tensor they will be used for temp storage
+- `state.[s,x0]` : if you pass torch.Tensor they will be used for temp storage
+
+RETURN:
+
+- `x*` : the new x vector, at the optimal point
+- `f` : a table of all function values where
+ `f[1]` is the value of the function before any optimization and
+ `f[#f]` is the final fully optimized value, at x*
+
+(Koray Kavukcuoglu, 2012)
+--]]
+function optim.cg(opfunc, x, config, state)
+ -- parameters
+ local config = config or {}
+ local state = state or config
+ local rho = config.rho or 0.01
+ local sig = config.sig or 0.5
+ local int = config.int or 0.1
+ local ext = config.ext or 3.0
+ local maxIter = config.maxIter or 20
+ local ratio = config.ratio or 100
+ local maxEval = config.maxEval or maxIter*1.25
+ local red = 1
+
+ local verbose = config.verbose or 0
+
+ local i = 0
+ local ls_failed = 0
+ local fx = {}
+
+ -- we need three points for the interpolation/extrapolation stuff
+ local z1,z2,z3 = 0,0,0
+ local d1,d2,d3 = 0,0,0
+ local f1,f2,f3 = 0,0,0
+
+ local df1 = state.df1 or x.new()
+ local df2 = state.df2 or x.new()
+ local df3 = state.df3 or x.new()
+ local tdf
+
+ df1:resizeAs(x)
+ df2:resizeAs(x)
+ df3:resizeAs(x)
+
+ -- search direction
+ local s = state.s or x.new()
+ s:resizeAs(x)
+
+ -- we need a temp storage for X
+ local x0 = state.x0 or x.new()
+ local f0 = 0
+ local df0 = state.df0 or x.new()
+ x0:resizeAs(x)
+ df0:resizeAs(x)
+
+ -- evaluate at initial point
+ f1,tdf = opfunc(x)
+ fx[#fx+1] = f1
+ df1:copy(tdf)
+ i=i+1
+
+ -- initial search direction
+ s:copy(df1):mul(-1)
+
+ d1 = -s:dot(s ) -- slope
+ z1 = red/(1-d1) -- initial step
+
+ while i < math.abs(maxEval) do
+
+ x0:copy(x)
+ f0 = f1
+ df0:copy(df1)
+
+ x:add(z1,s)
+ f2,tdf = opfunc(x)
+ df2:copy(tdf)
+ i=i+1
+ d2 = df2:dot(s)
+ f3,d3,z3 = f1,d1,-z1 -- init point 3 equal to point 1
+ local m = math.min(maxIter,maxEval-i)
+ local success = 0
+ local limit = -1
+
+ while true do
+ while (f2 > f1+z1*rho*d1 or d2 > -sig*d1) and m > 0 do
+ limit = z1
+ if f2 > f1 then
+ z2 = z3 - (0.5*d3*z3*z3)/(d3*z3+f2-f3)
+ else
+ local A = 6*(f2-f3)/z3+3*(d2+d3)
+ local B = 3*(f3-f2)-z3*(d3+2*d2)
+ z2 = (math.sqrt(B*B-A*d2*z3*z3)-B)/A
+ end
+ if z2 ~= z2 or z2 == math.huge or z2 == -math.huge then
+ z2 = z3/2;
+ end
+ z2 = math.max(math.min(z2, int*z3),(1-int)*z3);
+ z1 = z1 + z2;
+ x:add(z2,s)
+ f2,tdf = opfunc(x)
+ df2:copy(tdf)
+ i=i+1
+ m = m - 1
+ d2 = df2:dot(s)
+ z3 = z3-z2;
+ end
+ if f2 > f1+z1*rho*d1 or d2 > -sig*d1 then
+ break
+ elseif d2 > sig*d1 then
+ success = 1;
+ break;
+ elseif m == 0 then
+ break;
+ end
+ local A = 6*(f2-f3)/z3+3*(d2+d3);
+ local B = 3*(f3-f2)-z3*(d3+2*d2);
+ z2 = -d2*z3*z3/(B+math.sqrt(B*B-A*d2*z3*z3))
+
+ if z2 ~= z2 or z2 == math.huge or z2 == -math.huge or z2 < 0 then
+ if limit < -0.5 then
+ z2 = z1 * (ext -1)
+ else
+ z2 = (limit-z1)/2
+ end
+ elseif (limit > -0.5) and (z2+z1) > limit then
+ z2 = (limit-z1)/2
+ elseif limit < -0.5 and (z2+z1) > z1*ext then
+ z2 = z1*(ext-1)
+ elseif z2 < -z3*int then
+ z2 = -z3*int
+ elseif limit > -0.5 and z2 < (limit-z1)*(1-int) then
+ z2 = (limit-z1)*(1-int)
+ end
+ f3=f2; d3=d2; z3=-z2;
+ z1 = z1+z2;
+ x:add(z2,s)
+
+ f2,tdf = opfunc(x)
+ df2:copy(tdf)
+ i=i+1
+ m = m - 1
+ d2 = df2:dot(s)
+ end
+ if success == 1 then
+ f1 = f2
+ fx[#fx+1] = f1;
+ local ss = (df2:dot(df2)-df2:dot(df1)) / df1:dot(df1)
+ s:mul(ss)
+ s:add(-1,df2)
+ local tmp = df1:clone()
+ df1:copy(df2)
+ df2:copy(tmp)
+ d2 = df1:dot(s)
+ if d2> 0 then
+ s:copy(df1)
+ s:mul(-1)
+ d2 = -s:dot(s)
+ end
+
+ z1 = z1 * math.min(ratio, d1/(d2-1e-320))
+ d1 = d2
+ ls_failed = 0
+ else
+ x:copy(x0)
+ f1 = f0
+ df1:copy(df0)
+ if ls_failed or i>maxEval then
+ break
+ end
+ local tmp = df1:clone()
+ df1:copy(df2)
+ df2:copy(tmp)
+ s:copy(df1)
+ s:mul(-1)
+ d1 = -s:dot(s)
+ z1 = 1/(1-d1)
+ ls_failed = 1
+ end
+ end
+ state.df0 = df0
+ state.df1 = df1
+ state.df2 = df2
+ state.df3 = df3
+ state.x0 = x0
+ state.s = s
+ return x,fx,i
+end