aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/torch/optim/Logger.lua
blob: 3c1da54d312a1cd2cb8b22cc360580709318950f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
--[[ Logger: a simple class to log symbols during training,
        and automate plot generation

Example:
    logger = optim.Logger('somefile.log')    -- file to save stuff

    for i = 1,N do                           -- log some symbols during
        train_error = ...                     -- training/testing
        test_error = ...
        logger:add{['training error'] = train_error,
            ['test error'] = test_error}
    end

    logger:style{['training error'] = '-',   -- define styles for plots
                 ['test error'] = '-'}
    logger:plot()                            -- and plot

---- OR ---

    logger = optim.Logger('somefile.log')    -- file to save stuff
    logger:setNames{'training error', 'test error'}

    for i = 1,N do                           -- log some symbols during
       train_error = ...                     -- training/testing
       test_error = ...
       logger:add{train_error, test_error}
    end

    logger:style{'-', '-'}                   -- define styles for plots
    logger:plot()                            -- and plot

-----------

    logger:setlogscale(true)                 -- enable logscale on Y-axis
    logger:plot()                            -- and plot
]]
local Logger = torch.class('optim.Logger')

function Logger:__init(filename, timestamp)
   if filename then
      self.name = filename
      os.execute('mkdir ' .. (sys.uname() ~= 'windows' and '-p ' or '') .. ' "' .. paths.dirname(filename) .. '"')
      if timestamp then
         -- append timestamp to create unique log file
         filename = filename .. '-'..os.date("%Y_%m_%d_%X")
      end
      self.file = io.open(filename,'w')
      self.epsfile = self.name .. '.eps'
   else
      self.file = io.stdout
      self.name = 'stdout'
      print('<Logger> warning: no path provided, logging to std out')
   end
   self.empty = true
   self.symbols = {}
   self.styles = {}
   self.names = {}
   self.idx = {}
   self.figure = nil
   self.showPlot = true
   self.plotRawCmd = nil
   self.defaultStyle = '+'
   self.logscale = false
end

function Logger:setNames(names)
   self.names = names
   self.empty = false
   self.nsymbols = #names
   for k,key in pairs(names) do
      self.file:write(key .. '\t')
      self.symbols[k] = {}
      self.styles[k] = {self.defaultStyle}
      self.idx[key] = k
   end
   self.file:write('\n')
   self.file:flush()
   return self
end

function Logger:add(symbols)
   -- (1) first time ? print symbols' names on first row
   if self.empty then
      self.empty = false
      self.nsymbols = #symbols
      for k,val in pairs(symbols) do
         self.file:write(k .. '\t')
         self.symbols[k] = {}
         self.styles[k] = {self.defaultStyle}
         self.names[k] = k
      end
      self.idx = self.names
      self.file:write('\n')
   end
   -- (2) print all symbols on one row
   for k,val in pairs(symbols) do
      if type(val) == 'number' then
         self.file:write(string.format('%11.4e',val) .. '\t')
      elseif type(val) == 'string' then
         self.file:write(val .. '\t')
      else
         xlua.error('can only log numbers and strings', 'Logger')
      end
   end
   self.file:write('\n')
   self.file:flush()
   -- (3) save symbols in internal table
   for k,val in pairs(symbols) do
      table.insert(self.symbols[k], val)
   end
end

function Logger:style(symbols)
   for name,style in pairs(symbols) do
      if type(style) == 'string' then
         self.styles[name] = {style}
      elseif type(style) == 'table' then
         self.styles[name] = style
      else
         xlua.error('style should be a string or a table of strings','Logger')
      end
   end
   return self
end

function Logger:setlogscale(state)
   self.logscale = state
end

function Logger:display(state)
   self.showPlot = state
end

function Logger:plot(...)
   if not xlua.require('gnuplot') then
      if not self.warned then
         print('<Logger> warning: cannot plot with this version of Torch')
         self.warned = true
      end
      return
   end
   local plotit = false
   local plots = {}
   local plotsymbol =
      function(name,list)
         if #list > 1 then
            local nelts = #list
            local plot_y = torch.Tensor(nelts)
            for i = 1,nelts do
               plot_y[i] = list[i]
            end
            for _,style in ipairs(self.styles[name]) do
               table.insert(plots, {self.names[name], plot_y, style})
            end
            plotit = true
         end
      end
   local args = {...}
   if not args[1] then -- plot all symbols
      for name,list in pairs(self.symbols) do
         plotsymbol(name,list)
      end
   else -- plot given symbols
      for _,name in ipairs(args) do
         plotsymbol(self.idx[name], self.symbols[self.idx[name]])
      end
   end
   if plotit then
      if self.showPlot then
         self.figure = gnuplot.figure(self.figure)
         if self.logscale then gnuplot.logscale('on') end
         gnuplot.plot(plots)
         if self.plotRawCmd then gnuplot.raw(self.plotRawCmd) end
         gnuplot.grid('on')
         gnuplot.title('<Logger::' .. self.name .. '>')
      end
      if self.epsfile then
         os.execute('rm -f "' .. self.epsfile .. '"')
         local epsfig = gnuplot.epsfigure(self.epsfile)
         if self.logscale then gnuplot.logscale('on') end
         gnuplot.plot(plots)
         if self.plotRawCmd then gnuplot.raw(self.plotRawCmd) end
         gnuplot.grid('on')
         gnuplot.title('<Logger::' .. self.name .. '>')
         gnuplot.plotflush()
         gnuplot.close(epsfig)
      end
   end
end