@@ -57,7 +57,6 @@ OPTION(ENABLE_JEMALLOC "Build rspamd with jemalloc allocator [default: OFF] | |||
OPTION(ENABLE_COVERAGE "Build rspamd with code coverage options [default: OFF]" OFF) | |||
OPTION(ENABLE_FULL_DEBUG "Build rspamd with all possible debug [default: OFF]" OFF) | |||
OPTION(ENABLE_UTILS "Build rspamd internal utils [default: OFF]" OFF) | |||
OPTION(ENABLE_TORCH "Install torch7 with Rspamd [default: ON]" ON) | |||
OPTION(ENABLE_LIBUNWIND "Use libunwind to print crash traces [default: OFF]" OFF) | |||
OPTION(ENABLE_LUA_TRACE "Trace all Lua C API invocations [default: OFF]" OFF) | |||
@@ -1232,19 +1231,6 @@ IF(ENABLE_CLANG_PLUGIN MATCHES "ON") | |||
ADD_SUBDIRECTORY(clang-plugin) | |||
ENDIF() | |||
IF(ENABLE_TORCH MATCHES "ON") | |||
IF(WITH_LUAJIT) | |||
ADD_SUBDIRECTORY(contrib/lua-torch/paths) | |||
ADD_SUBDIRECTORY(contrib/lua-torch/torch7) | |||
ADD_SUBDIRECTORY(contrib/lua-torch/nn) | |||
ADD_SUBDIRECTORY(contrib/lua-torch/optim) | |||
ADD_SUBDIRECTORY(contrib/lua-torch/decisiontree) | |||
SET(WITH_TORCH 1) | |||
ELSE() | |||
MESSAGE(FATAL_ERROR "Cannot enable torch without luajit") | |||
ENDIF() | |||
ENDIF() | |||
ADD_SUBDIRECTORY(src) | |||
ADD_SUBDIRECTORY(test) | |||
ADD_SUBDIRECTORY(utils) | |||
@@ -1337,10 +1323,6 @@ INSTALL(FILES "contrib/lua-tableshape/tableshape.lua" DESTINATION ${LUALIBDIR}) | |||
INSTALL(FILES "contrib/lua-lupa/lupa.lua" DESTINATION ${LUALIBDIR}) | |||
INSTALL(FILES "contrib/lua-lpeg/lpegre.lua" DESTINATION ${LUALIBDIR}) | |||
IF(ENABLE_TORCH MATCHES "ON") | |||
INSTALL(FILES "contrib/lua-moses/moses.lua" DESTINATION ${LUALIBDIR}) | |||
ENDIF() | |||
# systemd unit | |||
IF(CMAKE_SYSTEM_NAME STREQUAL "Linux" AND WANT_SYSTEMD_UNITS MATCHES "ON") | |||
INSTALL(FILES "rspamd.service" DESTINATION ${SYSTEMDDIR}) |
@@ -1,20 +0,0 @@ | |||
Copyright (c) 2012-2014 Roland Yonaba | |||
Permission is hereby granted, free of charge, to any person obtaining a | |||
copy of this software and associated documentation files (the | |||
"Software"), to deal in the Software without restriction, including | |||
without limitation the rights to use, copy, modify, merge, publish, | |||
distribute, sublicense, and/or sell copies of the Software, and to | |||
permit persons to whom the Software is furnished to do so, subject to | |||
the following conditions: | |||
The above copyright notice and this permission notice shall be included | |||
in all copies or substantial portions of the Software. | |||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS | |||
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF | |||
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. | |||
IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY | |||
CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, | |||
TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE | |||
SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. |
@@ -1,364 +0,0 @@ | |||
local _ba='1.6.1'local aba,bba,cba,dba=next,type,select,pcall;local _ca,aca=setmetatable,getmetatable | |||
local bca,cca=table.insert,table.sort;local dca,_da=table.remove,table.concat | |||
local ada,bda,cda=math.randomseed,math.random,math.huge;local dda,__b,a_b=math.floor,math.max,math.min;local b_b=rawget | |||
local c_b=table.unpack or unpack;local d_b,_ab=pairs,ipairs;local aab=os.clock;local bab={} | |||
local function cab(dcb,_db)return dcb>_db end;local function dab(dcb,_db)return dcb<_db end | |||
local function _bb(dcb,_db,adb)return(dcb<_db)and _db or | |||
(dcb>adb and adb or dcb)end;local function abb(dcb,_db)return _db and true end | |||
local function bbb(dcb)return not dcb end | |||
local function cbb(dcb)local _db=0;for adb,bdb in d_b(dcb)do _db=_db+1 end;return _db end | |||
local function dbb(dcb,_db,adb,...)local bdb;local cdb=adb or bab.identity;for ddb,__c in d_b(dcb)do | |||
if not bdb then bdb=cdb(__c,...)else | |||
local a_c=cdb(__c,...)bdb=_db(bdb,a_c)and bdb or a_c end end;return bdb end | |||
local function _cb(dcb,_db,adb,bdb)for i=0,#dcb,_db do local cdb=bab.slice(dcb,i+1,i+_db) | |||
if#cdb>0 then while | |||
(#cdb<_db and bdb)do cdb[#cdb+1]=bdb end;adb(cdb)end end end | |||
local function acb(dcb,_db,adb,bdb) | |||
for i=0,#dcb,_db-1 do local cdb=bab.slice(dcb,i+1,i+_db)if | |||
#cdb>0 and i+1 <#dcb then while(#cdb<_db and bdb)do cdb[#cdb+1]=bdb end | |||
adb(cdb)end end end | |||
local function bcb(dcb,_db,adb)if _db==0 then adb(dcb)end | |||
for i=1,_db do dcb[_db],dcb[i]=dcb[i],dcb[_db]bcb(dcb,_db- | |||
1,adb)dcb[_db],dcb[i]=dcb[i],dcb[_db]end end;local ccb=-1 | |||
function bab.clear(dcb)for _db in d_b(dcb)do dcb[_db]=nil end;return dcb end | |||
function bab.each(dcb,_db,...)for adb,bdb in d_b(dcb)do _db(adb,bdb,...)end end | |||
function bab.eachi(dcb,_db,...) | |||
local adb=bab.sort(bab.select(bab.keys(dcb),function(bdb,cdb)return bab.isInteger(cdb)end))for bdb,cdb in _ab(adb)do _db(cdb,dcb[cdb],...)end end | |||
function bab.at(dcb,...)local _db={}for adb,bdb in _ab({...})do | |||
if bab.has(dcb,bdb)then _db[#_db+1]=dcb[bdb]end end;return _db end | |||
function bab.count(dcb,_db)if bab.isNil(_db)then return bab.size(dcb)end;local adb=0 | |||
bab.each(dcb,function(bdb,cdb)if | |||
bab.isEqual(cdb,_db)then adb=adb+1 end end)return adb end | |||
function bab.countf(dcb,_db,...)return bab.count(bab.map(dcb,_db,...),true)end | |||
function bab.cycle(dcb,_db)_db=_db or 1;if _db<=0 then return bab.noop end;local adb,bdb;local cdb=0 | |||
while true do | |||
return | |||
function()adb=adb and | |||
aba(dcb,adb)or aba(dcb) | |||
bdb=not bdb and adb or bdb;if _db then cdb=(adb==bdb)and cdb+1 or cdb | |||
if cdb>_db then return end end;return adb,dcb[adb]end end end | |||
function bab.map(dcb,_db,...)local adb={} | |||
for bdb,cdb in d_b(dcb)do local ddb,__c,a_c=bdb,_db(bdb,cdb,...)adb[a_c and __c or ddb]= | |||
a_c or __c end;return adb end;function bab.reduce(dcb,_db,adb) | |||
for bdb,cdb in d_b(dcb)do if adb==nil then adb=cdb else adb=_db(adb,cdb)end end;return adb end;function bab.reduceby(dcb,_db,adb,bdb,...)return | |||
bab.reduce(bab.select(dcb,bdb,...),_db,adb)end;function bab.reduceRight(dcb,_db,adb)return | |||
bab.reduce(bab.reverse(dcb),_db,adb)end | |||
function bab.mapReduce(dcb,_db,adb) | |||
local bdb={}for cdb,ddb in d_b(dcb)do bdb[cdb]=not adb and ddb or _db(adb,ddb) | |||
adb=bdb[cdb]end;return bdb end;function bab.mapReduceRight(dcb,_db,adb) | |||
return bab.mapReduce(bab.reverse(dcb),_db,adb)end | |||
function bab.include(dcb,_db)local adb= | |||
bab.isFunction(_db)and _db or bab.isEqual;for bdb,cdb in d_b(dcb)do if adb(cdb,_db)then | |||
return true end end;return false end | |||
function bab.detect(dcb,_db) | |||
local adb=bab.isFunction(_db)and _db or bab.isEqual;for bdb,cdb in d_b(dcb)do if adb(cdb,_db)then return bdb end end end | |||
function bab.where(dcb,_db) | |||
local adb=bab.select(dcb,function(bdb,cdb) | |||
for ddb in d_b(_db)do if cdb[ddb]~=_db[ddb]then return false end end;return true end)return#adb>0 and adb or nil end | |||
function bab.findWhere(dcb,_db) | |||
local adb=bab.detect(dcb,function(bdb)for cdb in d_b(_db)do | |||
if _db[cdb]~=bdb[cdb]then return false end end;return true end)return adb and dcb[adb]end | |||
function bab.select(dcb,_db,...)local adb={}for bdb,cdb in d_b(dcb)do | |||
if _db(bdb,cdb,...)then adb[#adb+1]=cdb end end;return adb end | |||
function bab.reject(dcb,_db,...)local adb=bab.map(dcb,_db,...)local bdb={}for cdb,ddb in d_b(adb)do if not ddb then | |||
bdb[#bdb+1]=dcb[cdb]end end;return bdb end | |||
function bab.all(dcb,_db,...)return( (#bab.select(bab.map(dcb,_db,...),abb))== | |||
cbb(dcb))end | |||
function bab.invoke(dcb,_db,...)local adb={...} | |||
return | |||
bab.map(dcb,function(bdb,cdb) | |||
if bab.isTable(cdb)then | |||
if bab.has(cdb,_db)then | |||
if | |||
bab.isCallable(cdb[_db])then return cdb[_db](cdb,c_b(adb))else return cdb[_db]end else | |||
if bab.isCallable(_db)then return _db(cdb,c_b(adb))end end elseif bab.isCallable(_db)then return _db(cdb,c_b(adb))end end)end | |||
function bab.pluck(dcb,_db)return | |||
bab.reject(bab.map(dcb,function(adb,bdb)return bdb[_db]end),bbb)end;function bab.max(dcb,_db,...)return dbb(dcb,cab,_db,...)end;function bab.min(dcb,_db,...)return | |||
dbb(dcb,dab,_db,...)end | |||
function bab.shuffle(dcb,_db)if _db then ada(_db)end | |||
local adb={} | |||
bab.each(dcb,function(bdb,cdb)local ddb=dda(bda()*bdb)+1;adb[bdb]=adb[ddb] | |||
adb[ddb]=cdb end)return adb end | |||
function bab.same(dcb,_db) | |||
return | |||
bab.all(dcb,function(adb,bdb)return bab.include(_db,bdb)end)and | |||
bab.all(_db,function(adb,bdb)return bab.include(dcb,bdb)end)end;function bab.sort(dcb,_db)cca(dcb,_db)return dcb end | |||
function bab.sortBy(dcb,_db,adb) | |||
local bdb=_db or bab.identity | |||
if bab.isString(_db)then bdb=function(ddb)return ddb[_db]end end;adb=adb or dab;local cdb={} | |||
bab.each(dcb,function(ddb,__c) | |||
cdb[#cdb+1]={value=__c,transform=bdb(__c)}end) | |||
cca(cdb,function(ddb,__c)return adb(ddb.transform,__c.transform)end)return bab.pluck(cdb,'value')end | |||
function bab.groupBy(dcb,_db,...)local adb={...}local bdb={} | |||
bab.each(dcb,function(cdb,ddb)local __c=_db(cdb,ddb,c_b(adb)) | |||
if | |||
bdb[__c]then bdb[__c][#bdb[__c]+1]=ddb else bdb[__c]={ddb}end end)return bdb end | |||
function bab.countBy(dcb,_db,...)local adb={...}local bdb={} | |||
bab.each(dcb,function(cdb,ddb)local __c=_db(cdb,ddb,c_b(adb))bdb[__c]=( | |||
bdb[__c]or 0)+1 end)return bdb end | |||
function bab.size(...)local dcb={...}local _db=dcb[1]if bab.isTable(_db)then return cbb(dcb[1])else | |||
return cbb(dcb)end end;function bab.containsKeys(dcb,_db) | |||
for adb in d_b(_db)do if not dcb[adb]then return false end end;return true end | |||
function bab.sameKeys(dcb,_db)for adb in | |||
d_b(dcb)do if not _db[adb]then return false end end;for adb in | |||
d_b(_db)do if not dcb[adb]then return false end end | |||
return true end | |||
function bab.sample(dcb,_db,adb)_db=_db or 1;if _db<1 then return end;if _db==1 then if adb then ada(adb)end;return | |||
dcb[bda(1,#dcb)]end;return | |||
bab.slice(bab.shuffle(dcb,adb),1,_db)end | |||
function bab.sampleProb(dcb,_db,adb)if adb then ada(adb)end;return | |||
bab.select(dcb,function(bdb,cdb)return bda()<_db end)end;function bab.toArray(...)return{...}end | |||
function bab.find(dcb,_db,adb)for i=adb or 1,#dcb do if | |||
bab.isEqual(dcb[i],_db)then return i end end end | |||
function bab.reverse(dcb)local _db={}for i=#dcb,1,-1 do _db[#_db+1]=dcb[i]end;return _db end;function bab.fill(dcb,_db,adb,bdb)bdb=bdb or bab.size(dcb) | |||
for i=adb or 1,bdb do dcb[i]=_db end;return dcb end | |||
function bab.selectWhile(dcb,_db,...) | |||
local adb={} | |||
for bdb,cdb in _ab(dcb)do if _db(bdb,cdb,...)then adb[bdb]=cdb else break end end;return adb end | |||
function bab.dropWhile(dcb,_db,...)local adb | |||
for bdb,cdb in _ab(dcb)do if not _db(bdb,cdb,...)then adb=bdb;break end end;if bab.isNil(adb)then return{}end;return bab.rest(dcb,adb)end | |||
function bab.sortedIndex(dcb,_db,adb,bdb)local cdb=adb or dab;if bdb then bab.sort(dcb,cdb)end;for i=1,#dcb do if not | |||
cdb(dcb[i],_db)then return i end end | |||
return#dcb+1 end | |||
function bab.indexOf(dcb,_db)for k=1,#dcb do if dcb[k]==_db then return k end end end | |||
function bab.lastIndexOf(dcb,_db)local adb=bab.indexOf(bab.reverse(dcb),_db)if adb then return | |||
#dcb-adb+1 end end;function bab.findIndex(dcb,_db,...) | |||
for k=1,#dcb do if _db(k,dcb[k],...)then return k end end end | |||
function bab.findLastIndex(dcb,_db,...) | |||
local adb=bab.findIndex(bab.reverse(dcb),_db,...)if adb then return#dcb-adb+1 end end;function bab.addTop(dcb,...) | |||
bab.each({...},function(_db,adb)bca(dcb,1,adb)end)return dcb end;function bab.push(dcb,...)bab.each({...},function(_db,adb) | |||
dcb[#dcb+1]=adb end) | |||
return dcb end | |||
function bab.pop(dcb,_db) | |||
_db=a_b(_db or 1,#dcb)local adb={} | |||
for i=1,_db do local bdb=dcb[1]adb[#adb+1]=bdb;dca(dcb,1)end;return c_b(adb)end | |||
function bab.unshift(dcb,_db)_db=a_b(_db or 1,#dcb)local adb={}for i=1,_db do local bdb=dcb[#dcb] | |||
adb[#adb+1]=bdb;dca(dcb)end;return c_b(adb)end | |||
function bab.pull(dcb,...) | |||
for _db,adb in _ab({...})do for i=#dcb,1,-1 do | |||
if bab.isEqual(dcb[i],adb)then dca(dcb,i)end end end;return dcb end | |||
function bab.removeRange(dcb,_db,adb)local bdb=bab.clone(dcb)local cdb,ddb=(aba(bdb)),#bdb | |||
if ddb<1 then return bdb end;_db=_bb(_db or cdb,cdb,ddb) | |||
adb=_bb(adb or ddb,cdb,ddb)if adb<_db then return bdb end;local __c=adb-_db+1;local a_c=_db;while __c>0 do | |||
dca(bdb,a_c)__c=__c-1 end;return bdb end | |||
function bab.chunk(dcb,_db,...)if not bab.isArray(dcb)then return dcb end;local adb,bdb,cdb={},0 | |||
local ddb=bab.map(dcb,_db,...) | |||
bab.each(ddb,function(__c,a_c)cdb=(cdb==nil)and a_c or cdb;bdb=( | |||
(a_c~=cdb)and(bdb+1)or bdb) | |||
if not adb[bdb]then adb[bdb]={dcb[__c]}else adb[bdb][ | |||
#adb[bdb]+1]=dcb[__c]end;cdb=a_c end)return adb end | |||
function bab.slice(dcb,_db,adb)return | |||
bab.select(dcb,function(bdb)return | |||
(bdb>= (_db or aba(dcb))and bdb<= (adb or#dcb))end)end;function bab.first(dcb,_db)local adb=_db or 1 | |||
return bab.slice(dcb,1,a_b(adb,#dcb))end | |||
function bab.initial(dcb,_db) | |||
if _db and _db<0 then return end;return | |||
bab.slice(dcb,1,_db and#dcb- (a_b(_db,#dcb))or#dcb-1)end;function bab.last(dcb,_db)if _db and _db<=0 then return end | |||
return bab.slice(dcb,_db and | |||
#dcb-a_b(_db-1,#dcb-1)or 2,#dcb)end;function bab.rest(dcb,_db)if _db and | |||
_db>#dcb then return{}end | |||
return bab.slice(dcb, | |||
_db and __b(1,a_b(_db,#dcb))or 1,#dcb)end;function bab.nth(dcb,_db) | |||
return dcb[_db]end;function bab.compact(dcb)return | |||
bab.reject(dcb,function(_db,adb)return not adb end)end | |||
function bab.flatten(dcb,_db)local adb= | |||
_db or false;local bdb;local cdb={} | |||
for ddb,__c in d_b(dcb)do | |||
if bab.isTable(__c)then bdb=adb and __c or | |||
bab.flatten(__c) | |||
bab.each(bdb,function(a_c,b_c)cdb[#cdb+1]=b_c end)else cdb[#cdb+1]=__c end end;return cdb end | |||
function bab.difference(dcb,_db)if not _db then return bab.clone(dcb)end;return | |||
bab.select(dcb,function(adb,bdb)return not | |||
bab.include(_db,bdb)end)end | |||
function bab.union(...)return bab.uniq(bab.flatten({...}))end | |||
function bab.intersection(dcb,...)local _db={...}local adb={} | |||
for bdb,cdb in _ab(dcb)do if | |||
bab.all(_db,function(ddb,__c)return bab.include(__c,cdb)end)then bca(adb,cdb)end end;return adb end | |||
function bab.symmetricDifference(dcb,_db)return | |||
bab.difference(bab.union(dcb,_db),bab.intersection(dcb,_db))end | |||
function bab.unique(dcb)local _db={}for i=1,#dcb do if not bab.find(_db,dcb[i])then | |||
_db[#_db+1]=dcb[i]end end;return _db end | |||
function bab.isunique(dcb)return bab.isEqual(dcb,bab.unique(dcb))end | |||
function bab.zip(...)local dcb={...} | |||
local _db=bab.max(bab.map(dcb,function(bdb,cdb)return#cdb end))local adb={}for i=1,_db do adb[i]=bab.pluck(dcb,i)end;return adb end | |||
function bab.append(dcb,_db)local adb={}for bdb,cdb in _ab(dcb)do adb[bdb]=cdb end;for bdb,cdb in _ab(_db)do | |||
adb[#adb+1]=cdb end;return adb end | |||
function bab.interleave(...)return bab.flatten(bab.zip(...))end;function bab.interpose(dcb,_db)return | |||
bab.flatten(bab.zip(_db,bab.rep(dcb,#_db-1)))end | |||
function bab.range(...) | |||
local dcb={...}local _db,adb,bdb | |||
if#dcb==0 then return{}elseif#dcb==1 then adb,_db,bdb=dcb[1],0,1 elseif#dcb==2 then | |||
_db,adb,bdb=dcb[1],dcb[2],1 elseif#dcb==3 then _db,adb,bdb=dcb[1],dcb[2],dcb[3]end;if(bdb and bdb==0)then return{}end;local cdb={} | |||
local ddb=__b(dda((adb-_db)/bdb),0)for i=1,ddb do cdb[#cdb+1]=_db+bdb*i end;if#cdb>0 then | |||
bca(cdb,1,_db)end;return cdb end | |||
function bab.rep(dcb,_db)local adb={}for i=1,_db do adb[#adb+1]=dcb end;return adb end;function bab.partition(dcb,_db,adb)if _db<=0 then return end | |||
return coroutine.wrap(function() | |||
_cb(dcb,_db or 1,coroutine.yield,adb)end)end;function bab.sliding(dcb,_db,adb)if | |||
_db<=1 then return end | |||
return coroutine.wrap(function() | |||
acb(dcb,_db or 2,coroutine.yield,adb)end)end | |||
function bab.permutation(dcb)return | |||
coroutine.wrap(function()bcb(dcb, | |||
#dcb,coroutine.yield)end)end;function bab.invert(dcb)local _db={} | |||
bab.each(dcb,function(adb,bdb)_db[bdb]=adb end)return _db end | |||
function bab.concat(dcb,_db,adb,bdb) | |||
local cdb=bab.map(dcb,function(ddb,__c)return | |||
tostring(__c)end)return _da(cdb,_db,adb or 1,bdb or#dcb)end;function bab.noop()return end;function bab.identity(dcb)return dcb end;function bab.constant(dcb)return | |||
function()return dcb end end | |||
function bab.memoize(dcb,_db) | |||
local adb=_ca({},{__mode='kv'})local bdb=_db or bab.identity;return | |||
function(...)local cdb=bdb(...)local ddb=adb[cdb]if not ddb then | |||
adb[cdb]=dcb(...)end;return adb[cdb]end end;function bab.once(dcb)local _db=0;local adb={} | |||
return function(...)_db=_db+1;if _db<=1 then adb={...}end | |||
return dcb(c_b(adb))end end | |||
function bab.before(dcb,_db) | |||
local adb=0;local bdb={}return | |||
function(...)adb=adb+1;if adb<=_db then bdb={...}end;return dcb(c_b(bdb))end end | |||
function bab.after(dcb,_db)local adb,bdb=_db,0;return | |||
function(...)bdb=bdb+1;if bdb>=adb then return dcb(...)end end end | |||
function bab.compose(...)local dcb=bab.reverse{...} | |||
return function(...)local _db,adb=true | |||
for bdb,cdb in _ab(dcb)do if _db then _db=false | |||
adb=cdb(...)else adb=cdb(adb)end end;return adb end end | |||
function bab.pipe(dcb,...)return bab.compose(...)(dcb)end | |||
function bab.complement(dcb)return function(...)return not dcb(...)end end;function bab.juxtapose(dcb,...)local _db={} | |||
bab.each({...},function(adb,bdb)_db[#_db+1]=bdb(dcb)end)return c_b(_db)end | |||
function bab.wrap(dcb,_db)return function(...)return | |||
_db(dcb,...)end end | |||
function bab.times(dcb,_db,...)local adb={}for i=1,dcb do adb[i]=_db(i,...)end;return adb end | |||
function bab.bind(dcb,_db)return function(...)return dcb(_db,...)end end;function bab.bind2(dcb,_db) | |||
return function(adb,...)return dcb(adb,_db,...)end end;function bab.bindn(dcb,...)local _db={...} | |||
return function(...)return | |||
dcb(c_b(bab.append(_db,{...})))end end | |||
function bab.bindAll(dcb,...)local _db={...} | |||
for adb,bdb in | |||
_ab(_db)do local cdb=dcb[bdb]if cdb then dcb[bdb]=bab.bind(cdb,dcb)end end;return dcb end | |||
function bab.uniqueId(dcb,...)ccb=ccb+1 | |||
if dcb then if bab.isString(dcb)then return dcb:format(ccb)elseif | |||
bab.isFunction(dcb)then return dcb(ccb,...)end end;return ccb end | |||
function bab.iterator(dcb,_db)return function()_db=dcb(_db)return _db end end | |||
function bab.array(...)local dcb={}for _db in...do dcb[#dcb+1]=_db end;return dcb end;function bab.flip(dcb)return | |||
function(...)return dcb(c_b(bab.reverse({...})))end end;function bab.over(...) | |||
local dcb={...} | |||
return function(...)local _db={}for adb,bdb in _ab(dcb)do _db[#_db+1]=bdb(...)end | |||
return _db end end;function bab.overEvery(...) | |||
local dcb=bab.over(...) | |||
return function(...)return | |||
bab.reduce(dcb(...),function(_db,adb)return _db and adb end)end end;function bab.overSome(...) | |||
local dcb=bab.over(...) | |||
return function(...)return | |||
bab.reduce(dcb(...),function(_db,adb)return _db or adb end)end end | |||
function bab.overArgs(dcb,...) | |||
local _db={...}return | |||
function(...)local adb={...}for i=1,#_db do local bdb=_db[i] | |||
if adb[i]then adb[i]=bdb(adb[i])end end;return dcb(c_b(adb))end end | |||
function bab.partial(dcb,...)local _db={...} | |||
return | |||
function(...)local adb={...}local bdb={}for cdb,ddb in _ab(_db)do bdb[cdb]= | |||
(ddb=='_')and bab.pop(adb)or ddb end;return | |||
dcb(c_b(bab.append(bdb,adb)))end end | |||
function bab.partialRight(dcb,...)local _db={...} | |||
return | |||
function(...)local adb={...}local bdb={} | |||
for k=1,#_db do bdb[k]= | |||
(_db[k]=='_')and bab.pop(adb)or _db[k]end;return dcb(c_b(bab.append(adb,bdb)))end end | |||
function bab.curry(dcb,_db)_db=_db or 2;local adb={} | |||
local function bdb(cdb)if _db==1 then return dcb(cdb)end;if cdb~=nil then | |||
adb[#adb+1]=cdb end;if#adb<_db then return bdb else local ddb={dcb(c_b(adb))}adb={}return | |||
c_b(ddb)end end;return bdb end | |||
function bab.time(dcb,...)local _db=aab()local adb={dcb(...)}return aab()-_db,c_b(adb)end;function bab.keys(dcb)local _db={} | |||
bab.each(dcb,function(adb)_db[#_db+1]=adb end)return _db end;function bab.values(dcb)local _db={} | |||
bab.each(dcb,function(adb,bdb)_db[ | |||
#_db+1]=bdb end)return _db end;function bab.kvpairs(dcb)local _db={} | |||
bab.each(dcb,function(adb,bdb)_db[ | |||
#_db+1]={adb,bdb}end)return _db end | |||
function bab.toObj(dcb)local _db={}for adb,bdb in | |||
_ab(dcb)do _db[bdb[1]]=bdb[2]end;return _db end | |||
function bab.property(dcb)return function(_db)return _db[dcb]end end | |||
function bab.propertyOf(dcb)return function(_db)return dcb[_db]end end;function bab.toBoolean(dcb)return not not dcb end | |||
function bab.extend(dcb,...)local _db={...} | |||
bab.each(_db,function(adb,bdb)if | |||
bab.isTable(bdb)then | |||
bab.each(bdb,function(cdb,ddb)dcb[cdb]=ddb end)end end)return dcb end | |||
function bab.functions(dcb,_db)dcb=dcb or bab;local adb={} | |||
bab.each(dcb,function(cdb,ddb)if bab.isFunction(ddb)then | |||
adb[#adb+1]=cdb end end)if not _db then return bab.sort(adb)end;local bdb=aca(dcb) | |||
if | |||
bdb and bdb.__index then local cdb=bab.functions(bdb.__index)bab.each(cdb,function(ddb,__c) | |||
adb[#adb+1]=__c end)end;return bab.sort(adb)end | |||
function bab.clone(dcb,_db)if not bab.isTable(dcb)then return dcb end;local adb={} | |||
bab.each(dcb,function(bdb,cdb)if | |||
bab.isTable(cdb)then | |||
if not _db then adb[bdb]=bab.clone(cdb,_db)else adb[bdb]=cdb end else adb[bdb]=cdb end end)return adb end;function bab.tap(dcb,_db,...)_db(dcb,...)return dcb end;function bab.has(dcb,_db)return | |||
dcb[_db]~=nil end | |||
function bab.pick(dcb,...)local _db=bab.flatten{...} | |||
local adb={} | |||
bab.each(_db,function(bdb,cdb) | |||
if not bab.isNil(dcb[cdb])then adb[cdb]=dcb[cdb]end end)return adb end | |||
function bab.omit(dcb,...)local _db=bab.flatten{...}local adb={} | |||
bab.each(dcb,function(bdb,cdb)if | |||
not bab.include(_db,bdb)then adb[bdb]=cdb end end)return adb end;function bab.template(dcb,_db) | |||
bab.each(_db or{},function(adb,bdb)if not dcb[adb]then dcb[adb]=bdb end end)return dcb end | |||
function bab.isEqual(dcb,_db,adb) | |||
local bdb=bba(dcb)local cdb=bba(_db)if bdb~=cdb then return false end | |||
if bdb~='table'then return(dcb==_db)end;local ddb=aca(dcb)local __c=aca(_db)if adb then | |||
if | |||
(ddb or __c)and(ddb.__eq or __c.__eq)then return | |||
ddb.__eq(dcb,_db)or __c.__eq(_db,dcb)or(dcb==_db)end end;if bab.size(dcb)~= | |||
bab.size(_db)then return false end;for a_c,b_c in d_b(dcb)do local c_c=_db[a_c] | |||
if | |||
bab.isNil(c_c)or not bab.isEqual(b_c,c_c,adb)then return false end end | |||
for a_c,b_c in d_b(_db)do | |||
local c_c=dcb[a_c]if bab.isNil(c_c)then return false end end;return true end | |||
function bab.result(dcb,_db,...) | |||
if dcb[_db]then if bab.isCallable(dcb[_db])then return dcb[_db](dcb,...)else return | |||
dcb[_db]end end;if bab.isCallable(_db)then return _db(dcb,...)end end;function bab.isTable(dcb)return bba(dcb)=='table'end | |||
function bab.isCallable(dcb)return | |||
( | |||
bab.isFunction(dcb)or | |||
(bab.isTable(dcb)and aca(dcb)and aca(dcb).__call~=nil)or false)end | |||
function bab.isArray(dcb)if not bab.isTable(dcb)then return false end;local _db=0 | |||
for adb in | |||
d_b(dcb)do _db=_db+1;if bab.isNil(dcb[_db])then return false end end;return true end | |||
function bab.isIterable(dcb)return bab.toBoolean((dba(d_b,dcb)))end | |||
function bab.isEmpty(dcb)if bab.isNil(dcb)then return true end;if bab.isString(dcb)then | |||
return#dcb==0 end | |||
if bab.isTable(dcb)then return aba(dcb)==nil end;return true end;function bab.isString(dcb)return bba(dcb)=='string'end;function bab.isFunction(dcb)return | |||
bba(dcb)=='function'end;function bab.isNil(dcb) | |||
return dcb==nil end | |||
function bab.isNumber(dcb)return bba(dcb)=='number'end | |||
function bab.isNaN(dcb)return bab.isNumber(dcb)and dcb~=dcb end | |||
function bab.isFinite(dcb)if not bab.isNumber(dcb)then return false end;return | |||
dcb>-cda and dcb<cda end;function bab.isBoolean(dcb)return bba(dcb)=='boolean'end | |||
function bab.isInteger(dcb)return | |||
bab.isNumber(dcb)and dda(dcb)==dcb end | |||
do bab.forEach=bab.each;bab.forEachi=bab.eachi;bab.loop=bab.cycle | |||
bab.collect=bab.map;bab.inject=bab.reduce;bab.foldl=bab.reduce | |||
bab.injectr=bab.reduceRight;bab.foldr=bab.reduceRight;bab.mapr=bab.mapReduce | |||
bab.maprr=bab.mapReduceRight;bab.any=bab.include;bab.some=bab.include;bab.contains=bab.include | |||
bab.filter=bab.select;bab.discard=bab.reject;bab.every=bab.all | |||
bab.takeWhile=bab.selectWhile;bab.rejectWhile=bab.dropWhile;bab.shift=bab.pop;bab.remove=bab.pull | |||
bab.rmRange=bab.removeRange;bab.chop=bab.removeRange;bab.sub=bab.slice;bab.head=bab.first | |||
bab.take=bab.first;bab.tail=bab.rest;bab.skip=bab.last;bab.without=bab.difference | |||
bab.diff=bab.difference;bab.symdiff=bab.symmetricDifference;bab.xor=bab.symmetricDifference | |||
bab.uniq=bab.unique;bab.isuniq=bab.isunique;bab.transpose=bab.zip;bab.part=bab.partition | |||
bab.perm=bab.permutation;bab.mirror=bab.invert;bab.join=bab.concat;bab.cache=bab.memoize | |||
bab.juxt=bab.juxtapose;bab.uid=bab.uniqueId;bab.iter=bab.iterator;bab.methods=bab.functions | |||
bab.choose=bab.pick;bab.drop=bab.omit;bab.defaults=bab.template;bab.compare=bab.isEqual end | |||
do local dcb={}local _db={}_db.__index=dcb;local function adb(bdb)local cdb={_value=bdb,_wrapped=true} | |||
return _ca(cdb,_db)end | |||
_ca(_db,{__call=function(bdb,cdb)return adb(cdb)end,__index=function(bdb,cdb,...)return | |||
dcb[cdb]end})function _db.chain(bdb)return adb(bdb)end | |||
function _db:value()return self._value end;dcb.chain,dcb.value=_db.chain,_db.value | |||
for bdb,cdb in d_b(bab)do | |||
dcb[bdb]=function(ddb,...)local __c=bab.isTable(ddb)and | |||
ddb._wrapped or false | |||
if __c then | |||
local a_c=ddb._value;local b_c=cdb(a_c,...)return adb(b_c)else return cdb(ddb,...)end end end | |||
dcb.import=function(bdb,cdb)bdb=bdb or _ENV or _G;local ddb=bab.functions() | |||
bab.each(ddb,function(__c,a_c) | |||
if | |||
b_b(bdb,a_c)then if not cdb then bdb[a_c]=bab[a_c]end else bdb[a_c]=bab[a_c]end end)return bdb end;_db._VERSION='Moses v'.._ba | |||
_db._URL='http://github.com/Yonaba/Moses' | |||
_db._LICENSE='MIT <http://raw.githubusercontent.com/Yonaba/Moses/master/LICENSE>'_db._DESCRIPTION='utility-belt library for functional programming in Lua'return | |||
_db end |
@@ -1,51 +0,0 @@ | |||
LIST(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}") | |||
SET(src | |||
init.c | |||
hash_map.c | |||
) | |||
SET(luasrc | |||
_env.lua | |||
benchmark.lua | |||
CartNode.lua | |||
CartTrainer.lua | |||
CartTree.lua | |||
DataSet.lua | |||
DecisionForest.lua | |||
DecisionForestTrainer.lua | |||
DecisionTree.lua | |||
DFD.lua | |||
GiniState.lua | |||
GradientBoostState.lua | |||
GradientBoostTrainer.lua | |||
init.lua | |||
LogitBoostCriterion.lua | |||
math.lua | |||
MSECriterion.lua | |||
RandomForestTrainer.lua | |||
Sparse2Dense.lua | |||
SparseTensor.lua | |||
test.lua | |||
TreeState.lua | |||
utils.lua | |||
WorkPool.lua | |||
) | |||
IF (WITH_OPENMP) | |||
FIND_PACKAGE(OpenMP) | |||
IF(OPENMP_FOUND) | |||
MESSAGE(STATUS "Compiling with OpenMP support") | |||
SET(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}") | |||
SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}") | |||
SET(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} ${OpenMP_EXE_LINKER_FLAGS}") | |||
ENDIF(OPENMP_FOUND) | |||
ENDIF (WITH_OPENMP) | |||
ADD_TORCH_PACKAGE(decisiontree "${src}" "${luasrc}" "A decision tree library, for Torch") | |||
INCLUDE_DIRECTORIES(${CMAKE_CURRENT_SOURCE_DIR}) | |||
### Torch packages supposes libraries prefix is "lib" | |||
SET_TARGET_PROPERTIES(decisiontree PROPERTIES | |||
PREFIX "lib" | |||
IMPORT_PREFIX "lib") | |||
TARGET_LINK_LIBRARIES(decisiontree ${TH_LIBRARIES}) | |||
INSTALL(TARGETS decisiontree DESTINATION ${RSPAMD_LIBDIR}) |
@@ -1,42 +0,0 @@ | |||
local dt = require 'decisiontree._env' | |||
local CartNode = torch.class("dt.CartNode", dt) | |||
function CartNode:__init(nodeId, leftChild, rightChild, splitFeatureId, splitFeatureValue, score, splitGain) | |||
self.nodeId = nodeId or 0 | |||
self.leftChild = leftChild | |||
self.rightChild = rightChild | |||
self.splitFeatureId = splitFeatureId or -1 | |||
self.splitFeatureValue = splitFeatureValue or 0 | |||
self.score = score or 0 | |||
self.splitGain = splitGain | |||
end | |||
function CartNode:__tostring__() | |||
return self:recursivetostring() | |||
end | |||
function CartNode:recursivetostring(indent) | |||
indent = indent or ' ' | |||
-- Is this a leaf node? | |||
local res = '' | |||
if not (self.leftChild or self.rightChild) then | |||
res = res .. self.score .. '\n' | |||
else | |||
-- Print the criteria | |||
res = res .. 'input[' .. self.splitFeatureId .. '] <' .. self.splitFeatureValue .. '?\n' | |||
-- Print the branches | |||
if self.leftChild then | |||
res = res .. indent .. 'True->' .. self.leftChild:recursivetostring(indent .. ' ') | |||
end | |||
if self.rightChild then | |||
res = res .. indent .. 'False->' .. self.rightChild:recursivetostring(indent .. ' ') | |||
end | |||
end | |||
return res | |||
end | |||
function CartNode:clone() | |||
return CartNode(self.nodeId, self.leftChild, self.rightChild, self.splitFeatureId, self.splitFeatureValue, self.score, self.splitGain) | |||
end |
@@ -1,180 +0,0 @@ | |||
local dt = require "decisiontree._env" | |||
local _ = require "moses" | |||
local CartTrainer = torch.class("dt.CartTrainer", dt) | |||
function CartTrainer:__init(dataset, minLeafSize, maxLeafNodes) | |||
assert(torch.isTypeOf(dataset, 'dt.DataSet')) | |||
self.dataset = dataset | |||
self.minLeafSize = assert(minLeafSize) -- min examples per leaf | |||
self.maxLeafNodes = assert(maxLeafNodes) -- max leaf nodes in tree | |||
-- by default, single thread | |||
self.parallelMode = 'singlethread' | |||
end | |||
function CartTrainer:train(rootTreeState, activeFeatures) | |||
assert(torch.isTypeOf(rootTreeState, 'dt.TreeState')) | |||
assert(torch.isTensor(activeFeatures)) | |||
local root = dt.CartNode() | |||
root.id = 0 | |||
root.score = rootTreeState:score(self.dataset) | |||
local nleaf = 1 | |||
-- TODO : nodeparallel: parallelize here. The queue is a workqueue. | |||
local queue = {} | |||
table.insert(queue, 1, {cartNode=root, treeState=rootTreeState}) | |||
while #queue > 0 and nleaf < self.maxLeafNodes do | |||
local treeGrowerArgs = table.remove(queue, #queue) | |||
local currentTreeState = treeGrowerArgs.treeState | |||
-- Note: if minLeafSize = 1 and maxLeafNode = inf, then each example will be its own leaf... | |||
if self:hasEnoughTrainingExamplesToSplit(currentTreeState.exampleIds:size(1)) then | |||
nleaf = self:processNode(nleaf, queue, treeGrowerArgs.cartNode, currentTreeState, activeFeatures) | |||
end | |||
end | |||
-- CartTree with random branching (when feature is missing) | |||
local branchleft = function() return math.random() < 0.5 end | |||
return dt.CartTree(root, branchleft), nleaf | |||
end | |||
function CartTrainer:processNode(nleaf, queue, node, treeState, activeFeatures) | |||
local bestSplit | |||
if self.parallelMode == 'singlethread' then | |||
bestSplit = self:findBestSplitForAllFeatures(treeState, activeFeatures) | |||
elseif self.parallelMode == 'featureparallel' then | |||
bestSplit = self:findBestSplitForAllFeaturesFP(treeState, activeFeatures) | |||
else | |||
error("Unrecognized parallel mode: " .. self.parallelMode) | |||
end | |||
if bestSplit then | |||
local leftTreeState, rightTreeState = treeState:branch(bestSplit, self.dataset) | |||
assert(bestSplit.leftChildSize + bestSplit.rightChildSize == leftTreeState.exampleIds:size(1) + rightTreeState.exampleIds:size(1), "The left and right subtrees don't match the split found!") | |||
self:setValuesAndCreateChildrenForNode(node, bestSplit, leftTreeState, rightTreeState, nleaf) | |||
table.insert(queue, 1, {cartNode=node.leftChild, treeState=leftTreeState}) | |||
table.insert(queue, 1, {cartNode=node.rightChild, treeState=rightTreeState}) | |||
return nleaf + 1 | |||
end | |||
return nleaf | |||
end | |||
function CartTrainer:findBestSplitForAllFeatures(treeState, activeFeatures) | |||
local timer = torch.Timer() | |||
local bestSplit = treeState:findBestSplit(self.dataset, activeFeatures, self.minLeafSize, -1, -1) | |||
if bestSplit then | |||
assert(torch.type(bestSplit) == 'table') | |||
end | |||
if dt.PROFILE then | |||
print("findBestSplitForAllFeatures time="..timer:time().real) | |||
end | |||
return bestSplit | |||
end | |||
function CartTrainer:setValuesAndCreateChildrenForNode(parentNode, bestSplit, leftState, rightState, nleaf) | |||
assert(torch.isTypeOf(parentNode, 'dt.CartNode')) | |||
assert(torch.type(bestSplit) == 'table') | |||
assert(torch.isTypeOf(leftState, 'dt.TreeState')) | |||
assert(torch.isTypeOf(rightState, 'dt.TreeState')) | |||
assert(torch.type(nleaf) == 'number') | |||
local leftChild = dt.CartNode() | |||
leftChild.score = leftState:score(self.dataset) | |||
leftChild.nodeId = 2 * nleaf - 1 | |||
local rightChild = dt.CartNode() | |||
rightChild.score = rightState:score(self.dataset) | |||
rightChild.nodeId = 2 * nleaf | |||
parentNode.splitFeatureId = bestSplit.splitId | |||
parentNode.splitFeatureValue = bestSplit.splitValue | |||
parentNode.leftChild = leftChild | |||
parentNode.rightChild = rightChild | |||
parentNode.splitGain = bestSplit.splitGain | |||
end | |||
function CartTrainer:hasEnoughTrainingExamplesToSplit(count) | |||
return count >= 2 * self.minLeafSize | |||
end | |||
function CartTrainer:featureParallel(workPool) | |||
assert(self.parallelMode == 'singlethread', self.parallelMode) | |||
self.parallelMode = 'featureparallel' | |||
self.workPool = torch.type(workPool) == 'number' and dt.WorkPool(workPool) or workPool | |||
assert(torch.isTypeOf(self.workPool, 'dt.WorkPool')) | |||
-- this deletes all SparseTensor hash maps so that they aren't serialized | |||
self.dataset:deleteIndex() | |||
-- require the dt package | |||
self.workPool:update('require', {libname='decisiontree',varname='dt'}) | |||
-- setup worker store (each worker will have its own copy) | |||
local store = { | |||
dataset=self.dataset, | |||
minLeafSize=self.minLeafSize | |||
} | |||
self.workPool:update('storeKeysValues', store) | |||
end | |||
function CartTrainer:findBestSplitForAllFeaturesFP(treeState, activeFeatures) | |||
local timer = torch.Timer() | |||
local bestSplit | |||
if treeState.findBestSplitFP then | |||
bestSplit = treeState:findBestSplitFP(self.dataset, activeFeatures, self.minLeafSize, self.workPool.nThread) | |||
end | |||
if not bestSplit then | |||
for i=1,self.workPool.nThread do | |||
-- upvalues | |||
local treeState = treeState | |||
local shardId = i | |||
local nShard = self.workPool.nThread | |||
local featureIds = activeFeatures | |||
-- closure | |||
local task = function(store) | |||
assert(store.dataset) | |||
assert(store.minLeafSize) | |||
if treeState.threadInitialize then | |||
treeState:threadInitialize() | |||
end | |||
local bestSplit = treeState:findBestSplit(store.dataset, featureIds, store.minLeafSize, shardId, nShard) | |||
return bestSplit | |||
end | |||
self.workPool:writeup('execute', task) | |||
end | |||
for i=1,self.workPool.nThread do | |||
local taskname, candidateSplit = self.workPool:read() | |||
assert(taskname == 'execute') | |||
if candidateSplit then | |||
if ((not bestSplit) or candidateSplit.splitGain < bestSplit.splitGain) then | |||
bestSplit = candidateSplit | |||
end | |||
end | |||
end | |||
end | |||
if bestSplit then | |||
assert(torch.type(bestSplit) == 'table') | |||
end | |||
if dt.PROFILE then | |||
print("findBestSplitForAllFeaturesFP time="..timer:time().real) | |||
end | |||
return bestSplit | |||
end |
@@ -1,90 +0,0 @@ | |||
local _ = require "moses" | |||
local dt = require 'decisiontree._env' | |||
local CartTree = torch.class("dt.CartTree", "dt.DecisionTree", dt) | |||
function CartTree:__init(root, branchleft) | |||
assert(torch.isTypeOf(root, 'dt.CartNode')) | |||
self.root = root | |||
self.branchleft = branchleft or function() return true end | |||
end | |||
function CartTree:score(input, stack, optimized) | |||
if optimized == true and stack == nil and torch.isTensor(input) and input.isContiguous and input:isContiguous() and input:nDimension() == 2 then | |||
return input.nn.CartTreeFastScore(input, self.root, input.new()) | |||
end | |||
return self:recursivescore(self.root, input, stack) | |||
end | |||
function CartTree:recursivescore(node, input, stack) | |||
assert(torch.isTypeOf(node, 'dt.CartNode')) | |||
if stack then | |||
stack = torch.type(stack) == 'table' and stack or {} | |||
table.insert(stack, node) | |||
end | |||
if not (node.leftChild or node.rightChild) then | |||
return node.score, node.nodeId, stack | |||
elseif not node.leftChild then | |||
return self:recursivescore(node.rightChild, input, stack) | |||
elseif not node.rightChild then | |||
return self:recursivescore(node.leftChild, input, stack) | |||
end | |||
local splitId = node.splitFeatureId | |||
local splitVal = node.splitFeatureValue | |||
if input[splitId] then -- if has key | |||
local featureVal = input[splitId] | |||
local nextNode = featureVal < splitVal and node.leftChild or node.rightChild | |||
return self:recursivescore(nextNode, input, stack) | |||
end | |||
-- if feature is missing, branch left | |||
local nextNode = self.branchleft() and node.leftChild or node.rightChild | |||
return self:recursivescore(nextNode, input, stack) | |||
end | |||
function CartTree:__tostring__() | |||
return self.root:recursivetostring() | |||
end | |||
function CartTree:stackToString(stack, input) | |||
assert(torch.type(stack) == 'table') | |||
assert(torch.isTypeOf(stack[1], 'dt.CartNode')) | |||
local res = 'Stack nodes from root to leaf\n' | |||
for i,node in ipairs(stack) do | |||
if not (node.leftChild or node.rightChild) then | |||
res = res .. "score="..node.score .. '\n' | |||
else | |||
local istr = '' | |||
if input then | |||
istr = '=' .. (input[node.splitFeatureId] or 'nil') | |||
end | |||
res = res .. 'input[' .. node.splitFeatureId .. ']' .. istr ..' < ' .. node.splitFeatureValue .. ' ? ' | |||
res = res .. '(' .. ((node.leftChild and node.rightChild) and 'LR' or node.leftChild and 'L' or node.rightChild and 'R' or 'WAT?') .. ') ' | |||
if node.leftChild == stack[i+1] then | |||
res = res .. 'Left\n' | |||
elseif node.rightChild == stack[i+1] then | |||
res = res .. 'Right\n' | |||
else | |||
error"stackToString error" | |||
end | |||
end | |||
end | |||
return res .. #stack .. " nodes" | |||
end | |||
function CartTree:clone() | |||
return CartTree(self.root:clone(), self.branchleft) | |||
end | |||
@@ -1,182 +0,0 @@ | |||
local DFD, parent = torch.class("nn.DFD", "nn.Module") | |||
function DFD:__init(df, onlyLastNode) | |||
parent.__init(self) | |||
if torch.type(df) == 'table' then | |||
self:reconstructFromInfo(df) | |||
else | |||
assert(torch.type(df) == 'dt.DecisionForest') | |||
self.rootIds = torch.LongTensor() | |||
-- nodeId of left and right child nodes | |||
self.leftChild = torch.LongTensor() | |||
self.rightChild = torch.LongTensor() | |||
-- index and value of the feature that splits this node | |||
self.splitFeatureId = torch.LongTensor() | |||
self.splitFeatureValue = torch.Tensor() | |||
-- initialize state given df | |||
self:convertForest2Tensors(df) | |||
self:clearState() | |||
end | |||
self.onlyLastNode = onlyLastNode | |||
self.nTrees = self.rootIds:size(1) | |||
end | |||
function DFD:convertForest2Tensors(df) | |||
self.rootIds:resize(#df.trees) | |||
-- nodeId will map to featureId | |||
local nodeId = 0 | |||
-- sets nodeIds of all subnodes | |||
-- and measures the maximum depth over all trees | |||
local function recursiveTree(node, depth) | |||
depth = (depth or 0) + 1 | |||
local rdepth = depth | |||
nodeId = nodeId + 1 | |||
node._nodeId = nodeId | |||
if node.leftChild then | |||
rdepth = math.max(rdepth, recursiveTree(node.leftChild, depth)) | |||
end | |||
if node.rightChild then | |||
rdepth = math.max(rdepth, recursiveTree(node.rightChild, depth)) | |||
end | |||
return rdepth | |||
end | |||
-- sum over trees of max depth | |||
self.depth = 0 | |||
for i,tree in ipairs(df.trees) do | |||
assert(torch.isTypeOf(tree.root, 'dt.CartNode')) | |||
self.depth = self.depth + recursiveTree(tree.root) | |||
end | |||
-- remove roots from depth | |||
self.depth = self.depth - self.rootIds:size(1) | |||
-- total number of nodes in all trees | |||
self.nNode = nodeId | |||
-- nodeId of left and right child nodes | |||
self.leftChild:resize(self.nNode):fill(-1) | |||
self.rightChild:resize(self.nNode):fill(-1) | |||
-- index and value of the feature that splits this node | |||
self.splitFeatureId:resize(self.nNode):fill(-1) | |||
self.splitFeatureValue:resize(self.nNode):fill(-1) | |||
-- aggregates CartNode attributes to an efficient tensor representation | |||
local function recursiveTree2(node) | |||
local nodeId = assert(node._nodeId) | |||
assert(self.splitFeatureId[nodeId] == -1) | |||
if node.leftChild then | |||
self.leftChild[nodeId] = assert(node.leftChild._nodeId) | |||
recursiveTree2(node.leftChild) | |||
else | |||
self.leftChild[nodeId] = 0 | |||
end | |||
if node.rightChild then | |||
self.rightChild[nodeId] = assert(node.rightChild._nodeId) | |||
recursiveTree2(node.rightChild) | |||
else | |||
self.rightChild[nodeId] = 0 | |||
end | |||
-- each node splits the dataset on a feature id-value pair | |||
self.splitFeatureId[nodeId] = assert(node.splitFeatureId) | |||
self.splitFeatureValue[nodeId] = assert(node.splitFeatureValue) | |||
end | |||
for i,tree in ipairs(df.trees) do | |||
self.rootIds[i] = assert(tree.root._nodeId) | |||
recursiveTree2(tree.root) | |||
end | |||
assert(self.leftChild:min() >= 0) | |||
assert(self.rightChild:min() >= 0) | |||
end | |||
function DFD:updateOutput(input) | |||
assert(torch.isTensor(input)) | |||
assert(input:dim() == 2) | |||
input = input:contiguous() | |||
local batchsize, inputsize = input:size(1), input:size(2) | |||
local size = self.onlyLastNode and self.nTree or self.depth | |||
-- each sample's output keys is resized to maxdepth, which is the maximum size that it can take on | |||
self.outputkeys = self.outputkeys or torch.LongTensor() | |||
self.outputkeys:resize(batchsize, size) | |||
-- values are 1 | |||
self.outputvalues = self.outputvalues or input.new() | |||
self.outputvalues:resize(batchsize, size):fill(1) | |||
self.output = input.nn.DFD_computeOutput(self.outputkeys, self.outputvalues, self.rootIds, self.leftChild, self.rightChild, self.splitFeatureId, self.splitFeatureValue, input, self.onlyLastNode) | |||
return self.output | |||
end | |||
function DFD:type(type, tensorCache) | |||
if type then | |||
local info = self:getReconstructionInfo() | |||
for k, v in pairs(info) do | |||
if torch.type(v) ~= 'torch.LongTensor' then | |||
info[k] = nil | |||
end | |||
end | |||
parent.type(self, type, tensorCache) | |||
self:reconstructFromInfo(info) | |||
return self | |||
else | |||
return parent.type(self) | |||
end | |||
end | |||
function DFD:updateGradInput() | |||
error"Not Implemented" | |||
end | |||
function DFD:clearState() | |||
self.output = {{},{}} | |||
self.taskbuffer = {} | |||
self.outputkeys = nil | |||
self.outputvalues = nil | |||
self._range = nil | |||
self._indices = nil | |||
self._mask = nil | |||
end | |||
function DFD:reconstructFromInfo(DFDinfo) | |||
for k,v in pairs(DFDinfo) do | |||
self[k] = v | |||
end | |||
assert(self.leftChild:nDimension() == 1) | |||
assert(self.rightChild:nDimension() == 1) | |||
assert(self.leftChild:size(1) == self.nNode) | |||
assert(self.rightChild:size(1) == self.nNode) | |||
assert(self.leftChild:min() >= 0) | |||
assert(self.rightChild:min() >= 0) | |||
assert(self.splitFeatureId:nDimension() == 1) | |||
assert(self.splitFeatureValue:nDimension() == 1) | |||
assert(self.splitFeatureId:size(1) == self.splitFeatureValue:size(1)) | |||
end | |||
function DFD:getReconstructionInfo() | |||
local DFDinfo = { | |||
nNode = self.nNode, | |||
rootIds = self.rootIds, | |||
leftChild = self.leftChild, | |||
rightChild = self.rightChild, | |||
splitFeatureId = self.splitFeatureId, | |||
splitFeatureValue = self.splitFeatureValue, | |||
depth = self.depth | |||
} | |||
return DFDinfo | |||
end |
@@ -1,142 +0,0 @@ | |||
local dt = require "decisiontree._env" | |||
local DataSet = torch.class("dt.DataSet", dt) | |||
function DataSet:__init(input, target, nThreads) | |||
if torch.type(input) == 'table' then | |||
assert(torch.isTypeOf(input[1], 'torch.SparseTensor')) | |||
else | |||
assert(torch.isTensor(input)) | |||
end | |||
self.input = input | |||
assert(torch.isTensor(target)) | |||
self.target = target | |||
self.nThreads = nThreads or 1 | |||
self.sortedFeatureValues, self.featureIds = self:sortFeatureValues(input) | |||
end | |||
function DataSet:sortFeatureValues(inputs) | |||
local isSparse = torch.typename(inputs[1]):match('torch.*SparseTensor') | |||
assert(isSparse or torch.isTensor(inputs)) | |||
local featureIds = torch.LongTensor() | |||
local dataset = {} -- TODO use tds.Hash (will require SparseTensor to be userdata) | |||
if isSparse then | |||
local proto = inputs[1].values | |||
-- get list of featureIds | |||
local featureMap = {} | |||
for i,input in ipairs(inputs) do | |||
input.keys:apply(function(key) | |||
featureMap[key] = (featureMap[key] or 0) + 1 | |||
end) | |||
end | |||
local _ = require "moses" | |||
featureIds = featureIds.new(_.keys(featureMap)) | |||
local featureCounts = torch.LongTensor(featureIds:size(1)) | |||
for i=1,featureIds:size(1) do | |||
featureCounts[i] = featureMap[featureIds[i]] | |||
end | |||
for i=1,featureIds:size(1) do | |||
local featureId = featureIds[i] | |||
local featureCount = featureCounts[i] | |||
dataset[featureId] = { | |||
values=proto.new(featureCount), | |||
examples=torch.LongTensor(featureCount), | |||
i=0 | |||
} | |||
end | |||
for exampleId,input in ipairs(inputs) do | |||
local sparseIdx = 0 | |||
input.keys:apply(function(key) | |||
sparseIdx = sparseIdx + 1 | |||
local f = dataset[key] | |||
f.i = f.i + 1 | |||
f.values[f.i] = input.values[sparseIdx] | |||
f.examples[f.i] = exampleId | |||
end) | |||
end | |||
local sortVal, sortIdx = proto.new(), torch.LongTensor() | |||
for featureId,f in pairs(dataset) do | |||
assert(f.values:size(1) == f.i) | |||
sortVal:sort(sortIdx, f.values, 1, false) | |||
local sortedExampleIds = torch.LongTensor(f.i) | |||
sortedExampleIds:index(f.examples, 1, sortIdx) | |||
dataset[featureId] = sortedExampleIds | |||
end | |||
else | |||
assert(torch.isTensor(inputs)) | |||
featureIds:range(1,inputs:size(2)) | |||
for i=1,inputs:size(2) do | |||
local featureId = i | |||
local values = inputs:select(2, i) | |||
local _, sortedFeatureExampleIds = values:sort(1, false) | |||
dataset[featureId] = sortedFeatureExampleIds | |||
end | |||
end | |||
return dataset, featureIds | |||
end | |||
function DataSet:getSortedFeature(featureId) | |||
assert(self.sortedFeatureValues) | |||
return self.sortedFeatureValues[featureId] | |||
end | |||
function DataSet:size() | |||
return self.target:size(1) | |||
end | |||
function DataSet:getExampleIds() | |||
if not self.exampleIds then | |||
self.exampleIds = torch.LongTensor():range(1,self:size()) | |||
end | |||
return self.exampleIds | |||
end | |||
function DataSet:countPositive(exampleIds) | |||
assert(torch.type(exampleIds) == 'torch.LongTensor') | |||
local dt = require 'decisiontree' | |||
local buffer = dt.getBufferTable('DataSet') | |||
buffer.tensor = buffer.tensor or self.target.new() | |||
buffer.tensor:index(self.target, 1, exampleIds) | |||
local nPositive = 0 | |||
buffer.tensor:apply(function(x) | |||
if x > 0 then nPositive = nPositive + 1 end | |||
end) | |||
return nPositive | |||
end | |||
function DataSet:initScore() | |||
self.score = self.score or torch.Tensor() | |||
self.score:resize(self:size()):fill(0) | |||
end | |||
function DataSet:buildIndex() | |||
if torch.type(self.input) == 'table' then | |||
for exampleId,input in ipairs(self.input) do | |||
if torch.isTypeOf(input, 'torch.SparseTensor') then | |||
input:buildIndex() | |||
end | |||
end | |||
end | |||
end | |||
function DataSet:deleteIndex() | |||
if torch.type(self.input) == 'table' then | |||
for exampleId,input in ipairs(self.input) do | |||
if torch.isTypeOf(input, 'torch.SparseTensor') then | |||
input:deleteIndex() | |||
end | |||
end | |||
end | |||
end |
@@ -1,82 +0,0 @@ | |||
local dt = require "decisiontree._env" | |||
local DecisionForest = torch.class("dt.DecisionForest", "dt.DecisionTree", dt) | |||
function DecisionForest:__init(trees, weight, bias) | |||
assert(torch.type(trees) == 'table') | |||
self.trees = trees | |||
if #trees == 0 then | |||
self.weight = weight or torch.Tensor() | |||
assert(torch.isTensor(self.weight)) | |||
assert(self.weight:nElement() == 0) | |||
else | |||
assert(torch.isTypeOf(trees[1], 'dt.DecisionTree')) | |||
self.weight = weight or torch.Tensor(#trees):fill(1) | |||
assert(torch.isTensor(self.weight)) | |||
assert(self.weight:dim() == 1) | |||
assert(self.weight:min() >= 0, "Expecting positive weights") | |||
assert(#trees == self.weight:size(1)) | |||
end | |||
self.bias = bias or 0 | |||
assert(torch.type(self.bias) == 'number') | |||
end | |||
function DecisionForest:score(input, incrementalId) | |||
assert(torch.isTensor(input)) | |||
local buffer = {} | |||
if incrementalId then | |||
self.buffers = self.buffers or {} | |||
self.buffers[incrementalId] = self.buffers[incrementalId] or {} | |||
buffer = self.buffers[incrementalId] | |||
end | |||
buffer.initialCounter = buffer.initialCounter or 0 | |||
-- TODO: score in parallel | |||
local output | |||
if torch.isTensor(input) and input.isContiguous and input:isContiguous() and input:nDimension() == 2 then | |||
buffer.output = buffer.output or input.new() | |||
output = buffer.output | |||
assert(output:nElement() == 0 or output:size(1) == input:size(1)) | |||
if output:nElement() == 0 then | |||
output:resize(input:size(1)):fill(self.bias) | |||
end | |||
for i,tree in ipairs(self.trees) do | |||
if i > buffer.initialCounter then | |||
local score = tree:score(input, nil, true) | |||
output:add(self.weight[i], score) | |||
end | |||
end | |||
else | |||
output = buffer.output or self.bias | |||
for i,tree in ipairs(self.trees) do | |||
if i > buffer.initialCounter then | |||
output = output + tree:score(input) * self.weight[i] | |||
end | |||
end | |||
buffer.output = output | |||
end | |||
buffer.initialCounter = #self.trees | |||
return output | |||
end | |||
function DecisionForest:add(tree, weight) | |||
assert(torch.type(weight) == 'number') | |||
assert(weight > 0) | |||
table.insert(self.trees, tree) | |||
self.weight:resize(#self.trees) | |||
self.weight[#self.trees] = weight | |||
return self | |||
end | |||
function DecisionForest:clone() | |||
local trees = {} | |||
for i, tree in ipairs(self.trees) do | |||
trees[i] = tree:clone() | |||
end | |||
return DecisionForest(trees, self.weight:clone(), self.bias) | |||
end |
@@ -1,22 +0,0 @@ | |||
local dt = require "decisiontree._env" | |||
local DFT = torch.class("dt.DecisionForestTrainer", dt) | |||
function DFT:train(examples, validFeatureIds, dataset) | |||
assert(torch.type(examples) == "table") | |||
assert(torch.isTypeOf(examples[1], "dt.LabeledExample")) | |||
assert(torch.type(validFeatureIds) == 'table') | |||
assert(torch.type(dataset) == 'table') | |||
for k,v in pairs(dataset) do | |||
assert(torch.type(v) == 'table') | |||
assert(torch.isTypeOf(v[1], 'dt.LabeledExample')) | |||
break | |||
end | |||
-- dataset is a table mapping featureIds to sorted lists of LabeledExamples | |||
-- e.g. {featureId={example1,example2,example3}} | |||
error"Not Implemented" | |||
end |
@@ -1,12 +0,0 @@ | |||
local dt = require "decisiontree._env" | |||
local DecisionTree = torch.class("dt.DecisionTree", dt) | |||
function DecisionTree:score(input) | |||
error"Not Implemented" | |||
return score, nodeId | |||
end |
@@ -1,106 +0,0 @@ | |||
#include "khash.h" | |||
#include <pthread.h> | |||
#define computeGradientBoostLoss(g, h) (-(g)*(g)/(h)) | |||
// we use khash to make iteration faster than lua tables | |||
KHASH_SET_INIT_INT64(long) | |||
// defines the data we need for running an instance of thet and its constructor/destructor | |||
typedef struct { | |||
khash_t(long)* exampleMap; | |||
THLongTensor *exampleIdsWithFeature_cache; | |||
long minLeafSize; | |||
} GBRunData; | |||
// allocates data that cannot be shared between threads | |||
static void gb_local_create_run_data(GBRunData *run_data) { | |||
run_data->exampleMap = kh_init(long); | |||
run_data->exampleIdsWithFeature_cache = THLongTensor_new(); | |||
} | |||
static void gb_create_run_data(GBRunData *run_data, int minLeafSize) { | |||
gb_local_create_run_data(run_data); | |||
run_data->minLeafSize = minLeafSize; | |||
} | |||
static void gb_destroy_run_data(GBRunData *run_data) { | |||
THLongTensor_free(run_data->exampleIdsWithFeature_cache); | |||
kh_destroy(long, run_data->exampleMap); | |||
} | |||
// initializes the data required by the optimizer for the given feature. | |||
static THLongTensor *gb_internal_prepare(lua_State *L, THLongTensor *exampleIds, | |||
THLongTensor *exampleIdsWithFeature_cache, int input_index, long feature_id, | |||
khash_t(long)* exampleMap) { | |||
long *exampleIds_data = THLongTensor_data(exampleIds); | |||
long exampleIds_size = THLongTensor_size(exampleIds, 0); | |||
int ret = 0; | |||
// if the the input is a table, then we have a sparse dataset | |||
if (lua_istable(L, input_index)) { | |||
if (exampleIds_size == 0) { | |||
return NULL; | |||
} | |||
else { | |||
// loops over the examples' ids that this node has to evaluate and, if they have the feature | |||
// we're looking for, marks them as present and stores them in the order provided by the | |||
// dataset | |||
THLongTensor_resize1d(exampleIdsWithFeature_cache, exampleIds_size); | |||
kh_clear(long, exampleMap); | |||
kh_resize(long, exampleMap, exampleIds_size*8); | |||
long *exampleIdsWithFeature_data = THLongTensor_data(exampleIdsWithFeature_cache); | |||
long j = 0; | |||
// for each sample to be evaluated | |||
for (long i = 0; i < exampleIds_size; i++) { | |||
// gets the representation for the example | |||
lua_pushinteger(L, exampleIds_data[i]); | |||
lua_gettable(L, input_index); | |||
// builds the index, which happens only once per thread for efficiency | |||
lua_pushstring(L, "buildIndex"); | |||
lua_gettable(L, -2); | |||
lua_pushvalue(L, -2); | |||
lua_call(L, 1, 0); | |||
// tries to get the feature for this sample | |||
lua_pushinteger(L, feature_id); | |||
lua_gettable(L, -2); | |||
// if present, then... | |||
if (!lua_isnil(L, -1)) { | |||
// saves the example | |||
exampleIdsWithFeature_data[j] = exampleIds_data[i]; | |||
j++; | |||
// marks it as present in the hash table | |||
kh_put(long, exampleMap, exampleIds_data[i], &ret); | |||
} | |||
lua_pop(L, 2); | |||
} | |||
// resizes to fit only the samples that have the feature | |||
THLongTensor_resize1d(exampleIdsWithFeature_cache, j); | |||
kh_resize(long, exampleMap, j*8); | |||
return exampleIdsWithFeature_cache; | |||
} | |||
} | |||
else { | |||
// if the input isn't a table, then it's dense and we cannot have exampleIds missing, so it | |||
// depends on feature_id | |||
// since exampleIds is fixed between calls and this is going to store the same values to the | |||
// same position, we can cache it between calls | |||
if (kh_size(exampleMap) == 0) { | |||
kh_resize(long, exampleMap, exampleIds_size*8); | |||
for (long i = 0; i < exampleIds_size; i++) { | |||
kh_put(long, exampleMap, exampleIds_data[i], &ret); | |||
} | |||
} | |||
// notice that we just return the given tensor of ids instead of copying it. the rest of the | |||
// code handles this transparently | |||
return exampleIds; | |||
} | |||
} | |||
@@ -1,54 +0,0 @@ | |||
local dt = require 'decisiontree._env' | |||
local GiniState, parent = torch.class("dt.GiniState", "dt.TreeState", dt) | |||
function GiniState:__init(exampleIds) | |||
parent.__init(self, exampleIds) | |||
self.nPositiveInLeftBranch = 0 | |||
self.nPositiveInRightBranch = 0 | |||
end | |||
function GiniState:score(dataset) | |||
local dt = require 'decisiontree' | |||
local nPositive = dataset:countPositive(self.exampleIds) | |||
return dt.calculateLogitScore(nPositive, self.exampleIds:size(1)) | |||
end | |||
function GiniState:initialize(exampleIdsWithFeature, dataset) | |||
assert(torch.type(exampleIdsWithFeature) == 'torch.LongTensor') | |||
assert(torch.isTypeOf(dataset, 'dt.DataSet')) | |||
self.nPositiveInLeftBranch = dataset:countPositive(exampleIdsWithFeature) | |||
self.nPositiveInRightBranch = 0 | |||
self.nExampleInLeftBranch = exampleIdsWithFeature:size(1) | |||
self.nExampleInRightBranch = 0 | |||
end | |||
function GiniState:update(exampleId, dataset) | |||
assert(torch.type(exampleId) == 'number') | |||
assert(torch.isTypeOf(dataset, 'dt.DataSet')) | |||
if dataset.target[exampleId] > 0 then | |||
self.nPositiveInLeftBranch = self.nPositiveInLeftBranch - 1 | |||
self.nPositiveInRightBranch = self.nPositiveInRightBranch + 1 | |||
end | |||
self.nExampleInLeftBranch = self.nExampleInLeftBranch - 1 | |||
self.nExampleInRightBranch = self.nExampleInRightBranch + 1 | |||
end | |||
function GiniState:computeSplitInfo(splitFeatureId, splitFeatureValue) | |||
local dt = require 'decisiontree' | |||
local gini = dt.computeGini(self.nExampleInLeftBranch, self.nPositiveInLeftBranch, self.nExampleInRightBranch, self.nPositiveInRightBranch) | |||
local splitInfo = { | |||
splitId = assert(splitFeatureId), | |||
splitValue = assert(splitFeatureValue), | |||
leftChildSize = assert(self.nExampleInLeftBranch), | |||
leftPositiveCount = assert(self.nPositiveInLeftBranch), | |||
rightChildSize = assert(self.nExampleInRightBranch), | |||
rightPositiveCount = assert(self.nPositiveInRightBranch), | |||
gini = assert(gini), | |||
splitGain = gini | |||
} | |||
return splitInfo | |||
end |
@@ -1,57 +0,0 @@ | |||
local dt = require 'decisiontree._env' | |||
local GradientBoostState, parent = torch.class("dt.GradientBoostState", "dt.TreeState", dt) | |||
function GradientBoostState:__init(exampleIds, gradInput, hessInput) | |||
parent.__init(self, exampleIds) | |||
self.gradInput = gradInput | |||
self.hessInput = hessInput | |||
end | |||
function GradientBoostState:score(dataset) | |||
local dt = require 'decisiontree' | |||
local gradInput = self.gradInput:index(1, self.exampleIds) | |||
local hessInput = self.hessInput:index(1, self.exampleIds) | |||
return dt.computeNewtonScore(gradInput:sum(), hessInput:sum()) | |||
end | |||
function GradientBoostState:branch(splitInfo, dataset) | |||
local leftExampleIds, rightExampleIds = self:_branch(splitInfo, dataset) | |||
return self.new(leftExampleIds, self.gradInput, self.hessInput), self.new(rightExampleIds, self.gradInput, self.hessInput) | |||
end | |||
function GradientBoostState:_branch(splitInfo, dataset) | |||
local input = dataset.input | |||
-- if the input is dense, we can use the optimized version | |||
if torch.isTensor(input) and input.isContiguous and input:isContiguous() and input:nDimension() == 2 then | |||
return input.nn.GBDT_branch(splitInfo, input, self.exampleIds) | |||
end | |||
return parent._branch(self, splitInfo, dataset) | |||
end | |||
function GradientBoostState:findBestFeatureSplit(dataset, featureId, minLeafSize) | |||
local ret = self.hessInput.nn.GBDT_findBestFeatureSplit(self.exampleIds, dataset, featureId, minLeafSize, self.gradInput, self.hessInput) | |||
return ret | |||
end | |||
function GradientBoostState:findBestSplit(dataset, featureIds, minLeafSize, shardId, nShard) | |||
local ret = self.hessInput.nn.GBDT_findBestSplit(self.exampleIds, dataset, featureIds, minLeafSize, shardId, nShard, self.gradInput, self.hessInput) | |||
return ret | |||
end | |||
function GradientBoostState:findBestSplitFP(dataset, featureIds, minLeafSize, nThread) | |||
local input = dataset.input | |||
if torch.isTensor(input) and input.isContiguous and input:isContiguous() and input:nDimension() == 2 then | |||
local ret = self.hessInput.nn.GBDT_findBestSplitFP(self.exampleIds, dataset, featureIds, minLeafSize, self.gradInput, self.hessInput, nThread) | |||
return ret | |||
end | |||
end |
@@ -1,244 +0,0 @@ | |||
local dt = require "decisiontree._env" | |||
local GradientBoostTrainer = torch.class("dt.GradientBoostTrainer", "dt.DecisionForestTrainer", dt) | |||
function GradientBoostTrainer:__init(opt) | |||
assert(torch.type(opt) == 'table') | |||
assert(torch.isTypeOf(opt.treeTrainer, 'dt.CartTrainer')) | |||
self.treeTrainer = opt.treeTrainer | |||
assert(torch.isTypeOf(opt.lossFunction, 'nn.Criterion')) | |||
self.lossFunction = opt.lossFunction | |||
assert(torch.type(opt.shrinkage) == 'number') | |||
assert(opt.shrinkage > 0) | |||
self.shrinkage = opt.shrinkage | |||
assert(torch.type(opt.downsampleRatio) == 'number') | |||
assert(opt.downsampleRatio > 0) | |||
self.downsampleRatio = opt.downsampleRatio | |||
assert(torch.type(opt.nTree) == 'number') | |||
assert(opt.nTree > 0) | |||
self.nTree = opt.nTree | |||
evalFreq = evalFreq or -1 | |||
assert(torch.type(opt.evalFreq) == 'number') | |||
assert(torch.round(opt.evalFreq) == opt.evalFreq) | |||
self.evalFreq = opt.evalFreq | |||
-- when non-positive, no early-stopping | |||
earlyStop = earlyStop or (evalFreq-1) | |||
assert(torch.type(opt.earlyStop) == 'number') | |||
self.earlyStop = opt.earlyStop | |||
-- when non-positive, defaults to sqrt(#feature) | |||
assert(torch.type(opt.featureBaggingSize) == 'number') | |||
self.featureBaggingSize = opt.featureBaggingSize | |||
if opt.decisionForest then | |||
assert(torch.isTypeOf(opt.decisionForest, 'dt.DecisionForest')) | |||
end | |||
self.decisionForest = opt.decisionForest | |||
self.useInitBias = opt.useInitBias | |||
end | |||
function GradientBoostTrainer:computeBias(trainSet, verbose) | |||
assert(torch.isTypeOf(trainSet, 'dt.DataSet')) | |||
if verbose then print("Use new bias generated from the training examples.") end | |||
return -0.5 * self.gradInput:sum() / self.hessInput:sum() | |||
end | |||
function GradientBoostTrainer:initialize(trainSet, verbose) | |||
assert(torch.isTypeOf(trainSet, 'dt.DataSet')) | |||
trainSet:initScore() | |||
self.gradInput, self.hessInput = self.lossFunction:backward2(trainSet.score, trainSet.target) | |||
-- used for early-stopping (see validate()) | |||
self.stopCount = 0 | |||
self.prevTrainLoss = math.huge | |||
self.prevTestLoss = math.huge | |||
if verbose then print("Processing initial decision forest") end | |||
local decisionForest, bias | |||
if self.decisionForest then | |||
local bias = self.useInitBias and self.decisionForest.bias or self:computeBias(trainSet, verbose) | |||
decisionForest = dt.DecisionForest(self.decisionForest.trees, self.decisionForest.weight, bias) | |||
local input = trainSet.input | |||
if torch.isTensor(input) and input.isContiguous and input:isContiguous() then | |||
score = decisionForest:score(input) | |||
else | |||
score:resize(trainSet:size()) | |||
for exampleId=1,trainSet:size() do | |||
score[exampleId] = decisionForest:score(input[exampleId]) | |||
end | |||
end | |||
else | |||
local bias = self:computeBias(trainSet, verbose) | |||
decisionForest = dt.DecisionForest({}, torch.Tensor(), bias) | |||
trainSet.score:fill(bias) | |||
end | |||
if verbose then print("Finish loading initial decision forest") end | |||
return decisionForest | |||
end | |||
function GradientBoostTrainer:train(trainSet, featureIds, validSet, verbose) | |||
assert(torch.isTypeOf(trainSet, 'dt.DataSet')) | |||
assert(torch.type(featureIds) == 'torch.LongTensor') | |||
assert(torch.isTypeOf(validSet, 'dt.DataSet')) | |||
local decisionForest = self:initialize(trainSet, verbose) | |||
local bestDecisionForest | |||
if verbose then print(string.format("Get %d featureIds.", featureIds:size(1))) end | |||
local baggingSize = self.featureBaggingSize > 0 and self.featureBaggingSize or torch.round(math.sqrt(featureIds:size(1))) | |||
local trainExampleIds = trainSet:getExampleIds() | |||
local baggingIndices, activeFeatures | |||
local treeExampleIds | |||
local timer = torch.Timer() | |||
for treeId = 1,self.nTree do | |||
timer:reset() | |||
if verbose then print(string.format("Begin processing tree number %d of %d", treeId, self.nTree)) end | |||
-- Get active features | |||
activeFeatures = activeFeatures or torch.LongTensor() | |||
if baggingSize < featureIds:size(1) then | |||
if verbose then print(string.format("Tree %d: Bagging %d from %d features", treeId, baggingSize, featureIds:size(1))) end | |||
baggingIndices = baggingIndices or torch.LongTensor() | |||
baggingIndices:randperm(featureIds:size(1)) | |||
activeFeatures:index(featureIds, 1, baggingIndices:narrow(1,1,baggingSize)) | |||
else | |||
activeFeatures = featureIds | |||
end | |||
-- Get data samples | |||
if self.downsampleRatio < 0.99 then | |||
local sampleSize = torch.round(trainSet:size() * self.downsampleRatio) | |||
if verbose then print(string.format("Tree %d: Downsampling %d of %d samples", treeId, sampleSize, trainSet:size())) end | |||
baggingIndices = baggingIndices or torch.LongTensor() | |||
baggingIndices:randperm(trainSet:size()) | |||
treeExampleIds = treeExampleIds or torch.LongTensor() | |||
treeExampleIds:index(trainExampleIds, 1, baggingIndices:narrow(1,1,sampleSize)) | |||
else | |||
treeExampleIds = trainExampleIds | |||
end | |||
if verbose then print(string.format("Tree %d: training CART tree", treeId)) end | |||
local rootTreeState = dt.GradientBoostState(treeExampleIds, self.gradInput, self.hessInput) | |||
local cartTree = self.treeTrainer:train(rootTreeState, activeFeatures) | |||
if verbose then print(string.format("Tree %d: finished training CART tree in %f seconds", treeId, timer:time().real)) end | |||
decisionForest:add(cartTree, self.shrinkage) | |||
-- update score | |||
local predictionScore | |||
local input = trainSet.input | |||
if torch.isTensor(input) and input:isContiguous() then | |||
predictionScore = cartTree:score(trainSet.input, nil, true) | |||
else | |||
local size = trainSet:size() | |||
predictionScore = torch.Tensor(size) | |||
for exampleId=1,size do | |||
predictionScore[exampleId] = cartTree:score(trainSet.input[exampleId]) | |||
end | |||
end | |||
trainSet.score:add(self.shrinkage, predictionScore) | |||
self.gradInput, self.hessInput = self.lossFunction:backward2(trainSet.score, trainSet.target) | |||
if verbose then print(string.format("Tree %d: training complete in %f seconds", treeId, timer:time().real)) end | |||
-- cross-validation/early-stopping | |||
if self.evalFreq > 0 and treeId % self.evalFreq == 0 then | |||
timer:reset() | |||
local stop, validLoss, bestDecisionForest = self:validate(trainSet, validSet, decisionForest, bestDecisionForest) | |||
if dt.PROFILE then print("validate tree time: "..timer:time().real) end | |||
if verbose then print(string.format("Loss: train=%7.4f, valid=%7.4f", 0, validLoss)) end | |||
if stop then | |||
if verbose then print(string.format("GBDT early stopped on tree %d", treeId)) end | |||
break | |||
end | |||
end | |||
end | |||
return bestDecisionForest or decisionForest | |||
end | |||
function dt.GradientBoostTrainer:validate(trainSet, validSet, decisionForest, bestDecisionForest) | |||
assert(torch.isTypeOf(trainSet, 'dt.DataSet')) | |||
assert(torch.isTypeOf(validSet, 'dt.DataSet')) | |||
assert(torch.isTypeOf(decisionForest, 'dt.DecisionForest')) | |||
assert(not bestDecisionForest or torch.isTypeOf(decisionForest, 'dt.DecisionForest')) | |||
-- buffer | |||
local buffer = dt.getBufferTable('GradientBoost') | |||
buffer.tensor = buffer.tensor or trainSet.score.new() | |||
local score = buffer.tensor | |||
-- per thread loss function (tensors are shared) | |||
local lossname = torch.typename(self.lossFunction) | |||
buffer[lossname] = buffer[lossname] or self.lossFunction:clone() | |||
local lossFunction = buffer[lossname] | |||
-- TODO batch this for large datasets | |||
local input = validSet.input | |||
if torch.isTensor(input) and input.isContiguous and input:isContiguous() then | |||
score = decisionForest:score(input, 'val') | |||
else | |||
score:resize(validSet:size()) | |||
for exampleId=1,validSet:size() do | |||
score[exampleId] = decisionForest:score(input[exampleId], 'val') | |||
end | |||
end | |||
local validLoss = lossFunction:forward(score, validSet.target) | |||
-- early stop is not enabled when earlyStop=0 | |||
local stop = false | |||
if self.earlyStop > 0 then | |||
-- Track test loss and detect early stop | |||
if self.prevTestLoss - validLoss < 0 then | |||
self.stopCount = self.stopCount + 1 | |||
else | |||
bestDecisionForest = decisionForest:clone() | |||
self.stopCount = 0 | |||
end | |||
stop = self.stopCount >= self.earlyStop | |||
end | |||
self.prevTestLoss = validLoss | |||
return stop, validLoss, bestDecisionForest | |||
end | |||
function GradientBoostTrainer:getName() | |||
return string.format( | |||
"gbdt-dRatio-%s-maxLeaf-%s-minExample-%s-nTree-%s-shrinkage-%s", | |||
self.downsampleRatio, self.maxLeafNodes, self.minLeafSize, self.nTree, self.shrinkage | |||
) | |||
end |
@@ -1,201 +0,0 @@ | |||
Apache License | |||
Version 2.0, January 2004 | |||
http://www.apache.org/licenses/ | |||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION | |||
1. Definitions. | |||
"License" shall mean the terms and conditions for use, reproduction, | |||
and distribution as defined by Sections 1 through 9 of this document. | |||
"Licensor" shall mean the copyright owner or entity authorized by | |||
the copyright owner that is granting the License. | |||
"Legal Entity" shall mean the union of the acting entity and all | |||
other entities that control, are controlled by, or are under common | |||
control with that entity. For the purposes of this definition, | |||
"control" means (i) the power, direct or indirect, to cause the | |||
direction or management of such entity, whether by contract or | |||
otherwise, or (ii) ownership of fifty percent (50%) or more of the | |||
outstanding shares, or (iii) beneficial ownership of such entity. | |||
"You" (or "Your") shall mean an individual or Legal Entity | |||
exercising permissions granted by this License. | |||
"Source" form shall mean the preferred form for making modifications, | |||
including but not limited to software source code, documentation | |||
source, and configuration files. | |||
"Object" form shall mean any form resulting from mechanical | |||
transformation or translation of a Source form, including but | |||
not limited to compiled object code, generated documentation, | |||
and conversions to other media types. | |||
"Work" shall mean the work of authorship, whether in Source or | |||
Object form, made available under the License, as indicated by a | |||
copyright notice that is included in or attached to the work | |||
(an example is provided in the Appendix below). | |||
"Derivative Works" shall mean any work, whether in Source or Object | |||
form, that is based on (or derived from) the Work and for which the | |||
editorial revisions, annotations, elaborations, or other modifications | |||
represent, as a whole, an original work of authorship. For the purposes | |||
of this License, Derivative Works shall not include works that remain | |||
separable from, or merely link (or bind by name) to the interfaces of, | |||
the Work and Derivative Works thereof. | |||
"Contribution" shall mean any work of authorship, including | |||
the original version of the Work and any modifications or additions | |||
to that Work or Derivative Works thereof, that is intentionally | |||
submitted to Licensor for inclusion in the Work by the copyright owner | |||
or by an individual or Legal Entity authorized to submit on behalf of | |||
the copyright owner. For the purposes of this definition, "submitted" | |||
means any form of electronic, verbal, or written communication sent | |||
to the Licensor or its representatives, including but not limited to | |||
communication on electronic mailing lists, source code control systems, | |||
and issue tracking systems that are managed by, or on behalf of, the | |||
Licensor for the purpose of discussing and improving the Work, but | |||
excluding communication that is conspicuously marked or otherwise | |||
designated in writing by the copyright owner as "Not a Contribution." | |||
"Contributor" shall mean Licensor and any individual or Legal Entity | |||
on behalf of whom a Contribution has been received by Licensor and | |||
subsequently incorporated within the Work. | |||
2. Grant of Copyright License. Subject to the terms and conditions of | |||
this License, each Contributor hereby grants to You a perpetual, | |||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable | |||
copyright license to reproduce, prepare Derivative Works of, | |||
publicly display, publicly perform, sublicense, and distribute the | |||
Work and such Derivative Works in Source or Object form. | |||
3. Grant of Patent License. Subject to the terms and conditions of | |||
this License, each Contributor hereby grants to You a perpetual, | |||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable | |||
(except as stated in this section) patent license to make, have made, | |||
use, offer to sell, sell, import, and otherwise transfer the Work, | |||
where such license applies only to those patent claims licensable | |||
by such Contributor that are necessarily infringed by their | |||
Contribution(s) alone or by combination of their Contribution(s) | |||
with the Work to which such Contribution(s) was submitted. If You | |||
institute patent litigation against any entity (including a | |||
cross-claim or counterclaim in a lawsuit) alleging that the Work | |||
or a Contribution incorporated within the Work constitutes direct | |||
or contributory patent infringement, then any patent licenses | |||
granted to You under this License for that Work shall terminate | |||
as of the date such litigation is filed. | |||
4. Redistribution. You may reproduce and distribute copies of the | |||
Work or Derivative Works thereof in any medium, with or without | |||
modifications, and in Source or Object form, provided that You | |||
meet the following conditions: | |||
(a) You must give any other recipients of the Work or | |||
Derivative Works a copy of this License; and | |||
(b) You must cause any modified files to carry prominent notices | |||
stating that You changed the files; and | |||
(c) You must retain, in the Source form of any Derivative Works | |||
that You distribute, all copyright, patent, trademark, and | |||
attribution notices from the Source form of the Work, | |||
excluding those notices that do not pertain to any part of | |||
the Derivative Works; and | |||
(d) If the Work includes a "NOTICE" text file as part of its | |||
distribution, then any Derivative Works that You distribute must | |||
include a readable copy of the attribution notices contained | |||
within such NOTICE file, excluding those notices that do not | |||
pertain to any part of the Derivative Works, in at least one | |||
of the following places: within a NOTICE text file distributed | |||
as part of the Derivative Works; within the Source form or | |||
documentation, if provided along with the Derivative Works; or, | |||
within a display generated by the Derivative Works, if and | |||
wherever such third-party notices normally appear. The contents | |||
of the NOTICE file are for informational purposes only and | |||
do not modify the License. You may add Your own attribution | |||
notices within Derivative Works that You distribute, alongside | |||
or as an addendum to the NOTICE text from the Work, provided | |||
that such additional attribution notices cannot be construed | |||
as modifying the License. | |||
You may add Your own copyright statement to Your modifications and | |||
may provide additional or different license terms and conditions | |||
for use, reproduction, or distribution of Your modifications, or | |||
for any such Derivative Works as a whole, provided Your use, | |||
reproduction, and distribution of the Work otherwise complies with | |||
the conditions stated in this License. | |||
5. Submission of Contributions. Unless You explicitly state otherwise, | |||
any Contribution intentionally submitted for inclusion in the Work | |||
by You to the Licensor shall be under the terms and conditions of | |||
this License, without any additional terms or conditions. | |||
Notwithstanding the above, nothing herein shall supersede or modify | |||
the terms of any separate license agreement you may have executed | |||
with Licensor regarding such Contributions. | |||
6. Trademarks. This License does not grant permission to use the trade | |||
names, trademarks, service marks, or product names of the Licensor, | |||
except as required for reasonable and customary use in describing the | |||
origin of the Work and reproducing the content of the NOTICE file. | |||
7. Disclaimer of Warranty. Unless required by applicable law or | |||
agreed to in writing, Licensor provides the Work (and each | |||
Contributor provides its Contributions) on an "AS IS" BASIS, | |||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
implied, including, without limitation, any warranties or conditions | |||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A | |||
PARTICULAR PURPOSE. You are solely responsible for determining the | |||
appropriateness of using or redistributing the Work and assume any | |||
risks associated with Your exercise of permissions under this License. | |||
8. Limitation of Liability. In no event and under no legal theory, | |||
whether in tort (including negligence), contract, or otherwise, | |||
unless required by applicable law (such as deliberate and grossly | |||
negligent acts) or agreed to in writing, shall any Contributor be | |||
liable to You for damages, including any direct, indirect, special, | |||
incidental, or consequential damages of any character arising as a | |||
result of this License or out of the use or inability to use the | |||
Work (including but not limited to damages for loss of goodwill, | |||
work stoppage, computer failure or malfunction, or any and all | |||
other commercial damages or losses), even if such Contributor | |||
has been advised of the possibility of such damages. | |||
9. Accepting Warranty or Additional Liability. While redistributing | |||
the Work or Derivative Works thereof, You may choose to offer, | |||
and charge a fee for, acceptance of support, warranty, indemnity, | |||
or other liability obligations and/or rights consistent with this | |||
License. However, in accepting such obligations, You may act only | |||
on Your own behalf and on Your sole responsibility, not on behalf | |||
of any other Contributor, and only if You agree to indemnify, | |||
defend, and hold each Contributor harmless for any liability | |||
incurred by, or claims asserted against, such Contributor by reason | |||
of your accepting any such warranty or additional liability. | |||
END OF TERMS AND CONDITIONS | |||
APPENDIX: How to apply the Apache License to your work. | |||
To apply the Apache License to your work, attach the following | |||
boilerplate notice, with the fields enclosed by brackets "{}" | |||
replaced with your own identifying information. (Don't include | |||
the brackets!) The text should be enclosed in the appropriate | |||
comment syntax for the file format. We also recommend that a | |||
file or class name and description of purpose be included on the | |||
same "printed page" as the copyright notice for easier | |||
identification within third-party archives. | |||
Copyright {yyyy} {name of copyright owner} | |||
Licensed under the Apache License, Version 2.0 (the "License"); | |||
you may not use this file except in compliance with the License. | |||
You may obtain a copy of the License at | |||
http://www.apache.org/licenses/LICENSE-2.0 | |||
Unless required by applicable law or agreed to in writing, software | |||
distributed under the License is distributed on an "AS IS" BASIS, | |||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
See the License for the specific language governing permissions and | |||
limitations under the License. |
@@ -1,45 +0,0 @@ | |||
local dt = require "decisiontree._env" | |||
local LogitBoostCriterion, parent = torch.class("nn.LogitBoostCriterion", "nn.Criterion") | |||
function LogitBoostCriterion:__init(sizeAverage) | |||
parent.__init(self) | |||
self.sizeAverage = sizeAverage | |||
self.hessInput = self.gradInput.new() | |||
self._output = torch.Tensor() | |||
end | |||
function LogitBoostCriterion:updateOutput(input, target) | |||
input.nn.LogitBoostCriterion_updateOutput(input, target, self._output, self.sizeAverage) | |||
self.output = self._output[1] | |||
return self.output | |||
end | |||
function LogitBoostCriterion:updateGradInput(input, target) | |||
input.nn.LogitBoostCriterion_updateGradInput(input, target, self.gradInput) | |||
return self.gradInput | |||
end | |||
function LogitBoostCriterion:updateHessInput(input, target) | |||
input.nn.LogitBoostCriterion_updateHessInput(input, target, self.hessInput) | |||
return self.hessInput | |||
end | |||
function LogitBoostCriterion:backward2(input, target) | |||
return self:updateGradInput(input, target), self:updateHessInput(input, target) | |||
end | |||
local gradWrapper = function(input, target, grad) | |||
input.nn.LogitBoostCriterion_updateGradInput(input, target, grad) | |||
end | |||
local hessianWrapper = function(input, target, hessian) | |||
input.nn.LogitBoostCriterion_updateHessInput(input, target, hessian) | |||
end | |||
function LogitBoostCriterion:getWrappers() | |||
return gradWrapper, hessianWrapper | |||
end |
@@ -1,13 +0,0 @@ | |||
local dt = require "decisiontree._env" | |||
function nn.MSECriterion.updateHessianInput(self, input, target) | |||
self.hessInput = self.hessInput or input.new() | |||
self.hessInput:resize(input:size()):fill(2) | |||
return self.hessInput | |||
end | |||
function nn.MSECriterion.backward2(self, input, target) | |||
return self:updateGradInput(input, target), self:updateHessInput(input, target) | |||
end | |||
@@ -1,386 +0,0 @@ | |||
# Torch decision tree library | |||
```lua | |||
local dt = require 'decisiontree' | |||
``` | |||
This project implements random forests and gradient boosted decision trees (GBDT). | |||
The latter uses gradient tree boosting. | |||
Both use ensemble learning to produce ensembles of decision trees (that is, forests). | |||
## `nn.DFD` | |||
One practical application for decision forests is to *discretize* an input feature space into a richer output feature space. | |||
The `nn.DFD` Module can be used as a decision forest discretizer (DFD): | |||
```lua | |||
local dfd = nn.DFD(df, onlyLastNode) | |||
``` | |||
where `df` is a `dt.DecisionForest` instance or the table returned by the method `getReconstructionInfo()` on another `nn.DFD` module, and `onlyLastNode` is a boolean that indicates that module should return only the id of the last node visited on each tree (by default it outputs all traversed nodes except for the roots). | |||
The `nn.DFD` module requires dense `input` tensors. | |||
Sparse `input` tensors (tables of tensors) are not supported. | |||
The `output` returned by a call to `updateOutput` is a batch of sparse tensors. | |||
This `output` where `output[1]` and `output[2]` are a respectively a list of key and value tensors: | |||
```lua | |||
{ | |||
{ [torch.LongTensor], ... , [torch.LongTensor] }, | |||
{ [torch.Tensor], ... , [torch.Tensor] } | |||
} | |||
``` | |||
This module doesn't support CUDA. | |||
### Example | |||
As a concrete example, let us first train a Random Forest on a dummy dense dataset: | |||
```lua | |||
local nExample = 100 | |||
local batchsize = 2 | |||
local inputsize = 10 | |||
local trainSet = dt.getDenseDummyData(nExample, nil, inputsize) | |||
local opt = { | |||
activeRatio=0.5, | |||
featureBaggingSize=5, | |||
nTree=4, | |||
maxLeafNodes=nExample/2, | |||
minLeafSize=nExample/10, | |||
} | |||
local trainer = dt.RandomForestTrainer(opt) | |||
local df = trainer:train(trainSet, trainSet.featureIds) | |||
mytester:assert(#df.trees == opt.nTree) | |||
``` | |||
Now that we have `df`, a `dt.DecisionForest` instance, we can use it to initialize `nn.DFD`: | |||
```lua | |||
local dfd = nn.DFD(df) | |||
``` | |||
The `dfd` instance holds no reference to `df`, instead it extracts the relevant attributes from `df`. | |||
These attributes are stored in tensors for batching and efficiency. | |||
We can discretize a hypothetical `input` by calling `forward`: | |||
```lua | |||
local input = trainSet.input:sub(1,batchsize) | |||
local output = dfd:forward(input) | |||
``` | |||
The resulting output is a table consisting of two tables: keys and values. | |||
The keys and values tables each contains `batchsize` tensors: | |||
```lua | |||
print(output) | |||
{ | |||
1 : | |||
{ | |||
1 : LongTensor - size: 14 | |||
2 : LongTensor - size: 16 | |||
3 : LongTensor - size: 15 | |||
4 : LongTensor - size: 13 | |||
} | |||
2 : | |||
{ | |||
1 : DoubleTensor - size: 14 | |||
2 : DoubleTensor - size: 16 | |||
3 : DoubleTensor - size: 15 | |||
4 : DoubleTensor - size: 13 | |||
} | |||
} | |||
``` | |||
An example's feature keys (`LongTensor`) and commensurate values (`DoubleTensor`) have the same number of elements. | |||
The examples have variable number of key-value pairs representing the nodes traversed in the tree. | |||
The output feature space has as many dimensions (that is, possible feature keys) for each node in the forest. | |||
## `torch.SparseTensor` | |||
Suppose you have a set of `keys` mapped to `values`: | |||
```lua | |||
local keys = torch.LongTensor{1,3,4,7,2} | |||
local values = torch.Tensor{0.1,0.3,0.4,0.7,0.2} | |||
``` | |||
You can use a `SparseTensor` to encapsulate these into a read-only tensor: | |||
```lua | |||
local st = torch.SparseTensor(input, target) | |||
``` | |||
The _decisiontree_ library uses `SparseTensors` to simulate the `__index` method of the `torch.Tensor`. | |||
For example, one can obtain the value associated to key 3 of the above `st` instance: | |||
```lua | |||
local value = st[3] | |||
assert(value == 0.3) | |||
``` | |||
When the key,value pair are missing, `nil` is returned instead: | |||
```lua | |||
local value = st[2] | |||
assert(value == nil) | |||
``` | |||
The best implementation for this kind of indexing is slow (it uses a sequential scan of the `keys). | |||
To speedup indexing, one can call the `buildIndex()` method before hand: | |||
```lua | |||
st:buildIndex() | |||
``` | |||
The `buildIndex()` creates a hash map (a Lua table) of keys to their commensurate indices in the `values` table. | |||
## `dt.DataSet` | |||
The `CartTrainer`, `RandomForestTrainer` and `GradientBoostTrainer` require that data sets be encapsulated into a `DataSet`. | |||
Suppose you have a dataset of dense inputs and targets: | |||
```lua | |||
local nExample = 10 | |||
local nFeature = 5 | |||
local input = torch.randn(nExample, nFeature) | |||
local target = torch.Tensor(nExample):random(0,1) | |||
``` | |||
these can be encapsulated into a `DataSet` object: | |||
```lua | |||
local dataset = dt.DataSet(input, target) | |||
``` | |||
Now suppose you have a dataset where the `input` is a table of `SparseTensor` instances: | |||
```lua | |||
local input = {} | |||
for i=1,nExample do | |||
local nKeyVal = math.random(2,nFeature) | |||
local keys = torch.LongTensor(nKeyVal):random(1,nFeature) | |||
local values = torch.randn(nKeyVal) | |||
input[i] = torch.SparseTensor(keys, values) | |||
end | |||
``` | |||
You can still use a `DataSet` to encapsulate the sparse dataset: | |||
```lua | |||
local dataset = dt.DataSet(input, target) | |||
``` | |||
The main purpose of the `DataSet` class is to sort each feature by value. | |||
This is captured by the `sortFeatureValues(input)` method, which is called in the constructor: | |||
```lua | |||
local sortedFeatureValues, featureIds = self:sortFeatureValues(input) | |||
``` | |||
The `featureIds` is a `torch.LongTensor` of all available feature IDs. | |||
For a dense `input` tensor, this is just `torch.LongTensor():range(1,input:size(2))`. | |||
But for a sparse `input` tensor, the `featureIds` tensor only contains the feature IDs present in the dataset. | |||
The resulting `sortedFeatureValues` is a table mapping `featureIds` to `exampleIds` sorted by `featureValues`. | |||
For each `featureId`, examples are sorted by `featureValue` in ascending order. | |||
For example, the table might look like: `{featureId=exampleIds}` where `examplesIds={1,3,2}`. | |||
The `CartTrainer` accesses the `sortedFeatureValues` tensor by calling `getSortedFeature(featureId)`: | |||
```lua | |||
local exampleIdsWithFeature = dataset:getSortedFeature(featureId) | |||
``` | |||
The ability to access examples IDs sorted by feature value, given a feature ID, is the main purpose of the `DataSet`. | |||
The `CartTrainer` relies on these sorted lists to find the best way to split a set of examples between two tree nodes. | |||
## `dt.CartTrainer` | |||
```lua | |||
local trainer = dt.CartTrainer(dataset, minLeafSize, maxLeafNodes) | |||
``` | |||
The `CartTrainer` is used by the `RandomForestTrainer` and `GradientBoostTrainer` to train individual trees. | |||
CART stands for classification and regression trees. | |||
However, only binary classifiers are unit tested. | |||
The constructor takes the following arguments: | |||
* `dataset` is a `dt.DataSet` instance representing the training set. | |||
* `minLeafSize` is the minimum examples per leaf node in a tree. The larger the value, the more regularization. | |||
* `maxLeafNodes` is the maximum nodes in the tree. The lower the value, the more regularization. | |||
Training is initiated by calling the `train()` method: | |||
```lua | |||
local trainSet = dt.DataSet(input, target) | |||
local rootTreeState = dt.GiniState(trainSet:getExampleIds()) | |||
local activeFeatures = trainSet.featureIds | |||
local tree = trainer:train(rootTreeState, activeFeatures) | |||
``` | |||
The resulting `tree` is a `CartTree` instance. | |||
The `rootTreeState` is a `TreeState` instance like `GiniState` (used by `RandomForestTrainer`) or `GradientBoostState` (used by `GradientBoostTrainer`). | |||
The `activeFeatures` is a `LongTensor` of feature IDs that used to build the tree. | |||
Every other feature ID is ignored during training. This is useful for feature bagging. | |||
By default the `CartTrainer` runs in a single-thread. | |||
The `featureParallel(nThread)` method can be called before calling `train()` to parallelize training using `nThread` workers: | |||
```lua | |||
local nThread = 3 | |||
trainer:featureParallel(nThread) | |||
trainer:train(rootTreeState, activeFeatures) | |||
``` | |||
Feature parallelization assigns a set of features IDs to each thread. | |||
The `CartTrainer` can be used as a stand-alone tree trainer. | |||
But it is recommended to use it within the context of a `RandomForestTrainer` or `GradientBoostTrainer` instead. | |||
The latter typically generalize better. | |||
## RandomForestTrainer | |||
The `RandomForestTrainer` is used to train a random forest: | |||
```lua | |||
local nExample = trainSet:size() | |||
local opt = { | |||
activeRatio=0.5, | |||
featureBaggingSize=5, | |||
nTree=14, | |||
maxLeafNodes=nExample/2, | |||
minLeafSize=nExample/10, | |||
} | |||
local trainer = dt.RandomForestTrainer(opt) | |||
local forest = trainer:train(trainSet, trainSet.featureIds) | |||
``` | |||
The returned `forest` is a `DecisionForest` instance. | |||
A `DecisionForest` has a similar interface to the `CartTree`. | |||
Indeed, they both sub-class the `DecisionTree` abstract class. | |||
The constructor takes a single `opt` table argument, which contains the actual arguments: | |||
* `activeRatio` is the ratio of active examples per tree. This is used for boostrap sampling. | |||
* `featureBaggingSize` is the number of features per tree. This is also used fpr feature bagging. | |||
* `nTree` is the number of trees to be trained. | |||
* `maxLeafNodes` and `minLeafSize` are passed to the underlying `CartTrainer` constructor (controls regularization). | |||
Internally, the `RandomForestTrainer` passes a `GiniBoostState` to the `CartTrainer:train()` method. | |||
Training can be parallelized by calling `treeParallel(nThread)`: | |||
```lua | |||
local nThread = 3 | |||
trainer:treeParallel(nThread) | |||
local forest = trainer:train(trainSet, trainSet.featureIds) | |||
``` | |||
Training then parallelizes by training each tree in its own thread worker. | |||
## GradientBoostTrainer | |||
References: | |||
* A. [Boosted Tree presentation](https://homes.cs.washington.edu/~tqchen/pdf/BoostedTree.pdf) | |||
Graient boosted decision trees (GBDT) can be trained as follows: | |||
```lua | |||
local nExample = trainSet:size() | |||
local maxLeafNode, minLeafSize = nExample/2, nExample/10 | |||
local cartTrainer = dt.CartTrainer(trainSet, minLeafSize, maxLeafNode) | |||
local opt = { | |||
lossFunction=nn.LogitBoostCriterion(false), | |||
treeTrainer=cartTrainer, | |||
shrinkage=0.1, | |||
downsampleRatio=0.8, | |||
featureBaggingSize=-1, | |||
nTree=14, | |||
evalFreq=8, | |||
earlyStop=0 | |||
} | |||
local trainer = dt.GradientBoostTrainer(opt) | |||
local forest = trainer:train(trainSet, trainSet.featureIds, validSet) | |||
``` | |||
The above code snippet uses the `LogitBoostCriterion` outlined in reference A. | |||
It is used for training binary classification trees. | |||
The returned `forest` is a `DecisionForest` instance. | |||
A `DecisionForest` has a similar interface to the `CartTree`. | |||
Indeed, they both sub-class the `DecisionTree` abstract class. | |||
The constructor takes a single `opt` table argument, which contains the actual arguments: | |||
* `lossFunction` is a `nn.Criterion` instance extended to include the `updateHessInput(input, target)` and `backward2(input, target)`. These return the hessian of the `input`. | |||
* `treeTrainer` is a `CartTrainer` instance. Its `featureParallel()` method can be called to implement feature parallelization. | |||
* `shrinkage` is the weight of each additional tree. | |||
* `downsampleRatio` is the ratio of examples to be sampled for each tree. Used for bootstrap sampling. | |||
* `featureBaggingSize` is the number of features to sample per tree. Used for feature bagging. `-1` defaults to `torch.round(math.sqrt(featureIds:size(1)))` | |||
* `nTree` is the maximum number of trees. | |||
* `evalFreq` is the number of epochs between calls to `validate()` for cross-validation and early-stopping. | |||
* `earlyStop` is the maximum number of epochs to wait for early-stopping. | |||
Internally, the `GradientBoostTrainer` passes a `GradientBoostState` to the `CartTrainer:train()` method. | |||
## TreeState | |||
An abstract class that holds the state of a subtree during decision tree training. | |||
It also manages the state of candidate splits. | |||
```lua | |||
local treeState = dt.TreeState(exampleIds) | |||
``` | |||
The `exampleIds` argument is a `LongTensor` containing the example IDs that make up the sub-tree. | |||
## GiniState | |||
A `TreeState` subclass used internally by the `RandomForestTrainer`. | |||
Uses Gini impurity to determine how to split trees. | |||
```lua | |||
local treeState = dt.GiniState(exampleIds) | |||
``` | |||
The `exampleIds` argument is a `LongTensor` containing the example IDs that make up the sub-tree. | |||
## GradientBoostState | |||
A `TreeState` subclass used internally by the `GradientBoostTrainer`. | |||
It implements the GBDT spliting algorithm, which uses a loss function. | |||
```lua | |||
local treeState = dt.GradientBoostState(exampleIds, lossFunction) | |||
``` | |||
The `exampleIds` argument is a `LongTensor` containing the example IDs that make up the sub-tree. | |||
The `lossFunction` is an `nn.Criterion` instance (see `GradientBoostTrainer`). | |||
## WorkPool | |||
Utility class that simplifies construction of a pool of daemon threads with which to execute tasks in parallel. | |||
```lua | |||
local workpool = dt.WorkPool(nThread) | |||
``` | |||
## CartTree | |||
Implements a trained CART decision tree: | |||
```lua | |||
local tree = nn.CartTree(rootNode) | |||
``` | |||
The `rootNode` is a `CartNode` instance. | |||
Each `CartNode` contains pointers to left and right branches, which are themselves `CartNode` instances. | |||
For inference, use the `score(input)` method: | |||
```lua | |||
local score = tree:score(input) | |||
``` |
@@ -1,159 +0,0 @@ | |||
local dt = require "decisiontree._env" | |||
local RandomForestTrainer = torch.class("dt.RandomForestTrainer", dt) | |||
function RandomForestTrainer:__init(opt) | |||
assert(torch.type(opt.nTree) == 'number') | |||
assert(opt.nTree > 0) | |||
self.nTree = opt.nTree | |||
-- max number of leaf nodes per tree | |||
assert(torch.type(opt.maxLeafNodes) == 'number') | |||
assert(opt.maxLeafNodes > 0) | |||
self.maxLeafNodes = opt.maxLeafNodes | |||
-- min number of examples per leaf | |||
assert(torch.type(opt.minLeafSize) == 'number') | |||
assert(opt.minLeafSize > 0) | |||
self.minLeafSize = opt.minLeafSize | |||
-- when non-positive, defaults to sqrt(#feature) | |||
assert(torch.type(opt.featureBaggingSize) == 'number') | |||
self.featureBaggingSize = opt.featureBaggingSize | |||
assert(torch.type(opt.activeRatio) == 'number') | |||
assert(opt.activeRatio > 0) | |||
self.activeRatio = opt.activeRatio | |||
-- default parallelization is singlethread | |||
self.parallelMode = 'singlethread' | |||
end | |||
function RandomForestTrainer:train(trainSet, featureIds, verbose) | |||
assert(torch.isTypeOf(trainSet, 'dt.DataSet')) | |||
assert(torch.type(featureIds) == 'torch.LongTensor') | |||
if verbose then print(string.format("Begin training Decision Forest with %d trees", self.nTree)) end | |||
local weight = torch.Tensor(self.nTree):fill(1 / self.nTree) -- RF uses uniform weights | |||
local trees | |||
if self.parallelMode == 'singlethread' then | |||
trees = self:trainTrees(trainSet, featureIds, verbose) | |||
elseif self.parallelMode == 'treeparallel' then | |||
trainSet:deleteIndex() -- prevents serialization bottleneck | |||
trees = self:trainTreesTP(trainSet, featureIds, verbose) | |||
else | |||
error("Unrecognized parallel mode: " .. self.parallelMode) | |||
end | |||
if verbose then print(string.format("Successfully trained %d trees", #trees)) end | |||
-- set bias | |||
local bias = 0; | |||
for i, tree in ipairs(trees) do | |||
bias = bias + tree.root.score * weight[i] | |||
end | |||
return dt.DecisionForest(trees, weight, bias) | |||
end | |||
function RandomForestTrainer:trainTrees(trainSet, featureIds, verbose) | |||
-- the same CartTrainer will be used for each tree | |||
local cartTrainer = dt.CartTrainer(trainSet, self.minLeafSize, self.maxLeafNodes) | |||
local trees = {} | |||
for treeId=1,self.nTree do | |||
-- Train a CartTree | |||
local tree = self.trainTree(cartTrainer, featureIds, self.featureBaggingSize, self.activeRatio, treeId, verbose) | |||
table.insert(trees, tree) | |||
end | |||
return trees | |||
end | |||
function RandomForestTrainer.trainTree(cartTrainer, featureIds, baggingSize, activeRatio, treeId, verbose) | |||
assert(torch.isTypeOf(cartTrainer, 'dt.CartTrainer')) | |||
assert(torch.type(featureIds) == 'torch.LongTensor') | |||
local baggingSize = baggingSize > 0 and baggingSize or torch.round(math.sqrt(featureIds:size(1))) | |||
if verbose then | |||
print(string.format("Tree %d: Creating features bootstrap sample with baggingSize %d, nFeatures %d", treeId, baggingSize, featureIds:size(1))) | |||
end | |||
local trainSet = cartTrainer.dataset | |||
-- sample boot strap features | |||
local baggingIndices = torch.LongTensor(baggingSize):random(1,featureIds:size(1)) | |||
local activeFeatures = featureIds:index(1, baggingIndices) | |||
-- sample boot strap examples | |||
local sampleSize = torch.round(trainSet:size() * activeRatio) | |||
if verbose then print(string.format("Creating bootstrap sample created of size %d", sampleSize)) end | |||
baggingIndices:resize(sampleSize):random(1,trainSet:size()) | |||
local bootStrapExampleIds = torch.LongTensor() | |||
bootStrapExampleIds:index(trainSet:getExampleIds(), 1, baggingIndices) | |||
local cartTree = cartTrainer:train(dt.GiniState(bootStrapExampleIds), activeFeatures) | |||
if verbose then print(string.format("Complete processing tree number %d", treeId)) end | |||
return cartTree | |||
end | |||
function RandomForestTrainer:treeParallel(workPool) | |||
assert(self.parallelMode == 'singlethread', self.parallelMode) | |||
self.parallelMode = 'treeparallel' | |||
self.workPool = torch.type(workPool) == 'number' and dt.WorkPool(workPool) or workPool | |||
assert(torch.isTypeOf(self.workPool, 'dt.WorkPool')) | |||
-- require the dt package | |||
self.workPool:update('require', {libname='decisiontree',varname='dt'}) | |||
end | |||
function RandomForestTrainer:trainTreesTP(trainSet, featureIds, verbose) | |||
assert(torch.isTypeOf(trainSet, 'dt.DataSet')) | |||
assert(torch.type(featureIds) == 'torch.LongTensor') | |||
local minLeafSize = self.minLeafSize | |||
local maxLeafNodes = self.maxLeafNodes | |||
-- setup worker store (each worker will have its own cartTrainer) | |||
self.workPool:updateup('execute', function(store) | |||
local dt = require 'decisiontree' | |||
store.cartTrainer = dt.CartTrainer(trainSet, minLeafSize, maxLeafNodes) | |||
store.featureIds = featureIds | |||
end) | |||
for treeId=1,self.nTree do | |||
-- upvalues | |||
local baggingSize = self.featureBaggingSize | |||
local activeRatio = self.activeRatio | |||
-- task closure that will be executed in worker-thread | |||
local function trainTreeTask(store) | |||
local dt = require 'decisiontree' | |||
return dt.RandomForestTrainer.trainTree(store.cartTrainer, store.featureIds, baggingSize, activeRatio, treeId, verbose) | |||
end | |||
self.workPool:writeup('execute', trainTreeTask) | |||
end | |||
local trees = {} | |||
for treeId=1,self.nTree do | |||
local taskname, tree = self.workPool:read() | |||
assert(taskname=='execute') | |||
assert(torch.isTypeOf(tree, 'dt.CartTree')) | |||
table.insert(trees, tree) | |||
end | |||
return trees | |||
end | |||
function RandomForestTrainer:getName() | |||
return string.format( | |||
"randomforest-aRatio-%4.2f-maxLeaf-%d-minExample-%d-nTree-%d", | |||
self.activeRatio, self.maxLeafNodes, self.minLeafSize, self.nTree | |||
) | |||
end | |||
@@ -1,88 +0,0 @@ | |||
local S2D, parent = torch.class("nn.Sparse2Dense", "nn.Module") | |||
local dt = require 'decisiontree._env' | |||
function S2D:__init(features) | |||
parent.__init(self) | |||
if torch.type(features) == 'table' then | |||
assert(#features > 0) | |||
features = torch.LongTensor(features) | |||
end | |||
assert(torch.isTensor(features)) | |||
self.features = features | |||
self.featureMap = nil | |||
self.masks = {} | |||
self.mappedKeys = {} | |||
end | |||
function S2D:updateOutput(input) | |||
if not self.featureMap then | |||
self.featureMap = dt.HashMap() | |||
self.featureMap:fill(self.features) | |||
end | |||
local batched, keys, values | |||
if torch.isTensor(input[1]) then | |||
keys = {input[1]} | |||
values = {input[2]} | |||
batched = false | |||
else | |||
keys = input[1] | |||
values = input[2] | |||
batched = true | |||
end | |||
assert(#keys == #values) | |||
local masks = self.masks | |||
local mappedKeys = self.mappedKeys | |||
local nKeys = #keys | |||
local nMasks = #masks | |||
if nMasks < nKeys then | |||
for i=nMasks+1,nKeys do | |||
masks[i] = torch.ByteTensor() | |||
mappedKeys[i] = torch.LongTensor() | |||
end | |||
elseif nMasks > nKeys then | |||
for i=nKeys+1,nMasks do | |||
masks[i] = nil | |||
mappedKeys[i] = nil | |||
end | |||
end | |||
self.featureMap:get(keys, mappedKeys, masks) | |||
self.output = self.output or torch.Tensor():type(self._type) | |||
self.output.nn.S2D_computeOutput(self.output, mappedKeys, values, masks, self.features) | |||
if not batched then | |||
self.output = self.output:view(-1) | |||
end | |||
return self.output | |||
end | |||
function S2D:type(type, tensorCache) | |||
if type then | |||
local features = self.features | |||
self.features = nil | |||
parent.type(self, type, tensorCache) | |||
self.features = features | |||
return self | |||
else | |||
return parent.type(self) | |||
end | |||
end | |||
function S2D:updateGradInput(input, gradOutput) | |||
error"Not Implemented" | |||
end | |||
function S2D:reset() | |||
parent.reset(self) | |||
self.featureMap = nil | |||
end | |||
function S2D:write(file) | |||
self.featureMap = nil | |||
parent.write(self, file) | |||
end | |||
function S2D:read(file) | |||
self.featureMap = nil | |||
parent.read(self, file) | |||
end |
@@ -1,54 +0,0 @@ | |||
local SparseTensor = torch.class("torch.SparseTensor") | |||
function SparseTensor:__init(keys, values) | |||
if keys and values then | |||
assert(torch.typename(keys):find('torch%..*LongTensor')) | |||
assert(torch.isTensor(values)) | |||
assert(keys:nElement() == values:nElement(), "Expecting key and value tensors of same size") | |||
self.keys = keys | |||
self.values = values | |||
elseif not (keys or values) then | |||
self.keys = torch.LongTensor() | |||
self.values = torch.Tensor() | |||
else | |||
error"Expecting zero or two args" | |||
end | |||
end | |||
function SparseTensor:buildIndex(overwrite) | |||
if self._map and not overwrite then return end | |||
assert(self.keys and self.keys:dim() == 1) | |||
assert(self.values and self.values:dim() == 1) | |||
-- hash table | |||
self._map = {} | |||
for i=1,self.keys:size(1) do | |||
self._map[self.keys[i]] = i | |||
end | |||
end | |||
function SparseTensor:deleteIndex() | |||
self._map = nil | |||
end | |||
local __index = SparseTensor.__index | |||
function SparseTensor:__index(key) | |||
if key == nil then | |||
error"Attempt to index using a nil key" | |||
elseif torch.type(key) ~= 'number' then | |||
return __index(self, key) | |||
end | |||
if self._map then | |||
assert(torch.type(self._map) == 'table') | |||
local idx = self._map[key] | |||
return idx and self.values[idx] or nil | |||
elseif self.keys:nElement() > 0 then | |||
for i=1,self.keys:size(1) do | |||
if self.keys[i] == key then | |||
return self.values[i] | |||
end | |||
end | |||
end | |||
return nil | |||
end |
@@ -1,191 +0,0 @@ | |||
local dt = require "decisiontree._env" | |||
local TreeState = torch.class("dt.TreeState", dt) | |||
function TreeState:__init(exampleIds) | |||
assert(torch.type(exampleIds) == 'torch.LongTensor') | |||
self.exampleIds = exampleIds | |||
self.nExampleInLeftBranch = 0 | |||
self.nExampleInRightBranch = 0 | |||
end | |||
function TreeState:score(dataset) | |||
error"NotImplemented" | |||
end | |||
function TreeState:initialize(exampleIdsWithFeature, dataset) | |||
error"NotImplemented" | |||
end | |||
function TreeState:update(exampleId, dataset) | |||
error"NotImplemented" | |||
end | |||
function TreeState:computeSplitInfo(splitFeatureId, splitFeatureValue) | |||
error"NotImplemented" | |||
end | |||
function TreeState:findBestFeatureSplit(dataset, featureId, minLeafSize) | |||
local dt = require "decisiontree" | |||
assert(torch.isTypeOf(dataset, 'dt.DataSet')) | |||
assert(torch.type(featureId) == 'number') | |||
assert(torch.type(minLeafSize) == 'number') | |||
-- all dataset example having this feature, sorted by value | |||
local featureExampleIds = dataset:getSortedFeature(featureId) | |||
local buffer = dt.getBufferTable('TreeState') | |||
buffer.longtensor = buffer.longtensor or torch.LongTensor() | |||
local exampleIdsWithFeature = buffer.longtensor | |||
-- map and tensor of examples containing feature: | |||
local exampleMap = {} | |||
local getExampleFeatureValue | |||
local j = 0 | |||
if torch.type(dataset.input) == 'table' then | |||
exampleIdsWithFeature:resize(self.exampleIds:size()) | |||
self.exampleIds:apply(function(exampleId) | |||
local input = dataset.input[exampleId] | |||
input:buildIndex()-- only builds index first time | |||
if input[featureId] then | |||
j = j + 1 | |||
exampleIdsWithFeature[j] = exampleId | |||
exampleMap[exampleId] = j | |||
end | |||
end) | |||
if j == 0 then | |||
return | |||
end | |||
exampleIdsWithFeature:resize(j) | |||
getExampleFeatureValue = function(exampleId) return dataset.input[exampleId][featureId] end | |||
else | |||
exampleIdsWithFeature = self.exampleIds | |||
self.exampleIds:apply(function(exampleId) | |||
j = j + 1 | |||
exampleMap[exampleId] = j | |||
end) | |||
local featureValues = dataset.input:select(2,featureId) | |||
getExampleFeatureValue = function(exampleId) return featureValues[exampleId] end | |||
end | |||
self:initialize(exampleIdsWithFeature, dataset) | |||
-- bottleneck | |||
local bestSplit, previousSplitValue, _tictoc | |||
for i=featureExampleIds:size(1),1,-1 do -- loop over examples sorted (desc) by feature value | |||
local exampleId = featureExampleIds[i] | |||
local exampleIdx = exampleMap[exampleId] | |||
if exampleIdx then | |||
local splitValue = getExampleFeatureValue(exampleId) | |||
if previousSplitValue and math.abs(splitValue - previousSplitValue) > dt.EPSILON then | |||
local splitInfo = self:computeSplitInfo(featureId, previousSplitValue, _tictoc) | |||
if (splitInfo.leftChildSize >= minLeafSize) and (splitInfo.rightChildSize >= minLeafSize) then | |||
if (not bestSplit) or (splitInfo.splitGain < bestSplit.splitGain) then | |||
_tictoc = bestSplit or {} -- reuse table | |||
bestSplit = splitInfo | |||
end | |||
end | |||
end | |||
previousSplitValue = splitValue | |||
-- bottleneck | |||
self:update(exampleId, dataset, exampleIdx) | |||
end | |||
end | |||
return bestSplit | |||
end | |||
function TreeState:findBestSplit(dataset, featureIds, minLeafSize, shardId, nShard) | |||
assert(torch.isTypeOf(dataset, 'dt.DataSet')) | |||
assert(torch.type(featureIds) == 'torch.LongTensor') | |||
assert(torch.type(minLeafSize) == 'number') | |||
assert(torch.type(shardId) == 'number') | |||
assert(torch.type(nShard) == 'number') | |||
local bestSplit | |||
for i=1,featureIds:size(1) do | |||
local featureId = featureIds[i] | |||
if (nShard <= 1) or ( (featureId % nShard) + 1 == shardId ) then -- feature sharded | |||
local splitCandidate = self:findBestFeatureSplit(dataset, featureId, minLeafSize) | |||
if splitCandidate and ((not bestSplit) or (splitCandidate.splitGain < bestSplit.splitGain)) then | |||
bestSplit = splitCandidate | |||
end | |||
end | |||
end | |||
return bestSplit | |||
end | |||
function TreeState:_branch(splitInfo, dataset) | |||
local leftIdx, rightIdx = 0, 0 | |||
local nExample = self.exampleIds:size(1) | |||
local splitExampleIds = torch.LongTensor(nExample) | |||
for i=1,self.exampleIds:size(1) do | |||
local exampleId = self.exampleIds[i] | |||
local input = dataset.input[exampleId] | |||
local val = input[splitInfo.splitId] | |||
-- Note: when the feature is not present in the example, the example is droped from all sub-trees. | |||
-- Which means that for most sparse data, a tree cannot reach 100% accuracy... | |||
if val then | |||
if val < splitInfo.splitValue then | |||
leftIdx = leftIdx + 1 | |||
splitExampleIds[leftIdx] = exampleId | |||
else | |||
rightIdx = rightIdx + 1 | |||
splitExampleIds[nExample-rightIdx+1] = exampleId | |||
end | |||
end | |||
end | |||
local leftExampleIds = splitExampleIds:narrow(1,1,leftIdx) | |||
local rightExampleIds = splitExampleIds:narrow(1,nExample-rightIdx+1,rightIdx) | |||
assert(leftExampleIds:size(1) + rightExampleIds:size(1) <= self.exampleIds:size(1), "Left and right branches contain more data than the parent!") | |||
return leftExampleIds, rightExampleIds | |||
end | |||
function TreeState:branch(splitInfo, dataset) | |||
local leftExampleIds, rightExampleIds = self:_branch(splitInfo, dataset) | |||
return self.new(leftExampleIds), self.new(rightExampleIds) | |||
end | |||
function TreeState:size() | |||
return self.exampleIds:size(1) | |||
end | |||
function TreeState:contains(exampleId) | |||
local found = false | |||
self.exampleIds:apply(function(x) | |||
if x == exampleId then | |||
found = true | |||
end | |||
end) | |||
return found | |||
end | |||
@@ -1,156 +0,0 @@ | |||
local dt = require "decisiontree._env" | |||
local WorkPool = torch.class("dt.WorkPool", dt) | |||
function WorkPool:__init(nThread) | |||
self.nThread = nThread or 16 | |||
assert(torch.type(self.nThread) == 'number') | |||
assert(self.nThread > 0) | |||
self:initialize() | |||
end | |||
function WorkPool:initialize() | |||
local ipc = require 'libipc' | |||
self.queuename = os.tmpname() | |||
self.queue = ipc.workqueue(self.queuename) | |||
self.queues = {} | |||
for i=1,self.nThread do | |||
self.queues[i] = ipc.workqueue(self.queuename.."/"..i) | |||
end | |||
-- spawn thread workers | |||
ipc.map(self.nThread, function(queuename, nThread, myId) | |||
assert(nThread) | |||
assert(myId) | |||
local ipc = require 'libipc' | |||
-- Open the queue by name (the main thread already created it) | |||
local mainqueue = ipc.workqueue(queuename) | |||
local workqueue = ipc.workqueue(queuename.."/"..myId) | |||
local taskname, args | |||
local store = {} | |||
local queue = mainqueue | |||
repeat | |||
local msg = queue:read() | |||
assert(torch.type(msg) == 'table') | |||
taskname, task = unpack(msg) | |||
if taskname == nil then | |||
break | |||
elseif torch.type(taskname) ~= 'string' then | |||
error("Expecting taskname string. Got "..torch.type(taskname)) | |||
elseif taskname == 'storeKeyValue' then | |||
assert(torch.type(task) == 'table') | |||
assert(queue == workqueue) | |||
store[task.key] = task.value | |||
queue:write({taskname}) | |||
elseif taskname == 'storeKeysValues' then | |||
assert(torch.type(task) == 'table') | |||
assert(queue == workqueue) | |||
for key,value in pairs(task) do | |||
store[key] = value | |||
end | |||
queue:write({taskname}) | |||
elseif taskname == 'require' then | |||
assert(torch.type(task) == 'table') | |||
assert(torch.type(task.libname) == 'string') | |||
assert(torch.type(task.varname) == 'string') | |||
_G[task.varname] = require(task.libname) | |||
assert(queue == workqueue) | |||
queue:write({taskname}) | |||
elseif taskname == 'storeReset' then | |||
store = {} | |||
mainqueue:write({taskname}) | |||
elseif taskname == 'echo' then | |||
mainqueue:write({taskname, task}) | |||
elseif taskname == 'readWorkerQueue' then | |||
queue = workqueue | |||
elseif taskname == 'readMainQueue' then | |||
queue = mainqueue | |||
elseif taskname == 'execute' then | |||
if torch.type(task) == 'table' then | |||
assert(task.func and task.args) | |||
queue:write({taskname, task.func(store, task.args, myId)}) | |||
else | |||
assert(torch.type(task) == 'function') | |||
queue:write({taskname, task(store, myId)}) | |||
end | |||
else | |||
error("Unknown taskname: "..taskname) | |||
end | |||
until taskname == nil | |||
end, self.queuename, self.nThread) | |||
end | |||
function WorkPool:terminate() | |||
for i=1,self.nThread do | |||
self.queue:write({}) | |||
end | |||
end | |||
function WorkPool:_update(taskname, task, upval) | |||
assert(torch.type(taskname) == 'string') | |||
local _ = require 'moses' | |||
assert(_.contains({'storeKeyValue','storeKeysValues','require','execute'}, taskname)) | |||
assert(torch.type(task) == 'table' or torch.type(task) == 'function') | |||
-- tell the workers to read their individual queue | |||
for i=1,self.nThread do | |||
self.queue:write({'readWorkerQueue'}) | |||
end | |||
-- write to individual worker queues | |||
for i=1,self.nThread do | |||
if upval then | |||
self.queues[i]:writeup({taskname, task}) | |||
else | |||
self.queues[i]:write({taskname, task}) | |||
end | |||
end | |||
-- TODO use ipc.mutex:barrier(nThread+1) | |||
-- barrier: make sure that every worker has completed task by reading their queue | |||
for i=1,self.nThread do | |||
assert(self.queues[i]:read()[1] == taskname) | |||
end | |||
-- finally, tell them to read the main queue | |||
for i=1,self.nThread do | |||
self.queues[i]:write({'readMainQueue'}) | |||
end | |||
end | |||
function WorkPool:update(taskname, task) | |||
return self:_update(taskname, task, false) | |||
end | |||
function WorkPool:updateup(taskname, task) | |||
return self:_update(taskname, task, true) | |||
end | |||
function WorkPool:write(taskname, task) | |||
assert(torch.type(taskname) == 'string') | |||
assert(taskname ~= 'storeKeyValue' or taskname ~= 'storeKeysValues') | |||
self.queue:write({taskname, task}) | |||
end | |||
function WorkPool:writeup(taskname, task) | |||
assert(torch.type(taskname) == 'string') | |||
assert(taskname ~= 'storeKeyValue' or taskname ~= 'storeKeysValues') | |||
self.queue:writeup({taskname, task}) | |||
end | |||
function WorkPool:read() | |||
local res = self.queue:read() | |||
assert(torch.type(res) == 'table') | |||
assert(torch.type(res[1] == 'string')) | |||
return unpack(res) | |||
end | |||
@@ -1,5 +0,0 @@ | |||
local dl = {} | |||
return dl |
@@ -1,171 +0,0 @@ | |||
local dt = require "decisiontree._env" | |||
local bm = {} | |||
function bm.CartTrainer(opt) | |||
local timer = torch.Timer() | |||
local trainSet, validSet = dt.getSparseDummyData(opt) | |||
print(string.format("CartTrainer: sparse dataset create: %f samples/sec; %f sec", opt.nExample/timer:time().real, timer:time().real)) | |||
local cartTrainer = dt.CartTrainer(trainSet, opt.minLeafSize, opt.maxLeafNodes) | |||
local treeState = dt.GiniState(trainSet:getExampleIds()) | |||
timer:reset() | |||
local cartTree, nleaf = cartTrainer:train(treeState, trainSet.featureIds) | |||
print(string.format("CartTrainer: train single-thread : %f samples/sec; %f sec", opt.nExample/timer:time().real, timer:time().real)) | |||
timer:reset() | |||
cartTrainer:featureParallel(opt.nThread) | |||
print(string.format("CartTrainer: setup feature-parallel : %f samples/sec; %f sec", opt.nExample/timer:time().real, timer:time().real)) | |||
timer:reset() | |||
local cartTree, nleaf = cartTrainer:train(treeState, trainSet.featureIds) | |||
print(string.format("CartTrainer: train feature-parallel : %f samples/sec; %f sec", opt.nExample/timer:time().real, timer:time().real)) | |||
end | |||
function bm.GradientBoostState(opt) | |||
local trainSet, validSet = dt.getSparseDummyData(opt) | |||
trainSet:initScore() | |||
local treeState = dt.GradientBoostState(trainSet:getExampleIds(), nn.LogitBoostCriterion(false)) | |||
local timer = torch.Timer() -- first step also calls SparseTensor:buildIndex() | |||
treeState:findBestSplit(trainSet, trainSet.featureIds, 10, 1, 3) | |||
print(string.format("GradientBoostState: findBestSplit (first) : %f sec", timer:time().real)) | |||
timer:reset() | |||
treeState:findBestSplit(trainSet, trainSet.featureIds, 10, 1, 3) | |||
print(string.format("GradientBoostState: findBestSplit (second) : %f sec", timer:time().real)) | |||
end | |||
local function file_exists(name) | |||
local f=io.open(name,"r") | |||
if f~=nil then io.close(f) return true else return false end | |||
end | |||
function bm.GradientBoostTrainer(opt) | |||
local trainSet, validSet | |||
if file_exists("/tmp/train.bin") and file_exists("/tmp/valid.bin") then | |||
trainSet = torch.load("/tmp/train.bin") | |||
validSet = torch.load("/tmp/valid.bin") | |||
else | |||
if opt.sparse then | |||
trainSet, validSet = dt.getSparseDummyData(opt) | |||
else | |||
trainSet, validSet = dt.getDenseDummyData(opt) | |||
end | |||
torch.save("/tmp/train.bin", trainSet) | |||
torch.save("/tmp/valid.bin", validSet) | |||
end | |||
local cartTrainer = dt.CartTrainer(trainSet, opt.minLeafSize, opt.maxLeafNodes) | |||
opt.lossFunction = nn.LogitBoostCriterion(false) | |||
opt.treeTrainer = cartTrainer | |||
local forestTrainer = dt.GradientBoostTrainer(opt) | |||
local timer = torch.Timer() | |||
local decisionForest = forestTrainer:train(trainSet, trainSet.featureIds, validSet) | |||
local time = timer:time().real | |||
print(string.format("GradientBoostTrainer: train single-thread : %f samples/sec; %f sec/tree, %f sec", opt.nExample/time, time/opt.nTree, time)) | |||
cartTrainer:featureParallel(opt.nThread) | |||
timer:reset() | |||
local decisionForest = forestTrainer:train(trainSet, trainSet.featureIds, validSet) | |||
local time = timer:time().real | |||
print(string.format("GradientBoostTrainer: train feature-parallel : %f samples/sec; %f sec/tree, %f sec", opt.nExample/time, time/opt.nTree, time)) | |||
end | |||
function bm.RandomForestTrainer(opt) | |||
local trainSet, validSet = dt.getSparseDummyData(opt) | |||
local forestTrainer = dt.RandomForestTrainer(opt) | |||
local decisionForest = forestTrainer:train(trainSet, trainSet.featureIds) | |||
local timer = torch.Timer() | |||
local decisionForest = forestTrainer:train(trainSet, trainSet.featureIds) | |||
local time = timer:time().real | |||
print(string.format("RandomForestTrainer: train single-thread : %f samples/sec; %f sec/tree, %f sec", opt.nExample/time, time/opt.nTree, time)) | |||
timer:reset() | |||
forestTrainer:treeParallel(opt.nThread) | |||
print(string.format("RandomForestTrainer: setup tree-parallel : %f samples/sec; %f sec", opt.nExample/timer:time().real, timer:time().real)) | |||
timer:reset() | |||
local decisionForest = forestTrainer:train(trainSet, trainSet.featureIds) | |||
local time = timer:time().real | |||
print(string.format("RandomForestTrainer: train tree-parallel : %f samples/sec; %f sec/tree, %f sec", opt.nExample/time, time/opt.nTree, time)) | |||
end | |||
function bm.DFD(opt) | |||
local _ = require 'moses' | |||
local opt = _.clone(opt) | |||
opt.nExample = 200 | |||
local trainSet, validSet = dt.getDenseDummyData(opt) | |||
local forestTrainer = dt.RandomForestTrainer(opt) | |||
forestTrainer:treeParallel(opt.nThread) | |||
local timer = torch.Timer() | |||
local decisionForest = forestTrainer:train(trainSet, trainSet.featureIds) | |||
local time = timer:time().real | |||
print(string.format("DFD: train random forest in parallel : %f samples/sec; %f sec/tree, %f sec", opt.nExample/time, time/opt.nTree, time)) | |||
-- benchmark nn.DFD | |||
local input = trainSet.input:sub(1,opt.batchsize) | |||
local dfd = nn.DFD(decisionForest) | |||
dfd:forward(input) | |||
timer:reset() | |||
for i=1,opt.nloop do | |||
dfd:forward(input) | |||
end | |||
print(string.format("DFD: updateOutput : %f samples/sec; %f sec", opt.nloop*opt.batchsize/timer:time().real, timer:time().real)) | |||
end | |||
function bm.Sparse2Dense(opt) | |||
local _ = require 'moses' | |||
local opt = _.clone(opt) | |||
opt.nExample = opt.batchsize | |||
local trainSet = dt.getSparseDummyData(opt) | |||
local input = {{},{}} | |||
for i=1,opt.batchsize do | |||
input[1][i] = trainSet.input[i].keys | |||
input[2][i] = trainSet.input[i].values | |||
end | |||
assert(#input[1] == opt.batchsize) | |||
-- benchmark nn.Sparse2Dense | |||
local s2d = nn.Sparse2Dense(torch.LongTensor():range(1,opt.nFeature)) | |||
s2d:forward(input) | |||
local timer = torch.Timer() | |||
for i=1,opt.nloop do | |||
s2d:forward(input) | |||
end | |||
print(string.format("Sparse2Dense: updateOutput : %f samples/sec; %f sec", opt.nloop*opt.batchsize/timer:time().real, timer:time().real)) | |||
end | |||
function dt.benchmark(benchmarks, opt2) | |||
local opt = { | |||
nExample=10000, nCluster=2, nFeature=1000, overlap=0, nValid=100, -- getSparseDummyData | |||
nTree=20, featureBaggingSize=-1, sparse=true, -- GradientBoostTrainer and RandomForestTrainer | |||
nThread=2, shrinkage=0.1, downsampleRatio=0.1, evalFreq=5, earlyStop=0, -- GradientBoostTrainer | |||
activeRatio=0.5, -- RandomForestTrainer | |||
batchsize=32, nloop=10 | |||
} | |||
local _ = require 'moses' | |||
benchmarks = benchmarks or _.keys(bm) | |||
assert(torch.type(benchmarks) == 'table') | |||
for i,benchmark in ipairs(benchmarks) do | |||
local opt1 = _.clone(opt) | |||
for key, value in pairs(opt2 or {}) do | |||
opt1[key] = value | |||
end | |||
opt1.nActive = opt1.nActive or torch.round(opt1.nFeature/10) | |||
opt1.maxLeafNodes = opt1.maxLeafNodes or (opt1.nExample/10) | |||
opt1.minLeafSize = opt1.minLeafSize or (opt1.nExample/100) | |||
assert(torch.type(benchmark) == 'string', benchmark) | |||
assert(bm[benchmark], benchmark) | |||
bm[benchmark](opt1) | |||
end | |||
end |
@@ -1,291 +0,0 @@ | |||
# Benchmarks | |||
This file outlines the roadmap (and commensurate benchmarks) of optimizations and refactorings over time. | |||
## Baseline | |||
The baseline implementation is very slow. | |||
We converted the Twitter decision tree library (used internally) from Java to Lua. | |||
The objective was to replicate the GBDT and Random Forest implementations as is (more or less). | |||
The Java library is very good and reasonably fast. The same code in Lua is slow. | |||
The point of this Lua baseline was not to obtain the same computational performance as the Java library. | |||
Instead, we wanted the training and inferences algorithms of the Lua lib to match thoses of the Java lib. | |||
As such, the training/validation error of the baseline Lua lib should match that of the Java lib. | |||
The unit tests seem to validate this claim as both training/validation set performance is unit tested. | |||
We also used the conversion exercise as a way to learn about decision tree implementation (our background is deep learning). | |||
That being said, the baseline performance is terrible: | |||
``` | |||
th -e "dt = require 'decisiontree'; dt.benchmark()" | |||
CartTrainer: sparse dataset create: 2963.192386 samples/sec; 0.337479 sec | |||
CartTrainer: train single-thread : 14.165438 samples/sec; 70.594361 sec | |||
CartTrainer: setup feature-parallel : 5.129034 samples/sec; 194.968478 sec | |||
CartTrainer: train feature-parallel : 9.736592 samples/sec; 102.705344 sec | |||
``` | |||
The original Java lib had approximately 43 classes. | |||
The baseline has about 24. | |||
This reduction is due to obvious merging of classes. But also to conversions of classes to functions. | |||
The next patches continue this process of reducing the number of classes. | |||
## Patch 1 (complete): | |||
This patch further reduces the number of classes, but adds the DataSet class. | |||
The code is much simple to read. Examples are batched. | |||
* [x] examples are batched in dt.DataSet: {input, target, score} | |||
* [x] deprecate dt.LabeledExample | |||
* [x] list of examples are replaced with torch.LongTensors of exampleIds | |||
* [x] merge TreeBrancher into TreeState | |||
* [x] merge BestSplitFinder and SplitStateUpdater into TreeState | |||
* [x] TreeState subclasses: GradientBoostState and GiniState | |||
``` | |||
th -e "dt = require 'decisiontree'; dt.benchmark()" | |||
CartTrainer: sparse dataset create: 3597.392294 samples/sec; 0.277984 sec | |||
CartTrainer: train single-thread : 35.763255 samples/sec; 27.961663 sec | |||
CartTrainer: setup feature-parallel : 36759.250495 samples/sec; 0.027220 sec | |||
CartTrainer: train feature-parallel : 72.523658 samples/sec; 13.788606 sec | |||
``` | |||
The setup time for feature-parallelization is most improved. | |||
The run-time for feature-parallel also about half that of single-thread. | |||
Since its using 2 threads, that means the parallelization is working quite well. | |||
We also added benchmarks for the `RandomForestTrainer` and `GradientBoostTrainer`: | |||
``` | |||
GradientBoostTrainer: train single-thread : 599.895105 samples/sec; 0.083348 sec/tree, 1.666958 sec | |||
GradientBoostTrainer: train feature-parallel : 974.235273 samples/sec; 0.051322 sec/tree, 1.026446 sec | |||
RandomForestTrainer: train single-thread : 134.781044 samples/sec; 0.370972 sec/tree, 7.419441 sec | |||
RandomForestTrainer: setup tree-parallel : 73341.097064 samples/sec; 0.013649 sec | |||
RandomForestTrainer: train tree-parallel : 262.975891 samples/sec; 0.190131 sec/tree, 3.802630 sec | |||
``` | |||
Looks good. | |||
## Patch 2 (complete): | |||
* [x] dt.LossFunction -> nn.Criterion (LogitBoost is done, missing MSE) | |||
* [x] use SparseTensor:buildIndex() to accelerate TreeState:findBestSplit() | |||
* [x] benchmarks use 10000 instead of 1000 examples | |||
The benchmarks indicate good improvements. Most improvements were made possible by the use of `buildIndex`: | |||
``` | |||
th -e "dt = require 'decisiontree'; dt.benchmark()" | |||
GradientBoostState: findBestSplit (first) : 11.415645 sec | |||
GradientBoostState: findBestSplit (second) : 11.246336 sec | |||
CartTrainer: sparse dataset create: 3284.803629 samples/sec; 3.044327 sec | |||
CartTrainer: train single-thread : 239.544758 samples/sec; 41.745858 sec | |||
CartTrainer: setup feature-parallel : 10996.443063 samples/sec; 0.909390 sec | |||
CartTrainer: train feature-parallel : 473.888592 samples/sec; 21.102011 sec | |||
RandomForestTrainer: train single-thread : 892.985186 samples/sec; 0.559920 sec/tree, 11.198394 sec | |||
RandomForestTrainer: setup tree-parallel : 176806.252266 samples/sec; 0.056569 sec | |||
RandomForestTrainer: train tree-parallel : 1377.849291 samples/sec; 0.362884 sec/tree, 7.257688 sec | |||
GradientBoostTrainer: train single-thread : 2685.485128 samples/sec; 0.186186 sec/tree, 3.723722 sec | |||
GradientBoostTrainer: train feature-parallel : 3712.313215 samples/sec; 0.134687 sec/tree, 2.693738 sec | |||
``` | |||
The main bottleneck now is in serializing the SparseTensor hash maps. We temporarly overcame this bottleneck by | |||
deleting indexes when calling `CartTrainer:featureParallel()` and `RandomForestTrainer:treeParallel()`. | |||
In this way, the indexes are recreated for each thread. Ideally, we would use a C hash map such that a pointer | |||
could be serialized instead. But `tds.Hash` does not serialize well. For now instead, we use lua tables. | |||
This is the benchmark for `GradientBoostTrainer` on a large dataset of dense inputs: | |||
``` | |||
th -e "dt = require 'decisiontree'; dt.benchmark({'GradientBoostTrainer'}, {nExample=100000, sparse=false, nFeature=836, nTree=5, downsampleRatio=1, minLeafSize=1000, maxLeafNodes=8})" | |||
GradientBoostTrainer: train single-thread : 152.463989 samples/sec; 131.178517 sec/tree, 655.892584 sec | |||
GradientBoostTrainer: train feature-parallel : 224.288488 samples/sec; 89.170872 sec/tree, 445.854358 sec | |||
[tw-mbp-nleonard decisiontree]$ th -e "dt = require 'decisiontree'; dt.benchmark({'GradientBoostTrainer'}, {nExample=100000, sparse=false, nFeature=836, nTree=5, downsampleRatio=1, minLeafSize=1000, maxLeafNodes=8,nThread=4})" | |||
GradientBoostTrainer: train single-thread : 163.836896 samples/sec; 122.072625 sec/tree, 610.363126 sec | |||
GradientBoostTrainer: train feature-parallel : 407.981442 samples/sec; 49.021838 sec/tree, 245.109188 sec | |||
``` | |||
## Patch 3 : | |||
Optimize GBDT for large datasets consisting of dense inputs. The benchmarks: | |||
``` | |||
th -e "dt = require 'decisiontree'; dt.benchmark({'GradientBoostTrainer'}, {nExample=100000, sparse=false, nFeature=836, nTree=5, downsampleRatio=1, minLeafSize=1000, maxLeafNodes=8})" | |||
GradientBoostTrainer: train single-thread : 547.553407 samples/sec; 36.526117 sec/tree, 182.630587 sec | |||
GradientBoostTrainer: train feature-parallel : 792.964678 samples/sec; 25.221804 sec/tree, 126.109022 sec | |||
[tw-mbp-nleonard decisiontree]$ th -e "dt = require 'decisiontree'; dt.benchmark({'GradientBoostTrainer'}, {nExample=100000, sparse=false, nFeature=836, nTree=5, downsampleRatio=1, minLeafSize=1000, maxLeafNodes=8,nThread=4})" | |||
GradientBoostTrainer: train single-thread : 555.793759 samples/sec; 35.984571 sec/tree, 179.922855 sec | |||
GradientBoostTrainer: train feature-parallel : 1289.977846 samples/sec; 15.504142 sec/tree, 77.520711 sec | |||
``` | |||
For 1, 2 and 4 threads, the speedups of patch 3 over patch 2 are respectively: 3.39, 3.53, and 3.18. | |||
For this patch, the multi-threading speedup of 2 and 4 threads over a single thread are respectively: 1.42 and 2.33. | |||
Improvements over the previous patch were obtained by optimizing two aspects: | |||
1. Optimizing `TreeState.findBestFeatureSplit` for dense datasets (for example: `if dense, then ...`); | |||
2. Removing `assert` clauses in `GradientBoostState.update`. The `update` method is called for every (example, feature), making it a major bottleneck. | |||
Converting the `update` to C could lead to further optimizations. | |||
This patch also improves the benchmark on sparse datasets: | |||
``` | |||
$ th -e "dt = require 'decisiontree'; dt.benchmark()" | |||
RandomForestTrainer: train single-thread : 1121.311196 samples/sec; 0.445907 sec/tree, 8.918131 sec | |||
RandomForestTrainer: setup tree-parallel : 168773.323354 samples/sec; 0.059256 sec | |||
RandomForestTrainer: train tree-parallel : 1701.280938 samples/sec; 0.293896 sec/tree, 5.877924 sec | |||
GradientBoostState: findBestSplit (first) : 8.250646 sec | |||
GradientBoostState: findBestSplit (second) : 7.952077 sec | |||
GradientBoostTrainer: train single-thread : 3355.248596 samples/sec; 0.149020 sec/tree, 2.980405 sec | |||
GradientBoostTrainer: train feature-parallel : 4399.133369 samples/sec; 0.113659 sec/tree, 2.273175 sec | |||
CartTrainer: sparse dataset create: 3428.105601 samples/sec; 2.917069 sec | |||
CartTrainer: train single-thread : 282.172416 samples/sec; 35.439331 sec | |||
CartTrainer: setup feature-parallel : 9455.440801 samples/sec; 1.057598 sec | |||
CartTrainer: train feature-parallel : 594.054049 samples/sec; 16.833491 sec | |||
DFD: train random forest in parallel : 346.831378 samples/sec; 0.288325 sec/tree, 5.766491 sec | |||
DFD: updateOutput : 831.105546 samples/sec; 0.038509 sec | |||
``` | |||
## Patch 4 : | |||
This patch improves `nn.DFD` from | |||
``` | |||
th -e "dt = require 'decisiontree'; dt.benchmark({'DFD'}, {nTree=500,maxLeafNodes=8,minLeafSize=1})" | |||
DFD: train random forest in parallel : 10.527251 samples/sec; 0.037997 sec/tree, 18.998313 sec | |||
DFD: updateOutput : 32.442950 samples/sec; 9.863472 sec | |||
``` | |||
to | |||
``` | |||
th -e "dt = require 'decisiontree'; dt.benchmark({'DFD'}, {nTree=500,maxLeafNodes=8,minLeafSize=1})" | |||
DFD: train random forest in parallel : 10.839547 samples/sec; 0.036902 sec/tree, 18.450956 sec | |||
DFD: updateOutput : 359.158353 samples/sec; 0.890975 sec | |||
Sparse2Dense: updateOutput : 15395.648952 samples/sec; 0.020791 sec | |||
``` | |||
That is a 10x speedup for `nn.DFD`. | |||
The patch also adds a benchmark for `nn.Sparse2Dense`: | |||
``` | |||
th -e "dt = require 'decisiontree'; dt.benchmark({'Sparse2Dense'}, {nTree=500,maxLeafNodes=8,minLeafSize=1})" | |||
Sparse2Dense: updateOutput : 17158.126406 samples/sec; 0.018653 sec | |||
``` | |||
Indeed, `nn.Sparse2Dense` is not the bottleneck; `nn.DFD` is. | |||
## Patch 5 : | |||
This patch improves `nn.DFD` inference from | |||
``` | |||
for i in `seq 3`; do th -e "dt = require 'decisiontree'; dt.benchmark({'DFD'}, {nTree=500,maxLeafNodes=8,minLeafSize=1,batchsize=16,nActive=1200,nFeature=1300,nloop=100})"; done | |||
DFD: train random forest in parallel : 8.452295 samples/sec; 0.047324 sec/tree, 23.662212 sec | |||
DFD: updateOutput : 176.617872 samples/sec; 9.059109 sec | |||
DFD: train random forest in parallel : 8.350019 samples/sec; 0.047904 sec/tree, 23.952042 sec | |||
DFD: updateOutput : 183.508204 samples/sec; 8.718962 sec | |||
DFD: train random forest in parallel : 8.525779 samples/sec; 0.046917 sec/tree, 23.458266 sec | |||
DFD: updateOutput : 178.877077 samples/sec; 8.944692 sec | |||
``` | |||
to | |||
``` | |||
for i in `seq 3`; do th -e "dt = require 'decisiontree'; dt.benchmark({'DFD'}, {nTree=500,maxLeafNodes=8,minLeafSize=1,batchsize=16,nActive=1200,nFeature=1300,nloop=100})"; done | |||
DFD: train random forest in parallel : 8.434502 samples/sec; 0.047424 sec/tree, 23.712129 sec | |||
DFD: updateOutput : 6479.597179 samples/sec; 0.246933 sec | |||
DFD: train random forest in parallel : 8.334543 samples/sec; 0.047993 sec/tree, 23.996518 sec | |||
DFD: updateOutput : 6663.641184 samples/sec; 0.240114 sec | |||
DFD: train random forest in parallel : 8.353265 samples/sec; 0.047885 sec/tree, 23.942735 sec | |||
DFD: updateOutput : 6882.607456 samples/sec; 0.232475 sec | |||
``` | |||
That is a 37x speedup for `nn.DFD`. | |||
## Patch 6: | |||
This patch improves `nn.DFD` from the previous result to | |||
``` | |||
for i in `seq 5`; do th -e "dt = require 'decisiontree'; dt.benchmark({'DFD'}, {nTree=500,maxLeafNodes=8,minLeafSize=1,batchsize=16,nActive=1200,nFeature=1300,nloop=10000})"; done | |||
DFD: train random forest in parallel : 8.353504 samples/sec; 0.047884 sec/tree, 23.942050 sec | |||
DFD: updateOutput : 91967.342339 samples/sec; 1.739753 sec | |||
DFD: train random forest in parallel : 8.528141 samples/sec; 0.046904 sec/tree, 23.451770 sec | |||
DFD: updateOutput : 91405.321702 samples/sec; 1.750451 sec | |||
DFD: train random forest in parallel : 8.184562 samples/sec; 0.048872 sec/tree, 24.436250 sec | |||
DFD: updateOutput : 91623.388867 samples/sec; 1.746284 sec | |||
DFD: train random forest in parallel : 8.779561 samples/sec; 0.045560 sec/tree, 22.780182 sec | |||
DFD: updateOutput : 93914.242852 samples/sec; 1.703686 sec | |||
DFD: train random forest in parallel : 8.636201 samples/sec; 0.046317 sec/tree, 23.158330 sec | |||
DFD: updateOutput : 94092.241963 samples/sec; 1.700465 sec | |||
``` | |||
That is another 13.8x speedup. | |||
## Patch 7: | |||
This patch improves `nn.Sparse2Dense` computation from | |||
``` | |||
for i in `seq 3`; do th -e "dt = require 'decisiontree'; torch.setdefaulttensortype('torch.FloatTensor'); dt.benchmark({'Sparse2Dense'}, {nTree=500,maxLeafNodes=8,minLeafSize=1,nFeature=1500,nActive=1300,nloop=1000})"; done | |||
Sparse2Dense: updateOutput : 1103.570777 samples/sec; 28.996786 sec | |||
Sparse2Dense: updateOutput : 1092.064331 samples/sec; 29.302309 sec | |||
Sparse2Dense: updateOutput : 1036.963572 samples/sec; 30.859334 sec | |||
``` | |||
to | |||
``` | |||
for i in `seq 3`; do th -e "dt = require 'decisiontree'; torch.setdefaulttensortype('torch.FloatTensor'); dt.benchmark({'Sparse2Dense'}, {nTree=500,maxLeafNodes=8,minLeafSize=1,nFeature=1500,nActive=1300,nloop=1000})"; done | |||
Sparse2Dense: updateOutput : 62995.834470 samples/sec; 0.507978 sec | |||
Sparse2Dense: updateOutput : 62471.568253 samples/sec; 0.512242 sec | |||
Sparse2Dense: updateOutput : 62965.099331 samples/sec; 0.508226 sec | |||
``` | |||
This represents a speedup of about 57x. | |||
## Patch 8: | |||
This patch improves `nn.Sparse2Dense` from the previous result to | |||
```for i in `seq 3`; do th -e "dt = require 'decisiontree'; torch.setdefaulttensortype('torch.FloatTensor'); dt.benchmark({'Sparse2Dense'}, {nTree=500,maxLeafNodes=8,minLeafSize=1,nFeature=1500,nActive=1300,nloop=1000})"; done | |||
Sparse2Dense: updateOutput : 124268.079914 samples/sec; 0.257515 sec | |||
Sparse2Dense: updateOutput : 114750.039542 samples/sec; 0.278873 sec | |||
Sparse2Dense: updateOutput : 122863.314766 samples/sec; 0.260458 sec | |||
``` | |||
which corresponds to another 1.95x speedup. | |||
## Patch 9: | |||
This patches moves the core of training GBDTs, which used to be a big bottleneck, to C. It also | |||
performs small optimizations across the board (faster scoring, faster branching, ...) that provide a | |||
little more performance. | |||
The original commit had this performance: | |||
``` | |||
th -e "dt = require 'decisiontree'; torch.setdefaulttensortype('torch.FloatTensor'); dt.benchmark({'GradientBoostTrainer'}, {nExample=100000, sparse=false, nFeature=836, nTree=5, downsampleRatio=1, minLeafSize=1000, maxLeafNodes=8})" | |||
GradientBoostTrainer: train single-thread : 500.414666 samples/sec; 39.966854 sec/tree, 199.834271 sec | |||
GradientBoostTrainer: train feature-parallel : 1227.228044 samples/sec; 16.296890 sec/tree, 81.484448 sec (4 threads) | |||
GradientBoostTrainer: train feature-parallel : 1385.926280 samples/sec; 14.430782 sec/tree, 72.153910 sec (8 threads) | |||
``` | |||
and the new version has | |||
``` | |||
GradientBoostTrainer: train single-thread : 15285.644631 samples/sec; 1.308417 sec/tree, 6.542086 sec | |||
GradientBoostTrainer: train feature-parallel : 43170.435932 samples/sec; 0.463280 sec/tree, 2.316400 sec (4 threads) | |||
GradientBoostTrainer: train feature-parallel : 50062.681239 samples/sec; 0.399499 sec/tree, 1.997496 sec (8 threads) | |||
``` | |||
That represents a speedup of about 30.5x over the baseline for 1 thread and 36.1x for 8 threads. | |||
Note that the performance doesn't increase much as we increase the number of threads since we use | |||
feature parallelism and the number of features evaluated is small (29 in this case) due to bagging. | |||
If we disable bagging, then we have the following result with 8 threads and the new code: | |||
``` | |||
GradientBoostTrainer: train single-thread : 590.823965 samples/sec; 33.851030 sec/tree, 169.255152 sec | |||
GradientBoostTrainer: train feature-parallel : 3232.188576 samples/sec; 6.187758 sec/tree, 30.938789 sec | |||
``` | |||
So processing 836 features now is much faster than processing 29 before. |
@@ -1,24 +0,0 @@ | |||
#ifndef _ERROR_H_ | |||
#define _ERROR_H_ | |||
#include "luaT.h" | |||
#include <string.h> | |||
static inline int _lua_error(lua_State *L, int ret, const char* file, int line) { | |||
int pos_ret = ret >= 0 ? ret : -ret; | |||
return luaL_error(L, "ERROR: (%s, %d): (%d, %s)\n", file, line, pos_ret, strerror(pos_ret)); | |||
} | |||
static inline int _lua_error_str(lua_State *L, const char *str, const char* file, int line) { | |||
return luaL_error(L, "ERROR: (%s, %d): (%s)\n", file, line, str); | |||
} | |||
static inline int _lua_error_str_str(lua_State *L, const char *str, const char* file, int line, const char *extra) { | |||
return luaL_error(L, "ERROR: (%s, %d): (%s: %s)\n", file, line, str, extra); | |||
} | |||
#define LUA_HANDLE_ERROR(L, ret) _lua_error(L, ret, __FILE__, __LINE__) | |||
#define LUA_HANDLE_ERROR_STR(L, str) _lua_error_str(L, str, __FILE__, __LINE__) | |||
#define LUA_HANDLE_ERROR_STR_STR(L, str, extra) _lua_error_str_str(L, str, __FILE__, __LINE__, extra) | |||
#endif |
@@ -1,88 +0,0 @@ | |||
#ifndef TH_GENERIC_FILE | |||
#define TH_GENERIC_FILE "generic/CartTree.c" | |||
#else | |||
static int nn_(tree_fast_score)(lua_State *L) { | |||
THTensor *input = luaT_checkudata(L, 1, torch_Tensor); | |||
THTensor *score = luaT_checkudata(L, 3, torch_Tensor); | |||
long n_samples = THTensor_(size)(input, 0); | |||
long n_features = THTensor_(size)(input, 1); | |||
THTensor_(resize1d)(score, n_samples); | |||
real *input_data = THTensor_(data)(input); | |||
real *score_data = THTensor_(data)(score); | |||
lua_pushstring(L, "leftChild"); | |||
const int left_child_string = 4; | |||
lua_pushstring(L, "rightChild"); | |||
const int right_child_string = 5; | |||
lua_pushstring(L, "score"); | |||
const int score_string = 6; | |||
lua_pushstring(L, "splitFeatureId"); | |||
const int id_string = 7; | |||
lua_pushstring(L, "splitFeatureValue"); | |||
const int value_string = 8; | |||
const int original_top = lua_gettop(L); | |||
for (long i = 0; i < n_samples; i++) { | |||
int node = 2; | |||
while (1) { | |||
int current_top = lua_gettop(L); | |||
lua_pushvalue(L, left_child_string); | |||
lua_rawget(L, node); | |||
lua_pushvalue(L, right_child_string); | |||
lua_rawget(L, node); | |||
if (lua_isnil(L, -2) && lua_isnil(L, -1)) { | |||
lua_pushvalue(L, score_string); | |||
lua_rawget(L, node); | |||
score_data[i] = lua_tonumber(L, -1); | |||
break; | |||
} | |||
if (lua_isnil(L, -2)) { | |||
// go to right | |||
node = current_top + 2; | |||
continue; | |||
} | |||
if (lua_isnil(L, -1)) { | |||
// go to left | |||
node = current_top + 1; | |||
continue; | |||
} | |||
lua_pushvalue(L, id_string); | |||
lua_rawget(L, node); | |||
lua_pushvalue(L, value_string); | |||
lua_rawget(L, node); | |||
long feature_id = lua_tointeger(L, -2); | |||
real feature_value = lua_tonumber(L, -1); | |||
real current_value = input_data[i * n_features + (feature_id-1)]; | |||
if (current_value < feature_value) { | |||
// go to left | |||
node = current_top + 1; | |||
} | |||
else { | |||
// go to right | |||
node = current_top + 2; | |||
} | |||
} | |||
lua_pop(L, lua_gettop(L) - original_top); | |||
} | |||
lua_pop(L, 5); | |||
lua_pushvalue(L, 3); | |||
return 1; | |||
} | |||
static const struct luaL_Reg nn_(CT__) [] = { | |||
{"CartTreeFastScore", nn_(tree_fast_score)}, | |||
{NULL, NULL} | |||
}; | |||
static void nn_(CT_init)(lua_State *L) | |||
{ | |||
luaT_pushmetatable(L, torch_Tensor); | |||
luaT_registeratname(L, nn_(CT__), "nn"); | |||
lua_pop(L,1); | |||
} | |||
#endif |
@@ -1,157 +0,0 @@ | |||
#ifndef TH_GENERIC_FILE | |||
#define TH_GENERIC_FILE "generic/DFD.c" | |||
#else | |||
static int nn_(DFD_computeOutput)(lua_State *L) { | |||
THLongTensor *outputkeys = luaT_checkudata(L, 1, "torch.LongTensor"); | |||
THTensor *outputvalues = luaT_checkudata(L, 2, torch_Tensor); | |||
THLongTensor *root_ids = luaT_checkudata(L, 3, "torch.LongTensor"); | |||
THLongTensor *left_child = luaT_checkudata(L, 4, "torch.LongTensor"); | |||
THLongTensor *right_child = luaT_checkudata(L, 5, "torch.LongTensor"); | |||
THLongTensor *split_feature_id = luaT_checkudata(L, 6, "torch.LongTensor"); | |||
THTensor *split_feature_value = luaT_checkudata(L, 7, torch_Tensor); | |||
THTensor *input = luaT_checkudata(L, 8, torch_Tensor); | |||
char only_last_node = lua_toboolean(L, 9); | |||
// gets some important sizes from the input | |||
long batch_size = THTensor_(size)(input, 0); | |||
long input_size = THTensor_(size)(input, 1); | |||
long roots_size = THLongTensor_size(root_ids, 0); | |||
long depth = THLongTensor_size(outputkeys, 1); | |||
// keeps track of the number of nodes traversed in the trees by each sample. | |||
// each traversed node maps to an output feature having a value of 1 | |||
long outputsize[batch_size]; | |||
for (long i = 0; i < batch_size; i++) | |||
outputsize[i] = 0; | |||
// gets direct pointers to the memory of each tensor for efficiency | |||
long *root_ids_data = THLongTensor_data(root_ids); | |||
long *left_child_data = THLongTensor_data(left_child); | |||
long *right_child_data = THLongTensor_data(right_child); | |||
real *split_feature_value_data = THTensor_(data)(split_feature_value); | |||
long *split_feature_id_data = THLongTensor_data(split_feature_id); | |||
long *outputkeys_data = THLongTensor_data(outputkeys); | |||
real *input_data = THTensor_(data)(input); | |||
// for each sample in the batch | |||
for (long sample_index = 0; sample_index < batch_size; sample_index++) { | |||
// gets pointers to the direct memory associated with each sample for efficiency | |||
const long outputkeys_offset = sample_index * depth; | |||
const long input_offset = sample_index * input_size; | |||
long *local_outputkeys_data = &outputkeys_data[outputkeys_offset]; | |||
real *local_input_data = &input_data[input_offset]; | |||
// for each tree in the forest | |||
for (long i = 0; i < roots_size; i++) { | |||
int root = 1; | |||
long node_id = root_ids_data[i]; | |||
// traverses the whole tree keeping track of which nodes were seen | |||
while (1) { | |||
if (root) { | |||
// root nodes aren't added to output because they are always traversed | |||
root = 0; | |||
} | |||
else if (!only_last_node) { | |||
// updates the outputsize for all samples traversing this node; and | |||
// set the traversed node as a feature in output for exampleIds | |||
long output_index = outputsize[sample_index]; | |||
// updates the outputsize for all samples traversing this node | |||
outputsize[sample_index]++; | |||
// sets the traversed node as a feature in output for exampleIds | |||
local_outputkeys_data[output_index] = node_id; | |||
} | |||
// gets the left and right nodes. values of -1 represent missing node | |||
long left_id = left_child_data[node_id-1]; | |||
long right_id = right_child_data[node_id-1]; | |||
if (left_id <= 0 && right_id <= 0) { | |||
if (only_last_node) { | |||
long output_index = outputsize[sample_index]; | |||
outputsize[sample_index]++; | |||
local_outputkeys_data[output_index] = node_id; | |||
} | |||
// if no children, stops | |||
break; | |||
} | |||
else if (left_id <= 0) { | |||
// if no left child, traverses right node | |||
node_id = right_id; | |||
} | |||
else if (right_id <= 0) { | |||
// if no right child, traverses left node | |||
node_id = left_id; | |||
} | |||
else { | |||
// if both left and right children, finds the direction for this sample | |||
// first get the reference from the node | |||
real split_value = split_feature_value_data[node_id-1]; | |||
long split_id = split_feature_id_data[node_id-1]-1; | |||
// then gets the value of the sample | |||
real node_value = local_input_data[split_id]; | |||
// and branchs | |||
if (node_value < split_value) | |||
node_id = left_id; | |||
else | |||
node_id = right_id; | |||
} | |||
} | |||
} | |||
} | |||
// now that we know which nodes were traverse for each sample, we can create the sparse output | |||
// with 1 entry pair for each sample | |||
THTensor *input_feature = THTensor_(new)(); | |||
THLongTensor *indices = THLongTensor_new(); | |||
// pushes the return table with 2 children tables | |||
lua_newtable(L); | |||
lua_pushinteger(L, 1); | |||
lua_newtable(L); | |||
lua_pushinteger(L, 2); | |||
lua_newtable(L); | |||
// for each sample... | |||
for (long i = 0; i < batch_size; i++) { | |||
long j = outputsize[i]; | |||
// selects the tensor lines from the dense output | |||
THLongTensor_select(indices, outputkeys, 0, i); | |||
THTensor_(select)(input_feature, outputvalues, 0, i); | |||
// narrows the keys to actual number of nodes traversed and saves to the output | |||
lua_pushinteger(L, i+1); | |||
luaT_pushudata(L, THLongTensor_newNarrow(indices, 0, 0, j), "torch.LongTensor"); | |||
lua_settable(L, -5); | |||
// and narrows the values | |||
lua_pushinteger(L, i+1); | |||
luaT_pushudata(L, THTensor_(newNarrow)(input_feature, 0, 0, j), torch_Tensor); | |||
lua_settable(L, -3); | |||
} | |||
// pushes the two parts of the output into the output table | |||
lua_settable(L, -5); | |||
lua_settable(L, -3); | |||
THLongTensor_free(indices); | |||
THTensor_(free)(input_feature); | |||
return 1; | |||
} | |||
static const struct luaL_Reg nn_(DFD__) [] = { | |||
{"DFD_computeOutput", nn_(DFD_computeOutput)}, | |||
{NULL, NULL} | |||
}; | |||
static void nn_(DFD_init)(lua_State *L) | |||
{ | |||
luaT_pushmetatable(L, torch_Tensor); | |||
luaT_registeratname(L, nn_(DFD__), "nn"); | |||
lua_pop(L,1); | |||
} | |||
#endif |
@@ -1,392 +0,0 @@ | |||
#ifndef TH_GENERIC_FILE | |||
#define TH_GENERIC_FILE "generic/GBDT.c" | |||
#else | |||
#include "GBDT_internal.h" | |||
#include "GBDT_internal.c" | |||
// note that each one of the functions to find the best split is a subset of the next. | |||
// first we have one that can only evaluate a single feature, using the logic in lua to control the | |||
// features | |||
// then we have one that can go over a shard of faetures, following the feature parallelism | |||
// introduced by the lua logic | |||
// and finally we have one that performans the feature parallelism itself in the special case of | |||
// dense tensors | |||
// these functions are provided for completeness and to test in case the logic is to be changed | |||
// finds the best split for a given node and feature | |||
static int nn_(gb_findBestFeatureSplit)(lua_State *L) { | |||
THLongTensor *exampleIds = luaT_checkudata(L, 1, "torch.LongTensor"); | |||
const int dataset_index = 2; | |||
if (!lua_isnumber(L, 3)) | |||
return LUA_HANDLE_ERROR_STR(L, "third argument should be an integer"); | |||
long feature_id = lua_tointeger(L, 3); | |||
if (!lua_isnumber(L, 4)) | |||
return LUA_HANDLE_ERROR_STR(L, "fourth argument should be an integer"); | |||
long minLeafSize = lua_tointeger(L, 4); | |||
// Since minLeafSize == 1 corresponds to each sample in its own leaf, any value below it doesn't | |||
// make sense | |||
if (minLeafSize < 1) | |||
minLeafSize = 1; | |||
THTensor *grad = luaT_checkudata(L, 5, torch_Tensor); | |||
THTensor *hess = luaT_checkudata(L, 6, torch_Tensor); | |||
if (!THLongTensor_isContiguous(exampleIds)) | |||
return LUA_HANDLE_ERROR_STR(L, "exampleIds has to be contiguous"); | |||
if (!THTensor_(isContiguous)(grad)) | |||
return LUA_HANDLE_ERROR_STR(L, "grad has to be contiguous"); | |||
if (!THTensor_(isContiguous)(hess)) | |||
return LUA_HANDLE_ERROR_STR(L, "hessian has to be contiguous"); | |||
// initializes the static data | |||
nn_(GBInitialization) initialization_data; | |||
nn_(gb_initialize)(L, &initialization_data, exampleIds, grad, hess, dataset_index); | |||
// initializes the dynamic data | |||
GBRunData run_data; | |||
gb_create_run_data(&run_data, minLeafSize); | |||
// finds the best state possible for the split | |||
nn_(GBBestState) bs; | |||
nn_(gb_find_best_feature_split)(L, &initialization_data, &bs, feature_id, &run_data); | |||
lua_pop(L, lua_gettop(L) - initialization_data.splitInfo_index); | |||
// fills the table we the best split found and the lua logic above will do everything else | |||
// if no state was found, returns nil | |||
if (bs.valid_state == 0) { | |||
lua_pop(L, 1); | |||
lua_pushnil(L); | |||
} | |||
else { | |||
nn_(gb_internal_split_info)(L, &bs, initialization_data.splitInfo_index); | |||
} | |||
gb_destroy_run_data(&run_data); | |||
return 1; | |||
} | |||
// finds the best split for a given node and shard of features | |||
// this is more efficient than calling the previous one multiple times | |||
static int nn_(gb_findBestSplit)(lua_State *L) { | |||
THLongTensor *exampleIds = luaT_checkudata(L, 1, "torch.LongTensor"); | |||
const int dataset_index = 2; | |||
THLongTensor *feature_ids = luaT_checkudata(L, 3, "torch.LongTensor"); | |||
if (!lua_isnumber(L, 4)) | |||
return LUA_HANDLE_ERROR_STR(L, "fourth argument should be an integer"); | |||
long minLeafSize = lua_tointeger(L, 4); | |||
// Since minLeafSize == 1 corresponds to each sample in its own leaf, any value below it doesn't | |||
// make sense | |||
if (minLeafSize < 1) | |||
minLeafSize = 1; | |||
if (!lua_isnumber(L, 5)) | |||
return LUA_HANDLE_ERROR_STR(L, "fifth argument should be an integer"); | |||
long shardId = lua_tointeger(L, 5); | |||
if (!lua_isnumber(L, 6)) | |||
return LUA_HANDLE_ERROR_STR(L, "sixth argument should be an integer"); | |||
long nShard = lua_tointeger(L, 6); | |||
THTensor *grad = luaT_checkudata(L, 7, torch_Tensor); | |||
THTensor *hess = luaT_checkudata(L, 8, torch_Tensor); | |||
if (!THLongTensor_isContiguous(exampleIds)) | |||
return LUA_HANDLE_ERROR_STR(L, "exampleIds has to be contiguous"); | |||
if (!THTensor_(isContiguous)(grad)) | |||
return LUA_HANDLE_ERROR_STR(L, "grad has to be contiguous"); | |||
if (!THTensor_(isContiguous)(hess)) | |||
return LUA_HANDLE_ERROR_STR(L, "hessian has to be contiguous"); | |||
// initializes the static data | |||
nn_(GBInitialization) initialization_data; | |||
nn_(gb_initialize)(L, &initialization_data, exampleIds, grad, hess, dataset_index); | |||
// initializes the dynamic data | |||
GBRunData run_data; | |||
gb_create_run_data(&run_data, minLeafSize); | |||
// initializes to evaluate all the features in this shard | |||
nn_(GBBestState) global_bs; | |||
global_bs.valid_state = 0; | |||
long n_features = THLongTensor_size(feature_ids, 0); | |||
if (!THLongTensor_isContiguous(feature_ids)) | |||
return LUA_HANDLE_ERROR_STR(L, "feature_ids must be contiguous"); | |||
long *feature_ids_data = THLongTensor_data(feature_ids); | |||
// for every feature | |||
for (long i = 0; i < n_features; i++) { | |||
long feature_id = feature_ids_data[i]; | |||
// if we are responsible for it | |||
if (nShard <= 1 || (feature_id % nShard) + 1 == shardId) { | |||
// finds the best state possible for the split | |||
nn_(GBBestState) bs; | |||
nn_(gb_find_best_feature_split)(L, &initialization_data, &bs, feature_id, &run_data); | |||
// if it's valid and better than one we found before, saves it | |||
if (bs.valid_state) { | |||
if (global_bs.valid_state == 0 || bs.gain < global_bs.gain) { | |||
global_bs = bs; | |||
} | |||
} | |||
} | |||
} | |||
lua_pop(L, lua_gettop(L) - initialization_data.splitInfo_index); | |||
// fills the table we the best split found and the lua logic above will do everything else | |||
// if no state was found, returns nil | |||
if (global_bs.valid_state == 0) { | |||
lua_pop(L, 1); | |||
lua_pushnil(L); | |||
} | |||
else { | |||
nn_(gb_internal_split_info)(L, &global_bs, initialization_data.splitInfo_index); | |||
} | |||
gb_destroy_run_data(&run_data); | |||
return 1; | |||
} | |||
// all the info we have to apss to the slave threads so that they can do their jobs | |||
// note that we do not pass the lua state since it isn't required. we perform direct C parallelism | |||
// instead of using lua's parallelism like with the previous version | |||
typedef struct { | |||
nn_(GBInitialization) *initialization_data; | |||
GBRunData *run_data; | |||
long *index; | |||
nn_(GBBestState) *global_bs; | |||
long n_features; | |||
long *feature_ids_data; | |||
pthread_mutex_t *mutex; | |||
THLongTensor *exampleIds; | |||
THTensor *input; | |||
THLongTensor **sorted_ids_per_feature; | |||
} nn_(ThreadInfo); | |||
// loops over all the features in parallel and finds the best global split | |||
static void* nn_(thread_worker)(void *arg) { | |||
nn_(ThreadInfo) *info = (nn_(ThreadInfo) *)arg; | |||
while (1) { | |||
pthread_mutex_lock(info->mutex); | |||
long index = (*info->index); | |||
(*info->index)++; | |||
pthread_mutex_unlock(info->mutex); | |||
if (index >= info->n_features) | |||
break; | |||
// performs part of steps (1) and (2) of gb_find_best_feature_split without having to access the | |||
// lua state using pre-loaded data | |||
long feature_id = info->feature_ids_data[index]; | |||
THLongTensor *exampleIdsWithFeature_ret = info->exampleIds; | |||
THLongTensor *featureExampleIds = info->sorted_ids_per_feature[index]; | |||
nn_(GBInitialization) *initialization_data = info->initialization_data; | |||
GBRunData *run_data = info->run_data; | |||
// performs steps (3) and (4) of gb_find_best_feature_split since (1) and (2) were already | |||
// performed before | |||
nn_(GBBestState) bs; | |||
nn_(gb_internal_create)(initialization_data->grad, initialization_data->hess, | |||
exampleIdsWithFeature_ret, &bs.state); | |||
nn_(gb_internal_get_best_split_special)(&bs, featureExampleIds, run_data->exampleMap, | |||
info->input, run_data->minLeafSize, feature_id); | |||
// saves to the global state if it's better | |||
if (bs.valid_state) { | |||
pthread_mutex_lock(info->mutex); | |||
if (info->global_bs->valid_state == 0 || bs.gain < info->global_bs->gain) { | |||
(*info->global_bs) = bs; | |||
} | |||
pthread_mutex_unlock(info->mutex); | |||
} | |||
} | |||
return NULL; | |||
} | |||
// finds the global best split by doing feature parallelism directly in C | |||
static int nn_(gb_findBestSplitFP)(lua_State *L) { | |||
THLongTensor *exampleIds = luaT_checkudata(L, 1, "torch.LongTensor"); | |||
const int dataset_index = 2; | |||
THLongTensor *feature_ids = luaT_checkudata(L, 3, "torch.LongTensor"); | |||
if (!lua_isnumber(L, 4)) | |||
return LUA_HANDLE_ERROR_STR(L, "fourth argument should be an integer"); | |||
long minLeafSize = lua_tointeger(L, 4); | |||
THTensor *grad = luaT_checkudata(L, 5, torch_Tensor); | |||
THTensor *hess = luaT_checkudata(L, 6, torch_Tensor); | |||
if (!lua_isnumber(L, 7)) | |||
return LUA_HANDLE_ERROR_STR(L, "seventh argument should be an integer"); | |||
long nThread = lua_tointeger(L, 7); | |||
if (!THLongTensor_isContiguous(exampleIds)) | |||
return LUA_HANDLE_ERROR_STR(L, "exampleIds has to be contiguous"); | |||
if (!THTensor_(isContiguous)(grad)) | |||
return LUA_HANDLE_ERROR_STR(L, "grad has to be contiguous"); | |||
if (!THTensor_(isContiguous)(hess)) | |||
return LUA_HANDLE_ERROR_STR(L, "hessian has to be contiguous"); | |||
pthread_mutex_t mutex; | |||
pthread_mutex_init(&mutex, NULL); | |||
// initializes the static data | |||
nn_(GBInitialization) initialization_data; | |||
nn_(gb_initialize)(L, &initialization_data, exampleIds, grad, hess, dataset_index); | |||
// initializes the dynamic data | |||
GBRunData run_data; | |||
gb_create_run_data(&run_data, minLeafSize); | |||
// initializes to evaluate all the features | |||
nn_(GBBestState) global_bs; | |||
global_bs.valid_state = 0; | |||
long n_features = THLongTensor_size(feature_ids, 0); | |||
if (!THLongTensor_isContiguous(feature_ids)) | |||
return LUA_HANDLE_ERROR_STR(L, "feature_ids must be contiguous"); | |||
long *feature_ids_data = THLongTensor_data(feature_ids); | |||
THTensor *input = luaT_checkudata(L, initialization_data.input_index, torch_Tensor); | |||
// performs step (1) of gb_find_best_feature_split so that we don't have to pass the lua state | |||
THLongTensor *sorted_ids_per_feature[n_features]; | |||
for (long i = 0; i < n_features; i++) { | |||
long feature_id = feature_ids_data[i]; | |||
lua_pushvalue(L, initialization_data.getSortedFeature_index); | |||
lua_pushvalue(L, initialization_data.dataset_index); | |||
lua_pushinteger(L, feature_id); | |||
lua_call(L, 2, 1); | |||
THLongTensor *featureExampleIds = luaT_checkudata(L, -1, "torch.LongTensor"); | |||
sorted_ids_per_feature[i] = featureExampleIds; | |||
} | |||
// performas step (2) of gb_find_best_feature_split since it's the same for all features when the | |||
// data is dense | |||
long exampleIds_size = THLongTensor_size(initialization_data.exampleIds, 0); | |||
long *exampleIds_data = THLongTensor_data(initialization_data.exampleIds); | |||
int ret; | |||
kh_resize(long, run_data.exampleMap, exampleIds_size*8); | |||
for (long i = 0; i < exampleIds_size; i++) | |||
kh_put(long, run_data.exampleMap, exampleIds_data[i], &ret); | |||
// saves the info for the threads | |||
long index = 0; | |||
nn_(ThreadInfo) info; | |||
info.initialization_data = &initialization_data; | |||
info.run_data = &run_data; | |||
info.index = &index; | |||
info.global_bs = &global_bs; | |||
info.n_features = n_features; | |||
info.feature_ids_data = feature_ids_data; | |||
info.mutex = &mutex; | |||
info.exampleIds = exampleIds; | |||
info.input = input; | |||
info.sorted_ids_per_feature = sorted_ids_per_feature; | |||
pthread_t threads[nThread]; | |||
// let the threads run like crazy over the features to find the minimum | |||
for (long i = 0; i < nThread; i++) { | |||
int ret = pthread_create(&threads[i], NULL, nn_(thread_worker), &info); | |||
if (ret) | |||
return LUA_HANDLE_ERROR_STR(L, "falied to create thread"); | |||
} | |||
for (long i = 0; i < nThread; i++) { | |||
int ret = pthread_join(threads[i], NULL); | |||
if (ret) | |||
return LUA_HANDLE_ERROR_STR(L, "failed to join thread"); | |||
} | |||
lua_pop(L, lua_gettop(L) - initialization_data.splitInfo_index); | |||
// fills the table we the best split found and the lua logic above will do everything else | |||
// if no state was found, returns nil | |||
if (global_bs.valid_state == 0) { | |||
lua_pop(L, 1); | |||
lua_pushnil(L); | |||
} | |||
else { | |||
nn_(gb_internal_split_info)(L, &global_bs, initialization_data.splitInfo_index); | |||
} | |||
gb_destroy_run_data(&run_data); | |||
pthread_mutex_destroy(&mutex); | |||
return 1; | |||
} | |||
// performs an efficient branch of the current examples based on a split info provided | |||
static int nn_(gb_branch)(lua_State *L) { | |||
if (!lua_istable(L, 1)) | |||
return LUA_HANDLE_ERROR_STR(L, "first argument must be a table"); | |||
THTensor *input = luaT_checkudata(L, 2, torch_Tensor); | |||
THLongTensor *exampleIds = luaT_checkudata(L, 3, "torch.LongTensor"); | |||
// gets direct access to the dataset | |||
long n_exampleIds = THLongTensor_size(exampleIds, 0); | |||
long *exampleIds_data = THLongTensor_data(exampleIds); | |||
long n_features = THTensor_(size)(input, 1); | |||
real *input_data = THTensor_(data)(input); | |||
// creates the tensors to be returned | |||
luaT_pushudata(L, THLongTensor_new(), "torch.LongTensor"); | |||
luaT_pushudata(L, THLongTensor_new(), "torch.LongTensor"); | |||
THLongTensor *leftExampleIds = luaT_checkudata(L, 4, "torch.LongTensor"); | |||
THLongTensor *rightExampleIds = luaT_checkudata(L, 5, "torch.LongTensor"); | |||
THLongTensor_resize1d(leftExampleIds, n_exampleIds); | |||
// gets direct access to the examples | |||
THLongTensor *splitExampleIds = leftExampleIds; | |||
long *splitExampleIds_data = THLongTensor_data(splitExampleIds); | |||
// gets the split info | |||
lua_pushstring(L, "splitId"); | |||
lua_rawget(L, 1); | |||
const long splitId = lua_tointeger(L, -1); | |||
lua_pushstring(L, "splitValue"); | |||
lua_rawget(L, 1); | |||
const real splitValue = lua_tonumber(L, -1); | |||
lua_pop(L, 2); | |||
long leftIdx = 0, rightIdx = 0; | |||
// goes over all the samples dividing them into the two sides | |||
for (long i = 0; i < n_exampleIds; i++) { | |||
long exampleId = exampleIds_data[i]; | |||
real val = input_data[(exampleId-1) * n_features + (splitId - 1)]; | |||
if (val <= splitValue) { | |||
leftIdx++; | |||
splitExampleIds_data[leftIdx-1] = exampleId; | |||
} | |||
else { | |||
rightIdx++; | |||
splitExampleIds_data[n_exampleIds - rightIdx + 1 - 1] = exampleId; | |||
} | |||
} | |||
// once done, the resulting tensors are just splits of the sample base. this is more efficient | |||
// than having 2 tensors since we didn't know where the split would happen (how much to each | |||
// side), but we knew that the sum would be constant | |||
THLongTensor_narrow(rightExampleIds, splitExampleIds, 0, n_exampleIds-rightIdx+1-1, rightIdx); | |||
THLongTensor_narrow(leftExampleIds, splitExampleIds, 0, 0, leftIdx); | |||
return 2; | |||
} | |||
static const struct luaL_Reg nn_(GBDT__) [] = { | |||
{"GBDT_findBestFeatureSplit", nn_(gb_findBestFeatureSplit)}, | |||
{"GBDT_findBestSplit", nn_(gb_findBestSplit)}, | |||
{"GBDT_findBestSplitFP", nn_(gb_findBestSplitFP)}, | |||
{"GBDT_branch", nn_(gb_branch)}, | |||
{NULL, NULL} | |||
}; | |||
static void nn_(GBDT_init)(lua_State *L) | |||
{ | |||
luaT_pushmetatable(L, torch_Tensor); | |||
luaT_registeratname(L, nn_(GBDT__), "nn"); | |||
lua_pop(L,1); | |||
} | |||
#endif |
@@ -1,312 +0,0 @@ | |||
// initializes the optimization structure based on the arguments provided, either filling directly | |||
// or making calls to lua to load some kind of data | |||
static void nn_(gb_initialize)(lua_State *L, nn_(GBInitialization) *initialization_data, | |||
THLongTensor *exampleIds, THTensor *grad, THTensor *hess, int dataset_index) { | |||
initialization_data->dataset_index = dataset_index; | |||
initialization_data->exampleIds = exampleIds; | |||
initialization_data->grad = grad; | |||
initialization_data->hess = hess; | |||
lua_newtable(L); | |||
initialization_data->splitInfo_index = lua_gettop(L); | |||
lua_pushstring(L, "input"); | |||
lua_gettable(L, dataset_index); | |||
initialization_data->input_index = lua_gettop(L); | |||
lua_pushstring(L, "getSortedFeature"); | |||
lua_gettable(L, dataset_index); | |||
initialization_data->getSortedFeature_index = lua_gettop(L); | |||
} | |||
// initializes a state that will be passed to the optimizer | |||
static void nn_(gb_internal_create)(THTensor *grad, THTensor *hessian, | |||
THLongTensor *exampleIds, nn_(GBState)* s) { | |||
long *exampleIds_data = THLongTensor_data(exampleIds); | |||
long n_examples = THLongTensor_size(exampleIds, 0); | |||
accreal leftGradientSum = 0; | |||
accreal leftHessianSum = 0; | |||
real *grad_data = THTensor_(data)(grad); | |||
real *hessian_data = THTensor_(data)(hessian); | |||
// only sums the relevant gradients and hessians | |||
for (long i = 0; i < n_examples; i++) { | |||
long exampleId = exampleIds_data[i]-1; | |||
leftGradientSum += grad_data[exampleId]; | |||
leftHessianSum += hessian_data[exampleId]; | |||
} | |||
// we move data from the left branch to the right branch | |||
s->rightGradientSum = 0; | |||
s->rightHessianSum = 1; | |||
s->nExampleInRightBranch = 0; | |||
s->leftGradientSum = leftGradientSum; | |||
s->leftHessianSum = leftHessianSum + 1; | |||
s->nExampleInLeftBranch = n_examples; | |||
// stores the loss in parent for efficiency | |||
real lossInParent = computeGradientBoostLoss(s->leftGradientSum + s->rightGradientSum, | |||
s->leftHessianSum + s->rightHessianSum); | |||
s->lossInParent = lossInParent; | |||
// caches the direct pointers to the data for efficiency | |||
s->grad_data = grad_data; | |||
s->hessian_data = hessian_data; | |||
} | |||
// computes the gain obtained by performing the split | |||
static real nn_(computeSplitGain)(nn_(GBState) *s) { | |||
real lossInLeftBranch = computeGradientBoostLoss(s->leftGradientSum, s->leftHessianSum); | |||
real lossInRightBranch = computeGradientBoostLoss(s->rightGradientSum, s->rightHessianSum); | |||
return lossInLeftBranch + lossInRightBranch - s->lossInParent; | |||
} | |||
// uses the state information to build the table required by the lua library about the best split | |||
static void nn_(gb_internal_split_info)(lua_State *L, nn_(GBBestState) *bs, int res) { | |||
long feature_id = bs->feature_id; | |||
real feature_value = bs->feature_value; | |||
real gain = bs->gain; | |||
nn_(GBState) *s = &bs->state; | |||
lua_pushstring(L, "splitGain"); | |||
lua_pushnumber(L, gain); | |||
lua_rawset(L, res); | |||
lua_pushstring(L, "splitId"); | |||
lua_pushinteger(L, feature_id); | |||
lua_rawset(L, res); | |||
lua_pushstring(L, "splitValue"); | |||
lua_pushnumber(L, feature_value); | |||
lua_rawset(L, res); | |||
lua_pushstring(L, "leftChildSize"); | |||
lua_pushinteger(L, s->nExampleInLeftBranch); | |||
lua_rawset(L, res); | |||
lua_pushstring(L, "rightChildSize"); | |||
lua_pushinteger(L, s->nExampleInRightBranch); | |||
lua_rawset(L, res); | |||
lua_pushstring(L, "leftGradient"); | |||
lua_pushnumber(L, s->leftGradientSum); | |||
lua_rawset(L, res); | |||
lua_pushstring(L, "rightGradient"); | |||
lua_pushnumber(L, s->rightGradientSum); | |||
lua_rawset(L, res); | |||
lua_pushstring(L, "leftHessian"); | |||
lua_pushnumber(L, s->leftHessianSum); | |||
lua_rawset(L, res); | |||
lua_pushstring(L, "rightHessian"); | |||
lua_pushnumber(L, s->rightHessianSum); | |||
lua_rawset(L, res); | |||
} | |||
// core of the computation, where we loop over all the relevant samples looking for the best split | |||
// we can find | |||
static void nn_(gb_internal_get_best_split)(lua_State *L, nn_(GBBestState) *bs, | |||
THLongTensor *featureExampleIds, khash_t(long)* exampleMap, int input_table_index, | |||
long minLeafSize, long feature_id) { | |||
nn_(GBState) current_state; | |||
nn_(GBState) best_state; | |||
current_state = bs->state; | |||
real best_gain = INFINITY; | |||
real best_value = 0; | |||
// if the data is dense, pre-loads direct access to it | |||
THTensor *input = NULL; | |||
real *input_data = NULL; | |||
long n_features = 0; | |||
if (lua_istable(L, input_table_index)) { | |||
} | |||
else { | |||
input = luaT_checkudata(L, input_table_index, torch_Tensor); | |||
input_data = THTensor_(data)(input); | |||
n_features = THTensor_(size)(input, 1); | |||
} | |||
long stride = featureExampleIds->stride[0]; | |||
long *featureExampleIds_data = THLongTensor_data(featureExampleIds); | |||
khiter_t k; | |||
real previousSplitValue = 0; | |||
// for each example with the given feature and from large to small value... | |||
for (long i = THLongTensor_size(featureExampleIds, 0)-1; i >= 0; i--) { | |||
long exampleId = featureExampleIds_data[i * stride]; | |||
// checks if the sample is in the list of ones that have to be evaluated by this node | |||
k = kh_get(long, exampleMap, exampleId); | |||
if (k != kh_end(exampleMap)) { | |||
long exampleIdx = exampleId; | |||
// gets the split value, depending on whether the input is sparse or dense | |||
real splitValue; | |||
if (input_data) { | |||
splitValue = input_data[(exampleId-1) * n_features + feature_id-1]; | |||
} | |||
else { | |||
lua_pushinteger(L, exampleId); | |||
lua_gettable(L, input_table_index); | |||
lua_pushinteger(L, feature_id); | |||
lua_gettable(L, -2); | |||
splitValue = lua_tonumber(L, -1); | |||
lua_pop(L, 2); | |||
} | |||
// performs one update of the state, moving a sample from the left branch to the right | |||
real gradient = current_state.grad_data[exampleIdx-1]; | |||
real hessian = current_state.hessian_data[exampleIdx-1]; | |||
current_state.leftGradientSum -= gradient; | |||
current_state.rightGradientSum += gradient; | |||
current_state.leftHessianSum -= hessian; | |||
current_state.rightHessianSum += hessian; | |||
current_state.nExampleInLeftBranch--; | |||
current_state.nExampleInRightBranch++; | |||
// since we remove from the left, once this becomes true, it stays true forever | |||
// hence we stop the loop | |||
if (current_state.nExampleInLeftBranch < minLeafSize) | |||
break; | |||
if (current_state.nExampleInRightBranch >= minLeafSize) { | |||
// if the values are equal between the steps, it doesn't make sense to evaluate the score | |||
// since we won't be able to separate the two | |||
if (previousSplitValue != splitValue) { | |||
// computes the gain **without including the parent** since it doesn't change as we move | |||
// examples between branches | |||
real lossInLeftBranch = computeGradientBoostLoss(current_state.leftGradientSum, current_state.leftHessianSum); | |||
real lossInRightBranch = computeGradientBoostLoss(current_state.rightGradientSum, current_state.rightHessianSum); | |||
real current_gain = lossInLeftBranch + lossInRightBranch; | |||
if (current_gain < best_gain) { | |||
best_gain = current_gain; | |||
best_value = splitValue; | |||
best_state = current_state; | |||
} | |||
} | |||
} | |||
previousSplitValue = splitValue; | |||
} | |||
} | |||
// if there is a valid gain, then marks the state as valid and fills the meta-info | |||
if (!isfinite(best_gain)) { | |||
bs->valid_state = 0; | |||
} | |||
else { | |||
bs->valid_state = 1; | |||
bs->state = best_state; | |||
bs->feature_id = feature_id; | |||
bs->gain = nn_(computeSplitGain)(&bs->state); | |||
bs->feature_value = best_value; | |||
} | |||
} | |||
// exactly like the previous version, but direct access to the data for efficiency. it also doesn't | |||
// rely on the lua state in the particular case of dense data, so we can evaluate this without using | |||
// the lua state | |||
static void nn_(gb_internal_get_best_split_special)(nn_(GBBestState) *bs, | |||
THLongTensor *featureExampleIds, khash_t(long)* exampleMap, THTensor *input, long minLeafSize, | |||
long feature_id) { | |||
nn_(GBState) current_state; | |||
nn_(GBState) best_state; | |||
current_state = bs->state; | |||
real best_gain = INFINITY; | |||
real best_value = 0; | |||
real *input_data = NULL; | |||
long n_features = 0; | |||
input_data = THTensor_(data)(input); | |||
n_features = THTensor_(size)(input, 1); | |||
long stride = featureExampleIds->stride[0]; | |||
long *featureExampleIds_data = THLongTensor_data(featureExampleIds); | |||
khiter_t k; | |||
real previousSplitValue = 0; | |||
for (long i = THLongTensor_size(featureExampleIds, 0)-1; i >= 0; i--) { | |||
long exampleId = featureExampleIds_data[i * stride]; | |||
k = kh_get(long, exampleMap, exampleId); | |||
if (k != kh_end(exampleMap)) { | |||
long exampleIdx = exampleId; | |||
// THIS is the main part that changes. seems crazy to have a special case just for this, but | |||
// since there are a **lot** of samples to be evaluated, the "if" in the previous case can | |||
// become expensive | |||
real splitValue; | |||
splitValue = input_data[(exampleId-1) * n_features + feature_id-1]; | |||
real gradient = current_state.grad_data[exampleIdx-1]; | |||
real hessian = current_state.hessian_data[exampleIdx-1]; | |||
current_state.leftGradientSum -= gradient; | |||
current_state.rightGradientSum += gradient; | |||
current_state.leftHessianSum -= hessian; | |||
current_state.rightHessianSum += hessian; | |||
current_state.nExampleInLeftBranch--; | |||
current_state.nExampleInRightBranch++; | |||
// since we remove from the left, once this becomes true, it stays true forever | |||
// hence we stop the loop | |||
if (current_state.nExampleInLeftBranch < minLeafSize) | |||
break; | |||
// This will always fail in the first pass since minLeafSize >= 1 and nExampleInRightBranch | |||
// starts at 0 | |||
if (current_state.nExampleInRightBranch >= minLeafSize) { | |||
if (previousSplitValue != splitValue) { | |||
real lossInLeftBranch = computeGradientBoostLoss(current_state.leftGradientSum, current_state.leftHessianSum); | |||
real lossInRightBranch = computeGradientBoostLoss(current_state.rightGradientSum, current_state.rightHessianSum); | |||
real current_gain = lossInLeftBranch + lossInRightBranch; | |||
if (current_gain < best_gain) { | |||
best_gain = current_gain; | |||
best_value = splitValue; | |||
best_state = current_state; | |||
} | |||
} | |||
} | |||
previousSplitValue = splitValue; | |||
} | |||
} | |||
if (!isfinite(best_gain)) { | |||
bs->valid_state = 0; | |||
} | |||
else { | |||
bs->valid_state = 1; | |||
bs->state = best_state; | |||
bs->feature_id = feature_id; | |||
bs->gain = nn_(computeSplitGain)(&bs->state); | |||
bs->feature_value = best_value; | |||
} | |||
} | |||
// core of the computation to find the split for a given feature and is divided in 4 steps | |||
static void nn_(gb_find_best_feature_split)(lua_State *L, | |||
nn_(GBInitialization) *initialization_data, nn_(GBBestState) *bs, long feature_id, | |||
GBRunData *run_data) { | |||
// 1) loads the examples in the dataset ordered by their feature value | |||
lua_pushvalue(L, initialization_data->getSortedFeature_index); | |||
lua_pushvalue(L, initialization_data->dataset_index); | |||
lua_pushinteger(L, feature_id); | |||
lua_call(L, 2, 1); | |||
THLongTensor *featureExampleIds = luaT_checkudata(L, -1, "torch.LongTensor"); | |||
// 2) processes the data to find the intersection between the examples in the dataset and the | |||
// examples the current node has to evaluate | |||
THLongTensor *exampleIdsWithFeature_ret = gb_internal_prepare(L, initialization_data->exampleIds, | |||
run_data->exampleIdsWithFeature_cache, initialization_data->input_index, feature_id, | |||
run_data->exampleMap); | |||
if (!exampleIdsWithFeature_ret) { | |||
bs->valid_state = 0; | |||
return; | |||
} | |||
// 3) creates a new state to be used by the optimizer | |||
nn_(gb_internal_create)(initialization_data->grad, initialization_data->hess, | |||
exampleIdsWithFeature_ret, &bs->state); | |||
// 4) optimize away! | |||
nn_(gb_internal_get_best_split)(L, bs, featureExampleIds, run_data->exampleMap, | |||
initialization_data->input_index, run_data->minLeafSize, feature_id); | |||
} |
@@ -1,34 +0,0 @@ | |||
// representation of a state used while searching for the best split | |||
typedef struct { | |||
real leftGradientSum, rightGradientSum; | |||
real leftHessianSum, rightHessianSum; | |||
real lossInParent; | |||
long nExampleInLeftBranch, nExampleInRightBranch; | |||
real *grad_data, *hessian_data; | |||
} nn_(GBState); | |||
// representation for the best state found for a given feature | |||
typedef struct { | |||
nn_(GBState) state; | |||
real gain; | |||
long feature_id; | |||
real feature_value; | |||
int valid_state; | |||
} nn_(GBBestState); | |||
// full data that must be initialized before calling the optimizer | |||
typedef struct { | |||
// *_index represent positions on the lua stack | |||
int dataset_index; | |||
int splitInfo_index; | |||
int input_index; | |||
// position of the dataset's function to return the samples ordered for a given feature | |||
int getSortedFeature_index; | |||
// samples that this node has to evaluate | |||
THLongTensor *exampleIds; | |||
// cached gradient and hessian for all data | |||
THTensor *grad; | |||
THTensor *hess; | |||
} nn_(GBInitialization); |
@@ -1,90 +0,0 @@ | |||
#ifndef TH_GENERIC_FILE | |||
#define TH_GENERIC_FILE "generic/LogitBoostCriterion.c" | |||
#else | |||
#define EPS 1e-12 | |||
static int nn_(LogitBoostCriterion_updateOutput)(lua_State *L) | |||
{ | |||
THTensor *input = luaT_checkudata(L, 1, torch_Tensor); | |||
THTensor *target = luaT_checkudata(L, 2, torch_Tensor); | |||
THTensor *output = luaT_checkudata(L, 3, torch_Tensor); | |||
int sizeAverage = lua_toboolean(L, 4); | |||
if (THTensor_(nElement)(input) != THTensor_(nElement)(target)) { | |||
luaL_error(L, "inconsistent input and target size"); | |||
} | |||
THTensor_(resize1d)(output, 1); | |||
real sum = 0; | |||
TH_TENSOR_APPLY2(real, input, real, target, | |||
real x = *input_data; | |||
real y = *target_data; | |||
// math.log(1 + math.exp(target[i] <= 0 and input[i] or -input[i])) | |||
sum += log(1 + exp(y <= 0 ? x : -x)); | |||
); | |||
if (sizeAverage) | |||
sum /= THTensor_(nElement)(input); | |||
THTensor_(set1d)(output, 0, sum); | |||
return 0; | |||
} | |||
static int nn_(LogitBoostCriterion_updateGradInput)(lua_State *L) | |||
{ | |||
THTensor *input = luaT_checkudata(L, 1, torch_Tensor); | |||
THTensor *target = luaT_checkudata(L, 2, torch_Tensor); | |||
THTensor *gradInput = luaT_checkudata(L, 3, torch_Tensor); | |||
if (THTensor_(nElement)(input) != THTensor_(nElement)(target)) { | |||
luaL_error(L, "inconsistent input and target size"); | |||
} | |||
THTensor_(resizeAs)(gradInput, input); | |||
TH_TENSOR_APPLY3(real, gradInput, real, input, real, target, | |||
real x = *input_data; | |||
real y = *target_data; | |||
real p = (x >= 0) ? (1 / (1 + exp(-x))) : (1 - 1 / (1 + exp(x))); | |||
*gradInput_data = (y <= 0) ? p : (p - 1); | |||
); | |||
return 0; | |||
} | |||
static int nn_(LogitBoostCriterion_updateHessInput)(lua_State *L) | |||
{ | |||
THTensor *input = luaT_checkudata(L, 1, torch_Tensor); | |||
THTensor *target = luaT_checkudata(L, 2, torch_Tensor); | |||
THTensor *hessInput = luaT_checkudata(L, 3, torch_Tensor); | |||
if (THTensor_(nElement)(input) != THTensor_(nElement)(target)) { | |||
luaL_error(L, "inconsistent input and target size"); | |||
} | |||
THTensor_(resizeAs)(hessInput, input); | |||
TH_TENSOR_APPLY3(real, hessInput, real, input, real, target, | |||
real x = *input_data; | |||
real p = (x >= 0) ? (1 / (1 + exp(-x))) : (1 - 1 / (1 + exp(x))); | |||
*hessInput_data = p * (1.0 - p); | |||
); | |||
return 0; | |||
} | |||
static const struct luaL_Reg nn_(LogitBoostCriterion__) [] = { | |||
{"LogitBoostCriterion_updateOutput", nn_(LogitBoostCriterion_updateOutput)}, | |||
{"LogitBoostCriterion_updateGradInput", nn_(LogitBoostCriterion_updateGradInput)}, | |||
{"LogitBoostCriterion_updateHessInput", nn_(LogitBoostCriterion_updateHessInput)}, | |||
{NULL, NULL} | |||
}; | |||
static void nn_(LogitBoostCriterion_init)(lua_State *L) | |||
{ | |||
luaT_pushmetatable(L, torch_Tensor); | |||
luaT_registeratname(L, nn_(LogitBoostCriterion__), "nn"); | |||
lua_pop(L,1); | |||
} | |||
#endif |
@@ -1,90 +0,0 @@ | |||
#ifndef TH_GENERIC_FILE | |||
#define TH_GENERIC_FILE "generic/S2D.c" | |||
#else | |||
static int nn_(S2D_computeOutput)(lua_State *L) { | |||
THTensor *output = luaT_checkudata(L, 1, torch_Tensor); | |||
const int keys_index = 2; | |||
const int values_index = 3; | |||
const int masks_index = 4; | |||
if (!lua_istable(L, keys_index)) | |||
return LUA_HANDLE_ERROR_STR(L, "expeced position 2 to be a table"); | |||
if (!lua_istable(L, values_index)) | |||
return LUA_HANDLE_ERROR_STR(L, "expeced position 3 to be a table"); | |||
if (!lua_istable(L, masks_index)) | |||
return LUA_HANDLE_ERROR_STR(L, "expeced position 4 to be a table"); | |||
THLongTensor *features = luaT_checkudata(L, 5, "torch.LongTensor"); | |||
const int original_top = lua_gettop(L); | |||
long outputsize = THLongTensor_size(features, 0); | |||
long batch_size = lua_objlen(L, keys_index); | |||
// initializes output | |||
THTensor_(resize2d)(output, batch_size, outputsize); | |||
THTensor_(zero)(output); | |||
real *output_data = THTensor_(data)(output); | |||
// iterates over samples | |||
lua_pushnil(L); | |||
const int local_top = lua_gettop(L); | |||
while (lua_next(L, keys_index) != 0) { | |||
// gets data corresponding to the current sample | |||
long i = lua_tointeger(L, -2)-1; | |||
real *current_output_data = &output_data[i * outputsize]; | |||
THLongTensor *keys = luaT_checkudata(L, -1, "torch.LongTensor"); | |||
lua_rawgeti(L, values_index, i+1); | |||
THTensor *values = luaT_checkudata(L, -1, torch_Tensor); | |||
lua_rawgeti(L, masks_index, i+1); | |||
THByteTensor *mask = luaT_checkudata(L, -1, "torch.ByteTensor"); | |||
long n_keys = THLongTensor_size(keys, 0); | |||
long n_values = THTensor_(size)(values, 0); | |||
// quick safety check | |||
if (n_keys != n_values) | |||
return LUA_HANDLE_ERROR_STR(L, "keys and values have to have the same size"); | |||
// gets the direct memory pointers | |||
long *keys_data = THLongTensor_data(keys); | |||
real *values_data = THTensor_(data)(values); | |||
unsigned char *mask_data = THByteTensor_data(mask); | |||
// for each value in the sparse input... | |||
for (long j = 0; j < n_keys; j++) { | |||
// loads the value and key | |||
real current_value = values_data[j]; | |||
long current_key = keys_data[j]; | |||
unsigned char current_mask = mask_data[j]; | |||
// if the feature is present in the map | |||
if (current_mask) | |||
// saves in the given position | |||
current_output_data[current_key-1] = current_value; | |||
} | |||
// cleans up the trash we create by iterating over keys to avoid it from overflowing | |||
lua_pop(L, lua_gettop(L) - local_top); | |||
} | |||
// cleans up the trash we added to the stack | |||
lua_pop(L, lua_gettop(L) - original_top); | |||
return 0; | |||
} | |||
static const struct luaL_Reg nn_(S2D__) [] = { | |||
{"S2D_computeOutput", nn_(S2D_computeOutput)}, | |||
{NULL, NULL} | |||
}; | |||
static void nn_(S2D_init)(lua_State *L) | |||
{ | |||
luaT_pushmetatable(L, torch_Tensor); | |||
luaT_registeratname(L, nn_(S2D__), "nn"); | |||
lua_pop(L,1); | |||
} | |||
#endif |
@@ -1,445 +0,0 @@ | |||
#include "utils.h" | |||
#include "hash_map.h" | |||
#include "internal_hash_map.h" | |||
#include <pthread.h> | |||
hash_map_t hash_map_init(void) { | |||
return kh_init(long); | |||
} | |||
void hash_map_destroy(hash_map_t h_) { | |||
internal_hash_map_t h = (internal_hash_map_t) h_; | |||
kh_destroy(long, h); | |||
} | |||
void hash_map_clear(hash_map_t h_) { | |||
internal_hash_map_t h = (internal_hash_map_t) h_; | |||
kh_clear(long, h); | |||
} | |||
int hash_map_put(hash_map_t h_, long key, long val) { | |||
internal_hash_map_t h = (internal_hash_map_t) h_; | |||
int ret; | |||
khiter_t k = kh_put(long, h, key, &ret); | |||
ret = (ret >= 0); | |||
if (ret) | |||
kh_value(h, k) = val; | |||
return ret; | |||
} | |||
int hash_map_put_tensor(hash_map_t h_, THLongTensor *keys_, THLongTensor *vals_) { | |||
long *keys = THLongTensor_data(keys_); | |||
long *vals = THLongTensor_data(vals_); | |||
long size = get_tensor_size(keys_, Long); | |||
for (long i = 0; i < size; i++) | |||
if (!hash_map_put(h_, keys[i], vals[i])) | |||
return 0; | |||
return 1; | |||
} | |||
int hash_map_fill(hash_map_t h_, long key, long *counter) { | |||
internal_hash_map_t h = (internal_hash_map_t) h_; | |||
khiter_t k = kh_get(long, h, key); | |||
if (k == kh_end(h)) | |||
return hash_map_put(h_, key, ++(*counter)); | |||
return 1; | |||
} | |||
int hash_map_fill_tensor(hash_map_t h_, THLongTensor *keys_, long *counter) { | |||
long *keys = THLongTensor_data(keys_); | |||
long size = get_tensor_size(keys_, Long); | |||
for (long i = 0; i < size; i++) | |||
if (!hash_map_fill(h_, keys[i], counter)) | |||
return 0; | |||
return 1; | |||
} | |||
int hash_map_get(hash_map_t h_, long key, long* val) { | |||
internal_hash_map_t h = (internal_hash_map_t) h_; | |||
khiter_t k = kh_get(long, h, key); | |||
if (k == kh_end(h)) | |||
return 0; | |||
*val = kh_value(h, k); | |||
return 1; | |||
} | |||
void hash_map_get_tensor(hash_map_t h_, THLongTensor *keys_, THLongTensor *vals_, THByteTensor *mask_) { | |||
long *keys = THLongTensor_data(keys_); | |||
long *vals = THLongTensor_data(vals_);; | |||
unsigned char *mask = THByteTensor_data(mask_); | |||
long size = get_tensor_size(keys_, Long); | |||
for (long i = 0; i < size; i++) | |||
mask[i] = hash_map_get(h_, keys[i], &vals[i]); | |||
} | |||
void hash_map_del(hash_map_t h_, long key) { | |||
internal_hash_map_t h = (internal_hash_map_t) h_; | |||
khiter_t k = kh_get(long, h, key); | |||
if (k != kh_end(h)) | |||
kh_del(long, h, k); | |||
} | |||
void hash_map_del_tensor(hash_map_t h_, THLongTensor *keys_) { | |||
long *keys = THLongTensor_data(keys_); | |||
long size = get_tensor_size(keys_, Long); | |||
for (long i = 0; i < size; i++) | |||
hash_map_del(h_, keys[i]); | |||
} | |||
size_t hash_map_size(hash_map_t h_) { | |||
internal_hash_map_t h = (internal_hash_map_t) h_; | |||
return kh_size(h); | |||
} | |||
void hash_map_to_tensor(hash_map_t h_, THLongTensor *keys_, THLongTensor *vals_) { | |||
internal_hash_map_t h = (internal_hash_map_t) h_; | |||
long *keys = THLongTensor_data(keys_); | |||
long *vals = THLongTensor_data(vals_); | |||
long key, val, i = 0; | |||
kh_foreach(h, key, val, { | |||
keys[i] = key; | |||
vals[i] = val; | |||
i++; | |||
}); | |||
} | |||
static void autolock(hash_map_lua_t *h) { | |||
if (h->autolock) { | |||
pthread_mutex_lock(&h->mutex); | |||
} | |||
} | |||
static void autounlock(hash_map_lua_t *h) { | |||
if (h->autolock) { | |||
pthread_mutex_unlock(&h->mutex); | |||
} | |||
} | |||
int hash_map_autolock_on_lua(lua_State *L) { | |||
hash_map_lua_t *h = *(hash_map_lua_t**)lua_touserdata(L, 1); | |||
h->autolock = 1; | |||
return 0; | |||
} | |||
int hash_map_autolock_off_lua(lua_State *L) { | |||
hash_map_lua_t *h = *(hash_map_lua_t**)lua_touserdata(L, 1); | |||
h->autolock = 0; | |||
return 0; | |||
} | |||
int hash_map_init_lua(lua_State *L) { | |||
hash_map_lua_t **hp = (hash_map_lua_t**)lua_newuserdata(L, sizeof(hash_map_lua_t*)); | |||
*hp = (hash_map_lua_t*)malloc(sizeof(hash_map_lua_t)); | |||
hash_map_lua_t *h = *hp; | |||
h->refcount = 1; | |||
h->counter = 0; | |||
h->autolock = 0; | |||
h->h = hash_map_init(); | |||
pthread_mutexattr_t mutex_attr; | |||
pthread_mutexattr_init(&mutex_attr); | |||
pthread_mutexattr_settype(&mutex_attr, PTHREAD_MUTEX_RECURSIVE); | |||
pthread_mutex_init(&h->mutex, &mutex_attr); | |||
luaL_getmetatable(L, "dt.HashMap"); | |||
lua_setmetatable(L, -2); | |||
return 1; | |||
} | |||
int hash_map_gc_lua(lua_State *L) { | |||
hash_map_lua_t *h = *(hash_map_lua_t**)lua_touserdata(L, 1); | |||
if (THAtomicDecrementRef(&h->refcount)) { | |||
pthread_mutex_destroy(&h->mutex); | |||
hash_map_destroy(h->h); | |||
free(h); | |||
} | |||
return 0; | |||
} | |||
int hash_map_retain_lua(lua_State *L) { | |||
hash_map_lua_t *h = *(hash_map_lua_t**)lua_touserdata(L, 1); | |||
THAtomicIncrementRef(&h->refcount); | |||
return 0; | |||
} | |||
int hash_map_metatablename_lua(lua_State *L) { | |||
lua_pushstring(L, "dt.HashMap"); | |||
return 1; | |||
} | |||
int hash_map_clear_lua(lua_State *L) { | |||
hash_map_lua_t *h = *(hash_map_lua_t**)lua_touserdata(L, 1); | |||
autolock(h); | |||
hash_map_clear(h->h); | |||
autounlock(h); | |||
return 0; | |||
} | |||
int hash_map_put_lua(lua_State *L) { | |||
hash_map_lua_t *h = *(hash_map_lua_t**)lua_touserdata(L, 1); | |||
int ret; | |||
#if LUA_VERSION_NUM <= 501 | |||
#define lua_isinteger lua_isnumber | |||
#endif | |||
if (lua_isinteger(L, 2)) { | |||
if (!lua_isinteger(L, 3)) | |||
return LUA_HANDLE_ERROR_STR(L, "second parameter is not a number"); | |||
long key = lua_tointeger(L, 2); | |||
long val = lua_tointeger(L, 3); | |||
autolock(h); | |||
ret = hash_map_put(h->h, key, val); | |||
autounlock(h); | |||
} | |||
else { | |||
THLongTensor *keys = (THLongTensor *)luaT_checkudata(L, 2, "torch.LongTensor"); | |||
THLongTensor *vals = (THLongTensor *)luaT_checkudata(L, 3, "torch.LongTensor"); | |||
check_tensor(L, keys, THLongTensor); | |||
check_tensor(L, vals, THLongTensor); | |||
check_tensors(L, keys, vals); | |||
autolock(h); | |||
ret = hash_map_put_tensor(h->h, keys, vals); | |||
autounlock(h); | |||
} | |||
if (!ret) | |||
return LUA_HANDLE_ERROR_STR(L, "failed to put into hash map"); | |||
return 0; | |||
} | |||
int hash_map_fill_lua(lua_State *L) { | |||
hash_map_lua_t *h = *(hash_map_lua_t**)lua_touserdata(L, 1); | |||
int ret; | |||
if (lua_isinteger(L, 2)) { | |||
long key = lua_tointeger(L, 2); | |||
autolock(h); | |||
ret = hash_map_fill(h->h, key, &h->counter); | |||
autounlock(h); | |||
} | |||
else { | |||
THLongTensor *keys = (THLongTensor *)luaT_checkudata(L, 2, "torch.LongTensor"); | |||
check_tensor(L, keys, THLongTensor); | |||
autolock(h); | |||
ret = hash_map_fill_tensor(h->h, keys, &h->counter); | |||
autounlock(h); | |||
} | |||
if (!ret) | |||
return LUA_HANDLE_ERROR_STR(L, "failed to fill into hash map"); | |||
return 0; | |||
} | |||
int hash_map_adjust_counter_lua(lua_State *L) { | |||
hash_map_lua_t *h_ = *(hash_map_lua_t**)lua_touserdata(L, 1); | |||
internal_hash_map_t h = (internal_hash_map_t) h_->h; | |||
long val; | |||
kh_foreach_value(h, val, { | |||
if (val >= h_->counter) | |||
h_->counter = val; | |||
}); | |||
return 0; | |||
} | |||
int hash_map_set_counter_lua(lua_State *L) { | |||
hash_map_lua_t *h_ = *(hash_map_lua_t**)lua_touserdata(L, 1); | |||
h_->counter = lua_tointeger(L, 2); | |||
return 0; | |||
} | |||
int hash_map_get_counter_lua(lua_State *L) { | |||
hash_map_lua_t *h_ = *(hash_map_lua_t**)lua_touserdata(L, 1); | |||
lua_pushinteger(L, h_->counter); | |||
return 1; | |||
} | |||
static int hash_map_get_tensor_lua(lua_State *L, hash_map_lua_t *h, int inplace) { | |||
THLongTensor *keys = (THLongTensor *)luaT_checkudata(L, 2, "torch.LongTensor"); | |||
check_tensor(L, keys, THLongTensor); | |||
THLongTensor *vals = inplace ? keys : NULL; | |||
THByteTensor *mask = NULL; | |||
int maskIdx = inplace ? 3 : 4; | |||
if (!inplace) { | |||
if (lua_gettop(L) < 3) { | |||
vals = THLongTensor_new(); | |||
} else { | |||
vals = (THLongTensor *)luaT_checkudata(L, 3, "torch.LongTensor"); | |||
check_tensor(L, vals, THLongTensor); | |||
} | |||
} | |||
if (lua_gettop(L) < maskIdx) { | |||
mask = THByteTensor_new(); | |||
} else { | |||
mask = (THByteTensor *)luaT_checkudata(L, maskIdx, "torch.ByteTensor"); | |||
check_tensor(L, mask, THByteTensor); | |||
} | |||
int n_dim = THLongTensor_nDimension(keys); | |||
THLongStorage *st = THLongStorage_newWithSize1(n_dim); | |||
for (int i = 0; i < n_dim; i++) { | |||
THLongStorage_set(st, i, THLongTensor_size(keys, i)); | |||
} | |||
THByteTensor_resize(mask, st, NULL); | |||
if (!inplace) THLongTensor_resize(vals, st, NULL); | |||
THLongStorage_free(st); | |||
autolock(h); | |||
hash_map_get_tensor(h->h, keys, vals, mask); | |||
autounlock(h); | |||
if (!inplace && lua_gettop(L) < 3) | |||
luaT_pushudata(L, vals, "torch.LongTensor"); | |||
if (lua_gettop(L) < maskIdx) | |||
luaT_pushudata(L, mask, "torch.ByteTensor"); | |||
return 2; | |||
} | |||
static int hash_map_get_table_lua(lua_State *L, hash_map_lua_t *h, int inplace) { | |||
const int kidx = 2; | |||
const int vidx = inplace ? 2 : 3; | |||
const int midx = inplace ? 3 : 4; | |||
const int narg = lua_gettop(L); | |||
if (inplace) { | |||
if (narg < 3) { | |||
LUA_HANDLE_ERROR_STR(L, "HashMap.getInplace requires two arguments."); | |||
} | |||
} else { | |||
if (narg < 4) { | |||
LUA_HANDLE_ERROR_STR(L, "HashMap.get requires three arguments."); | |||
} | |||
} | |||
int count = push_table_contents(L, kidx); | |||
verify_push_table_contents(L, vidx, count); | |||
verify_push_table_contents(L, midx, count); | |||
THLongTensor *keys; | |||
THLongTensor *vals; | |||
THByteTensor *mask; | |||
for (int i = count - 1; i >= 0; i--) { | |||
int maskIdx = i - count; | |||
int valIdx = maskIdx - count; | |||
int keyIdx = inplace ? valIdx : (valIdx - count); | |||
keys = (THLongTensor *)luaT_checkudata(L, keyIdx, "torch.LongTensor"); | |||
check_tensor(L, keys, THLongTensor); | |||
if (inplace) { | |||
vals = keys; | |||
} else { | |||
vals = (THLongTensor *)luaT_checkudata(L, valIdx, "torch.LongTensor"); | |||
} | |||
mask = (THByteTensor *)luaT_checkudata(L, maskIdx, "torch.ByteTensor"); | |||
int n_dim = THLongTensor_nDimension(keys); | |||
THLongStorage *st = THLongStorage_newWithSize1(n_dim); | |||
for (int i = 0; i < n_dim; i++) { | |||
THLongStorage_set(st, i, THLongTensor_size(keys, i)); | |||
} | |||
THByteTensor_resize(mask, st, NULL); | |||
THLongTensor_resize(vals, st, NULL); | |||
THLongStorage_free(st); | |||
autolock(h); | |||
hash_map_get_tensor(h->h, keys, vals, mask); | |||
autounlock(h); | |||
} | |||
lua_pop(L, (narg - 1) * count); | |||
return 2; | |||
} | |||
int hash_map_get_lua(lua_State *L) { | |||
hash_map_lua_t *h = *(hash_map_lua_t**)lua_touserdata(L, 1); | |||
if (lua_isinteger(L, 2)) { | |||
long key = lua_tointeger(L, 2); | |||
long val; | |||
autolock(h); | |||
int ret = hash_map_get(h->h, key, &val); | |||
autounlock(h); | |||
if (ret) { | |||
lua_pushinteger(L, val); | |||
lua_pushinteger(L, 1); | |||
} | |||
else { | |||
lua_pushinteger(L, 0); | |||
lua_pushinteger(L, 0); | |||
} | |||
} else if (lua_istable(L, 2)) { | |||
return hash_map_get_table_lua(L, h, 0); | |||
} else { | |||
return hash_map_get_tensor_lua(L, h, 0); | |||
} | |||
return 2; | |||
} | |||
int hash_map_get_inplace_lua(lua_State *L) { | |||
hash_map_lua_t *h = *(hash_map_lua_t**)lua_touserdata(L, 1); | |||
if (lua_isinteger(L, 2)) { | |||
LUA_HANDLE_ERROR_STR(L, "HashMap.getInplace does not support integer arguments."); | |||
} else if (lua_istable(L, 2)) { | |||
return hash_map_get_table_lua(L, h, 1); | |||
} else { | |||
return hash_map_get_tensor_lua(L, h, 1); | |||
} | |||
return 2; | |||
} | |||
int hash_map_del_lua(lua_State *L) { | |||
hash_map_lua_t *h = *(hash_map_lua_t**)lua_touserdata(L, 1); | |||
if (lua_isinteger(L, 2)) { | |||
long key = lua_tointeger(L, 2); | |||
autolock(h); | |||
hash_map_del(h->h, key); | |||
autounlock(h); | |||
} | |||
else { | |||
THLongTensor *keys = (THLongTensor *)luaT_checkudata(L, 2, "torch.LongTensor"); | |||
autolock(h); | |||
hash_map_del_tensor(h->h, keys); | |||
autounlock(h); | |||
} | |||
return 0; | |||
} | |||
int hash_map_size_lua(lua_State *L) { | |||
hash_map_lua_t *h = *(hash_map_lua_t**)lua_touserdata(L, 1); | |||
long size = hash_map_size(h->h); | |||
lua_pushinteger(L, size); | |||
return 1; | |||
} | |||
int hash_map_to_tensor_lua(lua_State *L) { | |||
hash_map_lua_t *h = *(hash_map_lua_t**)lua_touserdata(L, 1); | |||
THLongTensor *keys, *vals; | |||
if (lua_gettop(L) < 2) { | |||
keys = THLongTensor_new(); | |||
} | |||
else { | |||
keys = (THLongTensor *)luaT_checkudata(L, 2, "torch.LongTensor"); | |||
check_tensor(L, keys, THLongTensor); | |||
} | |||
if (lua_gettop(L) < 3) { | |||
vals = THLongTensor_new(); | |||
} | |||
else { | |||
vals = (THLongTensor *)luaT_checkudata(L, 3, "torch.LongTensor"); | |||
check_tensor(L, vals, THLongTensor); | |||
} | |||
size_t size = hash_map_size(h->h); | |||
THLongTensor_resize1d(keys, size); | |||
THLongTensor_resize1d(vals, size); | |||
autolock(h); | |||
hash_map_to_tensor(h->h, keys, vals); | |||
autounlock(h); | |||
if (lua_gettop(L) < 2) | |||
luaT_pushudata(L, keys, "torch.LongTensor"); | |||
if (lua_gettop(L) < 3) | |||
luaT_pushudata(L, vals, "torch.LongTensor"); | |||
return 2; | |||
} |
@@ -1,36 +0,0 @@ | |||
#include "luaT.h" | |||
#include "TH.h" | |||
typedef void* hash_map_t; | |||
hash_map_t hash_map_init(void); | |||
void hash_map_destroy(hash_map_t); | |||
void hash_map_clear(hash_map_t); | |||
int hash_map_put(hash_map_t, long key, long val); | |||
int hash_map_put_tensor(hash_map_t, THLongTensor *keys_, THLongTensor *vals_); | |||
int hash_map_fill(hash_map_t, long key, long *counter); | |||
int hash_map_fill_tensor(hash_map_t, THLongTensor *keys_, long *counter); | |||
int hash_map_get(hash_map_t, long key, long *val); | |||
void hash_map_get_tensor(hash_map_t, THLongTensor *keys_, THLongTensor *vals_, THByteTensor *mask_); | |||
void hash_map_del(hash_map_t, long key); | |||
void hash_map_del_tensor(hash_map_t, THLongTensor *keys_); | |||
size_t hash_map_size(hash_map_t); | |||
void hash_map_to_tensor(hash_map_t, THLongTensor *keys_, THLongTensor *vals_); | |||
int hash_map_autolock_on_lua(lua_State *L); | |||
int hash_map_autolock_off_lua(lua_State *L); | |||
int hash_map_init_lua(lua_State *L); | |||
int hash_map_gc_lua(lua_State *L); | |||
int hash_map_retain_lua(lua_State *L); | |||
int hash_map_metatablename_lua(lua_State *L); | |||
int hash_map_clear_lua(lua_State *L); | |||
int hash_map_put_lua(lua_State *L); | |||
int hash_map_fill_lua(lua_State *L); | |||
int hash_map_adjust_counter_lua(lua_State *L); | |||
int hash_map_set_counter_lua(lua_State *L); | |||
int hash_map_get_counter_lua(lua_State *L); | |||
int hash_map_get_lua(lua_State *L); | |||
int hash_map_get_inplace_lua(lua_State *L); | |||
int hash_map_del_lua(lua_State *L); | |||
int hash_map_size_lua(lua_State *L); | |||
int hash_map_to_tensor_lua(lua_State *L); |
@@ -1,77 +0,0 @@ | |||
#include "TH.h" | |||
#include "luaT.h" | |||
#ifdef _OPENMP | |||
#include "omp.h" | |||
#endif | |||
#include "error.h" | |||
#include "hash_map.h" | |||
#define torch_(NAME) TH_CONCAT_3(torch_, Real, NAME) | |||
#define torch_Tensor TH_CONCAT_STRING_3(torch., Real, Tensor) | |||
#define nn_(NAME) TH_CONCAT_3(nn_, Real, NAME) | |||
#include "generic/LogitBoostCriterion.c" | |||
#include "THGenerateFloatTypes.h" | |||
#include "generic/DFD.c" | |||
#include "THGenerateFloatTypes.h" | |||
#include "generic/S2D.c" | |||
#include "THGenerateFloatTypes.h" | |||
#include "generic/CartTree.c" | |||
#include "THGenerateFloatTypes.h" | |||
#include "GBDT_common.h" | |||
#include "generic/GBDT.c" | |||
#include "THGenerateFloatTypes.h" | |||
static const struct luaL_Reg decisiontree_hash_map_routines[] = { | |||
{"__gc", hash_map_gc_lua}, | |||
{"retain", hash_map_retain_lua}, | |||
{"metatablename", hash_map_metatablename_lua}, | |||
{"clear", hash_map_clear_lua}, | |||
{"put", hash_map_put_lua}, | |||
{"fill", hash_map_fill_lua}, | |||
{"adjustCounter", hash_map_adjust_counter_lua}, | |||
{"getCounter", hash_map_get_counter_lua}, | |||
{"setCounter", hash_map_set_counter_lua}, | |||
{"get", hash_map_get_lua}, | |||
{"getInplace", hash_map_get_inplace_lua}, | |||
{"del", hash_map_del_lua}, | |||
{"size", hash_map_size_lua}, | |||
{"safe", hash_map_autolock_on_lua}, | |||
{"unsafe", hash_map_autolock_off_lua}, | |||
{"toTensors", hash_map_to_tensor_lua}, | |||
{"new", hash_map_init_lua}, | |||
{NULL, NULL} | |||
}; | |||
DLL_EXPORT int luaopen_libdecisiontree(lua_State *L) | |||
{ | |||
// HashMap | |||
luaL_newmetatable(L, "dt.HashMap"); | |||
lua_pushstring(L, "__index"); | |||
lua_pushvalue(L, -2); | |||
lua_settable(L, -3); | |||
luaT_setfuncs(L, decisiontree_hash_map_routines, 0); | |||
nn_FloatLogitBoostCriterion_init(L); | |||
nn_DoubleLogitBoostCriterion_init(L); | |||
nn_FloatDFD_init(L); | |||
nn_DoubleDFD_init(L); | |||
nn_FloatS2D_init(L); | |||
nn_DoubleS2D_init(L); | |||
nn_FloatCT_init(L); | |||
nn_DoubleCT_init(L); | |||
nn_FloatGBDT_init(L); | |||
nn_DoubleGBDT_init(L); | |||
return 1; | |||
} |
@@ -1,70 +0,0 @@ | |||
require 'paths' | |||
--require 'xlua' | |||
require 'string' | |||
require 'os' | |||
--require 'sys' | |||
require 'nn' | |||
require 'moses' | |||
unpack = unpack or table.unpack | |||
local dt = require 'decisiontree._env' | |||
require "paths" | |||
paths.require 'libdecisiontree' | |||
dt.HashMap = torch.getmetatable("dt.HashMap").new | |||
dt.EPSILON = 1e-6 | |||
require 'decisiontree.SparseTensor' | |||
require 'decisiontree.math' | |||
require 'decisiontree.utils' | |||
--require 'decisiontree.WorkPool' | |||
require 'decisiontree.DecisionTree' | |||
require 'decisiontree.DecisionForest' | |||
require 'decisiontree.DecisionForestTrainer' | |||
require 'decisiontree.TreeState' | |||
require 'decisiontree.CartNode' | |||
require 'decisiontree.CartTree' | |||
require 'decisiontree.MSECriterion' | |||
require 'decisiontree.LogitBoostCriterion' | |||
require 'decisiontree.CartTrainer' | |||
require 'decisiontree.DataSet' | |||
require 'decisiontree.RandomForestTrainer' | |||
require 'decisiontree.GiniState' -- TreeState subclass | |||
require 'decisiontree.GradientBoostTrainer' | |||
require 'decisiontree.GradientBoostState' -- TreeState subclass | |||
--require 'decisiontree.test' | |||
--require 'decisiontree.benchmark' | |||
require 'decisiontree.DFD' | |||
require 'decisiontree.Sparse2Dense' | |||
return dt |
@@ -1,13 +0,0 @@ | |||
#include "khash.h" | |||
#include <pthread.h> | |||
KHASH_MAP_INIT_INT64(long, long) | |||
typedef khash_t(long)* internal_hash_map_t; | |||
typedef struct { | |||
hash_map_t h; | |||
int refcount; | |||
pthread_mutex_t mutex; | |||
int autolock; | |||
long counter; | |||
} hash_map_lua_t; |
@@ -1,627 +0,0 @@ | |||
/* The MIT License | |||
Copyright (c) 2008, 2009, 2011 by Attractive Chaos <attractor@live.co.uk> | |||
Permission is hereby granted, free of charge, to any person obtaining | |||
a copy of this software and associated documentation files (the | |||
"Software"), to deal in the Software without restriction, including | |||
without limitation the rights to use, copy, modify, merge, publish, | |||
distribute, sublicense, and/or sell copies of the Software, and to | |||
permit persons to whom the Software is furnished to do so, subject to | |||
the following conditions: | |||
The above copyright notice and this permission notice shall be | |||
included in all copies or substantial portions of the Software. | |||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, | |||
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF | |||
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND | |||
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS | |||
BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN | |||
ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN | |||
CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | |||
SOFTWARE. | |||
*/ | |||
/* | |||
An example: | |||
#include "khash.h" | |||
KHASH_MAP_INIT_INT(32, char) | |||
int main() { | |||
int ret, is_missing; | |||
khiter_t k; | |||
khash_t(32) *h = kh_init(32); | |||
k = kh_put(32, h, 5, &ret); | |||
kh_value(h, k) = 10; | |||
k = kh_get(32, h, 10); | |||
is_missing = (k == kh_end(h)); | |||
k = kh_get(32, h, 5); | |||
kh_del(32, h, k); | |||
for (k = kh_begin(h); k != kh_end(h); ++k) | |||
if (kh_exist(h, k)) kh_value(h, k) = 1; | |||
kh_destroy(32, h); | |||
return 0; | |||
} | |||
*/ | |||
/* | |||
2013-05-02 (0.2.8): | |||
* Use quadratic probing. When the capacity is power of 2, stepping function | |||
i*(i+1)/2 guarantees to traverse each bucket. It is better than double | |||
hashing on cache performance and is more robust than linear probing. | |||
In theory, double hashing should be more robust than quadratic probing. | |||
However, my implementation is probably not for large hash tables, because | |||
the second hash function is closely tied to the first hash function, | |||
which reduce the effectiveness of double hashing. | |||
Reference: http://research.cs.vt.edu/AVresearch/hashing/quadratic.php | |||
2011-12-29 (0.2.7): | |||
* Minor code clean up; no actual effect. | |||
2011-09-16 (0.2.6): | |||
* The capacity is a power of 2. This seems to dramatically improve the | |||
speed for simple keys. Thank Zilong Tan for the suggestion. Reference: | |||
- http://code.google.com/p/ulib/ | |||
- http://nothings.org/computer/judy/ | |||
* Allow to optionally use linear probing which usually has better | |||
performance for random input. Double hashing is still the default as it | |||
is more robust to certain non-random input. | |||
* Added Wang's integer hash function (not used by default). This hash | |||
function is more robust to certain non-random input. | |||
2011-02-14 (0.2.5): | |||
* Allow to declare global functions. | |||
2009-09-26 (0.2.4): | |||
* Improve portability | |||
2008-09-19 (0.2.3): | |||
* Corrected the example | |||
* Improved interfaces | |||
2008-09-11 (0.2.2): | |||
* Improved speed a little in kh_put() | |||
2008-09-10 (0.2.1): | |||
* Added kh_clear() | |||
* Fixed a compiling error | |||
2008-09-02 (0.2.0): | |||
* Changed to token concatenation which increases flexibility. | |||
2008-08-31 (0.1.2): | |||
* Fixed a bug in kh_get(), which has not been tested previously. | |||
2008-08-31 (0.1.1): | |||
* Added destructor | |||
*/ | |||
#ifndef __AC_KHASH_H | |||
#define __AC_KHASH_H | |||
/*! | |||
@header | |||
Generic hash table library. | |||
*/ | |||
#define AC_VERSION_KHASH_H "0.2.8" | |||
#include <stdlib.h> | |||
#include <string.h> | |||
#include <limits.h> | |||
/* compiler specific configuration */ | |||
#if UINT_MAX == 0xffffffffu | |||
typedef unsigned int khint32_t; | |||
#elif ULONG_MAX == 0xffffffffu | |||
typedef unsigned long khint32_t; | |||
#endif | |||
#if ULONG_MAX == ULLONG_MAX | |||
typedef unsigned long khint64_t; | |||
#else | |||
typedef unsigned long long khint64_t; | |||
#endif | |||
#ifndef kh_inline | |||
#ifdef _MSC_VER | |||
#define kh_inline __inline | |||
#else | |||
#define kh_inline inline | |||
#endif | |||
#endif /* kh_inline */ | |||
#ifndef klib_unused | |||
#if (defined __clang__ && __clang_major__ >= 3) || (defined __GNUC__ && __GNUC__ >= 3) | |||
#define klib_unused __attribute__ ((__unused__)) | |||
#else | |||
#define klib_unused | |||
#endif | |||
#endif /* klib_unused */ | |||
typedef khint32_t khint_t; | |||
typedef khint_t khiter_t; | |||
#define __ac_isempty(flag, i) ((flag[i>>4]>>((i&0xfU)<<1))&2) | |||
#define __ac_isdel(flag, i) ((flag[i>>4]>>((i&0xfU)<<1))&1) | |||
#define __ac_iseither(flag, i) ((flag[i>>4]>>((i&0xfU)<<1))&3) | |||
#define __ac_set_isdel_false(flag, i) (flag[i>>4]&=~(1ul<<((i&0xfU)<<1))) | |||
#define __ac_set_isempty_false(flag, i) (flag[i>>4]&=~(2ul<<((i&0xfU)<<1))) | |||
#define __ac_set_isboth_false(flag, i) (flag[i>>4]&=~(3ul<<((i&0xfU)<<1))) | |||
#define __ac_set_isdel_true(flag, i) (flag[i>>4]|=1ul<<((i&0xfU)<<1)) | |||
#define __ac_fsize(m) ((m) < 16? 1 : (m)>>4) | |||
#ifndef kroundup32 | |||
#define kroundup32(x) (--(x), (x)|=(x)>>1, (x)|=(x)>>2, (x)|=(x)>>4, (x)|=(x)>>8, (x)|=(x)>>16, ++(x)) | |||
#endif | |||
#ifndef kcalloc | |||
#define kcalloc(N,Z) calloc(N,Z) | |||
#endif | |||
#ifndef kmalloc | |||
#define kmalloc(Z) malloc(Z) | |||
#endif | |||
#ifndef krealloc | |||
#define krealloc(P,Z) realloc(P,Z) | |||
#endif | |||
#ifndef kfree | |||
#define kfree(P) free(P) | |||
#endif | |||
static const double __ac_HASH_UPPER = 0.77; | |||
#define __KHASH_TYPE(name, khkey_t, khval_t) \ | |||
typedef struct kh_##name##_s { \ | |||
khint_t n_buckets, size, n_occupied, upper_bound; \ | |||
khint32_t *flags; \ | |||
khkey_t *keys; \ | |||
khval_t *vals; \ | |||
} kh_##name##_t; | |||
#define __KHASH_PROTOTYPES(name, khkey_t, khval_t) \ | |||
extern kh_##name##_t *kh_init_##name(void); \ | |||
extern void kh_destroy_##name(kh_##name##_t *h); \ | |||
extern void kh_clear_##name(kh_##name##_t *h); \ | |||
extern khint_t kh_get_##name(const kh_##name##_t *h, khkey_t key); \ | |||
extern int kh_resize_##name(kh_##name##_t *h, khint_t new_n_buckets); \ | |||
extern khint_t kh_put_##name(kh_##name##_t *h, khkey_t key, int *ret); \ | |||
extern void kh_del_##name(kh_##name##_t *h, khint_t x); | |||
#define __KHASH_IMPL(name, SCOPE, khkey_t, khval_t, kh_is_map, __hash_func, __hash_equal) \ | |||
SCOPE kh_##name##_t *kh_init_##name(void) { \ | |||
return (kh_##name##_t*)kcalloc(1, sizeof(kh_##name##_t)); \ | |||
} \ | |||
SCOPE void kh_destroy_##name(kh_##name##_t *h) \ | |||
{ \ | |||
if (h) { \ | |||
kfree((void *)h->keys); kfree(h->flags); \ | |||
kfree((void *)h->vals); \ | |||
kfree(h); \ | |||
} \ | |||
} \ | |||
SCOPE void kh_clear_##name(kh_##name##_t *h) \ | |||
{ \ | |||
if (h && h->flags) { \ | |||
memset(h->flags, 0xaa, __ac_fsize(h->n_buckets) * sizeof(khint32_t)); \ | |||
h->size = h->n_occupied = 0; \ | |||
} \ | |||
} \ | |||
SCOPE khint_t kh_get_##name(const kh_##name##_t *h, khkey_t key) \ | |||
{ \ | |||
if (h->n_buckets) { \ | |||
khint_t k, i, last, mask, step = 0; \ | |||
mask = h->n_buckets - 1; \ | |||
k = __hash_func(key); i = k & mask; \ | |||
last = i; \ | |||
while (!__ac_isempty(h->flags, i) && (__ac_isdel(h->flags, i) || !__hash_equal(h->keys[i], key))) { \ | |||
i = (i + (++step)) & mask; \ | |||
if (i == last) return h->n_buckets; \ | |||
} \ | |||
return __ac_iseither(h->flags, i)? h->n_buckets : i; \ | |||
} else return 0; \ | |||
} \ | |||
SCOPE int kh_resize_##name(kh_##name##_t *h, khint_t new_n_buckets) \ | |||
{ /* This function uses 0.25*n_buckets bytes of working space instead of [sizeof(key_t+val_t)+.25]*n_buckets. */ \ | |||
khint32_t *new_flags = 0; \ | |||
khint_t j = 1; \ | |||
{ \ | |||
kroundup32(new_n_buckets); \ | |||
if (new_n_buckets < 4) new_n_buckets = 4; \ | |||
if (h->size >= (khint_t)(new_n_buckets * __ac_HASH_UPPER + 0.5)) j = 0; /* requested size is too small */ \ | |||
else { /* hash table size to be changed (shrink or expand); rehash */ \ | |||
new_flags = (khint32_t*)kmalloc(__ac_fsize(new_n_buckets) * sizeof(khint32_t)); \ | |||
if (!new_flags) return -1; \ | |||
memset(new_flags, 0xaa, __ac_fsize(new_n_buckets) * sizeof(khint32_t)); \ | |||
if (h->n_buckets < new_n_buckets) { /* expand */ \ | |||
khkey_t *new_keys = (khkey_t*)krealloc((void *)h->keys, new_n_buckets * sizeof(khkey_t)); \ | |||
if (!new_keys) { kfree(new_flags); return -1; } \ | |||
h->keys = new_keys; \ | |||
if (kh_is_map) { \ | |||
khval_t *new_vals = (khval_t*)krealloc((void *)h->vals, new_n_buckets * sizeof(khval_t)); \ | |||
if (!new_vals) { kfree(new_flags); return -1; } \ | |||
h->vals = new_vals; \ | |||
} \ | |||
} /* otherwise shrink */ \ | |||
} \ | |||
} \ | |||
if (j) { /* rehashing is needed */ \ | |||
for (j = 0; j != h->n_buckets; ++j) { \ | |||
if (__ac_iseither(h->flags, j) == 0) { \ | |||
khkey_t key = h->keys[j]; \ | |||
khval_t val; \ | |||
khint_t new_mask; \ | |||
new_mask = new_n_buckets - 1; \ | |||
if (kh_is_map) val = h->vals[j]; \ | |||
__ac_set_isdel_true(h->flags, j); \ | |||
while (1) { /* kick-out process; sort of like in Cuckoo hashing */ \ | |||
khint_t k, i, step = 0; \ | |||
k = __hash_func(key); \ | |||
i = k & new_mask; \ | |||
while (!__ac_isempty(new_flags, i)) i = (i + (++step)) & new_mask; \ | |||
__ac_set_isempty_false(new_flags, i); \ | |||
if (i < h->n_buckets && __ac_iseither(h->flags, i) == 0) { /* kick out the existing element */ \ | |||
{ khkey_t tmp = h->keys[i]; h->keys[i] = key; key = tmp; } \ | |||
if (kh_is_map) { khval_t tmp = h->vals[i]; h->vals[i] = val; val = tmp; } \ | |||
__ac_set_isdel_true(h->flags, i); /* mark it as deleted in the old hash table */ \ | |||
} else { /* write the element and jump out of the loop */ \ | |||
h->keys[i] = key; \ | |||
if (kh_is_map) h->vals[i] = val; \ | |||
break; \ | |||
} \ | |||
} \ | |||
} \ | |||
} \ | |||
if (h->n_buckets > new_n_buckets) { /* shrink the hash table */ \ | |||
h->keys = (khkey_t*)krealloc((void *)h->keys, new_n_buckets * sizeof(khkey_t)); \ | |||
if (kh_is_map) h->vals = (khval_t*)krealloc((void *)h->vals, new_n_buckets * sizeof(khval_t)); \ | |||
} \ | |||
kfree(h->flags); /* free the working space */ \ | |||
h->flags = new_flags; \ | |||
h->n_buckets = new_n_buckets; \ | |||
h->n_occupied = h->size; \ | |||
h->upper_bound = (khint_t)(h->n_buckets * __ac_HASH_UPPER + 0.5); \ | |||
} \ | |||
return 0; \ | |||
} \ | |||
SCOPE khint_t kh_put_##name(kh_##name##_t *h, khkey_t key, int *ret) \ | |||
{ \ | |||
khint_t x; \ | |||
if (h->n_occupied >= h->upper_bound) { /* update the hash table */ \ | |||
if (h->n_buckets > (h->size<<1)) { \ | |||
if (kh_resize_##name(h, h->n_buckets - 1) < 0) { /* clear "deleted" elements */ \ | |||
*ret = -1; return h->n_buckets; \ | |||
} \ | |||
} else if (kh_resize_##name(h, h->n_buckets + 1) < 0) { /* expand the hash table */ \ | |||
*ret = -1; return h->n_buckets; \ | |||
} \ | |||
} /* TODO: to implement automatically shrinking; resize() already support shrinking */ \ | |||
{ \ | |||
khint_t k, i, site, last, mask = h->n_buckets - 1, step = 0; \ | |||
x = site = h->n_buckets; k = __hash_func(key); i = k & mask; \ | |||
if (__ac_isempty(h->flags, i)) x = i; /* for speed up */ \ | |||
else { \ | |||
last = i; \ | |||
while (!__ac_isempty(h->flags, i) && (__ac_isdel(h->flags, i) || !__hash_equal(h->keys[i], key))) { \ | |||
if (__ac_isdel(h->flags, i)) site = i; \ | |||
i = (i + (++step)) & mask; \ | |||
if (i == last) { x = site; break; } \ | |||
} \ | |||
if (x == h->n_buckets) { \ | |||
if (__ac_isempty(h->flags, i) && site != h->n_buckets) x = site; \ | |||
else x = i; \ | |||
} \ | |||
} \ | |||
} \ | |||
if (__ac_isempty(h->flags, x)) { /* not present at all */ \ | |||
h->keys[x] = key; \ | |||
__ac_set_isboth_false(h->flags, x); \ | |||
++h->size; ++h->n_occupied; \ | |||
*ret = 1; \ | |||
} else if (__ac_isdel(h->flags, x)) { /* deleted */ \ | |||
h->keys[x] = key; \ | |||
__ac_set_isboth_false(h->flags, x); \ | |||
++h->size; \ | |||
*ret = 2; \ | |||
} else *ret = 0; /* Don't touch h->keys[x] if present and not deleted */ \ | |||
return x; \ | |||
} \ | |||
SCOPE void kh_del_##name(kh_##name##_t *h, khint_t x) \ | |||
{ \ | |||
if (x != h->n_buckets && !__ac_iseither(h->flags, x)) { \ | |||
__ac_set_isdel_true(h->flags, x); \ | |||
--h->size; \ | |||
} \ | |||
} | |||
#define KHASH_DECLARE(name, khkey_t, khval_t) \ | |||
__KHASH_TYPE(name, khkey_t, khval_t) \ | |||
__KHASH_PROTOTYPES(name, khkey_t, khval_t) | |||
#define KHASH_INIT2(name, SCOPE, khkey_t, khval_t, kh_is_map, __hash_func, __hash_equal) \ | |||
__KHASH_TYPE(name, khkey_t, khval_t) \ | |||
__KHASH_IMPL(name, SCOPE, khkey_t, khval_t, kh_is_map, __hash_func, __hash_equal) | |||
#define KHASH_INIT(name, khkey_t, khval_t, kh_is_map, __hash_func, __hash_equal) \ | |||
KHASH_INIT2(name, static kh_inline klib_unused, khkey_t, khval_t, kh_is_map, __hash_func, __hash_equal) | |||
/* --- BEGIN OF HASH FUNCTIONS --- */ | |||
/*! @function | |||
@abstract Integer hash function | |||
@param key The integer [khint32_t] | |||
@return The hash value [khint_t] | |||
*/ | |||
#define kh_int_hash_func(key) (khint32_t)(key) | |||
/*! @function | |||
@abstract Integer comparison function | |||
*/ | |||
#define kh_int_hash_equal(a, b) ((a) == (b)) | |||
/*! @function | |||
@abstract 64-bit integer hash function | |||
@param key The integer [khint64_t] | |||
@return The hash value [khint_t] | |||
*/ | |||
#define kh_int64_hash_func(key) (khint32_t)((key)>>33^(key)^(key)<<11) | |||
/*! @function | |||
@abstract 64-bit integer comparison function | |||
*/ | |||
#define kh_int64_hash_equal(a, b) ((a) == (b)) | |||
/*! @function | |||
@abstract const char* hash function | |||
@param s Pointer to a null terminated string | |||
@return The hash value | |||
*/ | |||
static kh_inline khint_t __ac_X31_hash_string(const char *s) | |||
{ | |||
khint_t h = (khint_t)*s; | |||
if (h) for (++s ; *s; ++s) h = (h << 5) - h + (khint_t)*s; | |||
return h; | |||
} | |||
/*! @function | |||
@abstract Another interface to const char* hash function | |||
@param key Pointer to a null terminated string [const char*] | |||
@return The hash value [khint_t] | |||
*/ | |||
#define kh_str_hash_func(key) __ac_X31_hash_string(key) | |||
/*! @function | |||
@abstract Const char* comparison function | |||
*/ | |||
#define kh_str_hash_equal(a, b) (strcmp(a, b) == 0) | |||
static kh_inline khint_t __ac_Wang_hash(khint_t key) | |||
{ | |||
key += ~(key << 15); | |||
key ^= (key >> 10); | |||
key += (key << 3); | |||
key ^= (key >> 6); | |||
key += ~(key << 11); | |||
key ^= (key >> 16); | |||
return key; | |||
} | |||
#define kh_int_hash_func2(key) __ac_Wang_hash((khint_t)key) | |||
/* --- END OF HASH FUNCTIONS --- */ | |||
/* Other convenient macros... */ | |||
/*! | |||
@abstract Type of the hash table. | |||
@param name Name of the hash table [symbol] | |||
*/ | |||
#define khash_t(name) kh_##name##_t | |||
/*! @function | |||
@abstract Initiate a hash table. | |||
@param name Name of the hash table [symbol] | |||
@return Pointer to the hash table [khash_t(name)*] | |||
*/ | |||
#define kh_init(name) kh_init_##name() | |||
/*! @function | |||
@abstract Destroy a hash table. | |||
@param name Name of the hash table [symbol] | |||
@param h Pointer to the hash table [khash_t(name)*] | |||
*/ | |||
#define kh_destroy(name, h) kh_destroy_##name(h) | |||
/*! @function | |||
@abstract Reset a hash table without deallocating memory. | |||
@param name Name of the hash table [symbol] | |||
@param h Pointer to the hash table [khash_t(name)*] | |||
*/ | |||
#define kh_clear(name, h) kh_clear_##name(h) | |||
/*! @function | |||
@abstract Resize a hash table. | |||
@param name Name of the hash table [symbol] | |||
@param h Pointer to the hash table [khash_t(name)*] | |||
@param s New size [khint_t] | |||
*/ | |||
#define kh_resize(name, h, s) kh_resize_##name(h, s) | |||
/*! @function | |||
@abstract Insert a key to the hash table. | |||
@param name Name of the hash table [symbol] | |||
@param h Pointer to the hash table [khash_t(name)*] | |||
@param k Key [type of keys] | |||
@param r Extra return code: -1 if the operation failed; | |||
0 if the key is present in the hash table; | |||
1 if the bucket is empty (never used); 2 if the element in | |||
the bucket has been deleted [int*] | |||
@return Iterator to the inserted element [khint_t] | |||
*/ | |||
#define kh_put(name, h, k, r) kh_put_##name(h, k, r) | |||
/*! @function | |||
@abstract Retrieve a key from the hash table. | |||
@param name Name of the hash table [symbol] | |||
@param h Pointer to the hash table [khash_t(name)*] | |||
@param k Key [type of keys] | |||
@return Iterator to the found element, or kh_end(h) if the element is absent [khint_t] | |||
*/ | |||
#define kh_get(name, h, k) kh_get_##name(h, k) | |||
/*! @function | |||
@abstract Remove a key from the hash table. | |||
@param name Name of the hash table [symbol] | |||
@param h Pointer to the hash table [khash_t(name)*] | |||
@param k Iterator to the element to be deleted [khint_t] | |||
*/ | |||
#define kh_del(name, h, k) kh_del_##name(h, k) | |||
/*! @function | |||
@abstract Test whether a bucket contains data. | |||
@param h Pointer to the hash table [khash_t(name)*] | |||
@param x Iterator to the bucket [khint_t] | |||
@return 1 if containing data; 0 otherwise [int] | |||
*/ | |||
#define kh_exist(h, x) (!__ac_iseither((h)->flags, (x))) | |||
/*! @function | |||
@abstract Get key given an iterator | |||
@param h Pointer to the hash table [khash_t(name)*] | |||
@param x Iterator to the bucket [khint_t] | |||
@return Key [type of keys] | |||
*/ | |||
#define kh_key(h, x) ((h)->keys[x]) | |||
/*! @function | |||
@abstract Get value given an iterator | |||
@param h Pointer to the hash table [khash_t(name)*] | |||
@param x Iterator to the bucket [khint_t] | |||
@return Value [type of values] | |||
@discussion For hash sets, calling this results in segfault. | |||
*/ | |||
#define kh_val(h, x) ((h)->vals[x]) | |||
/*! @function | |||
@abstract Alias of kh_val() | |||
*/ | |||
#define kh_value(h, x) ((h)->vals[x]) | |||
/*! @function | |||
@abstract Get the start iterator | |||
@param h Pointer to the hash table [khash_t(name)*] | |||
@return The start iterator [khint_t] | |||
*/ | |||
#define kh_begin(h) (khint_t)(0) | |||
/*! @function | |||
@abstract Get the end iterator | |||
@param h Pointer to the hash table [khash_t(name)*] | |||
@return The end iterator [khint_t] | |||
*/ | |||
#define kh_end(h) ((h)->n_buckets) | |||
/*! @function | |||
@abstract Get the number of elements in the hash table | |||
@param h Pointer to the hash table [khash_t(name)*] | |||
@return Number of elements in the hash table [khint_t] | |||
*/ | |||
#define kh_size(h) ((h)->size) | |||
/*! @function | |||
@abstract Get the number of buckets in the hash table | |||
@param h Pointer to the hash table [khash_t(name)*] | |||
@return Number of buckets in the hash table [khint_t] | |||
*/ | |||
#define kh_n_buckets(h) ((h)->n_buckets) | |||
/*! @function | |||
@abstract Iterate over the entries in the hash table | |||
@param h Pointer to the hash table [khash_t(name)*] | |||
@param kvar Variable to which key will be assigned | |||
@param vvar Variable to which value will be assigned | |||
@param code Block of code to execute | |||
*/ | |||
#define kh_foreach(h, kvar, vvar, code) { khint_t __i; \ | |||
for (__i = kh_begin(h); __i != kh_end(h); ++__i) { \ | |||
if (!kh_exist(h,__i)) continue; \ | |||
(kvar) = kh_key(h,__i); \ | |||
(vvar) = kh_val(h,__i); \ | |||
code; \ | |||
} } | |||
/*! @function | |||
@abstract Iterate over the values in the hash table | |||
@param h Pointer to the hash table [khash_t(name)*] | |||
@param vvar Variable to which value will be assigned | |||
@param code Block of code to execute | |||
*/ | |||
#define kh_foreach_value(h, vvar, code) { khint_t __i; \ | |||
for (__i = kh_begin(h); __i != kh_end(h); ++__i) { \ | |||
if (!kh_exist(h,__i)) continue; \ | |||
(vvar) = kh_val(h,__i); \ | |||
code; \ | |||
} } | |||
/* More conenient interfaces */ | |||
/*! @function | |||
@abstract Instantiate a hash set containing integer keys | |||
@param name Name of the hash table [symbol] | |||
*/ | |||
#define KHASH_SET_INIT_INT(name) \ | |||
KHASH_INIT(name, khint32_t, char, 0, kh_int_hash_func, kh_int_hash_equal) | |||
/*! @function | |||
@abstract Instantiate a hash map containing integer keys | |||
@param name Name of the hash table [symbol] | |||
@param khval_t Type of values [type] | |||
*/ | |||
#define KHASH_MAP_INIT_INT(name, khval_t) \ | |||
KHASH_INIT(name, khint32_t, khval_t, 1, kh_int_hash_func, kh_int_hash_equal) | |||
/*! @function | |||
@abstract Instantiate a hash map containing 64-bit integer keys | |||
@param name Name of the hash table [symbol] | |||
*/ | |||
#define KHASH_SET_INIT_INT64(name) \ | |||
KHASH_INIT(name, khint64_t, char, 0, kh_int64_hash_func, kh_int64_hash_equal) | |||
/*! @function | |||
@abstract Instantiate a hash map containing 64-bit integer keys | |||
@param name Name of the hash table [symbol] | |||
@param khval_t Type of values [type] | |||
*/ | |||
#define KHASH_MAP_INIT_INT64(name, khval_t) \ | |||
KHASH_INIT(name, khint64_t, khval_t, 1, kh_int64_hash_func, kh_int64_hash_equal) | |||
typedef const char *kh_cstr_t; | |||
/*! @function | |||
@abstract Instantiate a hash map containing const char* keys | |||
@param name Name of the hash table [symbol] | |||
*/ | |||
#define KHASH_SET_INIT_STR(name) \ | |||
KHASH_INIT(name, kh_cstr_t, char, 0, kh_str_hash_func, kh_str_hash_equal) | |||
/*! @function | |||
@abstract Instantiate a hash map containing const char* keys | |||
@param name Name of the hash table [symbol] | |||
@param khval_t Type of values [type] | |||
*/ | |||
#define KHASH_MAP_INIT_STR(name, khval_t) \ | |||
KHASH_INIT(name, kh_cstr_t, khval_t, 1, kh_str_hash_func, kh_str_hash_equal) | |||
#endif /* __AC_KHASH_H */ |
@@ -1,84 +0,0 @@ | |||
local dt = require "decisiontree._env" | |||
local PSEUDOCOUNT = 1.0 | |||
local MIN_LOGISTIC = 1E-8 | |||
local MAX_LOGISTIC = 1.0 - MIN_LOGISTIC | |||
function dt.uniquecounts(counts, inputset, nclass) | |||
counts = counts or inputset.input.new() | |||
nclass = nclass or inputset.target:max() | |||
counts:resize(nclass):zero() | |||
inputset.target:apply(function(c) counts[c] = counts[c] + 1 end) | |||
return counts | |||
end | |||
local counts, logprobs | |||
function dt.entropy(inputset, nclass) | |||
local dt = require 'decisiontree' | |||
counts = dt.uniquecounts(counts, inputset, nclass) | |||
-- convert counts to categorical probabilities | |||
counts:add(0.0000001) -- prevent NaN | |||
counts:div(counts:sum()) | |||
logprobs = logprobs or counts.new() | |||
logprobs:resize(counts:size()) | |||
logprobs:log(counts):div(math.log(2)) -- log2(x) | |||
counts:cmul(logprobs) | |||
return -counts:sum() | |||
end | |||
function dt.probabilityPositive(nPositive, nTotal) | |||
return (nPositive + PSEUDOCOUNT) / (nTotal + 2.0 * PSEUDOCOUNT); | |||
end | |||
function dt.logit(p) | |||
assert(p >= 0.0 and p <= 1.0, "Expecting probability for arg 1") | |||
local truncatedP = math.max(MIN_LOGISTIC, math.min(MAX_LOGISTIC, p)) | |||
return math.log(truncatedP / (1.0 - truncatedP)) | |||
end | |||
function dt.logistic(x) | |||
return (x >= 0) and (1 / (1 + math.exp(-x))) or (1 - 1 / (1 + math.exp(x))) | |||
end | |||
function dt.computeGradientBoostLoss(gradient, hessian) | |||
return -gradient * gradient / hessian | |||
end | |||
function dt.computeNewtonScore(gradient, hessian) | |||
return -0.5 * gradient / hessian; | |||
end | |||
function dt.calculateLogitScore(nPositive, nTotal) | |||
local dt = require 'decisiontree' | |||
return dt.logit(dt.probabilityPositive(nPositive, nTotal)) | |||
end | |||
function dt.computeGini(leftCount, positiveLeftCount, rightCount, positiveRightCount) | |||
assert(torch.type(leftCount) == 'number', 'Expecting total number examples falling into leftBranch.') | |||
assert(torch.type(positiveLeftCount) == 'number', 'Expecting total number of positive examples falling into left branch.') | |||
assert(torch.type(rightCount) == 'number', 'Expecting total number of examples falling into the right branch.') | |||
assert(torch.type(positiveRightCount) == 'number', 'Expecting total number of positive examples falling into the right branch.') | |||
local total = leftCount + rightCount | |||
local pPositiveLeft = leftCount == 0 and 0 or (positiveLeftCount / leftCount) | |||
local leftGini = pPositiveLeft * (1.0 - pPositiveLeft) | |||
local pPositiveRight = rightCount == 0 and 0 or (positiveRightCount / rightCount) | |||
local rightGini = pPositiveRight * (1.0 - pPositiveRight) | |||
return (leftCount * leftGini + rightCount * rightGini) / total | |||
end |
@@ -1,40 +0,0 @@ | |||
package = "decisiontree" | |||
version = "scm-1" | |||
source = { | |||
url = "git://github.com/Twitter/decisiontree", | |||
tag = "master" | |||
} | |||
description = { | |||
summary = "Decision trees for Torch by Twitter", | |||
detailed = [[ | |||
Classification and regression trees (CART). | |||
Gradients boosted decision trees (GBDT). | |||
]], | |||
homepage = "https://github.com/Twitter/decisiontree", | |||
license = "BSD" | |||
} | |||
dependencies = { | |||
"torch >= 7.0", | |||
"moses >= 1.3.1", | |||
"xlua >= 1.0", | |||
"image >= 1.0", | |||
"luafilesystem >= 1.6.2", | |||
"sys >= 1.1", | |||
"paths >= 1.0", | |||
"ipc >= 1.0", | |||
"nn >= 1.0" | |||
} | |||
build = { | |||
type = "command", | |||
build_command = [[ | |||
cmake -E make_directory build; | |||
cd build; | |||
cmake .. -DCMAKE_BUILD_TYPE=Release -DCMAKE_PREFIX_PATH="$(LUA_BINDIR)/.." -DCMAKE_INSTALL_PREFIX="$(PREFIX)" -DCMAKE_C_FLAGS=-fPIC -DCMAKE_CXX_FLAGS=-fPIC; | |||
$(MAKE) | |||
]], | |||
install_command = "cd build && $(MAKE) install" | |||
} |
@@ -1,817 +0,0 @@ | |||
local dt = require "decisiontree._env" | |||
local dttest = {} | |||
local nloop = 50 | |||
local epsilon = 0.000001 | |||
local mytester | |||
--e.g. usage: th -e "dt = require 'decisiontree'; dt.test()" | |||
local function testAccuracy(cartTree, name, dataset, minacc) | |||
assert(torch.isTypeOf(dataset, 'dt.DataSet')) | |||
minacc = minacc or 0.99 | |||
local output = torch.Tensor(dataset:size()) | |||
local target, input = dataset.target, dataset.input | |||
for i=1,dataset:size() do | |||
local stack = {} | |||
local score = cartTree:score(input[i], stack) | |||
output[i] = score >= 0 and 1 or 0 | |||
if dt.VERBOSE and torch.type(cartTree) == 'dt.CartTree' and target[i] ~= output[i] then | |||
print(cartTree:stackToString(stack, example.input)) | |||
print(i, score, target[i], output[i]) | |||
end | |||
end | |||
local accuracy = torch.eq(target, output):float():mean() | |||
mytester:assert(accuracy >= minacc, name .. ": insufficient accuracy: " .. accuracy .. " < " .. minacc) | |||
end | |||
function dttest.SparseTensor() | |||
local keys = torch.LongTensor{1,5,6,10} | |||
local values = torch.randn(keys:size(1)) | |||
local st = torch.SparseTensor(keys, values) | |||
mytester:assert(st[1] == values[1]) | |||
mytester:assert(st[5] == values[2]) | |||
mytester:assert(st[6] == values[3]) | |||
mytester:assert(st[10] == values[4]) | |||
mytester:assert(st[2] == nil) | |||
st:buildIndex() | |||
mytester:assert(st[1] == values[1]) | |||
mytester:assert(st[5] == values[2]) | |||
mytester:assert(st[6] == values[3]) | |||
mytester:assert(st[10] == values[4]) | |||
mytester:assert(st[2] == nil) | |||
-- test empty sparse tensor | |||
local est = torch.SparseTensor() | |||
end | |||
function dttest.GiniState() | |||
local featureId = 2 | |||
local minLeafSize = 0 | |||
local input = torch.Tensor({{0,1,0},{0,2,0},{0,3,0}}) | |||
local target = torch.Tensor({1, 1, 1}) | |||
local dataset = dt.DataSet(input, target, 3) | |||
local splitInfo1 = {_id=1} | |||
local splitInfo2 = {_id=2, leftChildSize = 1, rightChildSize = 2, splitGain = 0} | |||
local splitInfo3 = {_id=3, leftChildSize = 2, rightChildSize = 1, splitGain = -1} | |||
local exampleIds = torch.LongTensor{1,2,3} | |||
local treeState = dt.GiniState(exampleIds) | |||
function treeState.computeSplitInfo(self, splitFeatureId, splitFeatureValue) | |||
if splitFeatureId == featureId and splitFeatureValue == 2 then | |||
return splitInfo2 | |||
elseif splitFeatureId == featureId and splitFeatureValue == 3 then | |||
return splitInfo3 | |||
else | |||
error("Unhandled computeSplitInfo call "..splitFeatureId.." "..splitFeatureValue) | |||
end | |||
end | |||
local splitInfo = treeState:findBestFeatureSplit(dataset, featureId, minLeafSize) | |||
mytester:assert(splitInfo._id == splitInfo3._id) | |||
end | |||
function dttest.CartTree() | |||
local splitFeatureId = 100 | |||
local splitFeatureValue = 1.0 | |||
local function getBinaryCartTreeRootNode() | |||
local leftNodeScore = 0.2 | |||
local rightNodeScore = 0.4 | |||
local rootNode = dt.CartNode() | |||
rootNode.nodeId = 0 | |||
rootNode.score = 0.5 | |||
rootNode.splitFeatureId = splitFeatureId | |||
rootNode.splitFeautreValue = splitFeatureValue | |||
local leftChild = dt.CartNode() | |||
leftChild.score = leftNodeScore | |||
leftChild.nodeId = 1 | |||
local rightChild = dt.CartNode() | |||
rightChild.score = rightNodeScore | |||
rightChild.nodeId = 2 | |||
rootNode.leftChild = leftChild | |||
rootNode.rightChild = rightChild | |||
return rootNode | |||
end | |||
local function testScoreCartTreeBranchLeftIfMissing() | |||
local rootNode = getBinaryCartTreeRootNode() | |||
local cartTree = dt.CartTree(rootNode) | |||
local continuousFeatures = torch.SparseTensor() | |||
local score, nodeId = cartTree:score(continuousFeatures) | |||
mytester:assert(math.abs(rootNode.leftChild.score - score) < epsilon) | |||
mytester:assert(rootNode.leftChild.nodeId == nodeId) | |||
end | |||
local function testBranchRightWithFeature() | |||
local rootNode = getBinaryCartTreeRootNode() | |||
local cartTree = dt.CartTree(rootNode) | |||
local continuousFeatures = torch.zeros(100) | |||
continuousFeatures[splitFeatureId] = splitFeatureValue | |||
local score, nodeId = cartTree:score(continuousFeatures) | |||
mytester:assert(math.abs(rootNode.rightChild.score - score) < epsilon) | |||
mytester:assert(rootNode.rightChild.nodeId == nodeId) | |||
end | |||
local function testMissingRightNode() | |||
local rootNode = getBinaryCartTreeRootNode() | |||
rootNode.rightChild = nil | |||
local cartTree = dt.CartTree(rootNode) | |||
local continuousFeatures = torch.Tensor() | |||
local score, nodeId = cartTree:score(continuousFeatures) | |||
mytester:assert(math.abs(rootNode.leftChild.score - score) < epsilon) | |||
mytester:assert(rootNode.leftChild.nodeId == nodeId) | |||
end | |||
local function testMissingLeftNode() | |||
local rootNode = getBinaryCartTreeRootNode() | |||
rootNode.leftChild = nil | |||
local cartTree = dt.CartTree(rootNode) | |||
local continuousFeatures = torch.Tensor() | |||
local score, nodeId = cartTree:score(continuousFeatures) | |||
mytester:assert(math.abs(rootNode.rightChild.score - score) < epsilon) | |||
mytester:assert(rootNode.rightChild.nodeId == nodeId) | |||
end | |||
local function testMissingAllChildren() | |||
local rootNode = getBinaryCartTreeRootNode() | |||
rootNode.leftChild = nil | |||
rootNode.rightChild = nil | |||
local cartTree = dt.CartTree(rootNode) | |||
local continuousFeatures = torch.Tensor() | |||
local score, nodeId = cartTree:score(continuousFeatures) | |||
mytester:assert(math.abs(rootNode.score - score) < epsilon) | |||
mytester:assert(rootNode.nodeId == nodeId) | |||
end | |||
local function testScoreCartTreeBranchRandomlyRight() | |||
local rootNode = getBinaryCartTreeRootNode(); | |||
-- Force Branch Right | |||
local cartTree = dt.CartTree(rootNode, function() return false end); | |||
local continuousFeatures = torch.SparseTensor() | |||
local score, nodeId = cartTree:score(continuousFeatures) | |||
mytester:assert(math.abs(rootNode.rightChild.score - score) < epsilon) | |||
mytester:assert(rootNode.rightChild.nodeId == nodeId) | |||
end | |||
local function testScoreCartTreeBranchRandomlyLeft() | |||
local rootNode = getBinaryCartTreeRootNode(); | |||
-- Force Branch Left | |||
local cartTree = dt.CartTree(rootNode, function() return true end); | |||
local continuousFeatures = torch.SparseTensor() | |||
local score, nodeId = cartTree:score(continuousFeatures) | |||
mytester:assert(math.abs(rootNode.leftChild.score - score) < epsilon) | |||
mytester:assert(rootNode.leftChild.nodeId == nodeId) | |||
end | |||
testScoreCartTreeBranchLeftIfMissing() | |||
testBranchRightWithFeature() | |||
testMissingRightNode() | |||
testMissingLeftNode() | |||
testMissingAllChildren() | |||
testScoreCartTreeBranchRandomlyRight() | |||
testScoreCartTreeBranchRandomlyLeft() | |||
end | |||
function dttest.TreeState_branch() | |||
local _ = require 'moses' | |||
local binFeatureId = 1 | |||
local featureId = 2 | |||
local input = { | |||
torch.SparseTensor(torch.LongTensor{binFeatureId},torch.Tensor{1}), | |||
torch.SparseTensor(torch.LongTensor{binFeatureId,featureId},torch.Tensor{1,1}), | |||
torch.SparseTensor(torch.LongTensor{binFeatureId,featureId},torch.Tensor{0,2}), | |||
torch.SparseTensor(torch.LongTensor{binFeatureId,featureId},torch.Tensor{0,3}) | |||
} | |||
local target = torch.LongTensor(4):fill(1) | |||
local dataset = dt.DataSet(input, target) | |||
local treeState = dt.TreeState(torch.LongTensor():range(1,4)) | |||
local splitInfo = {splitId = binFeatureId, splitValue = 1} | |||
local function testBranchBinaryFeature() | |||
splitInfo = {splitId = binFeatureId, splitValue = 1} | |||
local leftBranch, rightBranch = treeState:branch(splitInfo, dataset) | |||
mytester:assert(leftBranch ~= nil and rightBranch ~= nil) | |||
mytester:assert(2 == leftBranch:size()) | |||
mytester:assert(leftBranch:contains(3)) | |||
mytester:assert(leftBranch:contains(4)) | |||
mytester:assert(2 == rightBranch:size()) | |||
mytester:assert(rightBranch:contains(1)) | |||
mytester:assert(rightBranch:contains(2)) | |||
end | |||
local function testBranchContinuousFeature() | |||
local splitValue = 2 | |||
splitInfo = {splitId = featureId, splitValue = splitValue} | |||
local leftBranch, rightBranch = treeState:branch(splitInfo, dataset) | |||
mytester:assert(leftBranch ~= nil and rightBranch ~= nil) | |||
mytester:assert(1 == leftBranch:size()) | |||
mytester:assert(leftBranch:contains(2)) | |||
mytester:assert(2 == rightBranch:size()) | |||
mytester:assert(rightBranch:contains(3)) | |||
mytester:assert(rightBranch:contains(4)) | |||
end | |||
testBranchBinaryFeature() | |||
testBranchContinuousFeature() | |||
end | |||
function dttest.DecisionForest() | |||
-- Create test decision forest, each forest has only a single node, and returns score == score of root node. | |||
local function createCartTreeWithSingleNode(score) | |||
local cartNode = dt.CartNode() | |||
cartNode.score = score | |||
return dt.CartTree(cartNode) | |||
end | |||
local function getTestDecisionForest() | |||
local cartTrees = { | |||
createCartTreeWithSingleNode(1), | |||
createCartTreeWithSingleNode(2), | |||
createCartTreeWithSingleNode(3) | |||
} | |||
local weight = torch.Tensor{10,20,30} | |||
local bias = 0.5 | |||
return dt.DecisionForest(cartTrees, weight, bias) | |||
end | |||
local function testScoreDecisionForest() | |||
local df = getTestDecisionForest() | |||
local continuousFeatures = torch.SparseTensor() | |||
local expectedResult = 1.0 * 10.0 + 2.0 * 20.0 + 3.0 * 30.0 + 0.5; | |||
local result = df:score(continuousFeatures) | |||
mytester:assert(math.abs(expectedResult - result) < epsilon) | |||
end | |||
testScoreDecisionForest() | |||
end | |||
function dttest.CartTrainer() | |||
local minLeafSize, maxLeafNodes = 1, 1000 | |||
local nExample = 100 | |||
-- 1. dense dataset | |||
local trainSet, validSet, clusterExamples, inputs, targets = dt.getDenseDummyData(nExample) | |||
-- assert that the dataset is valid | |||
for clusterId, exampleIds in ipairs(clusterExamples) do | |||
local exampleIdx = torch.LongTensor(exampleIds) | |||
local input = inputs:index(1,exampleIdx) | |||
local target = targets:index(1,exampleIdx) | |||
assert(input:std(1):mean() < 0.05) | |||
end | |||
local cartTrainer = dt.CartTrainer(trainSet, minLeafSize, maxLeafNodes) | |||
local treeState = dt.GiniState(trainSet:getExampleIds()) | |||
local cartTree, nleaf = cartTrainer:train(treeState, trainSet.featureIds) | |||
mytester:assert(nleaf == nExample) -- for dense inputs, minLeafSize =1 and maxLeafNode = inf, this is true | |||
testAccuracy(cartTree, "dense train single-thread first", trainSet, 0.99) | |||
testAccuracy(cartTree, "dense valid single-thread first", validSet, 0.7) -- they don't generalize very well.. | |||
local cartTree, nleaf = cartTrainer:train(treeState, trainSet.featureIds) | |||
testAccuracy(cartTree, "dense single-thread second", trainSet) | |||
-- test feature parallelization | |||
local nThread = 2 | |||
cartTrainer:featureParallel(nThread) | |||
local treeState = dt.GiniState(trainSet:getExampleIds()) | |||
local cartTree, nleaf = cartTrainer:train(treeState, trainSet.featureIds) | |||
testAccuracy(cartTree, "dense feature-parallel", trainSet) | |||
-- 2. sparse-dense dataset | |||
local trainSet, validSet, clusterExamples, inputs, targets = dt.getSparseDummyData(nExample, nil, 10, nil, nil, 10) | |||
-- assert that the dataset is valid | |||
for clusterId, exampleIds in ipairs(clusterExamples) do | |||
local input = torch.Tensor(#exampleIds, 10):zero() | |||
for i, exampleId in ipairs(exampleIds) do | |||
input[i]:indexCopy(1, inputs[exampleId].keys, inputs[exampleId].values) | |||
end | |||
assert(input:std(1):mean() < 0.05) | |||
end | |||
local cartTrainer = dt.CartTrainer(trainSet, minLeafSize, maxLeafNodes) | |||
local treeState = dt.GiniState(trainSet:getExampleIds()) | |||
local cartTree, nleaf = cartTrainer:train(treeState, trainSet.featureIds) | |||
mytester:assert(nleaf == nExample) -- for dense inputs, minLeafSize =1 and maxLeafNode = inf, this is true | |||
testAccuracy(cartTree, "sparse-dense train single-thread first", trainSet, 0.99) | |||
local shuffle = torch.LongTensor():randperm(10) | |||
for i, input in ipairs(inputs) do | |||
input.keys = input.keys:index(1, shuffle) | |||
input.values = input.values:index(1, shuffle) | |||
input._map = nil | |||
end | |||
testAccuracy(cartTree, "sparse-dense shuffled keys train single-thread first", trainSet, 0.99) | |||
testAccuracy(cartTree, "sparse-dense valid single-thread first", validSet, 0.8) | |||
-- 3. sparse dataset | |||
local trainSet, validSet = dt.getSparseDummyData(nExample, 2, 10, nil, nil, 9) | |||
local cartTrainer = dt.CartTrainer(trainSet, minLeafSize, maxLeafNodes) | |||
local treeState = dt.GiniState(trainSet:getExampleIds()) | |||
local cartTree, nleaf = cartTrainer:train(treeState, trainSet.featureIds) | |||
cartTree.branchleft = function() return true end | |||
mytester:assert(nleaf < nExample) -- for dense inputs, minLeafSize =1 and maxLeafNode = inf, this is true | |||
testAccuracy(cartTree, "sparse train single-thread first", trainSet, 0.9) -- the TreeBrancher drops examples with missing features, making it difficult to overfit | |||
testAccuracy(cartTree, "sparse valid single-thread first", validSet, 0.8) | |||
end | |||
function dttest.GradientBoostTrainer() | |||
local nExample = 100 | |||
local trainSet, validSet = dt.getSparseDummyData(nExample, 2, 10, nil, nil, 9) | |||
local maxLeafNode, minLeafSize = nExample/2, nExample/10 | |||
local loss = nn.LogitBoostCriterion(false) | |||
local cartTrainer = dt.CartTrainer(trainSet, minLeafSize, maxLeafNode) | |||
local opt = { | |||
lossFunction=loss, | |||
treeTrainer=cartTrainer, | |||
shrinkage=0.1, | |||
downsampleRatio=6, | |||
featureBaggingSize=-1, | |||
nTree=14, | |||
evalFreq=8, | |||
earlyStop=0 -- no early-stopping | |||
} | |||
-- test single-thread | |||
local trainer = dt.GradientBoostTrainer(opt) | |||
local decisionForest = trainer:train(trainSet, trainSet.featureIds, validSet) | |||
mytester:assert(#decisionForest.trees == opt.nTree) | |||
testAccuracy(decisionForest, "sparse train single-thread first", trainSet, 0.98) | |||
testAccuracy(decisionForest, "sparse valid single-thread first", validSet, 0.95) | |||
-- test stateless | |||
local decisionForest = trainer:train(trainSet, trainSet.featureIds, validSet) | |||
mytester:assert(#decisionForest.trees == opt.nTree) | |||
testAccuracy(decisionForest, "sparse train single-thread second", trainSet, 0.98) | |||
testAccuracy(decisionForest, "sparse valid single-thread second", validSet, 0.95) | |||
-- test feature-parallel | |||
local nThread = 2 | |||
cartTrainer:featureParallel(nThread) | |||
local trainer = dt.GradientBoostTrainer(opt) | |||
local decisionForest = trainer:train(trainSet, trainSet.featureIds, validSet) | |||
mytester:assert(#decisionForest.trees == opt.nTree) | |||
testAccuracy(decisionForest, "sparse train feature-parallel first", trainSet, 0.98) | |||
testAccuracy(decisionForest, "sparse valid feature-parallel first", validSet, 0.95) | |||
end | |||
function dttest.RandomForestTrainer() | |||
local nExample = 100 | |||
local trainSet, validSet = dt.getSparseDummyData(nExample, 2, 10, nil, nil, 9) | |||
local opt = { | |||
activeRatio=0.5, | |||
featureBaggingSize=5, | |||
nTree=14, | |||
maxLeafNodes=nExample/2, | |||
minLeafSize=nExample/10, | |||
} | |||
local trainer = dt.RandomForestTrainer(opt) | |||
local decisionForest = trainer:train(trainSet, trainSet.featureIds) | |||
mytester:assert(#decisionForest.trees == opt.nTree) | |||
testAccuracy(decisionForest, "sparse train single-thread first", trainSet, 0.98) | |||
testAccuracy(decisionForest, "sparse valid single-thread first", validSet, 0.95) | |||
-- test stateless | |||
local decisionForest = trainer:train(trainSet, trainSet.featureIds) | |||
mytester:assert(#decisionForest.trees == opt.nTree) | |||
testAccuracy(decisionForest, "sparse train single-thread second", trainSet, 0.98) | |||
testAccuracy(decisionForest, "sparse valid single-thread second", validSet, 0.95) | |||
-- test tree-parallel | |||
local nThread = 2 | |||
trainer:treeParallel(nThread) | |||
local trainer = dt.RandomForestTrainer(opt) | |||
local decisionForest = trainer:train(trainSet, trainSet.featureIds) | |||
mytester:assert(#decisionForest.trees == opt.nTree) | |||
testAccuracy(decisionForest, "sparse train tree-parallel first", trainSet, 0.98) | |||
testAccuracy(decisionForest, "sparse valid tree-parallel first", validSet, 0.95) | |||
end | |||
function dttest.WorkPool() | |||
local nThread = 2 | |||
local wp = dt.WorkPool(nThread) | |||
-- 1. some easy tests | |||
local store = {key='nick',value=7} | |||
wp:update('storeKeyValue', store) | |||
wp:update('require', {libname='decisiontree', varname='dt'}) | |||
local bias = 2 | |||
local obj = nn.MSECriterion() | |||
wp:update('require', {libname='decisiontree', varname='dt'}) | |||
wp:writeup('execute', function(store) return bias + obj:updateOutput(torch.Tensor{1},torch.Tensor{1}) + store.nick end) | |||
local taskname, res = wp:read() | |||
mytester:assert(taskname == 'execute') | |||
mytester:assert(res == 9) | |||
-- 2. trying to reproduce a difficult error | |||
local trainSet, validSet = dt.getSparseDummyData() | |||
-- setup worker store (each worker will have its own copy) | |||
local store = { | |||
trainSet=trainSet, | |||
minLeafSize=2 | |||
} | |||
wp:update('storeKeysValues', store) | |||
-- arguments/upvalues | |||
local treeState = dt.GiniState(trainSet:getExampleIds()) | |||
local shardId = 1 | |||
local nShard = nThread | |||
local featureIds = trainSet.featureIds | |||
local task = function(store, args) | |||
assert(store.trainSet) | |||
assert(store.minLeafSize) | |||
local bestSplit = args.treeState:findBestSplit(store.trainSet, args.featureIds, store.minLeafSize, args.shardId, args.nShard) | |||
return bestSplit | |||
end | |||
local args = {treeState=treeState,featureIds=featureIds,shardId=shardId,nShard=nShard} | |||
wp:writeup("execute", {func=task,args=args}) | |||
local taskname, bestSplit = wp:read() | |||
mytester:assert(taskname == 'execute') | |||
mytester:assert(torch.type(bestSplit) == 'table') | |||
-- closure | |||
local task = function(store) | |||
assert(store.trainSet) | |||
assert(store.minLeafSize) | |||
local bestSplit = treeState:findBestSplit(store.trainSet, featureIds, store.minLeafSize, shardId, nShard) | |||
return bestSplit | |||
end | |||
wp:writeup("execute", task) | |||
local taskname, bestSplit = wp:read() | |||
mytester:assert(taskname == 'execute') | |||
mytester:assert(torch.type(bestSplit) == 'table') | |||
local task = function(store, args) | |||
assert(store.trainSet) | |||
assert(torch.isTypeOf(treeState, 'dt.TreeState'), torch.type(treeState)) | |||
local bestSplit = treeState:findBestSplit(store.trainSet, featureIds, store.minLeafSize, shardId, nShard) | |||
return bestSplit | |||
end | |||
local args = {featureIds=featureIds,shardId=shardId,nShard=nShard} | |||
wp:writeup("execute", {func=task,args=args}) | |||
local taskname, bestSplit = wp:read() | |||
mytester:assert(taskname == 'execute') | |||
mytester:assert(torch.type(bestSplit) == 'table') | |||
wp:terminate() | |||
end | |||
function dttest.Sparse2Dense() | |||
local batchsize = 4 | |||
local minFeatureId, maxFeatureId = 10, 100 | |||
local input = {{},{}} | |||
for i=1,batchsize do | |||
local inputsize = math.random(5,10) | |||
input[1][i] = torch.LongTensor(inputsize):random(minFeatureId,maxFeatureId) | |||
input[2][i] = torch.Tensor(inputsize):uniform(0,1) | |||
end | |||
local s2d = nn.Sparse2Dense(torch.LongTensor():range(minFeatureId,maxFeatureId)) | |||
-- test 2d forward | |||
local output = s2d:forward(input) | |||
local output2 = torch.Tensor(batchsize, maxFeatureId-minFeatureId+1):zero() | |||
local featureMap = {} | |||
local j = 0 | |||
for i=minFeatureId,maxFeatureId do | |||
j = j + 1 | |||
featureMap[i] = j | |||
end | |||
for i=1,batchsize do | |||
local keys, values = input[1][i], input[2][i] | |||
for j=1,keys:size(1) do | |||
output2[{i,featureMap[keys[j] ]}] = values[j] | |||
end | |||
end | |||
mytester:assertTensorEq(output, output2, 0.000001) | |||
-- test 1d forward | |||
local input = {input[1][batchsize], input[2][batchsize]} | |||
local output = s2d:forward(input) | |||
mytester:assertTensorEq(output, output2[batchsize], 0.000001) | |||
end | |||
function dttest.Sparse2DenseDouble() | |||
local batchsize = 4 | |||
local minFeatureId, maxFeatureId = 10, 100 | |||
local input = {{},{}} | |||
for i=1,batchsize do | |||
local inputsize = math.random(5,10) | |||
input[1][i] = torch.LongTensor(inputsize):random(minFeatureId,maxFeatureId) | |||
input[2][i] = torch.Tensor(inputsize):uniform(0,1):double() | |||
end | |||
local s2d = nn.Sparse2Dense(torch.LongTensor():range(minFeatureId,maxFeatureId)) | |||
s2d:double() | |||
-- test 2d forward | |||
local output = s2d:forward(input) | |||
local output2 = torch.Tensor(batchsize, maxFeatureId-minFeatureId+1):zero():double() | |||
local featureMap = {} | |||
local j = 0 | |||
for i=minFeatureId,maxFeatureId do | |||
j = j + 1 | |||
featureMap[i] = j | |||
end | |||
for i=1,batchsize do | |||
local keys, values = input[1][i], input[2][i] | |||
for j=1,keys:size(1) do | |||
output2[{i,featureMap[keys[j] ]}] = values[j] | |||
end | |||
end | |||
mytester:assertTensorEq(output, output2, 0.000001) | |||
-- test 1d forward | |||
local input = {input[1][batchsize], input[2][batchsize]} | |||
local output = s2d:forward(input) | |||
mytester:assertTensorEq(output, output2[batchsize], 0.000001) | |||
end | |||
function dttest.LogitBoostCriterion() | |||
local input = torch.randn(10) | |||
local target = torch.LongTensor(10):random(0,1):type(torch.type(input)) | |||
local lb = nn.LogitBoostCriterion(false) | |||
local loss = lb:updateOutput(input, target) | |||
local loss2 = 0 | |||
for i=1,10 do | |||
loss2 = loss2 + math.log(1 + math.exp(target[i] <= 0 and input[i] or -input[i])) | |||
end | |||
mytester:assert(math.abs(loss - loss2) < 0.00001) | |||
local gradInput = lb:updateGradInput(input, target) | |||
local gradInput2 = gradInput:clone():zero() | |||
for i=1,10 do | |||
local p = dt.logistic(input[i]) | |||
gradInput2[i] = (target[i] <= 0) and p or (p - 1) | |||
end | |||
mytester:assertTensorEq(gradInput, gradInput2, 0.000001) | |||
local hessInput = lb:updateHessInput(input, target) | |||
local hessInput2 = hessInput:clone():zero() | |||
for i=1,10 do | |||
local p = dt.logistic(input[i]) | |||
hessInput2[i] = p * (1.0 - p) | |||
end | |||
mytester:assertTensorEq(hessInput, hessInput2, 0.000001) | |||
end | |||
function dttest.DFD() | |||
local nExample = 100 | |||
local batchsize = 4 | |||
local inputsize = 10 | |||
-- train Random Forest | |||
local trainSet, validSet, clusterExamples, inputs, targets = dt.getDenseDummyData(nExample, nil, inputsize) | |||
local opt = { | |||
activeRatio=0.5, | |||
featureBaggingSize=5, | |||
nTree=4, | |||
maxLeafNodes=nExample/2, | |||
minLeafSize=nExample/10, | |||
} | |||
local trainer = dt.RandomForestTrainer(opt) | |||
local df = trainer:train(trainSet, trainSet.featureIds) | |||
mytester:assert(#df.trees == opt.nTree) | |||
local dfd = nn.DFD(df) | |||
dfd = nn.DFD(dfd:getReconstructionInfo()) | |||
local dfd2 = nn.DFD(dfd:getReconstructionInfo(), true) | |||
local input = validSet.input:sub(1,batchsize) | |||
local output = dfd:forward(input) | |||
local output2 = dfd2:forward(input) | |||
local _ = require 'moses' | |||
local function hasKey(keys,key) | |||
local found = false | |||
keys:apply(function(x) | |||
if x == key then | |||
found = true | |||
end | |||
end) | |||
return found | |||
end | |||
for i=1,batchsize do | |||
local nodes = {} | |||
local keys = output[1][i] | |||
local keys2 = output2[1][i] | |||
for j,tree in ipairs(df.trees) do | |||
local stack = {} | |||
tree:score(input[i], stack) | |||
mytester:assert(hasKey(keys2, stack[#stack]._nodeId)) | |||
for k,node in ipairs(stack) do | |||
if k > 1 then | |||
assert(node._nodeId) | |||
mytester:assert(hasKey(keys, node._nodeId), string.format("missing key=%d in %s", node._nodeId, tostring(keys))) | |||
table.insert(nodes, node._nodeId) | |||
end | |||
end | |||
end | |||
mytester:assert(#nodes == keys:size(1)) | |||
mytester:assert(#df.trees == keys2:size(1)) | |||
end | |||
end | |||
function dttest.DFDDouble() | |||
local nExample = 100 | |||
local batchsize = 4 | |||
local inputsize = 10 | |||
-- train Random Forest | |||
local trainSet, validSet, clusterExamples, inputs, targets = dt.getDenseDummyData(nExample, nil, inputsize) | |||
local opt = { | |||
activeRatio=0.5, | |||
featureBaggingSize=5, | |||
nTree=4, | |||
maxLeafNodes=nExample/2, | |||
minLeafSize=nExample/10, | |||
} | |||
local trainer = dt.RandomForestTrainer(opt) | |||
local df = trainer:train(trainSet, trainSet.featureIds) | |||
mytester:assert(#df.trees == opt.nTree) | |||
local dfd = nn.DFD(df) | |||
dfd:double() | |||
dfd = nn.DFD(dfd:getReconstructionInfo()) | |||
local dfd2 = nn.DFD(dfd:getReconstructionInfo(), true) | |||
local input = validSet.input:sub(1,batchsize):double() | |||
local output = dfd:forward(input) | |||
local output2 = dfd2:forward(input) | |||
local _ = require 'moses' | |||
local function hasKey(keys,key) | |||
local found = false | |||
keys:apply(function(x) | |||
if x == key then | |||
found = true | |||
end | |||
end) | |||
return found | |||
end | |||
for i=1,batchsize do | |||
local nodes = {} | |||
local keys = output[1][i] | |||
local keys2 = output2[1][i] | |||
for j,tree in ipairs(df.trees) do | |||
local stack = {} | |||
tree:score(input[i], stack) | |||
mytester:assert(hasKey(keys2, stack[#stack]._nodeId)) | |||
for k,node in ipairs(stack) do | |||
if k > 1 then | |||
assert(node._nodeId) | |||
mytester:assert(hasKey(keys, node._nodeId), string.format("missing key=%d in %s", node._nodeId, tostring(keys))) | |||
table.insert(nodes, node._nodeId) | |||
end | |||
end | |||
end | |||
mytester:assert(#nodes == keys:size(1)) | |||
mytester:assert(#df.trees == keys2:size(1)) | |||
end | |||
end | |||
function dttest.uniquecounts() -- DEPRECATED | |||
local target = torch.LongTensor(100):random(1,3) | |||
local input = torch.Tensor() | |||
local inputset = {input=input, target=target} | |||
local counts = dt.uniquecounts(nil, inputset, 3) | |||
mytester:assert(counts:sum() == 100) | |||
mytester:assert(counts:nElement() == 3) | |||
local res = torch.Tensor(3):zero() | |||
target:apply(function(t) res[t] = res[t] + 1 end) | |||
mytester:assertTensorEq(counts, res) | |||
end | |||
function dttest.entropy() -- DEPRECATED | |||
-- 2 clusters with a bit overlap between classes: | |||
local input = torch.Tensor(100,2) | |||
input:narrow(1,1,50):normal(-1,.01) | |||
input:narrow(1,51,50):normal(2,.01) | |||
local target = torch.LongTensor(100):fill(3) | |||
target:narrow(1,1,45):fill(1) | |||
target:narrow(1,56,45):fill(2) | |||
local inputset = {input=input, target=target} | |||
-- test entropy() | |||
local fullent = dt.entropy(inputset) | |||
local halfset = {input=input:narrow(1,1,50), target=target:narrow(1,1,50)} | |||
local halfent = dt.entropy(halfset) | |||
local perfectset = {input=input:narrow(1,56,45), target=target:narrow(1,56,45)} | |||
local perfectent = dt.entropy(perfectset) | |||
mytester:assert(fullent > halfent) | |||
mytester:assert(halfent > perfectent) | |||
mytester:assert(perfectent < 0.0000001 and perfectent >= 0) | |||
end | |||
function dt.test(tests) | |||
math.randomseed(os.time()) | |||
mytester = torch.Tester() | |||
mytester:add(dttest) | |||
mytester:run(tests) | |||
end |
@@ -1,45 +0,0 @@ | |||
#include "error.h" | |||
#define check_tensors(L, a, b) \ | |||
do { \ | |||
if ((a)->nDimension != (b)->nDimension) \ | |||
return LUA_HANDLE_ERROR_STR((L), "different tensor dimensions"); \ | |||
for (int __local__var = 0; __local__var < (a)->nDimension; __local__var++) \ | |||
if ((a)->size[__local__var] != (b)->size[__local__var]) \ | |||
return LUA_HANDLE_ERROR_STR((L), "different tensor sizes"); \ | |||
} while (0) | |||
#define check_tensor(L, t, type) \ | |||
do { \ | |||
if (!type##_isContiguous(t)) \ | |||
return LUA_HANDLE_ERROR_STR((L), "tensor should be contiguous"); \ | |||
} while (0) | |||
#define get_tensor_size(t, type) \ | |||
(TH##type##Tensor_nElement(t)) | |||
#define get_tensor(L, idx, type) \ | |||
(TH##type##Tensor *)luaT_checkudata(L, idx, "torch." #type "Tensor") | |||
static int push_table_contents(lua_State *L, int arg) | |||
{ | |||
int size = 0; | |||
while(1) { | |||
lua_checkstack(L, 1); | |||
lua_rawgeti(L, arg, size + 1); | |||
if (lua_isnil(L, -1)) { | |||
lua_pop(L, 1); | |||
break; | |||
} | |||
size++; | |||
} | |||
return size; | |||
} | |||
#define verify_push_table_contents(L, idx, count) do { \ | |||
int __tmp_count = push_table_contents(L, idx); \ | |||
if (__tmp_count != count) { \ | |||
lua_pop(L, __tmp_count); \ | |||
LUA_HANDLE_ERROR_STR(L, "Table sizes do not match"); \ | |||
} \ | |||
} while(0) |
@@ -1,125 +0,0 @@ | |||
local dt = require "decisiontree._env" | |||
function dt.getBufferTable(name) | |||
local dt = require 'decisiontree' | |||
assert(torch.type(name) == 'string') | |||
dt.buffer = dt.buffer or {} | |||
dt.buffer[name] = dt.buffer[name] or {} | |||
return dt.buffer[name] | |||
end | |||
function dt.getSparseDummyData(nExample, nCluster, nFeature, overlap, nValid, nActive) | |||
local dt = require 'decisiontree' | |||
if torch.type(nExample) == 'table' then | |||
local opt = nExample | |||
nExample = opt.nExample | |||
nCluster = opt.nCluster | |||
nFeature = opt.nFeature | |||
overlap = opt.overlap | |||
nValid = opt.nValid | |||
nActive = opt.nActive | |||
end | |||
nExample = nExample or 100 -- training set size | |||
nCluster = nCluster or 10 | |||
assert(nCluster >= 2) | |||
nFeature = math.max(2, nFeature or 10) | |||
overlap = overlap or 0 | |||
nValid = nValid or nExample/10 -- validation set size | |||
nActive = nActive or math.max(2, nFeature / 2) | |||
-- sample nCluster centers | |||
local clusterCenter = torch.rand(nCluster, nFeature) | |||
local clusterLabel = torch.LongTensor(nCluster) | |||
local clusterExamples = {} | |||
for i=1,nCluster do | |||
clusterCenter[i]:add(i) | |||
clusterLabel[i] = i % 2 | |||
clusterExamples[i] = {} | |||
end | |||
local sparseCenter = torch.Tensor() | |||
local shuffle = torch.LongTensor() | |||
-- build dataset in pseudo-dense format | |||
local inputs = {} | |||
local targets = torch.Tensor(nExample+nValid) | |||
for i=1,nExample+nValid do | |||
local clusterIdx = torch.random(1,nCluster) | |||
table.insert(clusterExamples[clusterIdx], i) | |||
shuffle:randperm(nFeature) | |||
local keys = torch.LongTensor(nActive):copy(shuffle:narrow(1,1,nActive)) | |||
sparseCenter:index(clusterCenter[clusterIdx], 1, keys) | |||
local stdiv = i <= nExample and 100 or 1000 | |||
local values = torch.randn(nActive):div(stdiv):add(sparseCenter) | |||
table.insert(inputs, torch.SparseTensor(keys, values)) | |||
local label = clusterLabel[clusterIdx] | |||
if math.random() < overlap then | |||
targets[i] = label == 1 and 0 or 1 | |||
else | |||
targets[i] = label | |||
end | |||
end | |||
local _ = require 'moses' | |||
local validSet = dt.DataSet(_.slice(inputs, nExample+1, nExample+nValid), targets:narrow(1,nExample+1,nValid)) | |||
local trainSet = dt.DataSet(_.slice(inputs, 1, nExample), targets:narrow(1,1,nExample)) | |||
return trainSet, validSet, clusterExamples, inputs, targets | |||
end | |||
function dt.getDenseDummyData(nExample, nCluster, nFeature, overlap, nValid) | |||
local dt = require 'decisiontree' | |||
if torch.type(nExample) == 'table' then | |||
local opt = nExample | |||
nExample = opt.nExample | |||
nCluster = opt.nCluster | |||
nFeature = opt.nFeature | |||
overlap = opt.overlap | |||
nValid = opt.nValid | |||
end | |||
nExample = nExample or 100 -- training set size | |||
nCluster = nCluster or 10 | |||
assert(nCluster >= 2) | |||
nFeature = math.max(2, nFeature or 10) | |||
overlap = overlap or 0 | |||
nValid = nValid or nExample/10 -- validation set size | |||
-- sample nCluster centers | |||
local clusterCenter = torch.rand(nCluster, nFeature) | |||
local clusterLabel = torch.LongTensor(nCluster) | |||
local clusterExamples = {} | |||
for i=1,nCluster do | |||
clusterCenter[i]:add(i) | |||
clusterLabel[i] = i % 2 | |||
clusterExamples[i] = {} | |||
end | |||
-- build dataset in pseudo-dense format | |||
local inputs = torch.Tensor(nExample+nValid, nFeature) | |||
local targets = torch.Tensor(nExample+nValid) | |||
for i=1,nExample+nValid do | |||
local clusterIdx = torch.random(1,nCluster) | |||
table.insert(clusterExamples[clusterIdx], i) | |||
local stdiv = i <= nExample and 100 or 1000 | |||
inputs[i]:normal():div(stdiv):add(clusterCenter[clusterIdx]) | |||
local label = clusterLabel[clusterIdx] | |||
if math.random() < overlap then | |||
targets[i] = label == 1 and 0 or 1 | |||
else | |||
targets[i] = label | |||
end | |||
end | |||
local _ = require 'moses' | |||
local validSet = dt.DataSet(inputs:narrow(1,nExample+1,nValid), targets:narrow(1,nExample+1,nValid)) | |||
local trainSet = dt.DataSet(inputs:narrow(1,1,nExample), targets:narrow(1,1,nExample)) | |||
return trainSet, validSet, clusterExamples, inputs, targets | |||
end |
@@ -1,2 +0,0 @@ | |||
build/ | |||
THNN_h.lua |
@@ -1,13 +0,0 @@ | |||
std = "luajit" | |||
globals = { | |||
"torch", | |||
"nn", | |||
"include", | |||
} | |||
unused_args = false | |||
files['test.lua'].redefined = false |
@@ -1,56 +0,0 @@ | |||
language: c | |||
compiler: | |||
- gcc | |||
- clang | |||
cache: | |||
directories: | |||
- $HOME/OpenBlasInstall | |||
sudo: false | |||
env: | |||
- TORCH_LUA_VERSION=LUAJIT21 | |||
- TORCH_LUA_VERSION=LUA51 | |||
- TORCH_LUA_VERSION=LUA52 | |||
addons: | |||
apt: | |||
packages: | |||
- cmake | |||
- gfortran | |||
- gcc-multilib | |||
- gfortran-multilib | |||
- liblapack-dev | |||
- build-essential | |||
- gcc | |||
- g++ | |||
- curl | |||
- cmake | |||
- libreadline-dev | |||
- git-core | |||
- libqt4-core | |||
- libqt4-gui | |||
- libqt4-dev | |||
- libjpeg-dev | |||
- libpng-dev | |||
- ncurses-dev | |||
- imagemagick | |||
- libzmq3-dev | |||
- gfortran | |||
- unzip | |||
- gnuplot | |||
- gnuplot-x11 | |||
before_script: | |||
- export ROOT_TRAVIS_DIR=$(pwd) | |||
- export INSTALL_PREFIX=~/torch/install | |||
- ls $HOME/OpenBlasInstall/lib || (cd /tmp/ && git clone https://github.com/xianyi/OpenBLAS.git -b master && cd OpenBLAS && (make NO_AFFINITY=1 -j$(getconf _NPROCESSORS_ONLN) 2>/dev/null >/dev/null) && make PREFIX=$HOME/OpenBlasInstall install) | |||
- git clone https://github.com/torch/distro.git ~/torch --recursive | |||
- cd ~/torch && git submodule update --init --recursive | |||
- mkdir build && cd build | |||
- export CMAKE_LIBRARY_PATH=$HOME/OpenBlasInstall/include:$HOME/OpenBlasInstall/lib:$CMAKE_LIBRARY_PATH | |||
- cmake .. -DCMAKE_INSTALL_PREFIX="${INSTALL_PREFIX}" -DCMAKE_BUILD_TYPE=Release -DWITH_${TORCH_LUA_VERSION}=ON | |||
- make && make install | |||
- cd $ROOT_TRAVIS_DIR | |||
- export LD_LIBRARY_PATH=${INSTALL_PREFIX}/lib:$LD_LIBRARY_PATH | |||
script: | |||
- ${INSTALL_PREFIX}/bin/luarocks make rocks/nn-scm-1.rockspec | |||
- export PATH=${INSTALL_PREFIX}/bin:$PATH | |||
- export TESTLUA=$(which luajit lua | head -n 1) | |||
- ${TESTLUA} -lnn -e "t=nn.test(); if t.errors[1] then os.exit(1) end" |
@@ -1,22 +0,0 @@ | |||
local Abs, parent = torch.class('nn.Abs', 'nn.Module') | |||
function Abs:__init() | |||
parent.__init(self) | |||
end | |||
function Abs:updateOutput(input) | |||
input.THNN.Abs_updateOutput( | |||
input:cdata(), | |||
self.output:cdata() | |||
) | |||
return self.output | |||
end | |||
function Abs:updateGradInput(input, gradOutput) | |||
input.THNN.Abs_updateGradInput( | |||
input:cdata(), | |||
gradOutput:cdata(), | |||
self.gradInput:cdata() | |||
) | |||
return self.gradInput | |||
end |
@@ -1,32 +0,0 @@ | |||
local AbsCriterion, parent = torch.class('nn.AbsCriterion', 'nn.Criterion') | |||
function AbsCriterion:__init(sizeAverage) | |||
parent.__init(self) | |||
if sizeAverage ~= nil then | |||
self.sizeAverage = sizeAverage | |||
else | |||
self.sizeAverage = true | |||
end | |||
end | |||
function AbsCriterion:updateOutput(input, target) | |||
self.output_tensor = self.output_tensor or input.new(1) | |||
input.THNN.AbsCriterion_updateOutput( | |||
input:cdata(), | |||
target:cdata(), | |||
self.output_tensor:cdata(), | |||
self.sizeAverage | |||
) | |||
self.output = self.output_tensor[1] | |||
return self.output | |||
end | |||
function AbsCriterion:updateGradInput(input, target) | |||
input.THNN.AbsCriterion_updateGradInput( | |||
input:cdata(), | |||
target:cdata(), | |||
self.gradInput:cdata(), | |||
self.sizeAverage | |||
) | |||
return self.gradInput | |||
end |
@@ -1,66 +0,0 @@ | |||
local Add, parent = torch.class('nn.Add', 'nn.Module') | |||
function Add:__init(inputSize,scalar) | |||
parent.__init(self) | |||
local size = inputSize | |||
if scalar then size=1 end | |||
self.scalar = scalar | |||
self.bias = torch.Tensor(size) | |||
self.gradBias = torch.Tensor(size) | |||
self._ones = torch.Tensor{1} | |||
self:reset() | |||
end | |||
function Add:reset(stdv) | |||
if stdv then | |||
stdv = stdv * math.sqrt(3) | |||
else | |||
stdv = 1./math.sqrt(self.bias:size(1)) | |||
end | |||
self.bias:uniform(-stdv, stdv) | |||
end | |||
function Add:updateOutput(input) | |||
self.output:resizeAs(input):copy(input) | |||
if self.scalar then | |||
self.output:add(self.bias[1]); | |||
else | |||
if input:isSameSizeAs(self.bias) then | |||
self.output:add(self.bias) | |||
else | |||
local batchSize = input:size(1) | |||
if self._ones:size(1) ~= batchSize then | |||
self._ones:resize(batchSize):fill(1) | |||
end | |||
local bias = self.bias:view(-1) | |||
local output = self.output:view(batchSize, -1) | |||
output:addr(1, self._ones, bias) | |||
end | |||
end | |||
return self.output | |||
end | |||
function Add:updateGradInput(input, gradOutput) | |||
if self.gradInput then | |||
self.gradInput:resizeAs(gradOutput):copy(gradOutput) | |||
return self.gradInput | |||
end | |||
end | |||
function Add:accGradParameters(input, gradOutput, scale) | |||
scale = scale or 1 | |||
if self.gradBias:size(1) == 1 then | |||
self.gradBias[1] = self.gradBias[1] + scale*gradOutput:sum(); | |||
else | |||
if input:isSameSizeAs(self.bias) then | |||
self.gradBias:add(scale, gradOutput) | |||
else | |||
local gradOutput = gradOutput:view(input:size(1), -1) | |||
self.gradBias:view(-1):addmv(scale, gradOutput:t(), self._ones) | |||
end | |||
end | |||
end |
@@ -1,50 +0,0 @@ | |||
local AddConstant, parent = torch.class('nn.AddConstant', 'nn.Module') | |||
function AddConstant:__init(constant_scalar,ip) | |||
parent.__init(self) | |||
self.constant_scalar = constant_scalar | |||
-- default for inplace is false | |||
self.inplace = ip or false | |||
if (ip and type(ip) ~= 'boolean') then | |||
error('in-place flag must be boolean') | |||
end | |||
end | |||
function AddConstant:updateOutput(input) | |||
assert(type(self.constant_scalar) == 'number' or | |||
(torch.isTensor(self.constant_scalar) and input:nDimension() <= 2 and | |||
input:size(input:nDimension()) == self.constant_scalar:size(1)), | |||
'input is not scalar or doesn\'t match with the dimension of constant!') | |||
local tmp | |||
if torch.isTensor(self.constant_scalar) and input:nDimension() == 2 then | |||
local nOutput = self.constant_scalar:size(1) | |||
tmp = self.constant_scalar.new() | |||
tmp:resize(1,nOutput) | |||
tmp:copy(self.constant_scalar) | |||
tmp = tmp:expand(input:size(1),nOutput) | |||
else | |||
tmp = self.constant_scalar | |||
end | |||
if self.inplace then | |||
input:add(tmp) | |||
self.output:set(input) | |||
else | |||
self.output:resizeAs(input) | |||
self.output:copy(input) | |||
self.output:add(tmp) | |||
end | |||
return self.output | |||
end | |||
function AddConstant:updateGradInput(input, gradOutput) | |||
if self.inplace then | |||
self.gradInput:set(gradOutput) | |||
-- restore previous input value | |||
input:add(-self.constant_scalar) | |||
else | |||
self.gradInput:resizeAs(gradOutput) | |||
self.gradInput:copy(gradOutput) | |||
end | |||
return self.gradInput | |||
end |
@@ -1,64 +0,0 @@ | |||
local THNN = require 'nn.THNN' | |||
local BCECriterion, parent = torch.class('nn.BCECriterion', 'nn.Criterion') | |||
function BCECriterion:__init(weights, sizeAverage) | |||
parent.__init(self) | |||
if sizeAverage ~= nil then | |||
self.sizeAverage = sizeAverage | |||
else | |||
self.sizeAverage = true | |||
end | |||
if weights ~= nil then | |||
assert(weights:dim() == 1, "weights input should be 1-D Tensor") | |||
self.weights = weights | |||
end | |||
end | |||
function BCECriterion:__len() | |||
return self.weights and #self.weights or 0 | |||
end | |||
function BCECriterion:updateOutput(input, target) | |||
-- - log(input) * target - log(1 - input) * (1 - target) | |||
assert( input:nElement() == target:nElement(), | |||
"input and target size mismatch") | |||
self.output_tensor = self.output_tensor or input.new(1) | |||
local weights = self.weights | |||
if weights ~= nil and target:dim() ~= 1 then | |||
weights = self.weights:view(1, target:size(2)):expandAs(target) | |||
end | |||
input.THNN.BCECriterion_updateOutput( | |||
input:cdata(), | |||
target:cdata(), | |||
self.output_tensor:cdata(), | |||
self.sizeAverage, | |||
THNN.optionalTensor(weights) | |||
) | |||
self.output = self.output_tensor[1] | |||
return self.output | |||
end | |||
function BCECriterion:updateGradInput(input, target) | |||
-- - (target - input) / ( input (1 - input) ) | |||
assert( input:nElement() == target:nElement(), | |||
"input and target size mismatch") | |||
local weights = self.weights | |||
if weights ~= nil and target:dim() ~= 1 then | |||
weights = self.weights:view(1, target:size(2)):expandAs(target) | |||
end | |||
input.THNN.BCECriterion_updateGradInput( | |||
input:cdata(), | |||
target:cdata(), | |||
self.gradInput:cdata(), | |||
self.sizeAverage, | |||
THNN.optionalTensor(weights) | |||
) | |||
return self.gradInput | |||
end |
@@ -1,213 +0,0 @@ | |||
--[[ | |||
This file implements Batch Normalization as described in the paper: | |||
"Batch Normalization: Accelerating Deep Network Training | |||
by Reducing Internal Covariate Shift" | |||
by Sergey Ioffe, Christian Szegedy | |||
This implementation is useful for inputs NOT coming from convolution layers. | |||
For convolution layers, use nn.SpatialBatchNormalization. | |||
The operation implemented is: | |||
y = ( x - mean(x) ) | |||
-------------------- * gamma + beta | |||
standard-deviation(x) | |||
where gamma and beta are learnable parameters. | |||
The learning of gamma and beta is optional. | |||
Usage: | |||
with learnable parameters: nn.BatchNormalization(N [,eps] [,momentum]) | |||
where N = dimensionality of input | |||
without learnable parameters: nn.BatchNormalization(N [,eps] [,momentum], false) | |||
eps is a small value added to the standard-deviation to avoid divide-by-zero. | |||
Defaults to 1e-5 | |||
In training time, this layer keeps a running estimate of it's computed mean and std. | |||
The running sum is kept with a default momentum of 0.1 (unless over-ridden) | |||
In test time, this running mean/std is used to normalize. | |||
]]-- | |||
local BN,parent = torch.class('nn.BatchNormalization', 'nn.Module') | |||
local THNN = require 'nn.THNN' | |||
BN.__version = 2 | |||
BN.nDim = 2 | |||
function BN:__init(nOutput, eps, momentum, affine) | |||
parent.__init(self) | |||
assert(nOutput and type(nOutput) == 'number', | |||
'Missing argument #1: dimensionality of input. ') | |||
assert(nOutput ~= 0, 'To set affine=false call BatchNormalization' | |||
.. '(nOutput, eps, momentum, false) ') | |||
if affine ~= nil then | |||
assert(type(affine) == 'boolean', 'affine has to be true/false') | |||
self.affine = affine | |||
else | |||
self.affine = true | |||
end | |||
self.eps = eps or 1e-5 | |||
self.train = true | |||
self.momentum = momentum or 0.1 | |||
self.running_mean = torch.zeros(nOutput) | |||
self.running_var = torch.ones(nOutput) | |||
if self.affine then | |||
self.weight = torch.Tensor(nOutput) | |||
self.bias = torch.Tensor(nOutput) | |||
self.gradWeight = torch.Tensor(nOutput) | |||
self.gradBias = torch.Tensor(nOutput) | |||
self:reset() | |||
end | |||
end | |||
function BN:reset() | |||
if self.weight then | |||
self.weight:uniform() | |||
end | |||
if self.bias then | |||
self.bias:zero() | |||
end | |||
self.running_mean:zero() | |||
self.running_var:fill(1) | |||
end | |||
function BN:checkInputDim(input) | |||
local iDim = input:dim() | |||
assert(iDim == self.nDim or | |||
(iDim == self.nDim - 1 and self.train == false), string.format( | |||
'only mini-batch supported (%dD tensor), got %dD tensor instead', | |||
self.nDim, iDim)) | |||
local featDim = (iDim == self.nDim - 1) and 1 or 2 | |||
assert(input:size(featDim) == self.running_mean:nElement(), string.format( | |||
'got %d-feature tensor, expected %d', | |||
input:size(featDim), self.running_mean:nElement())) | |||
end | |||
local function makeContiguous(self, input, gradOutput) | |||
if not input:isContiguous() then | |||
self._input = self._input or input.new() | |||
self._input:resizeAs(input):copy(input) | |||
input = self._input | |||
end | |||
if gradOutput then | |||
if not gradOutput:isContiguous() then | |||
self._gradOutput = self._gradOutput or gradOutput.new() | |||
self._gradOutput:resizeAs(gradOutput):copy(gradOutput) | |||
gradOutput = self._gradOutput | |||
end | |||
end | |||
return input, gradOutput | |||
end | |||
local function makeBatch(self, input) | |||
local iDim = input:dim() | |||
if self.train == false and iDim == self.nDim - 1 then | |||
return nn.utils.addSingletonDimension(input, input, 1) | |||
else | |||
return input | |||
end | |||
end | |||
function BN:updateOutput(input) | |||
self:checkInputDim(input) | |||
input = makeContiguous(self, input) | |||
input = makeBatch(self, input) | |||
self.save_mean = self.save_mean or input.new() | |||
self.save_mean:resizeAs(self.running_mean) | |||
self.save_std = self.save_std or input.new() | |||
self.save_std:resizeAs(self.running_var) | |||
input.THNN.BatchNormalization_updateOutput( | |||
input:cdata(), | |||
self.output:cdata(), | |||
THNN.optionalTensor(self.weight), | |||
THNN.optionalTensor(self.bias), | |||
self.running_mean:cdata(), | |||
self.running_var:cdata(), | |||
self.save_mean:cdata(), | |||
self.save_std:cdata(), | |||
self.train, | |||
self.momentum, | |||
self.eps) | |||
return self.output | |||
end | |||
local function backward(self, input, gradOutput, scale, gradInput, gradWeight, gradBias) | |||
self:checkInputDim(input) | |||
self:checkInputDim(gradOutput) | |||
assert(self.save_mean and self.save_std, 'must call :updateOutput() first') | |||
input, gradOutput = makeContiguous(self, input, gradOutput) | |||
input = makeBatch(self, input) | |||
gradOutput = makeBatch(self, gradOutput) | |||
scale = scale or 1 | |||
if gradInput then | |||
gradInput:resizeAs(gradOutput) | |||
end | |||
input.THNN.BatchNormalization_backward( | |||
input:cdata(), | |||
gradOutput:cdata(), | |||
THNN.optionalTensor(gradInput), | |||
THNN.optionalTensor(gradWeight), | |||
THNN.optionalTensor(gradBias), | |||
THNN.optionalTensor(self.weight), | |||
self.running_mean:cdata(), | |||
self.running_var:cdata(), | |||
self.save_mean:cdata(), | |||
self.save_std:cdata(), | |||
self.train, | |||
scale, | |||
self.eps) | |||
return self.gradInput | |||
end | |||
function BN:backward(input, gradOutput, scale) | |||
return backward(self, input, gradOutput, scale, self.gradInput, self.gradWeight, self.gradBias) | |||
end | |||
function BN:updateGradInput(input, gradOutput) | |||
return backward(self, input, gradOutput, 1, self.gradInput) | |||
end | |||
function BN:accGradParameters(input, gradOutput, scale) | |||
return backward(self, input, gradOutput, scale, nil, self.gradWeight, self.gradBias) | |||
end | |||
function BN:read(file, version) | |||
parent.read(self, file) | |||
if version < 2 then | |||
if self.running_std then | |||
self.running_var = self.running_std:pow(-2):add(-self.eps) | |||
self.running_std = nil | |||
end | |||
end | |||
end | |||
function BN:clearState() | |||
-- first 5 buffers are not present in the current implementation, | |||
-- but we keep them for cleaning old saved models | |||
nn.utils.clear(self, { | |||
'buffer', | |||
'buffer2', | |||
'centered', | |||
'std', | |||
'normalized', | |||
'_input', | |||
'_gradOutput', | |||
'save_mean', | |||
'save_std', | |||
}) | |||
return parent.clearState(self) | |||
end | |||
function BN:__tostring__() | |||
return string.format('%s (%dD) (%d)', torch.type(self), self.nDim, self.running_mean:nElement()) | |||
end |
@@ -1,163 +0,0 @@ | |||
local Bilinear, parent = torch.class('nn.Bilinear', 'nn.Module') | |||
local function isint(x) return type(x) == 'number' and x == math.floor(x) end | |||
function Bilinear:__assertInput(input) | |||
assert(input and type(input) == 'table' and #input == 2, | |||
'input should be a table containing two data Tensors') | |||
assert(input[1]:nDimension() == 2 and input[2]:nDimension() == 2, | |||
'input Tensors should be two-dimensional') | |||
assert(input[1]:size(1) == input[2]:size(1), | |||
'input Tensors should have the same number of rows (instances)') | |||
assert(input[1]:size(2) == self.weight:size(2), | |||
'dimensionality of first input is erroneous') | |||
assert(input[2]:size(2) == self.weight:size(3), | |||
'dimensionality of second input is erroneous') | |||
end | |||
function Bilinear:__assertInputGradOutput(input, gradOutput) | |||
assert(input[1]:size(1) == gradOutput:size(1), | |||
'number of rows in gradOutput does not match input') | |||
assert(gradOutput:size(2) == self.weight:size(1), | |||
'number of columns in gradOutput does not output size of layer') | |||
end | |||
function Bilinear:__init(inputSize1, inputSize2, outputSize, bias) | |||
-- assertions: | |||
assert(self and inputSize1 and inputSize2 and outputSize, | |||
'should specify inputSize1 and inputSize2 and outputSize') | |||
assert(isint(inputSize1) and isint(inputSize2) and isint(outputSize), | |||
'inputSize1 and inputSize2 and outputSize should be integer numbers') | |||
assert(inputSize1 > 0 and inputSize2 > 0 and outputSize > 0, | |||
'inputSize1 and inputSize2 and outputSize should be positive numbers') | |||
-- set up model: | |||
parent.__init(self) | |||
local bias = ((bias == nil) and true) or bias | |||
self.weight = torch.Tensor(outputSize, inputSize1, inputSize2) | |||
self.gradWeight = torch.Tensor(outputSize, inputSize1, inputSize2) | |||
if bias then | |||
self.bias = torch.Tensor(outputSize) | |||
self.gradBias = torch.Tensor(outputSize) | |||
end | |||
self.gradInput = {torch.Tensor(), torch.Tensor()} | |||
self:reset() | |||
end | |||
function Bilinear:reset(stdv) | |||
assert(self) | |||
if stdv then | |||
assert(stdv and type(stdv) == 'number' and stdv > 0, | |||
'standard deviation should be a positive number') | |||
stdv = stdv * math.sqrt(3) | |||
else | |||
stdv = 1 / math.sqrt(self.weight:size(2)) | |||
end | |||
self.weight:uniform(-stdv, stdv) | |||
if self.bias then self.bias:uniform(-stdv, stdv) end | |||
return self | |||
end | |||
function Bilinear:updateOutput(input) | |||
assert(self) | |||
self:__assertInput(input) | |||
-- set up buffer: | |||
self.buff2 = self.buff2 or input[1].new() | |||
self.buff2:resizeAs(input[2]) | |||
-- compute output scores: | |||
self.output:resize(input[1]:size(1), self.weight:size(1)) | |||
for k = 1,self.weight:size(1) do | |||
torch.mm(self.buff2, input[1], self.weight[k]) | |||
self.buff2:cmul(input[2]) | |||
torch.sum(self.output:narrow(2, k, 1), self.buff2, 2) | |||
end | |||
if self.bias then | |||
self.output:add( | |||
self.bias:reshape(1, self.bias:nElement()):expandAs(self.output) | |||
) | |||
end | |||
return self.output | |||
end | |||
function Bilinear:updateGradInput(input, gradOutput) | |||
assert(self) | |||
if self.gradInput then | |||
self:__assertInputGradOutput(input, gradOutput) | |||
if #self.gradInput == 0 then | |||
for i = 1, 2 do self.gradInput[i] = input[1].new() end | |||
end | |||
-- compute d output / d input: | |||
self.gradInput[1]:resizeAs(input[1]):fill(0) | |||
self.gradInput[2]:resizeAs(input[2]):fill(0) | |||
-- do first slice of weight tensor (k = 1) | |||
self.gradInput[1]:mm(input[2], self.weight[1]:t()) | |||
self.gradInput[1]:cmul(gradOutput:narrow(2,1,1):expand(self.gradInput[1]:size(1), | |||
self.gradInput[1]:size(2))) | |||
self.gradInput[2]:addmm(1, input[1], self.weight[1]) | |||
self.gradInput[2]:cmul(gradOutput:narrow(2,1,1):expand(self.gradInput[2]:size(1), | |||
self.gradInput[2]:size(2))) | |||
-- do remaining slices of weight tensor | |||
if self.weight:size(1) > 1 then | |||
self.buff1 = self.buff1 or input[1].new() | |||
self.buff1:resizeAs(input[1]) | |||
for k = 2, self.weight:size(1) do | |||
self.buff1:mm(input[2], self.weight[k]:t()) | |||
self.buff1:cmul(gradOutput:narrow(2,k,1):expand(self.gradInput[1]:size(1), | |||
self.gradInput[1]:size(2))) | |||
self.gradInput[1]:add(self.buff1) | |||
self.buff2:mm(input[1], self.weight[k]) | |||
self.buff2:cmul(gradOutput:narrow(2,k,1):expand(self.gradInput[2]:size(1), | |||
self.gradInput[2]:size(2))) | |||
self.gradInput[2]:add(self.buff2) | |||
end | |||
end | |||
return self.gradInput | |||
end | |||
end | |||
function Bilinear:accGradParameters(input, gradOutput, scale) | |||
local scale = scale or 1 | |||
self:__assertInputGradOutput(input, gradOutput) | |||
assert(scale and type(scale) == 'number' and scale >= 0) | |||
-- make sure we have buffer: | |||
self.buff1 = self.buff1 or input[1].new() | |||
self.buff1:resizeAs(input[1]) | |||
-- accumulate parameter gradients: | |||
for k = 1,self.weight:size(1) do | |||
torch.cmul( | |||
self.buff1, input[1], gradOutput:narrow(2, k, 1):expandAs(input[1]) | |||
) | |||
self.gradWeight[k]:addmm(self.buff1:t(), input[2]) | |||
end | |||
if self.bias then self.gradBias:add(scale, gradOutput:sum(1)) end | |||
end | |||
function Bilinear:sharedAccUpdateGradParameters(input, gradOutput, lr) | |||
-- we do not need to accumulate parameters when sharing: | |||
self:defaultAccUpdateGradParameters(input, gradOutput, lr) | |||
end | |||
function Bilinear:__tostring__() | |||
return torch.type(self) .. | |||
string.format( | |||
'(%dx%d -> %d) %s', | |||
self.weight:size(2), self.weight:size(3), self.weight:size(1), | |||
(self.bias == nil and ' without bias' or '') | |||
) | |||
end | |||
function Bilinear:clearState() | |||
if self.buff2 then self.buff2:set() end | |||
if self.buff1 then self.buff1:set() end | |||
return parent.clearState(self) | |||
end |
@@ -1,71 +0,0 @@ | |||
local Bottle, parent = torch.class("nn.Bottle", "nn.Decorator") | |||
local unpack = unpack or table.unpack | |||
function Bottle:__init(module, nInputDim, nOutputDim) | |||
parent.__init(self, module) | |||
self.nInputDim = nInputDim or 2 | |||
self.nOutputDim = nOutputDim or self.nInputDim | |||
self.dimDelta = self.nInputDim - self.nOutputDim | |||
-- Used to reshape the gradients | |||
self.inShape = torch.Tensor(self.nInputDim) | |||
self.outShape = torch.Tensor(self.nOutputDim) | |||
end | |||
function Bottle:updateOutput(input) | |||
-- first batchDims dimensions will be fused | |||
local batchDims = input:dim() - self.nInputDim + 1 | |||
-- see if bottle is required | |||
if batchDims > 1 then | |||
-- bottle the first dims | |||
local inSize = torch.LongTensor(input:size()) | |||
local squeezeSize = inSize[{{1, batchDims - 1}}]:prod() | |||
self.inShape:copy(inSize[{{batchDims, input:dim()}}]) | |||
self.inShape[{{1}}]:mul(squeezeSize) | |||
-- Forward with the module's dimension | |||
local newInput = input:view(unpack(self.inShape:totable())) | |||
local output = self.modules[1]:updateOutput(newInput) | |||
assert(output:dim() == self.nOutputDim, | |||
"Wrong number of output dims on module. Expected: " .. | |||
self.nOutputDim .. ' but got ' .. | |||
tostring(output and output:dim())) | |||
self.outShape:copy(torch.LongTensor(output:size())) | |||
if math.abs(self.dimDelta) > 0 then | |||
inSize:resize(inSize:size(1) - self.dimDelta) | |||
end | |||
inSize[{{batchDims, inSize:size(1)}}]:copy(self.outShape) | |||
inSize[{{batchDims}}]:div(squeezeSize) | |||
-- unbottle | |||
self.output:set(output:view(unpack(torch.totable(inSize)))) | |||
else | |||
self.output:set(self.modules[1]:updateOutput(input)) | |||
end | |||
return self.output | |||
end | |||
function Bottle:updateGradInput(input, gradOutput) | |||
if input:dim() > self.nInputDim then | |||
local input_ = input:view(unpack(self.inShape:totable())) | |||
local gradOutput_ = gradOutput:view(unpack(self.outShape:totable())) | |||
self.modules[1]:updateGradInput(input_, gradOutput_) | |||
if self.modules[1].gradInput then | |||
self.gradInput:set(self.modules[1].gradInput:viewAs(input)) | |||
else | |||
self.gradInput = nil | |||
end | |||
else | |||
if self.modules[1].gradInput then | |||
self.gradInput:set(self.modules[1]:updateGradInput(input, gradOutput)) | |||
else | |||
self.gradInput = nil | |||
end | |||
end | |||
return self.gradInput | |||
end | |||
function Bottle:accGradParameters(input, gradOutput, scale) | |||
if input:dim() > self.nInputDim then | |||
input = input:view(unpack(self.inShape:totable())) | |||
gradOutput = gradOutput:view(unpack(self.outShape:totable())) | |||
end | |||
self.modules[1]:accGradParameters(input, gradOutput, scale) | |||
end |
@@ -1,127 +0,0 @@ | |||
local CAdd, parent = torch.class("nn.CAdd", "nn.Module") | |||
function CAdd:__init(...) | |||
parent.__init(self) | |||
local arg = {...} | |||
self.size = torch.LongStorage() | |||
local n = #arg | |||
if n == 1 and torch.type(arg[1]) == 'torch.LongStorage' then | |||
self.size:resize(#arg[1]):copy(arg[1]) | |||
else | |||
self.size:resize(n) | |||
for i=1,n do | |||
self.size[i] = arg[i] | |||
end | |||
end | |||
self.bias = torch.Tensor(self.size) | |||
self.gradBias = torch.Tensor(self.size) | |||
self.output:resize(self.size) | |||
self:reset() | |||
end | |||
function CAdd:reset(stdv) | |||
if stdv then | |||
--std of uniform distribution on interval [-a,a] = a/sqrt(3) | |||
stdv = stdv * math.sqrt(3) | |||
else | |||
stdv = 1.0/math.sqrt(self.bias:nElement()) | |||
end | |||
self.bias:uniform(-stdv,stdv) | |||
end | |||
function CAdd:updateOutput(input) | |||
self._output = self._output or input.new() | |||
self._bias = self._bias or input.new() | |||
self._expand = self._expand or input.new() | |||
self._repeat = self._repeat or input.new() | |||
self.output:resizeAs(input):copy(input) | |||
if input:nElement() == self.bias:nElement() then | |||
self.output:add(self.bias) | |||
else | |||
if self.bias:dim() == input:dim() then | |||
self._output:set(self.output) | |||
self._bias:set(self.bias) | |||
else | |||
local batchSize = input:size(1) | |||
self._output:view(self.output, batchSize, -1) | |||
self._bias:view(self.bias, 1, -1) | |||
end | |||
self._expand:expandAs(self._bias, self._output) | |||
--expandAs uses stride 0 and self._expand is not contiguous | |||
--cuda ops may assume contiguous input | |||
if torch.type(input) == 'torch.CudaTensor' then | |||
self._repeat:resizeAs(self._expand):copy(self._expand) | |||
self._output:add(self._repeat) | |||
else | |||
self._output:add(self._expand) | |||
end | |||
end | |||
return self.output | |||
end | |||
function CAdd:updateGradInput(input, gradOutput) | |||
self.gradInput = self.gradInput or input.new() | |||
self.gradInput:resizeAs(gradOutput):copy(gradOutput) | |||
return self.gradInput | |||
end | |||
function CAdd:accGradParameters(input, gradOutput, scale) | |||
scale = scale or 1 | |||
self._gradBias = self._gradBias or gradOutput.new() | |||
self._gradOutput = self._gradOutput or gradOutput.new() | |||
self._repeat = self._repeat or gradOutput.new() | |||
if self.bias:nElement() == gradOutput:nElement() then | |||
self.gradBias:add(scale, gradOutput) | |||
else | |||
if self.bias:dim() == gradOutput:dim() then | |||
self._gradBias:set(self.gradBias) | |||
self._gradOutput:set(gradOutput) | |||
else | |||
local batchSize = input:size(1) | |||
self._gradBias:view(self.gradBias, 1, -1) | |||
self._gradOutput:view(gradOutput, batchSize, -1) | |||
end | |||
self._gradBias:expandAs(self._gradBias, self._gradOutput) | |||
--expandAs uses stride 0 and self._gradBias is not contiguous | |||
--cuda ops may assume contiguous input | |||
if torch.type(self._gradBias) == 'torch.CudaTensor' then | |||
self._repeat:resizeAs(self._gradBias):copy(self._gradBias) | |||
self._repeat:add(scale, self._gradOutput) | |||
self._gradBias:copy(self._repeat) | |||
else | |||
self._gradBias:add(scale, self._gradOutput) | |||
end | |||
end | |||
end | |||
function CAdd:type(type, tensorCache) | |||
if type then | |||
self:clearState() | |||
end | |||
return parent.type(self, type, tensorCache) | |||
end | |||
function CAdd:clearState() | |||
nn.utils.clear(self, { | |||
'_gradBias', | |||
'_expand', | |||
'_output', | |||
'_bias', | |||
'_repeat' | |||
}) | |||
return parent.clearState(self) | |||
end |
@@ -1,36 +0,0 @@ | |||
local CAddTable, parent = torch.class('nn.CAddTable', 'nn.Module') | |||
function CAddTable:__init(ip) | |||
parent.__init(self) | |||
self.inplace = ip | |||
self.gradInput = {} | |||
end | |||
function CAddTable:updateOutput(input) | |||
if self.inplace then | |||
self.output:set(input[1]) | |||
else | |||
self.output:resizeAs(input[1]):copy(input[1]) | |||
end | |||
for i=2,#input do | |||
self.output:add(input[i]) | |||
end | |||
return self.output | |||
end | |||
function CAddTable:updateGradInput(input, gradOutput) | |||
for i=1,#input do | |||
self.gradInput[i] = self.gradInput[i] or input[1].new() | |||
if self.inplace then | |||
self.gradInput[i]:set(gradOutput) | |||
else | |||
self.gradInput[i]:resizeAs(input[i]):copy(gradOutput) | |||
end | |||
end | |||
for i=#input+1, #self.gradInput do | |||
self.gradInput[i] = nil | |||
end | |||
return self.gradInput | |||
end |
@@ -1,43 +0,0 @@ | |||
local CAddTensorTable, parent = torch.class('nn.CAddTensorTable', 'nn.Module') | |||
function CAddTensorTable:__init() | |||
parent.__init(self) | |||
self.gradInput = {} | |||
end | |||
function CAddTensorTable:updateOutput(input) | |||
local currentOutput = {} | |||
for i=1,#input[2] do | |||
currentOutput[i] = currentOutput[i] or input[1].new() | |||
currentOutput[i]:resizeAs(input[1]) | |||
currentOutput[i]:copy(input[2][i]) | |||
currentOutput[i]:add(input[1]) | |||
end | |||
for i = #input[2]+1, #currentOutput do | |||
currentOutput[i] = nil | |||
end | |||
self.output = currentOutput | |||
return self.output | |||
end | |||
function CAddTensorTable:updateGradInput(input, gradOutput) | |||
self.gradInput[1] = self.gradInput[1] or input[1].new() | |||
self.gradInput[1]:resizeAs(input[1]) | |||
self.gradInput[1]:copy(gradOutput[1]) | |||
for i=2, #input[2] do | |||
self.gradInput[1]:add(gradOutput[i]) | |||
end | |||
self.gradInput[2] = self.gradInput[2] or {} | |||
for i=1,#input[2] do | |||
self.gradInput[2][i] = self.gradInput[2][i] or input[1].new() | |||
self.gradInput[2][i]:resizeAs(input[1]) | |||
self.gradInput[2][i]:copy(gradOutput[i]) | |||
end | |||
for i=#input[2]+1, #self.gradInput[2] do | |||
self.gradInput[2][i] = nil | |||
end | |||
return self.gradInput | |||
end |
@@ -1,26 +0,0 @@ | |||
local CDivTable, parent = torch.class('nn.CDivTable', 'nn.Module') | |||
function CDivTable:__init() | |||
parent.__init(self) | |||
self.gradInput = {} | |||
end | |||
function CDivTable:updateOutput(input) | |||
self.output:resizeAs(input[1]):copy(input[1]) | |||
self.output:cdiv(input[2]) | |||
return self.output | |||
end | |||
function CDivTable:updateGradInput(input, gradOutput) | |||
self.gradInput[1] = self.gradInput[1] or input[1].new() | |||
self.gradInput[2] = self.gradInput[2] or input[1].new() | |||
self.gradInput[1]:resizeAs(input[1]):copy(gradOutput):cdiv(input[2]) | |||
self.gradInput[2]:resizeAs(input[2]):zero():addcdiv(-1,self.gradInput[1],input[2]):cmul(input[1]) | |||
for i=#input+1, #self.gradInput do | |||
self.gradInput[i] = nil | |||
end | |||
return self.gradInput | |||
end |
@@ -1,14 +0,0 @@ | |||
CMAKE_MINIMUM_REQUIRED(VERSION 2.6 FATAL_ERROR) | |||
CMAKE_POLICY(VERSION 2.6) | |||
INCLUDE_DIRECTORIES(${CMAKE_CURRENT_SOURCE_DIR}/../torch7/lib/TH) | |||
INCLUDE_DIRECTORIES(${CMAKE_CURRENT_BINARY_DIR}/../torch7/lib/TH) | |||
ADD_SUBDIRECTORY(lib) | |||
FILE(STRINGS lib/THNN/generic/THNN.h THNN_headers NEWLINE_CONSUME) | |||
FILE(WRITE THNN_h.lua "return [[") | |||
FILE(APPEND THNN_h.lua ${THNN_headers}) | |||
FILE(APPEND THNN_h.lua "]]") | |||
FILE(GLOB luasrc *.lua) | |||
ADD_TORCH_PACKAGE(nn "" "${luasrc}") |
@@ -1,46 +0,0 @@ | |||
local CMaxTable, parent = torch.class('nn.CMaxTable', 'nn.Module') | |||
function CMaxTable:__init() | |||
parent.__init(self) | |||
self.gradInput = {} | |||
self.maxIdx = torch.Tensor() | |||
self.mask = torch.Tensor() | |||
self.maxVals = torch.Tensor() | |||
self.gradMaxVals = torch.Tensor() | |||
end | |||
function CMaxTable:updateOutput(input) | |||
self.output:resizeAs(input[1]):copy(input[1]) | |||
self.maxIdx:resizeAs(input[1]):fill(1) | |||
for i=2,#input do | |||
self.maskByteTensor = self.maskByteTensor or | |||
(torch.type(self.output) == 'torch.CudaTensor' and | |||
torch.CudaByteTensor() or torch.ByteTensor()) | |||
self.mask:gt(input[i], self.output) | |||
self.maskByteTensor:resize(self.mask:size()):copy(self.mask) | |||
self.maxIdx:maskedFill(self.maskByteTensor, i) | |||
self.maxVals:maskedSelect(input[i], self.maskByteTensor) | |||
self.output:maskedCopy(self.maskByteTensor, self.maxVals) | |||
end | |||
return self.output | |||
end | |||
function CMaxTable:updateGradInput(input, gradOutput) | |||
for i=1,#input do | |||
self.gradInput[i] = self.gradInput[i] or input[i].new() | |||
self.gradInput[i]:resizeAs(input[i]):fill(0.0) | |||
self.maskByteTensor = self.maskByteTensor or | |||
(torch.type(self.output) == 'torch.CudaTensor' and | |||
torch.CudaByteTensor() or torch.ByteTensor()) | |||
self.mask:eq(self.maxIdx, i) | |||
self.maskByteTensor:resize(self.mask:size()):copy(self.mask) | |||
self.gradMaxVals:maskedSelect(gradOutput, self.maskByteTensor) | |||
self.gradInput[i]:maskedCopy(self.maskByteTensor, self.gradMaxVals) | |||
end | |||
for i=#input+1, #self.gradInput do | |||
self.gradInput[i] = nil | |||
end | |||
return self.gradInput | |||
end |
@@ -1,46 +0,0 @@ | |||
local CMinTable, parent = torch.class('nn.CMinTable', 'nn.Module') | |||
function CMinTable:__init() | |||
parent.__init(self) | |||
self.gradInput = {} | |||
self.minIdx = torch.Tensor() | |||
self.mask = torch.Tensor() | |||
self.minVals = torch.Tensor() | |||
self.gradMaxVals = torch.Tensor() | |||
end | |||
function CMinTable:updateOutput(input) | |||
self.output:resizeAs(input[1]):copy(input[1]) | |||
self.minIdx:resizeAs(input[1]):fill(1) | |||
for i=2,#input do | |||
self.maskByteTensor = self.maskByteTensor or | |||
(torch.type(self.output) == 'torch.CudaTensor' and | |||
torch.CudaByteTensor() or torch.ByteTensor()) | |||
self.mask:lt(input[i], self.output) | |||
self.maskByteTensor:resize(self.mask:size()):copy(self.mask) | |||
self.minIdx:maskedFill(self.maskByteTensor, i) | |||
self.minVals:maskedSelect(input[i], self.maskByteTensor) | |||
self.output:maskedCopy(self.maskByteTensor, self.minVals) | |||
end | |||
return self.output | |||
end | |||
function CMinTable:updateGradInput(input, gradOutput) | |||
for i=1,#input do | |||
self.gradInput[i] = self.gradInput[i] or input[i].new() | |||
self.gradInput[i]:resizeAs(input[i]):fill(0.0) | |||
self.maskByteTensor = self.maskByteTensor or | |||
(torch.type(self.output) == 'torch.CudaTensor' and | |||
torch.CudaByteTensor() or torch.ByteTensor()) | |||
self.mask:eq(self.minIdx, i) | |||
self.maskByteTensor:resize(self.mask:size()):copy(self.mask) | |||
self.gradMaxVals:maskedSelect(gradOutput, self.maskByteTensor) | |||
self.gradInput[i]:maskedCopy(self.maskByteTensor, self.gradMaxVals) | |||
end | |||
for i=#input+1, #self.gradInput do | |||
self.gradInput[i] = nil | |||
end | |||
return self.gradInput | |||
end |
@@ -1,166 +0,0 @@ | |||
local CMul, parent = torch.class('nn.CMul', 'nn.Module') | |||
function CMul:__init(...) | |||
parent.__init(self) | |||
local arg = {...} | |||
self.size = torch.LongStorage() | |||
local n = #arg | |||
if n == 1 and torch.type(arg[1]) == 'torch.LongStorage' then | |||
self.size:resize(#arg[1]):copy(arg[1]) | |||
else | |||
self.size:resize(n) | |||
for i=1,n do | |||
self.size[i] = arg[i] | |||
end | |||
end | |||
self.weight = torch.Tensor(self.size) | |||
self.gradWeight = torch.Tensor(self.size) | |||
self.output:resize(self.size) | |||
self:reset() | |||
end | |||
function CMul:reset(stdv) | |||
if stdv then | |||
stdv = stdv * math.sqrt(3) | |||
else | |||
stdv = 1./math.sqrt(self.weight:nElement()) | |||
end | |||
self.weight:uniform(-stdv,stdv) | |||
end | |||
function CMul:updateOutput(input) | |||
-- lazy-initialize | |||
self._output = self._output or input.new() | |||
self._weight = self._weight or input.new() | |||
self._expand = self._expand or input.new() | |||
self._repeat = self._repeat or input.new() | |||
self.output:resizeAs(input):copy(input) | |||
if input:nElement() == self.weight:nElement() then | |||
self._output:view(self.output, -1) | |||
self._weight:view(self.weight, -1) | |||
self._output:cmul(self._weight) | |||
else | |||
if self.weight:dim() == input:dim() then | |||
self._output:set(self.output) | |||
self._weight:set(self.weight) | |||
else | |||
local batchSize = input:size(1) | |||
self._output:view(self.output, batchSize, -1) | |||
self._weight:view(self.weight, 1, -1) | |||
end | |||
self._expand:expandAs(self._weight, self._output) | |||
if torch.type(input) == 'torch.CudaTensor' then | |||
self._repeat:resizeAs(self._expand):copy(self._expand) | |||
self._output:cmul(self._repeat) | |||
else | |||
self._output:cmul(self._expand) | |||
end | |||
end | |||
return self.output | |||
end | |||
function CMul:updateGradInput(input, gradOutput) | |||
if not self.gradInput then | |||
return | |||
end | |||
self._gradOutput = self._gradOutput or input.new() | |||
self._gradInput = self._gradInput or input.new() | |||
self.gradInput:resizeAs(input):zero() | |||
if self.weight:nElement() == gradOutput:nElement() then | |||
self.gradInput:addcmul(1, self.weight, gradOutput) | |||
else | |||
if self.weight:dim() == input:dim() then | |||
nn.utils.contiguousView(self._gradOutput, gradOutput, gradOutput:size()) | |||
nn.utils.contiguousView(self._gradInput, self.gradInput, self.gradInput:size()) | |||
self._weight:set(self.weight) | |||
else | |||
local batchSize = input:size(1) | |||
nn.utils.contiguousView(self._gradOutput, gradOutput, batchSize, -1) | |||
nn.utils.contiguousView(self._gradInput, self.gradInput, batchSize, -1) | |||
self._weight:view(self.weight, 1, -1) | |||
end | |||
self._expand:expandAs(self._weight, self._gradOutput) | |||
if torch.type(input) == 'torch.CudaTensor' then | |||
self._repeat:resizeAs(self._expand):copy(self._expand) | |||
self._gradInput:addcmul(1, self._repeat, self._gradOutput) | |||
else | |||
self._gradInput:addcmul(1, self._expand, self._gradOutput) | |||
end | |||
end | |||
return self.gradInput | |||
end | |||
function CMul:accGradParameters(input, gradOutput, scale) | |||
scale = scale or 1 | |||
self._input = self._input or input.new() | |||
self._gradWeight = self._gradWeight or input.new() | |||
self._sum = self._sum or input.new() | |||
if self.weight:nElement() == gradOutput:nElement() then | |||
self.gradWeight:addcmul(scale, input, gradOutput) | |||
else | |||
if self.weight:dim() == input:dim() then | |||
nn.utils.contiguousView(self._input, input, input:size()) | |||
nn.utils.contiguousView(self._gradOutput, gradOutput, gradOutput:size()) | |||
self._gradWeight:set(self.gradWeight) | |||
self._repeat:cmul(self._input, self._gradOutput) | |||
local sumInto = self._sum | |||
local sumFrom = self._repeat | |||
for i=1,self.weight:dim() do | |||
if self.weight:size(i) ~= input:size(i) then | |||
sumInto:sum(sumFrom, i) | |||
sumInto = sumFrom | |||
sumFrom = sumFrom == self._repeat and self._sum or self._repeat | |||
end | |||
end | |||
self._gradWeight:add(scale, sumFrom) | |||
else | |||
local batchSize = input:size(1) | |||
nn.utils.contiguousView(self._input, input, batchSize, -1) | |||
nn.utils.contiguousView(self._gradOutput, gradOutput, batchSize, -1) | |||
self._gradWeight:view(self.gradWeight, 1, -1) | |||
self._repeat:cmul(self._input, self._gradOutput) | |||
self._sum:sum(self._repeat, 1) | |||
self._gradWeight:add(scale, self._sum) | |||
end | |||
end | |||
end | |||
function CMul:type(type, tensorCache) | |||
if type then | |||
self:clearState() | |||
end | |||
return parent.type(self, type, tensorCache) | |||
end | |||
function CMul:clearState() | |||
nn.utils.clear(self, { | |||
'_input', | |||
'_output', | |||
'_weight', | |||
'_gradWeight', | |||
'_expand', | |||
'_repeat', | |||
'_sum', | |||
}) | |||
return parent.clearState(self) | |||
end |
@@ -1,55 +0,0 @@ | |||
local CMulTable, parent = torch.class('nn.CMulTable', 'nn.Module') | |||
function CMulTable:__init() | |||
parent.__init(self) | |||
self.gradInput = {} | |||
end | |||
function CMulTable:updateOutput(input) | |||
self.output:resizeAs(input[1]):copy(input[1]) | |||
for i=2,#input do | |||
self.output:cmul(input[i]) | |||
end | |||
return self.output | |||
end | |||
function CMulTable:updateGradInput_efficient(input, gradOutput) | |||
self.tout = self.tout or input[1].new() | |||
self.tout:resizeAs(self.output) | |||
for i=1,#input do | |||
self.gradInput[i] = self.gradInput[i] or input[1].new() | |||
self.gradInput[i]:resizeAs(input[i]):copy(gradOutput) | |||
self.tout:copy(self.output):cdiv(input[i]) | |||
self.gradInput[i]:cmul(self.tout) | |||
end | |||
for i=#input+1, #self.gradInput do | |||
self.gradInput[i] = nil | |||
end | |||
return self.gradInput | |||
end | |||
function CMulTable:updateGradInput(input, gradOutput) | |||
for i=1,#input do | |||
self.gradInput[i] = self.gradInput[i] or input[1].new() | |||
self.gradInput[i]:resizeAs(input[i]):copy(gradOutput) | |||
for j=1,#input do | |||
if i~=j then | |||
self.gradInput[i]:cmul(input[j]) | |||
end | |||
end | |||
end | |||
for i=#input+1, #self.gradInput do | |||
self.gradInput[i] = nil | |||
end | |||
return self.gradInput | |||
end | |||
function CMulTable:clearState() | |||
if self.tout then self.tout:set() end | |||
return parent.clearState(self) | |||
end |
@@ -1,136 +0,0 @@ | |||
# Contributing to Torch7 Core (torch7, nn, cutorch, cunn) | |||
Thanks a lot! There are plenty of ways you can help! | |||
Please take a moment to review this document in order to make the contribution | |||
process easy and effective for everyone involved. | |||
Following these guidelines helps to communicate that you respect the time of | |||
the developers managing and developing this open source project. In return, | |||
they should reciprocate that respect in addressing your issue or assessing | |||
patches and features. | |||
## Using the issue tracker | |||
The [issue tracker](https://github.com/torch/nn/issues) is | |||
the preferred channel for [bug reports](#bugs), [features requests](#features) | |||
and [submitting pull requests](#pull-requests), but please respect the following | |||
restrictions: | |||
* Please **do not** use the issue tracker for personal support requests (use | |||
[mailing-list](http://groups.google.com/forum/#!forum/torch7)). | |||
* Please **do not** open issues regarding the code in a torch package | |||
outside the core. For example don't open issues about the | |||
REPL in the nn issue tracker, use the trepl issue tracker for that. | |||
<a name="bugs"></a> | |||
## Bug reports | |||
A bug is a _demonstrable problem_ that is caused by the code in the repository. | |||
Good bug reports are extremely helpful - thank you! | |||
Guidelines for bug reports: | |||
1. **Use the GitHub issue search** — check if the issue has already been | |||
reported. | |||
2. **Check if the issue has been fixed** — try to reproduce it using the | |||
latest `master` or development branch in the repository. | |||
3. **Isolate the problem** — ideally create test case that is within reason, | |||
preferably within 100 lines of code. | |||
A good bug report shouldn't leave others needing to chase you up for more | |||
information. Please try to be as detailed as possible in your report. What is | |||
your environment? What steps will reproduce the issue? What OS do you | |||
experience the problem? What would you expect to be the outcome? All these | |||
details will help people to fix any potential bugs. | |||
<a name="features"></a> | |||
## Feature requests | |||
Feature requests are welcome to be filed. Torch is community-developed, | |||
the maintainers are not exclusive torch developers, so keep that in mind. | |||
The purpose of feature requests is for others who are looking to implement | |||
a feature are aware of the interest in the feature. | |||
<a name="pull-requests"></a> | |||
## Pull requests | |||
Good pull requests - patches, improvements, new features - are a fantastic | |||
help. They should remain focused in scope **and avoid containing unrelated | |||
commits.** | |||
**Please ask first** before embarking on any significant pull request (e.g. | |||
implementing features, refactoring code, porting to a different language), | |||
otherwise you risk spending a lot of time working on something that the | |||
project's developers might not want to merge into the project. | |||
Please adhere to the coding conventions used throughout a project (indentation, | |||
accurate comments, etc.) and any other requirements (such as test coverage). | |||
Adhering to the following this process is the best way to get your work | |||
included in the project: | |||
1. [Fork](https://help.github.com/articles/fork-a-repo) the project, clone your | |||
fork, and configure the remotes: | |||
```bash | |||
# Clone your fork of the repo into the current directory | |||
git clone https://github.com/<your-username>/nn.git | |||
# Navigate to the newly cloned directory | |||
cd nn | |||
# Assign the original repo to a remote called "upstream" | |||
git remote add upstream https://github.com/torch/nn.git | |||
``` | |||
2. If you cloned a while ago, get the latest changes from upstream: | |||
```bash | |||
git checkout master | |||
git pull upstream master | |||
``` | |||
3. Create a new topic branch (off the main project development branch) to | |||
contain your feature, change, or fix: | |||
```bash | |||
git checkout -b <topic-branch-name> | |||
``` | |||
4. Commit your changes in logical chunks. Please try to adhere to these [git commit | |||
message guidelines](http://tbaggery.com/2008/04/19/a-note-about-git-commit-messages.html) | |||
. Use Git's [interactive rebase](https://help.github.com/articles/about-git-rebase) | |||
feature to tidy up your commits before making them public. This helps us keep the | |||
commit history in logical blocks and clean, as torch grows. | |||
For example: | |||
- If you are adding a new function or a module, keep the module + tests + doc | |||
to a single commit unless logically warranted. | |||
- If you are fixing a bug, keep the bugfix to a single commit unless logically warranted. | |||
5. Locally merge (or rebase) the upstream development branch into your topic branch: | |||
```bash | |||
git pull [--rebase] upstream master | |||
``` | |||
6. Push your topic branch up to your fork: | |||
```bash | |||
git push origin <topic-branch-name> | |||
``` | |||
7. [Open a Pull Request](https://help.github.com/articles/using-pull-requests/) | |||
with a clear title and description. | |||
**IMPORTANT**: By submitting a patch, you agree to allow the project owners to | |||
license your work under the terms of the BSD License. | |||
## Development workflow tips | |||
* While you are changing lua files, one can simply symlink the cloned nn directory to ~/torch/install/share/lua/5.1/nn so that any change is reflected in the current install, without constantly having to do luarocks make rocks/* | |||
* If you are changing C files, then, after every change, you run luarocks make rocks/* | |||
* To test, you can just use: th -lnn -e "nn.test()" |
@@ -1,36 +0,0 @@ | |||
Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) | |||
Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) | |||
Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) | |||
Copyright (c) 2011-2013 NYU (Clement Farabet) | |||
Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) | |||
Copyright (c) 2006 Idiap Research Institute (Samy Bengio) | |||
Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) | |||
All rights reserved. | |||
Redistribution and use in source and binary forms, with or without | |||
modification, are permitted provided that the following conditions are met: | |||
1. Redistributions of source code must retain the above copyright | |||
notice, this list of conditions and the following disclaimer. | |||
2. Redistributions in binary form must reproduce the above copyright | |||
notice, this list of conditions and the following disclaimer in the | |||
documentation and/or other materials provided with the distribution. | |||
3. Neither the names of Deepmind Technologies, NYU, NEC Laboratories America | |||
and IDIAP Research Institute nor the names of its contributors may be | |||
used to endorse or promote products derived from this software without | |||
specific prior written permission. | |||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | |||
ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE | |||
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR | |||
CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF | |||
SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS | |||
INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN | |||
CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) | |||
ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE | |||
POSSIBILITY OF SUCH DAMAGE. |
@@ -1,57 +0,0 @@ | |||
local CReLU, parent = torch.class('nn.CReLU', 'nn.Sequential') | |||
function CReLU:__init(nInputDims, inplace) | |||
parent.__init(self) | |||
self.nInputDims = nInputDims | |||
self.inplace = inplace or false | |||
local concatTable = nn.ConcatTable() | |||
concatTable:add(nn.Identity()) | |||
concatTable:add(nn.MulConstant(-1)) | |||
self:add(concatTable) | |||
self:add(nn.JoinTable(2)) | |||
self:add(nn.ReLU(self.inplace)) | |||
end | |||
function CReLU:updateOutput(input) | |||
local input_ | |||
local batched = input:dim() == (self.nInputDims + 1) | |||
if not batched then | |||
input_ = input:view(1, -1) | |||
else | |||
input_ = input:view(input:size(1), -1) | |||
end | |||
parent.updateOutput(self, input_) | |||
local osize = input:size() | |||
if not batched then | |||
osize[1] = osize[1] * 2 | |||
else | |||
osize[2] = osize[2] * 2 | |||
end | |||
self.output:resize(osize) | |||
return self.output | |||
end | |||
function CReLU:backward(input, gradOutput) | |||
return self:updateGradInput(input, gradOutput) | |||
end | |||
function CReLU:updateGradInput(input, gradOutput) | |||
local batched = input:dim() == (self.nInputDims + 1) | |||
if not batched then | |||
parent.updateGradInput(self, input:view(1, -1), gradOutput:view(1, -1)) | |||
else | |||
parent.updateGradInput(self, input:view(input:size(1), -1), | |||
gradOutput:view(input:size(1), -1)) | |||
end | |||
self.gradInput:resizeAs(input) | |||
return self.gradInput | |||
end | |||
function CReLU:__tostring__() | |||
return "CReLU()" | |||
end |
@@ -1,26 +0,0 @@ | |||
local CSubTable, parent = torch.class('nn.CSubTable', 'nn.Module') | |||
function CSubTable:__init() | |||
parent.__init(self) | |||
self.gradInput = {} | |||
end | |||
function CSubTable:updateOutput(input) | |||
self.output:resizeAs(input[1]):copy(input[1]) | |||
self.output:add(-1,input[2]) | |||
return self.output | |||
end | |||
function CSubTable:updateGradInput(input, gradOutput) | |||
self.gradInput[1] = self.gradInput[1] or input[1].new() | |||
self.gradInput[2] = self.gradInput[2] or input[1].new() | |||
self.gradInput[1]:resizeAs(input[1]):copy(gradOutput) | |||
self.gradInput[2]:resizeAs(input[2]):copy(gradOutput):mul(-1) | |||
for i=#input+1, #self.gradInput do | |||
self.gradInput[i] = nil | |||
end | |||
return self.gradInput | |||
end |
@@ -1,5 +0,0 @@ | |||
local Clamp, Parent = torch.class('nn.Clamp', 'nn.HardTanh') | |||
function Clamp:__init(min_value, max_value) | |||
Parent.__init(self, min_value, max_value) | |||
end |
@@ -1,82 +0,0 @@ | |||
local THNN = require 'nn.THNN' | |||
local ClassNLLCriterion, parent = torch.class('nn.ClassNLLCriterion', 'nn.Criterion') | |||
function ClassNLLCriterion:__init(weights, sizeAverage, ignoreIndex) | |||
parent.__init(self) | |||
self.sizeAverage = (sizeAverage == nil) and true or sizeAverage | |||
self.ignoreIndex = ignoreIndex or -100 -- this target index will be ignored | |||
if weights then | |||
assert(weights:dim() == 1, "weights input should be 1-D Tensor") | |||
self.weights = weights | |||
end | |||
self.output_tensor = torch.zeros(1) | |||
self.total_weight_tensor = torch.ones(1) | |||
self.target = torch.zeros(1):long() | |||
end | |||
function ClassNLLCriterion:__len() | |||
if (self.weights) then | |||
return #self.weights | |||
else | |||
return 0 | |||
end | |||
end | |||
function ClassNLLCriterion:updateOutput(input, target) | |||
if type(target) == 'number' then | |||
if torch.typename(input):find('torch%.Cuda.*Tensor') then | |||
self.target = torch.CudaLongTensor and self.target:cudaLong() or self.target:cuda() | |||
else | |||
self.target = self.target:long() | |||
end | |||
self.target:resize(1) | |||
self.target[1] = target | |||
elseif torch.typename(input):find('torch%.Cuda.*Tensor') then | |||
self.target = torch.CudaLongTensor and target:cudaLong() or target | |||
else | |||
self.target = target:long() | |||
end | |||
input.THNN.ClassNLLCriterion_updateOutput( | |||
input:cdata(), | |||
self.target:cdata(), | |||
self.output_tensor:cdata(), | |||
self.sizeAverage, | |||
THNN.optionalTensor(self.weights), | |||
self.total_weight_tensor:cdata(), | |||
self.ignoreIndex | |||
) | |||
self.output = self.output_tensor[1] | |||
return self.output, self.total_weight_tensor[1] | |||
end | |||
function ClassNLLCriterion:updateGradInput(input, target) | |||
if type(target) == 'number' then | |||
if torch.typename(input):find('torch%.Cuda.*Tensor') then | |||
self.target = torch.CudaLongTensor and self.target:cudaLong() or self.target:cuda() | |||
else | |||
self.target = self.target:long() | |||
end | |||
self.target:resize(1) | |||
self.target[1] = target | |||
elseif torch.typename(input):find('torch%.Cuda.*Tensor') then | |||
self.target = torch.CudaLongTensor and target:cudaLong() or target | |||
else | |||
self.target = target:long() | |||
end | |||
self.gradInput:resizeAs(input):zero() | |||
input.THNN.ClassNLLCriterion_updateGradInput( | |||
input:cdata(), | |||
self.target:cdata(), | |||
self.gradInput:cdata(), | |||
self.sizeAverage, | |||
THNN.optionalTensor(self.weights), | |||
self.total_weight_tensor:cdata(), | |||
self.ignoreIndex | |||
) | |||
return self.gradInput | |||
end |
@@ -1,118 +0,0 @@ | |||
local ClassSimplexCriterion, parent | |||
= torch.class('nn.ClassSimplexCriterion', 'nn.MSECriterion') | |||
--[[ | |||
This file implements a criterion for multi-class classification. | |||
It learns an embedding per class, where each class' embedding | |||
is a point on an (N-1)-dimensional simplex, where N is | |||
the number of classes. | |||
For example usage of this class, look at doc/criterion.md | |||
Reference: http://arxiv.org/abs/1506.08230 | |||
]]-- | |||
--[[ | |||
function regsplex(n): | |||
regsplex returns the coordinates of the vertices of a | |||
regular simplex centered at the origin. | |||
The Euclidean norms of the vectors specifying the vertices are | |||
all equal to 1. The input n is the dimension of the vectors; | |||
the simplex has n+1 vertices. | |||
input: | |||
n -- dimension of the vectors specifying the vertices of the simplex | |||
output: | |||
a -- tensor dimensioned (n+1,n) whose rows are | |||
vectors specifying the vertices | |||
reference: | |||
http://en.wikipedia.org/wiki/Simplex#Cartesian_coordinates_for_regular_n-dimensional_simplex_in_Rn | |||
--]] | |||
local function regsplex(n) | |||
local a = torch.zeros(n+1,n) | |||
for k = 1,n do | |||
-- determine the last nonzero entry in the vector for the k-th vertex | |||
if k==1 then a[k][k] = 1 end | |||
if k>1 then a[k][k] = math.sqrt( 1 - a[{ {k},{1,k-1} }]:norm()^2 ) end | |||
-- fill the k-th coordinates for the vectors of the remaining vertices | |||
local c = (a[k][k]^2 - 1 - 1/n) / a[k][k] | |||
a[{ {k+1,n+1},{k} }]:fill(c) | |||
end | |||
return a | |||
end | |||
function ClassSimplexCriterion:__init(nClasses) | |||
parent.__init(self) | |||
assert(nClasses and nClasses > 1 and nClasses == (nClasses -(nClasses % 1)), | |||
"Required positive integer argument nClasses > 1") | |||
self.nClasses = nClasses | |||
-- embedding the simplex in a space of dimension strictly greater than | |||
-- the minimum possible (nClasses-1) is critical for effective training. | |||
local simp = regsplex(nClasses - 1) | |||
self.simplex = torch.cat(simp, | |||
torch.zeros(simp:size(1), nClasses -simp:size(2)), | |||
2) | |||
self._target = torch.Tensor(nClasses) | |||
end | |||
local function transformTarget(self, target) | |||
if torch.type(target) == 'number' then | |||
self._target:resize(self.nClasses) | |||
self._target:copy(self.simplex[target]) | |||
elseif torch.isTensor(target) then | |||
assert(target:dim() == 1, '1D tensors only!') | |||
local nSamples = target:size(1) | |||
self._target:resize(nSamples, self.nClasses) | |||
for i=1,nSamples do | |||
self._target[i]:copy(self.simplex[target[i]]) | |||
end | |||
end | |||
end | |||
function ClassSimplexCriterion:updateOutput(input, target) | |||
transformTarget(self, target) | |||
assert(input:nElement() == self._target:nElement()) | |||
self.output_tensor = self.output_tensor or input.new(1) | |||
input.THNN.MSECriterion_updateOutput( | |||
input:cdata(), | |||
self._target:cdata(), | |||
self.output_tensor:cdata(), | |||
self.sizeAverage | |||
) | |||
self.output = self.output_tensor[1] | |||
return self.output | |||
end | |||
function ClassSimplexCriterion:updateGradInput(input, target) | |||
assert(input:nElement() == self._target:nElement()) | |||
input.THNN.MSECriterion_updateGradInput( | |||
input:cdata(), | |||
self._target:cdata(), | |||
self.gradInput:cdata(), | |||
self.sizeAverage | |||
) | |||
return self.gradInput | |||
end | |||
function ClassSimplexCriterion:getPredictions(input) | |||
if input:dim() == 1 then | |||
input = input:view(1, -1) | |||
end | |||
return torch.mm(input, self.simplex:t()) | |||
end | |||
function ClassSimplexCriterion:getTopPrediction(input) | |||
local prod = self:getPredictions(input) | |||
local _, maxs = prod:max(prod:nDimension()) | |||
return maxs:view(-1) | |||
end |
@@ -1,30 +0,0 @@ | |||
local Collapse, parent = torch.class('nn.Collapse', 'nn.Module') | |||
function Collapse:__init(nInputDim) | |||
parent.__init(self) | |||
self.nInputDim = nInputDim | |||
end | |||
function Collapse:updateOutput(input) | |||
if not input:isContiguous() then | |||
self._input = self._input or input.new() | |||
self._input:resize(input:size()):copy(input) | |||
input = self._input | |||
end | |||
if input:dim() > self.nInputDim then | |||
self.output:view(input,input:size(1),-1) | |||
else | |||
self.output:view(input,-1) | |||
end | |||
return self.output | |||
end | |||
function Collapse:updateGradInput(input, gradOutput) | |||
self.gradInput:view(gradOutput, input:size()) | |||
return self.gradInput | |||
end | |||
function Collapse:clearState() | |||
self._input = nil | |||
end |
@@ -1,158 +0,0 @@ | |||
local Concat, parent = torch.class('nn.Concat', 'nn.Container') | |||
function Concat:__init(dimension) | |||
parent.__init(self) | |||
self.outputSize = torch.LongStorage() | |||
self.dimension = dimension | |||
end | |||
function Concat:updateOutput(input) | |||
self.outputSize = self.outputSize or torch.LongStorage() | |||
local outs = {} | |||
for i=1,#self.modules do | |||
local currentOutput = self:rethrowErrors(self.modules[i], i, 'updateOutput', input) | |||
outs[i] = currentOutput | |||
if i == 1 then | |||
self.outputSize:resize(currentOutput:dim()):copy(currentOutput:size()) | |||
else | |||
self.outputSize[self.dimension] = self.outputSize[self.dimension] + currentOutput:size(self.dimension) | |||
end | |||
end | |||
self.output:resize(self.outputSize) | |||
local offset = 1 | |||
for i,module in ipairs(self.modules) do | |||
local currentOutput = outs[i] | |||
self.output:narrow(self.dimension, offset, currentOutput:size(self.dimension)):copy(currentOutput) | |||
offset = offset + currentOutput:size(self.dimension) | |||
end | |||
return self.output | |||
end | |||
local function retable(t1, t2, f) | |||
for k, v in ipairs(t2) do | |||
if (torch.type(v) == "table") then | |||
t1[k] = retable(t1[k] or {}, t2[k], f) | |||
else | |||
f(t1, k, v) | |||
end | |||
end | |||
for i=#t2+1, #t1 do | |||
t1[i] = nil | |||
end | |||
return t1 | |||
end | |||
local function backward(self, method, input, gradOutput, scale) | |||
local isTable = torch.type(input) == 'table' | |||
local wasTable = torch.type(self.gradInput) == 'table' | |||
scale = scale or 1 | |||
if isTable then | |||
local offset = 1 | |||
for i,module in ipairs(self.modules) do | |||
local currentOutput = module.output | |||
local currentGradInput = self:rethrowErrors(module, i, method, input, | |||
gradOutput:narrow(self.dimension, offset, currentOutput:size(self.dimension)), scale) | |||
if torch.type(currentGradInput) ~= 'table' then | |||
error"currentGradInput is not a table!" | |||
end | |||
if #input ~= #currentGradInput then | |||
error("table size mismatch: "..#input.." ~= "..#currentGradInput) | |||
end | |||
if i == 1 then | |||
self.gradInput = wasTable and self.gradInput or {} | |||
retable(self.gradInput, currentGradInput, | |||
function(t, k, v) | |||
t[k] = t[k] or v:clone() | |||
t[k]:resizeAs(v) | |||
t[k]:copy(v) | |||
end | |||
) | |||
else | |||
retable(self.gradInput, currentGradInput, | |||
function(t, k, v) | |||
if t[k] then | |||
t[k]:add(v) | |||
else | |||
t[k] = v:clone() | |||
end | |||
end | |||
) | |||
end | |||
offset = offset + currentOutput:size(self.dimension) | |||
end | |||
else | |||
self.gradInput = (not wasTable) and self.gradInput:resizeAs(input) or input:clone() | |||
local offset = 1 | |||
for i,module in ipairs(self.modules) do | |||
local currentOutput = module.output | |||
local currentGradInput = self:rethrowErrors(module, i, method, input, | |||
gradOutput:narrow(self.dimension, offset, currentOutput:size(self.dimension)), scale) | |||
if currentGradInput then -- if the module does not produce a gradInput (for example first layer), then ignore it and move on. | |||
if i==1 then | |||
self.gradInput:copy(currentGradInput) | |||
else | |||
self.gradInput:add(currentGradInput) | |||
end | |||
end | |||
offset = offset + currentOutput:size(self.dimension) | |||
end | |||
end | |||
return self.gradInput | |||
end | |||
function Concat:updateGradInput(input, gradOutput) | |||
return backward(self, 'updateGradInput', input, gradOutput) | |||
end | |||
function Concat:backward(input, gradOutput, scale) | |||
return backward(self, 'backward', input, gradOutput, scale) | |||
end | |||
function Concat:accGradParameters(input, gradOutput, scale) | |||
scale = scale or 1 | |||
local offset = 1 | |||
for i,module in ipairs(self.modules) do | |||
local currentOutput = module.output | |||
self:rethrowErrors(module, i, 'accGradParameters', input, | |||
gradOutput:narrow(self.dimension, offset, currentOutput:size(self.dimension)), | |||
scale) | |||
offset = offset + currentOutput:size(self.dimension) | |||
end | |||
end | |||
function Concat:accUpdateGradParameters(input, gradOutput, lr) | |||
local offset = 1 | |||
for i,module in ipairs(self.modules) do | |||
local currentOutput = module.output | |||
self:rethrowErrors(module, i, 'accUpdateGradParameters', | |||
input, | |||
gradOutput:narrow(self.dimension, offset, currentOutput:size(self.dimension)), | |||
lr) | |||
offset = offset + currentOutput:size(self.dimension) | |||
end | |||
end | |||
function Concat:__tostring__() | |||
local tab = ' ' | |||
local line = '\n' | |||
local next = ' |`-> ' | |||
local lastNext = ' `-> ' | |||
local ext = ' | ' | |||
local extlast = ' ' | |||
local last = ' ... -> ' | |||
local str = torch.type(self) | |||
str = str .. ' {' .. line .. tab .. 'input' | |||
for i=1,#self.modules do | |||
if i == #self.modules then | |||
str = str .. line .. tab .. lastNext .. '(' .. i .. '): ' .. tostring(self.modules[i]):gsub(line, line .. tab .. extlast) | |||
else | |||
str = str .. line .. tab .. next .. '(' .. i .. '): ' .. tostring(self.modules[i]):gsub(line, line .. tab .. ext) | |||
end | |||
end | |||
str = str .. line .. tab .. last .. 'output' | |||
str = str .. line .. '}' | |||
return str | |||
end |
@@ -1,118 +0,0 @@ | |||
local ConcatTable, parent = torch.class('nn.ConcatTable', 'nn.Container') | |||
function ConcatTable:__init() | |||
parent.__init(self) | |||
self.modules = {} | |||
self.output = {} | |||
end | |||
function ConcatTable:updateOutput(input) | |||
for i=1,#self.modules do | |||
self.output[i] = self:rethrowErrors(self.modules[i], i, 'updateOutput', input) | |||
end | |||
return self.output | |||
end | |||
local function retable(t1, t2, f) | |||
for k, v in ipairs(t2) do | |||
if (torch.type(v) == "table") then | |||
t1[k] = retable(t1[k] or {}, t2[k], f) | |||
else | |||
f(t1, k, v) | |||
end | |||
end | |||
for i=#t2+1, #t1 do | |||
t1[i] = nil | |||
end | |||
return t1 | |||
end | |||
local function backward(self, method, input, gradOutput, scale) | |||
local isTable = torch.type(input) == 'table' | |||
local wasTable = torch.type(self.gradInput) == 'table' | |||
if isTable then | |||
for i,module in ipairs(self.modules) do | |||
local currentGradInput = self:rethrowErrors(module, i, method, input, gradOutput[i], scale) | |||
if torch.type(currentGradInput) ~= 'table' then | |||
error"currentGradInput is not a table!" | |||
end | |||
if #input ~= #currentGradInput then | |||
error("table size mismatch: "..#input.." ~= "..#currentGradInput) | |||
end | |||
if i == 1 then | |||
self.gradInput = wasTable and self.gradInput or {} | |||
retable(self.gradInput, currentGradInput, | |||
function(t, k, v) | |||
t[k] = t[k] or v:clone() | |||
t[k]:resize(v:size()) | |||
t[k]:copy(v) | |||
end | |||
) | |||
else | |||
retable(self.gradInput, currentGradInput, | |||
function(t, k, v) | |||
if t[k] then | |||
t[k]:add(v) | |||
else | |||
t[k] = v:clone() | |||
end | |||
end | |||
) | |||
end | |||
end | |||
else | |||
self.gradInput = (not wasTable) and self.gradInput or input:clone() | |||
for i,module in ipairs(self.modules) do | |||
local currentGradInput = self:rethrowErrors(module, i, method, input, gradOutput[i], scale) | |||
if i == 1 then | |||
self.gradInput:resize(currentGradInput:size()):copy(currentGradInput) | |||
else | |||
self.gradInput:add(currentGradInput) | |||
end | |||
end | |||
end | |||
return self.gradInput | |||
end | |||
function ConcatTable:updateGradInput(input, gradOutput) | |||
return backward(self, 'updateGradInput', input, gradOutput) | |||
end | |||
function ConcatTable:backward(input, gradOutput, scale) | |||
return backward(self, 'backward', input, gradOutput, scale) | |||
end | |||
function ConcatTable:accGradParameters(input, gradOutput, scale) | |||
scale = scale or 1 | |||
for i,module in ipairs(self.modules) do | |||
self:rethrowErrors(module, i, 'accGradParameters', input, gradOutput[i], scale) | |||
end | |||
end | |||
function ConcatTable:accUpdateGradParameters(input, gradOutput, lr) | |||
for i,module in ipairs(self.modules) do | |||
self:rethrowErrors(module, i, 'accUpdateGradParameters', input, gradOutput[i], lr) | |||
end | |||
end | |||
function ConcatTable:__tostring__() | |||
local tab = ' ' | |||
local line = '\n' | |||
local next = ' |`-> ' | |||
local lastNext = ' `-> ' | |||
local ext = ' | ' | |||
local extlast = ' ' | |||
local last = ' ... -> ' | |||
local str = torch.type(self) | |||
str = str .. ' {' .. line .. tab .. 'input' | |||
for i=1,#self.modules do | |||
if i == #self.modules then | |||
str = str .. line .. tab .. lastNext .. '(' .. i .. '): ' .. tostring(self.modules[i]):gsub(line, line .. tab .. extlast) | |||
else | |||
str = str .. line .. tab .. next .. '(' .. i .. '): ' .. tostring(self.modules[i]):gsub(line, line .. tab .. ext) | |||
end | |||
end | |||
str = str .. line .. tab .. last .. 'output' | |||
str = str .. line .. '}' | |||
return str | |||
end |
@@ -1,36 +0,0 @@ | |||
------------------------------------------------------------------------ | |||
--[[ Constant ]]-- | |||
------------------------------------------------------------------------ | |||
local Constant, parent = torch.class("nn.Constant", "nn.Module") | |||
function Constant:__init(value, nInputDim) | |||
self.value = value | |||
if torch.type(self.value) == 'number' then | |||
self.value = torch.Tensor{self.value} | |||
end | |||
assert(torch.isTensor(self.value), "Expecting number or tensor at arg 1") | |||
self.nInputDim = nInputDim | |||
parent.__init(self) | |||
end | |||
function Constant:updateOutput(input) | |||
if self.nInputDim and input:dim() > self.nInputDim then | |||
local vsize = self.value:size():totable() | |||
self.output:resize(input:size(1), table.unpack(vsize)) | |||
local value = self.value:view(1, table.unpack(vsize)) | |||
self.output:copy(value:expand(self.output:size())) | |||
else | |||
self.output:resize(self.value:size()):copy(self.value) | |||
end | |||
return self.output | |||
end | |||
function Constant:updateGradInput(input, gradOutput) | |||
self.gradInput:resizeAs(input):zero() | |||
return self.gradInput | |||
end |
@@ -1,149 +0,0 @@ | |||
local Container, parent = torch.class('nn.Container', 'nn.Module') | |||
function Container:__init(...) | |||
parent.__init(self, ...) | |||
self.modules = {} | |||
end | |||
function Container:add(module) | |||
table.insert(self.modules, module) | |||
return self | |||
end | |||
function Container:get(index) | |||
return self.modules[index] | |||
end | |||
function Container:size() | |||
return #self.modules | |||
end | |||
local _, XPCALL_ARGS = xpcall(function(x) return x ~= nil end, function() end, 1) | |||
local TRACEBACK_WARNING = "WARNING: If you see a stack trace below, it doesn't point to the place where this error occurred. Please use only the one above." | |||
function Container:rethrowErrors(module, moduleIndex, funcName, ...) | |||
assert(module == self.modules[moduleIndex], | |||
"mismatch between moduleIndex and self.modules in rethrowErrors") | |||
local function handleError(err) | |||
-- This will be executed only in the first container that handles the error. | |||
if not err:find(TRACEBACK_WARNING) then | |||
local traceback = debug.traceback() | |||
-- Remove this handler from the stack | |||
local _, first_line_end = traceback:find('^.-\n') | |||
local _, second_line_end = traceback:find('^.-\n.-\n') | |||
traceback = traceback:sub(1, first_line_end) .. traceback:sub(second_line_end+1) | |||
err = err .. '\n' .. traceback .. '\n\n' .. TRACEBACK_WARNING | |||
else | |||
-- Remove file path | |||
err = err:sub(err:find('\n')+1) | |||
end | |||
local msg = string.format('In %d module of %s:', | |||
moduleIndex, torch.type(self)) | |||
-- Preceding newline has to be here, because Lua will prepend a file path. | |||
err = '\n' .. msg .. '\n' .. err | |||
return err | |||
end | |||
-- Lua 5.1 doesn't support passing arguments through xpcall, so they have to | |||
-- be passed via a closure. This incurs some overhead, so it's better not to | |||
-- make it the default. | |||
local ok, ret, noret | |||
if not XPCALL_ARGS then | |||
local args = {...} | |||
local unpack = unpack or table.unpack | |||
ok, ret, noret = xpcall(function() | |||
return module[funcName](module, unpack(args)) | |||
end, | |||
handleError) | |||
else | |||
ok, ret, noret = xpcall(module[funcName], handleError, module, ...) | |||
end | |||
assert(noret == nil, "rethrowErrors supports only one return argument") | |||
if not ok then error(ret) end | |||
return ret | |||
end | |||
function Container:applyToModules(func) | |||
for _, module in ipairs(self.modules) do | |||
func(module) | |||
end | |||
end | |||
function Container:zeroGradParameters() | |||
self:applyToModules(function(module) module:zeroGradParameters() end) | |||
end | |||
function Container:updateParameters(learningRate) | |||
self:applyToModules(function(module) module:updateParameters(learningRate) end) | |||
end | |||
function Container:training() | |||
self:applyToModules(function(module) module:training() end) | |||
parent.training(self) | |||
end | |||
function Container:evaluate() | |||
self:applyToModules(function(module) module:evaluate() end) | |||
parent.evaluate(self) | |||
end | |||
function Container:share(mlp, ...) | |||
for i=1,#self.modules do | |||
self.modules[i]:share(mlp.modules[i], ...); | |||
end | |||
return self | |||
end | |||
function Container:reset(stdv) | |||
self:applyToModules(function(module) module:reset(stdv) end) | |||
end | |||
function Container:parameters() | |||
local function tinsert(to, from) | |||
if type(from) == 'table' then | |||
for i=1,#from do | |||
tinsert(to,from[i]) | |||
end | |||
else | |||
table.insert(to,from) | |||
end | |||
end | |||
local w = {} | |||
local gw = {} | |||
for i=1,#self.modules do | |||
local mw,mgw = self.modules[i]:parameters() | |||
if mw then | |||
tinsert(w,mw) | |||
tinsert(gw,mgw) | |||
end | |||
end | |||
return w,gw | |||
end | |||
function Container:clearState() | |||
-- don't call set because it might reset referenced tensors | |||
local function clear(f) | |||
if self[f] then | |||
if torch.isTensor(self[f]) then | |||
self[f] = self[f].new() | |||
elseif type(self[f]) == 'table' then | |||
self[f] = {} | |||
else | |||
self[f] = nil | |||
end | |||
end | |||
end | |||
clear('output') | |||
clear('gradInput') | |||
if self.modules then | |||
for i,module in pairs(self.modules) do | |||
module:clearState() | |||
end | |||
end | |||
return self | |||
end |
@@ -1,21 +0,0 @@ | |||
local Contiguous, parent = torch.class('nn.Contiguous', 'nn.Module') | |||
function Contiguous:updateOutput(input) | |||
if not input:isContiguous() then | |||
if self.output:storage() == input:storage() then self.output:set() end | |||
self.output:resizeAs(input):copy(input) | |||
else | |||
self.output:set(input) | |||
end | |||
return self.output | |||
end | |||
function Contiguous:updateGradInput(input, gradOutput) | |||
if not gradOutput:isContiguous() then | |||
if self.gradInput:storage() == gradOutput:storage() then self.gradInput:set() end | |||
self.gradInput:resizeAs(gradOutput):copy(gradOutput) | |||
else | |||
self.gradInput:set(gradOutput) | |||
end | |||
return self.gradInput | |||
end |
@@ -1,245 +0,0 @@ | |||
------------------------------------------------------------------------ | |||
--[ nn.Convert ]-- | |||
------------------------------------------------------------------------ | |||
local _ = require 'moses' | |||
local Convert, parent = torch.class("nn.Convert", "nn.Container") | |||
function Convert:__init(inputShape, outputShape) | |||
if outputShape and not inputShape then | |||
error"Expecting non-nil arg 1 when arg 2 is provided" | |||
end | |||
inputShape = inputShape or 'b*' | |||
outputShape = outputShape or inputShape | |||
self.inputShape = inputShape:find('b') and inputShape or ('b'..inputShape) | |||
self.outputShape = outputShape:find('b') and outputShape or ('b'..outputShape) | |||
self.inputBatchDim = self.inputShape:find('b') | |||
self.outputBatchDim = self.outputShape:find('b') | |||
if self.inputShape == 'b*' or self.outputShape == 'b*' then | |||
assert(self.inputShape == 'b*' and self.outputShape == 'b*', 'Both or neither shapes must be b*') | |||
self.nInputDim = -1 | |||
self.nOutputDim = -1 | |||
self.transposition = true | |||
else | |||
-- number of dims in batch mode | |||
self.nInputDim = #self.inputShape | |||
self.nOutputDim = #self.outputShape | |||
-- is the outputShape just a transposition of the inputShape? | |||
if self.nInputDim == self.nOutputDim then | |||
self.transposition = true | |||
for i=1,self.nInputDim do | |||
if not self.outputShape:find(self.inputShape:sub(i,i)) then | |||
self.transposition = false | |||
break | |||
end | |||
end | |||
end | |||
end | |||
parent.__init(self) | |||
end | |||
function Convert:buildConverter(input) | |||
if self.transposition then | |||
self.converter = self:transpose(self.outputShape) | |||
else | |||
if (torch.type(self[self.outputShape]) ~= 'function') then | |||
error(string.format("Unrecognized conversion of shape %s to %s", self.inputShape, self.outputShape)) | |||
end | |||
self.converter = self[self.outputShape](self, input) | |||
end | |||
assert(torch.isTensor(self.output), "Expecting Tensor output") | |||
self.converter:type(torch.type(self.output)) | |||
self.modules[1] = self.converter | |||
end | |||
function Convert:updateOutput(input) | |||
assert(torch.isTensor(input), "expecting Tensor") | |||
if not torch.isTypeOf(input, torch.type(self.output)) then | |||
-- handle different input type | |||
self._input = self._input or self.output.new() | |||
self._input:resize(input:size()):copy(input) | |||
input = self._input | |||
end | |||
self.batchMode = true | |||
if input:dim() < self.nInputDim then | |||
-- handle non-batch mode | |||
local inputSize = input:size():totable() | |||
table.insert(inputSize, self.inputBatchDim, 1) | |||
self.__input = self.__input or input.new() | |||
self.__input:set(input):resize(table.unpack(inputSize)) | |||
input = self.__input | |||
self.batchMode = false | |||
end | |||
if not self.converter then | |||
self:buildConverter(input) | |||
end | |||
self.output = self.converter:updateOutput(input) | |||
if not self.batchMode then | |||
local outputSize = self.output:size():totable() | |||
table.remove(outputSize, self.outputBatchDim) | |||
self.__output = self.__output or self.output.new() | |||
self.__output:set(self.output):resize(table.unpack(outputSize)) | |||
self.output = self.__output | |||
end | |||
return self.output | |||
end | |||
function Convert:updateGradInput(input, gradOutput) | |||
local input_ = input | |||
input = self._input or input | |||
if not self.batchMode then | |||
input = self.__input | |||
self.__gradOutput = self.__gradOutput or gradOutput.new() | |||
self.__gradOutput:set(gradOutput):resize(self.converter.output:size()) | |||
gradOutput = self.__gradOutput | |||
end | |||
local gradInput = self.converter:updateGradInput(input, gradOutput) | |||
if not self.batchMode then | |||
self.__gradInput = self.__gradInput or gradInput.new() | |||
self.__gradInput:set(gradInput):resize(input_:size()) | |||
gradInput = self.__gradInput | |||
end | |||
if self._input then | |||
self._gradInput = self._gradInput or input.new() | |||
self._gradInput:resize(input:size()):copy(gradInput) | |||
self.gradInput = self._gradInput | |||
else | |||
self.gradInput = gradInput | |||
end | |||
return self.gradInput | |||
end | |||
function Convert:accGradParameters(input, gradOutput, scale) | |||
input = self.batchMode and self.__input or self._input or input | |||
gradOutput = self.batchMode and self.__gradOutput or gradOutput | |||
self.converter:accGradParameters(input, gradOutput, scale) | |||
end | |||
function Convert:accUpdateGradParameters(input, gradOutput, lr) | |||
input = self.batchMode and self.__input or self._input or input | |||
gradOutput = self.batchMode and self.__gradOutput or gradOutput | |||
self.converter:accUpdateGradParameters(input, gradOutput, lr) | |||
end | |||
function Convert:bf(input) | |||
local b_pos = self:findAxis('b', self.inputShape) | |||
local dim = #self.inputShape | |||
if self.inputShape == 'bt' then | |||
error"Conversion of shape bt to bf not supported: open an issue on github" | |||
end | |||
-- was b | |||
if dim == 1 then | |||
return nn.Reshape(1) | |||
end | |||
-- was b... | |||
local modula | |||
if b_pos ~= 1 then | |||
modula = nn.Transpose({1, b_pos}) | |||
end | |||
if dim > 2 then | |||
local transpose = modula | |||
local sampleSize = input:select(self:findAxis('b'),1):nElement() | |||
local reshape = nn.Reshape(sampleSize) | |||
if transpose then | |||
modula = nn.Sequential() | |||
modula:add(transpose) | |||
modula:add(reshape) | |||
else | |||
modula = reshape | |||
end | |||
end | |||
return modula or nn.Identity() | |||
end | |||
function Convert:b(input) | |||
local b_pos = self:findAxis('b') | |||
if self.inputShape == 'bt' or self.inputShape == 'tb' then | |||
local t_pos = self:findAxis('t') | |||
-- select first set of classes | |||
return nn.Select(t_pos, 1) | |||
elseif self.inputShape == 'bf' or self.inputShape == 'fb' then | |||
-- this wont work as expected with size(f) > 1 | |||
local f_pos = self:findAxis('f') | |||
if input:size(f_pos) > 1 then | |||
error("Cannot convert shape "..self.inputShape.." to b when feature > 1") | |||
end | |||
return nn.Select(f_pos, 1) | |||
else | |||
error("Cannot convert shape "..self.inputShape.." to shape b") | |||
end | |||
end | |||
function Convert:default() | |||
return nn.Identity() | |||
end | |||
function Convert:bt() | |||
local b_pos = self:findAxis('b') | |||
local modula | |||
if self.inputShape == 'b' then | |||
modula = nn.Reshape(1) | |||
else | |||
error("cannot convert shape '"..self.inputShape.."' to bt") | |||
end | |||
return modula | |||
end | |||
function Convert:transpose(newShape) | |||
if newShape == self.inputShape then | |||
return nn.Identity() | |||
end | |||
local inputShape = {} | |||
for i=1,#self.inputShape do | |||
table.insert(inputShape, self.inputShape:sub(i,i)) | |||
end | |||
local transpositions = {} | |||
for i=1,#newShape do | |||
local j = _.indexOf(inputShape, newShape:sub(i,i)) | |||
if i ~= j then | |||
local char = inputShape[i] | |||
inputShape[i] = inputShape[j] | |||
inputShape[j] = char | |||
table.insert(transpositions, {j, i}) | |||
end | |||
end | |||
return nn.Transpose(table.unpack(transpositions)) | |||
end | |||
function Convert:findAxis(axis_char, shape, silent) | |||
shape = shape or self.inputShape | |||
local axis_pos = shape:find(axis_char) | |||
if (not silent) and (not axis_pos) then | |||
error("Provided shape '"..shape.."' has no axis '"..axis_char.."'", 2) | |||
end | |||
return axis_pos | |||
end | |||
function Convert:clearState() | |||
self._input = nil | |||
self._gradInput = nil | |||
self.__input = nil | |||
self.__output = nil | |||
self.__gradInput = nil | |||
self.__gradOutput = nil | |||
end | |||
function Convert:type(type) | |||
self:clearState() | |||
return parent.type(self, type) | |||
end |
@@ -1,42 +0,0 @@ | |||
local Copy, parent = torch.class('nn.Copy', 'nn.Module') | |||
function Copy:__init(intype, outtype, forceCopy, dontCast) | |||
intype = intype or torch.Tensor.__typename | |||
outtype = outtype or torch.Tensor.__typename | |||
self.dontCast = dontCast | |||
parent.__init(self) | |||
self.gradInput = torch.getmetatable(intype).new() | |||
self.output = torch.getmetatable(outtype).new() | |||
if (not forceCopy) and intype == outtype then | |||
self.updateOutput = function(self, input) | |||
self.output:set(input) | |||
return input | |||
end | |||
self.updateGradInput = function(self, input, gradOutput) | |||
self.gradInput:set(gradOutput) | |||
return gradOutput | |||
end | |||
end | |||
end | |||
function Copy:updateOutput(input) | |||
self.output:resize(input:size()):copy(input) | |||
return self.output | |||
end | |||
function Copy:updateGradInput(input, gradOutput) | |||
self.gradInput:resize(gradOutput:size()):copy(gradOutput) | |||
return self.gradInput | |||
end | |||
function Copy:type(type, tensorCache) | |||
if type and self.dontCast then | |||
return self | |||
end | |||
return parent.type(self, type, tensorCache) | |||
end |
@@ -1,175 +0,0 @@ | |||
local Cosine, parent = torch.class('nn.Cosine', 'nn.Module') | |||
function Cosine:__init(inputSize,outputSize) | |||
parent.__init(self) | |||
self.weight = torch.Tensor(outputSize,inputSize) | |||
self.gradWeight = torch.Tensor(outputSize,inputSize) | |||
self:reset() | |||
end | |||
function Cosine:reset(stdv) | |||
if stdv then | |||
stdv = stdv * math.sqrt(3) | |||
else | |||
stdv = 1./math.sqrt(self.weight:size(1)) | |||
end | |||
self.weight:uniform(-stdv, stdv) | |||
end | |||
function Cosine:updateOutput(input) | |||
local inputSize = self.weight:size(2) | |||
local outputSize = self.weight:size(1) | |||
self._weightNorm = self._weightNorm or self.weight.new() | |||
self._inputNorm = self._inputNorm or self.weight.new() | |||
-- y_j = (w_j * x) / ( || w_j || * || x || ) | |||
self._weightNorm:norm(self.weight,2,2):add(1e-12) | |||
if input:dim() == 1 then | |||
self.output:resize(outputSize):zero() | |||
self.output:addmv(1, self.weight, input) | |||
self.__norm = input:norm()+1e-12 | |||
self.output:cdiv(self._weightNorm:view(outputSize)):div(self.__norm) | |||
elseif input:dim() == 2 then | |||
local batchSize = input:size(1) | |||
local nElement = self.output:nElement() | |||
self.output:resize(batchSize, outputSize) | |||
if self.output:nElement() ~= nElement then | |||
self.output:zero() | |||
end | |||
self.output:addmm(0, self.output, 1, input, self.weight:t()) | |||
self._inputNorm:norm(input,2,2):add(1e-12) | |||
self.output:cdiv(self._weightNorm:view(1,outputSize):expandAs(self.output)) | |||
self.output:cdiv(self._inputNorm:expandAs(self.output)) | |||
else | |||
error('input must be vector or matrix') | |||
end | |||
return self.output | |||
end | |||
function Cosine:updateGradInput(input, gradOutput) | |||
if not self.gradInput then | |||
return | |||
end | |||
local inputSize = self.weight:size(2) | |||
local outputSize = self.weight:size(1) | |||
--[[ | |||
dy_j w_ji x_i | |||
---- = ------------------- - y_j --------- | |||
dx_i || w_j || * || x || || x ||^2 | |||
--]] | |||
local nElement = self.gradInput:nElement() | |||
self.gradInput:resizeAs(input) | |||
if self.gradInput:nElement() ~= nElement then | |||
self.gradInput:zero() | |||
end | |||
if input:dim() == 1 then | |||
self._weight = self._weight or input.new() | |||
self._weight:resizeAs(self.weight):copy(self.weight) | |||
self._weight:cdiv(self._weightNorm:expandAs(self.weight)) | |||
self._weight:div(self.__norm) | |||
self._weight:addr(1, self._weight, -1/(self.__norm*self.__norm), self.output, input) | |||
self.gradInput:addmv(0, 1, self._weight:t(), gradOutput) | |||
elseif input:dim() == 2 then | |||
local inputNorm = self._inputNorm:expandAs(input) | |||
local weightNorm = self._weightNorm:view(1,outputSize):expandAs(gradOutput) | |||
self.gradInput:copy(input):cdiv(inputNorm) | |||
self._gradOutput = self._gradOutput or gradOutput.new() | |||
self._gradOutput:resizeAs(gradOutput):copy(gradOutput) | |||
self._gradOutput:cmul(self.output) | |||
self._sum = self._sum or input.new() | |||
self._sum:sum(self._gradOutput, 2) | |||
self.gradInput:cmul(self._sum:expandAs(input)) | |||
self._gradOutput:resizeAs(gradOutput):copy(gradOutput) | |||
self._gradOutput:cdiv(weightNorm) | |||
self.gradInput:addmm(-1, self.gradInput, 1, self._gradOutput, self.weight) | |||
self.gradInput:cdiv(inputNorm) | |||
end | |||
return self.gradInput | |||
end | |||
function Cosine:accGradParameters(input, gradOutput, scale) | |||
scale = scale or 1 | |||
local inputSize = self.weight:size(2) | |||
local outputSize = self.weight:size(1) | |||
--[[ | |||
dy_j x_i w_ji | |||
----- = ------------------- - y_j ----------- | |||
dw_ji || w_j || * || x || || w_j ||^2 | |||
--]] | |||
if input:dim() == 1 then | |||
self._gradOutput = self._gradOutput or gradOutput.new() | |||
self._gradOutput:resizeAs(gradOutput):copy(gradOutput) | |||
local weightNorm = self._weightNorm:view(outputSize) | |||
self._gradOutput:cdiv(weightNorm) | |||
self.gradWeight:addr(scale/self.__norm, self._gradOutput, input) | |||
self._gradOutput:cdiv(weightNorm) | |||
self._gradOutput:cmul(self.output) | |||
self._weight = self._weight or self.weight.new() | |||
self._weight:resizeAs(self._weight):copy(self.weight) | |||
self._weight:cmul(self._gradOutput:view(outputSize, 1):expandAs(self.weight)) | |||
self.gradWeight:add(-1, self._weight) | |||
elseif input:dim() == 2 then | |||
self._weight = self._weight or self.weight.new() | |||
self._weight:resizeAs(self.weight):copy(self.weight) | |||
self._gradOutput = self._gradOutput or gradOutput.new() | |||
self._gradOutput:resizeAs(gradOutput):copy(gradOutput) | |||
self._gradOutput:cmul(self.output) | |||
self._sum = self._sum or input.new() | |||
self._sum:sum(self._gradOutput, 1) | |||
local grad = self._sum[1] | |||
grad:cdiv(self._weightNorm:select(2,1)) | |||
self._weight:cmul(grad:view(outputSize,1):expandAs(self._weight)) | |||
local input_ = self._gradOutput | |||
input_:resizeAs(input):copy(input) | |||
input_:cdiv(self._inputNorm:expandAs(input)) | |||
self._weight:addmm(-1, self._weight, 1, gradOutput:t(), input_) | |||
self._weight:cdiv(self._weightNorm:expandAs(self._weight)) | |||
self.gradWeight:add(self._weight) | |||
else | |||
error"1D or 2D input expected" | |||
end | |||
end | |||
function Cosine:type(type, tensorCache) | |||
if type then | |||
-- prevent premature memory allocations | |||
self._input = nil | |||
self._weight = nil | |||
self._inputNorm = nil | |||
self._weightNorm = nil | |||
self._gradOutput = nil | |||
self._sum = nil | |||
end | |||
return parent.type(self, type, tensorCache) | |||
end | |||
function Cosine:clearState() | |||
nn.utils.clear(self, { | |||
'_input', | |||
'_weight', | |||
'_gradOutput', | |||
'_sum', | |||
'_inputNorm', | |||
'_weightNorm', | |||
}) | |||
return parent.clearState(self) | |||
end |
@@ -1,116 +0,0 @@ | |||
local CosineDistance, parent = torch.class('nn.CosineDistance', 'nn.Module') | |||
function CosineDistance:__init() | |||
parent.__init(self) | |||
self.gradInput = {torch.Tensor(), torch.Tensor()} | |||
end | |||
local function makeContiguous(self, input1, input2) | |||
if not input1:isContiguous() then | |||
self._input1 = self._input1 or input1.new() | |||
self._input1:resizeAs(input1):copy(input1) | |||
input1 = self._input1 | |||
end | |||
if not input2:isContiguous() then | |||
self._input2 = self._input2 or input2.new() | |||
self._input2:resizeAs(input2):copy(input2) | |||
input2 = self._input2 | |||
end | |||
return input1, input2 | |||
end | |||
function CosineDistance:updateOutput(input) | |||
local input1, input2 = input[1], input[2] | |||
input1, input2 = makeContiguous(self, input1, input2) | |||
if input1:dim() == 1 then | |||
input1 = input1:view(1,-1) | |||
input2 = input2:view(1,-1) | |||
end | |||
if not self.buffer then | |||
self.buffer = input1.new() | |||
self.w1 = input1.new() | |||
self.w22 = input1.new() | |||
self.w = input1.new() | |||
self.w32 = input1.new() | |||
self.ones = input1.new() | |||
end | |||
self.buffer:cmul(input1,input2) | |||
self.w1:sum(self.buffer,2) | |||
local epsilon = 1e-12 | |||
self.buffer:cmul(input1,input1) | |||
self.w22:sum(self.buffer,2):add(epsilon) | |||
self.ones:resizeAs(self.w22):fill(1) | |||
self.w22:cdiv(self.ones, self.w22) | |||
self.w:resizeAs(self.w22):copy(self.w22) | |||
self.buffer:cmul(input2,input2) | |||
self.w32:sum(self.buffer,2):add(epsilon) | |||
self.w32:cdiv(self.ones, self.w32) | |||
self.w:cmul(self.w32) | |||
self.w:sqrt() | |||
self.output:cmul(self.w1,self.w) | |||
self.output:resize(input1:size(1)) | |||
return self.output | |||
end | |||
function CosineDistance:updateGradInput(input, gradOutput) | |||
local v1 = input[1] | |||
local v2 = input[2] | |||
local not_batch = false | |||
v1, v2 = makeContiguous(self, v1, v2) | |||
if v1:dim() == 1 then | |||
v1 = v1:view(1,-1) | |||
v2 = v2:view(1,-1) | |||
not_batch = true | |||
end | |||
if #self.gradInput ~= 2 then | |||
self.gradInput[1] = self.gradInput[1] or v1.new() | |||
self.gradInput[2] = self.gradInput[2] or v1.new() | |||
end | |||
local gw1 = self.gradInput[1] | |||
local gw2 = self.gradInput[2] | |||
gw1:resizeAs(v1):copy(v2) | |||
gw2:resizeAs(v1):copy(v1) | |||
self.buffer:cmul(self.w1,self.w22) | |||
gw1:addcmul(-1,self.buffer:expandAs(v1),v1) | |||
gw1:cmul(self.w:expandAs(v1)) | |||
self.buffer:cmul(self.w1,self.w32) | |||
gw2:addcmul(-1,self.buffer:expandAs(v1),v2) | |||
gw2:cmul(self.w:expandAs(v1)) | |||
local go = gradOutput:view(-1,1):expandAs(v1) | |||
gw1:cmul(go) | |||
gw2:cmul(go) | |||
if not_batch then | |||
self.gradInput[1]:resize(gw1:size(2)) | |||
self.gradInput[2]:resize(gw2:size(2)) | |||
end | |||
return self.gradInput | |||
end | |||
function CosineDistance:clearState() | |||
nn.utils.clear(self, { | |||
'buffer', | |||
'w1', | |||
'w22', | |||
'w', | |||
'w32', | |||
'ones', | |||
}) | |||
return parent.clearState(self) | |||
end |
@@ -1,142 +0,0 @@ | |||
local CosineEmbeddingCriterion, parent = torch.class('nn.CosineEmbeddingCriterion', 'nn.Criterion') | |||
function CosineEmbeddingCriterion:__init(margin) | |||
parent.__init(self) | |||
margin = margin or 0 | |||
self.margin = margin | |||
self.gradInput = {torch.Tensor(), torch.Tensor()} | |||
self.sizeAverage = true | |||
end | |||
function CosineEmbeddingCriterion:updateOutput(input,y) | |||
local input1, input2 = input[1], input[2] | |||
-- keep backward compatibility | |||
if type(y) == 'number' then | |||
self._y = self._y or input1.new(1) | |||
self._y[1] = y | |||
y = self._y | |||
end | |||
if input1:dim() == 1 then | |||
input1 = input1:view(1,-1) | |||
input2 = input2:view(1,-1) | |||
end | |||
if not self.buffer then | |||
self.buffer = input1.new() | |||
self.w1 = input1.new() | |||
self.w22 = input1.new() | |||
self.w = input1.new() | |||
self.w32 = input1.new() | |||
self._outputs = input1.new() | |||
-- comparison operators behave differently from cuda/c implementations | |||
if input1:type() == 'torch.CudaTensor' then | |||
self._idx = input1.new() | |||
else | |||
self._idx = torch.ByteTensor() | |||
end | |||
end | |||
self.buffer:cmul(input1,input2) | |||
self.w1:sum(self.buffer,2) | |||
local epsilon = 1e-12 | |||
self.buffer:cmul(input1,input1) | |||
self.w22:sum(self.buffer,2):add(epsilon) | |||
-- self._outputs is also used as a temporary buffer | |||
self._outputs:resizeAs(self.w22):fill(1) | |||
self.w22:cdiv(self._outputs, self.w22) | |||
self.w:resizeAs(self.w22):copy(self.w22) | |||
self.buffer:cmul(input2,input2) | |||
self.w32:sum(self.buffer,2):add(epsilon) | |||
self.w32:cdiv(self._outputs, self.w32) | |||
self.w:cmul(self.w32) | |||
self.w:sqrt() | |||
self._outputs:cmul(self.w1,self.w) | |||
self._outputs = self._outputs:select(2,1) | |||
y.eq(self._idx,y,-1) | |||
self._outputs[self._idx] = self._outputs[self._idx]:add(-self.margin):cmax(0) | |||
y.eq(self._idx,y,1) | |||
self._outputs[self._idx] = self._outputs[self._idx]:mul(-1):add(1) | |||
self.output = self._outputs:sum() | |||
if self.sizeAverage then | |||
self.output = self.output/y:size(1) | |||
end | |||
return self.output | |||
end | |||
function CosineEmbeddingCriterion:updateGradInput(input, y) | |||
local v1 = input[1] | |||
local v2 = input[2] | |||
local not_batch = false | |||
-- keep backward compatibility | |||
if type(y) == 'number' then | |||
self._y = self._y or input1.new(1) | |||
self._y[1] = y | |||
y = self._y | |||
end | |||
if v1:dim() == 1 then | |||
v1 = v1:view(1,-1) | |||
v2 = v2:view(1,-1) | |||
not_batch = true | |||
end | |||
local gw1 = self.gradInput[1] | |||
local gw2 = self.gradInput[2] | |||
gw1:resizeAs(v1):copy(v2) | |||
gw2:resizeAs(v1):copy(v1) | |||
self.buffer:cmul(self.w1,self.w22) | |||
gw1:addcmul(-1,self.buffer:expandAs(v1),v1) | |||
gw1:cmul(self.w:expandAs(v1)) | |||
self.buffer:cmul(self.w1,self.w32) | |||
gw2:addcmul(-1,self.buffer:expandAs(v1),v2) | |||
gw2:cmul(self.w:expandAs(v1)) | |||
-- self._idx = self._outputs <= 0 | |||
y.le(self._idx,self._outputs,0) | |||
self._idx = self._idx:view(-1,1):expand(gw1:size()) | |||
gw1[self._idx] = 0 | |||
gw2[self._idx] = 0 | |||
y.eq(self._idx,y,1) | |||
self._idx = self._idx:view(-1,1):expand(gw2:size()) | |||
gw1[self._idx] = gw1[self._idx]:mul(-1) | |||
gw2[self._idx] = gw2[self._idx]:mul(-1) | |||
if self.sizeAverage then | |||
gw1:div(y:size(1)) | |||
gw2:div(y:size(1)) | |||
end | |||
if not_batch then | |||
self.gradInput[1]:resize(gw1:size(2)) | |||
self.gradInput[2]:resize(gw2:size(2)) | |||
end | |||
return self.gradInput | |||
end | |||
function CosineEmbeddingCriterion:type(type) | |||
self._idx = nil | |||
parent.type(self,type) | |||
-- comparison operators behave differently from cuda/c implementations | |||
if type == 'torch.CudaTensor' then | |||
self._idx = torch.CudaTensor() | |||
else | |||
self._idx = torch.ByteTensor() | |||
end | |||
return self | |||
end |
@@ -1,64 +0,0 @@ | |||
local Criterion = torch.class('nn.Criterion') | |||
function Criterion:__init() | |||
self.gradInput = torch.Tensor() | |||
self.output = 0 | |||
end | |||
function Criterion:updateOutput(input, target) | |||
end | |||
function Criterion:forward(input, target) | |||
return self:updateOutput(input, target) | |||
end | |||
function Criterion:backward(input, target) | |||
return self:updateGradInput(input, target) | |||
end | |||
function Criterion:updateGradInput(input, target) | |||
end | |||
function Criterion:clone() | |||
local f = torch.MemoryFile("rw"):binary() | |||
f:writeObject(self) | |||
f:seek(1) | |||
local clone = f:readObject() | |||
f:close() | |||
return clone | |||
end | |||
function Criterion:type(type, tensorCache) | |||
assert(type, 'Criterion: must provide a type to convert to') | |||
-- find all tensors and convert them | |||
for key,param in pairs(self) do | |||
self[key] = nn.utils.recursiveType(param, type, tensorCache) | |||
end | |||
return self | |||
end | |||
function Criterion:float() | |||
return self:type('torch.FloatTensor') | |||
end | |||
function Criterion:double() | |||
return self:type('torch.DoubleTensor') | |||
end | |||
function Criterion:cuda() | |||
return self:type('torch.CudaTensor') | |||
end | |||
function Criterion:cudaHalf() | |||
return self:type('torch.CudaHalfTensor') | |||
end | |||
function Criterion:cudaDouble() | |||
return self:type('torch.CudaDoubleTensor') | |||
end | |||
function Criterion:__call__(input, target) | |||
self.output = self:forward(input, target) | |||
self.gradInput = self:backward(input, target) | |||
return self.output, self.gradInput | |||
end |
@@ -1,17 +0,0 @@ | |||
local CriterionTable, parent = torch.class('nn.CriterionTable', 'nn.Module') | |||
function CriterionTable:__init(criterion) | |||
parent.__init(self) | |||
self.criterion = criterion | |||
self.gradInput = {criterion.gradInput} | |||
end | |||
function CriterionTable:updateOutput(input) | |||
self.output = self.criterion:updateOutput(table.unpack(input)) | |||
return self.output | |||
end | |||
function CriterionTable:updateGradInput(input, gradOutput) | |||
self.criterion:updateGradInput(table.unpack(input)) | |||
return self.gradInput | |||
end |
@@ -1,42 +0,0 @@ | |||
local CrossEntropyCriterion, Criterion = torch.class('nn.CrossEntropyCriterion', 'nn.Criterion') | |||
function CrossEntropyCriterion:__init(weights, sizeAverage) | |||
Criterion.__init(self) | |||
self.lsm = nn.LogSoftMax() | |||
self.nll = nn.ClassNLLCriterion(weights, sizeAverage) | |||
self.sizeAverage = self.nll.sizeAverage | |||
self.oldSizeAverage = self.sizeAverage | |||
end | |||
function CrossEntropyCriterion:updateOutput(input, target) | |||
input = input:squeeze() | |||
target = type(target) == 'number' and target or target:squeeze() | |||
-- only propagate if value has changed to preserve old behavior | |||
-- of setting nll.sizeAverage directly | |||
if self.sizeAverage ~= self.oldSizeAverage then | |||
self.nll.sizeAverage = self.sizeAverage | |||
end | |||
self.lsm:updateOutput(input) | |||
self.nll:updateOutput(self.lsm.output, target) | |||
self.output = self.nll.output | |||
self.oldSizeAverage = self.sizeAverage | |||
return self.output | |||
end | |||
function CrossEntropyCriterion:updateGradInput(input, target) | |||
local size = input:size() | |||
input = input:squeeze() | |||
target = type(target) == 'number' and target or target:squeeze() | |||
-- only propagate if value has changed to preserve old behavior | |||
-- of setting nll.sizeAverage directly | |||
if self.sizeAverage ~= self.oldSizeAverage then | |||
self.nll.sizeAverage = self.sizeAverage | |||
end | |||
self.nll:updateGradInput(self.lsm.output, target) | |||
self.lsm:updateGradInput(input, self.nll.gradInput) | |||
self.gradInput:view(self.lsm.gradInput, size) | |||
self.oldSizeAverage = self.sizeAverage | |||
return self.gradInput | |||
end | |||
return nn.CrossEntropyCriterion |
@@ -1,47 +0,0 @@ | |||
local Decorator, parent = torch.class("nn.Decorator", "nn.Container") | |||
function Decorator:__init(module) | |||
parent.__init(self) | |||
-- so that it can be handled like a Container | |||
self.modules[1] = module | |||
end | |||
function Decorator:updateOutput(input) | |||
self.output = self.modules[1]:updateOutput(input) | |||
return self.output | |||
end | |||
function Decorator:updateGradInput(input, gradOutput) | |||
self.gradInput = self.modules[1]:updateGradInput(input, gradOutput) | |||
return self.gradInput | |||
end | |||
function Decorator:accGradParameters(input, gradOutput, scale) | |||
self.modules[1]:accGradParameters(input, gradOutput, scale) | |||
end | |||
function Decorator:accUpdateGradParameters(input, gradOutput, lr) | |||
self.modules[1]:accUpdateGradParameters(input, gradOutput, lr) | |||
end | |||
function Decorator:sharedAccUpdateGradParameters(input, gradOutput, lr) | |||
self.modules[1]:sharedAccUpdateGradParameters(input, gradOutput, lr) | |||
end | |||
function Decorator:__tostring__() | |||
if self.modules[1].__tostring__ then | |||
return torch.type(self) .. ' @ ' .. self.modules[1]:__tostring__() | |||
else | |||
return torch.type(self) .. ' @ ' .. torch.type(self.modules[1]) | |||
end | |||
end | |||
function Decorator.decorate(class) | |||
class.updateOutput = nn.Decorator.updateOutput | |||
class.updateGradInput = nn.Decorator.updateGradInput | |||
class.accGradParameters = nn.Decorator.accGradParameters | |||
class.accUpdateGradParameters = nn.Decorator.accUpdateGradParameters | |||
class.sharedAccUpdateGradParameters = nn.Decorator.sharedAccUpdateGradParameters | |||
class.__tostring__ = nn.Decorator.__tostring__ | |||
end |
@@ -1,116 +0,0 @@ | |||
------------------------------------------------------------------------ | |||
--[[ DepthConcat ]]-- | |||
------------------------------------------------------------------------ | |||
local DepthConcat, _ = torch.class('nn.DepthConcat', 'nn.Concat') | |||
function DepthConcat:windowNarrow(output, currentOutput, offset) | |||
local outputWindow = output:narrow(self.dimension, offset, currentOutput:size(self.dimension)) | |||
for dim=1,self.outputSize:size(1) do | |||
local currentSize = currentOutput:size(dim) | |||
if dim ~= self.dimension and self.outputSize[dim] ~= currentSize then | |||
-- 5x5 vs 3x3 -> start = [(5-3)/2] + 1 = 2 (1 pad each side) | |||
-- 9x9 vs 5x5 -> start = [(9-5)/2] + 1 = 3 (2 pad each side) | |||
-- 9x9 vs 4x4 -> start = [(9-4)/2] + 1 = 3.5 (2 pad, 3 pad) | |||
local start = math.floor(((self.outputSize[dim] - currentSize) / 2) + 1) | |||
outputWindow = outputWindow:narrow(dim, start, currentSize) | |||
end | |||
end | |||
return outputWindow | |||
end | |||
function DepthConcat:updateOutput(input) | |||
self.outputSize = self.outputSize or torch.LongStorage() | |||
local outs = {} | |||
for i=1,#self.modules do | |||
local currentOutput = self:rethrowErrors(self.modules[i], i, 'updateOutput', input) | |||
outs[i] = currentOutput | |||
if i == 1 then | |||
self.outputSize:resize(currentOutput:dim()):copy(currentOutput:size()) | |||
else | |||
self.outputSize[self.dimension] = self.outputSize[self.dimension] + currentOutput:size(self.dimension) | |||
for dim=1,self.outputSize:size(1) do | |||
if dim ~= self.dimension then | |||
-- take the maximum size (shouldn't change anything for batch dim) | |||
self.outputSize[dim] = math.max(self.outputSize[dim], currentOutput:size(dim)) | |||
end | |||
end | |||
end | |||
end | |||
self.output:resize(self.outputSize):zero() --zero for padding | |||
local offset = 1 | |||
for i,module in ipairs(self.modules) do | |||
local currentOutput = outs[i] | |||
local outputWindow = self:windowNarrow(self.output, currentOutput, offset) | |||
outputWindow:copy(currentOutput) | |||
offset = offset + currentOutput:size(self.dimension) | |||
end | |||
return self.output | |||
end | |||
function DepthConcat:updateGradInput(input, gradOutput) | |||
self.gradInput:resizeAs(input) | |||
local offset = 1 | |||
for i,module in ipairs(self.modules) do | |||
local currentOutput = module.output | |||
local gradOutputWindow = self:windowNarrow(gradOutput, currentOutput, offset) | |||
local currentGradInput = self:rethrowErrors(module, i, 'updateGradInput', input, gradOutputWindow) | |||
if i==1 then | |||
self.gradInput:copy(currentGradInput) | |||
else | |||
self.gradInput:add(currentGradInput) | |||
end | |||
offset = offset + currentOutput:size(self.dimension) | |||
end | |||
return self.gradInput | |||
end | |||
function DepthConcat:accGradParameters(input, gradOutput, scale) | |||
scale = scale or 1 | |||
local offset = 1 | |||
for i,module in ipairs(self.modules) do | |||
local currentOutput = module.output | |||
local gradOutputWindow = self:windowNarrow(gradOutput, currentOutput, offset) | |||
self:rethrowErrors(module, i, 'accGradParameters', input, gradOutputWindow, scale) | |||
offset = offset + currentOutput:size(self.dimension) | |||
end | |||
end | |||
function DepthConcat:backward(input, gradOutput, scale) | |||
self.gradInput:resizeAs(input) | |||
scale = scale or 1 | |||
local offset = 1 | |||
for i,module in ipairs(self.modules) do | |||
local currentOutput = module.output | |||
local gradOutputWindow = self:windowNarrow(gradOutput, currentOutput, offset) | |||
local currentGradInput = self:rethrowErrors(module, i, 'backward', input, gradOutputWindow) | |||
if i==1 then | |||
self.gradInput:copy(currentGradInput) | |||
else | |||
self.gradInput:add(currentGradInput) | |||
end | |||
offset = offset + currentOutput:size(self.dimension) | |||
end | |||
return self.gradInput | |||
end | |||
function DepthConcat:accUpdateGradParameters(input, gradOutput, lr) | |||
local offset = 1 | |||
for i,module in ipairs(self.modules) do | |||
local currentOutput = module.output | |||
local gradOutputWindow = self:windowNarrow(gradOutput, currentOutput, offset) | |||
self:rethrowErrors(module, i, 'accUpdateGradParameters', input, gradOutputWindow, lr) | |||
offset = offset + currentOutput:size(self.dimension) | |||
end | |||
end |
@@ -1,34 +0,0 @@ | |||
local DistKLDivCriterion, parent = torch.class('nn.DistKLDivCriterion', 'nn.Criterion') | |||
function DistKLDivCriterion:__init() | |||
parent.__init(self) | |||
self.sizeAverage = true | |||
end | |||
function DistKLDivCriterion:updateOutput(input, target) | |||
assert(input:dim() == target:dim() and | |||
torch.LongTensor(input:size()):eq(torch.LongTensor(target:size())):all(), | |||
'input and target should have the same size') | |||
self.output_tensor = self.output_tensor or input.new(1) | |||
input.THNN.DistKLDivCriterion_updateOutput( | |||
input:cdata(), | |||
target:cdata(), | |||
self.output_tensor:cdata(), | |||
self.sizeAverage | |||
) | |||
self.output = self.output_tensor[1] | |||
return self.output | |||
end | |||
function DistKLDivCriterion:updateGradInput(input, target) | |||
assert(input:dim() == target:dim() and | |||
torch.LongTensor(input:size()):eq(torch.LongTensor(target:size())):all(), | |||
'input and target should have the same size') | |||
input.THNN.DistKLDivCriterion_updateGradInput( | |||
input:cdata(), | |||
target:cdata(), | |||
self.gradInput:cdata(), | |||
self.sizeAverage | |||
) | |||
return self.gradInput | |||
end |
@@ -1,142 +0,0 @@ | |||
--[[ | |||
Probabilistic Criterion for Triplet Siamese Model for learning embedding. | |||
Ref: https://arxiv.org/pdf/1610.00243.pdf | |||
loss = -log( exp(-X) / ( exp(-X) + exp(-Y) ) ) | |||
where | |||
X : Distance between similar samples | |||
Y : Distance between dissimilar samples | |||
The loss could be break down to following log expansion | |||
loss = -log( exp(-X) ) - (-log( exp(-X) + exp(-Y) )) | |||
= -log( exp(-X) ) + log( exp(-X) + exp(-Y) ) | |||
= -(-X) + log( exp(-X) + exp(-Y) ) | |||
= X + log( exp(-X) + exp(-Y) ) | |||
Gradients: | |||
dLoss/dX = 1 + 1 / (exp(-X) + exp(-Y)) * -1 * exp(-X) | |||
= 1 - exp(-X) / (exp(-X) + exp(-Y)) | |||
dLoss/dY = 0 + 1 / (exp(-X) + exp(-Y)) * -1 * exp(-Y) | |||
= -exp(-Y) / (exp(-X) + exp(-Y)) | |||
--]] | |||
local DistanceRatioCriterion, parent = torch.class('nn.DistanceRatioCriterion', | |||
'nn.Criterion') | |||
function DistanceRatioCriterion:__init(sizeAverage) | |||
parent.__init(self) | |||
if sizeAverage ~= nil then | |||
self.sizeAverage = sizeAverage | |||
else | |||
self.sizeAverage = true | |||
end | |||
end | |||
--[[ | |||
loss = -log( exp(-X) ) - (-log( exp(-X) + exp(-Y) )) | |||
= -log( exp(-X) ) + log( exp(-X) + exp(-Y) ) | |||
= -(-X) + log( exp(-X) + exp(-Y) ) | |||
= X + log( exp(-X) + exp(-Y) ) | |||
--]] | |||
function DistanceRatioCriterion:updateOutput(input) | |||
assert(#input == 2, "Invalid number of inputs") | |||
local X = input[1] | |||
local Y = input[2] | |||
assert(X:nElement() == Y:nElement(), "Number of distances don't match.") | |||
assert(X:size(1) == Y:size(1), "Invalid distances' size.") | |||
-- Compute exp(-X) and exp(-Y) | |||
self._expMinusX = self._expMinusX or X.new() | |||
self._expMinusY = self._expMinusY or Y.new() | |||
-- Compute ( exp(-X) + exp(-Y) ) | |||
self._expMinusX:resizeAs(X):copy(X):mul(-1):exp() | |||
self._expMinusY:resizeAs(Y):copy(Y):mul(-1):exp() | |||
self._sumExpMinusXY = self.sumExpMinusExp or X.new() | |||
self._sumExpMinusXY:resizeAs(self._expMinusX):copy(self._expMinusX) | |||
:add(self._expMinusY) | |||
-- Compute log( exp(-X) + exp(-Y) ) | |||
self._logSumExpMinusXY = self._logSumExpMinusXY or self._sumExpMinusXY.new() | |||
self._logSumExpMinusXY:resizeAs(self._sumExpMinusXY) | |||
:copy(self._sumExpMinusXY):log() | |||
-- Compute log( exp(-X) + exp(-Y) ) | |||
self.loss = self.loss or self._logSumExpMinusXY.new() | |||
self.loss:resizeAs(X):copy(X):add(self._logSumExpMinusXY) | |||
if self.sizeAverage then | |||
return self.loss:sum()/X:size(1) | |||
else | |||
return self.loss:sum() | |||
end | |||
end | |||
--[[ | |||
Gradients: | |||
dLoss/dX = 1 + 1 / (exp(-X) + exp(-Y)) * -1 * exp(-X) | |||
= 1 - exp(-X) / (exp(-X) + exp(-Y)) | |||
dLoss/dY = 0 + 1 / (exp(-X) + exp(-Y)) * -1 * exp(-Y) | |||
= -exp(-Y) / (exp(-X) + exp(-Y)) | |||
--]] | |||
function DistanceRatioCriterion:updateGradInput(input) | |||
assert(#input == 2, "Invalid number of inputs") | |||
local X = input[1] | |||
local Y = input[2] | |||
assert(X:nElement() == Y:nElement(), "Number of distances don't match.") | |||
assert(X:size(1) == Y:size(1), "Invalid distances' size.") | |||
-- dLoss/dX | |||
-- -exp(-X) | |||
self.dX = self.dX or X.new() | |||
self.dX:resizeAs(self._expMinusX):copy(self._expMinusX):mul(-1) | |||
-- -exp(-X) / (exp(-X) + exp(-Y)) | |||
self.dX:cdiv(self._sumExpMinusXY) | |||
-- 1 - exp(-X) / (exp(-X) + exp(-Y)) | |||
self.dX:add(1) | |||
-- dLoss/dY | |||
-- -exp(-Y) | |||
self.dY = self.dY or Y.new() | |||
self.dY:resizeAs(self._expMinusY):copy(self._expMinusY):mul(-1) | |||
-- -exp(-Y) / (exp(-X) + exp(-Y)) | |||
self.dY:cdiv(self._sumExpMinusXY) | |||
if self.sizeAverage then | |||
self.dX:div(X:size(1)) | |||
self.dY:div(X:size(1)) | |||
end | |||
return {self.dX, self.dY} | |||
end | |||
function DistanceRatioCriterion:type(type, tensorCache) | |||
if type then | |||
self._expMinusX = nil | |||
self._expMinusY = nil | |||
self._sumExpMinusXY = nil | |||
self._logSumExpMinusXY = nil | |||
self.loss = nil | |||
self.dX = nil | |||
self.dY = nil | |||
end | |||
return parent.type(self, type, tensorCache) | |||
end |
@@ -1,124 +0,0 @@ | |||
local DontCast, parent = torch.class("nn.DontCast", "nn.Decorator") | |||
local function recursiveTypeCopy(dst, src, type_str) | |||
if torch.type(src) == 'table' then | |||
dst = (torch.type(dst) == 'table') and dst or {} | |||
for k, v in pairs(src) do | |||
dst[k] = recursiveTypeCopy(dst[k], v, type_str) | |||
end | |||
elseif torch.isTensor(src) then | |||
dst = (torch.type(dst) == type_str) and dst or torch.getmetatable(type_str).new() | |||
dst:resize(src:size()) | |||
if src:nElement() > 0 then | |||
dst:copy(src) | |||
end | |||
end | |||
return dst | |||
end | |||
local function tableTensorType(src) | |||
if type(src) == 'table' then | |||
local type_str, found | |||
for k,v in pairs(src) do | |||
type_str, found = tableTensorType(v) | |||
if found then | |||
return type_str, true | |||
end | |||
end | |||
return type_str, found | |||
else | |||
return torch.type(src), torch.isTensor(src) | |||
end | |||
end | |||
function DontCast:__init(module, castin, castout, moduleType) | |||
parent.__init(self, module) | |||
self.castin = castin | |||
self.castout = (castout == nil) and castin or castout | |||
self.moduleType = moduleType | |||
if (self.castin or self.castout) and not self.moduleType then | |||
local moduleType, found = tableTensorType(module.output) | |||
if found then | |||
self.moduleType = moduleType | |||
else | |||
moduleType, found = tableTensorType(module:parameters()) | |||
if found then | |||
self.moduleType = moduleType | |||
else | |||
error"Cannot extrapolate moduleType. Provide constructor argument 4" | |||
end | |||
end | |||
end | |||
end | |||
function DontCast:updateOutput(input) | |||
if self.castin and tableTensorType(input) ~= self.moduleType then | |||
self._input = recursiveTypeCopy(self._input, input, self.moduleType) | |||
input = self._input | |||
end | |||
local output = self.modules[1]:updateOutput(input) | |||
if self.castout then | |||
self.output = recursiveTypeCopy(self.output, output, tableTensorType(self.output)) | |||
else | |||
self.output = output | |||
end | |||
return self.output | |||
end | |||
function DontCast:updateGradInput(input, gradOutput) | |||
if self.castin and tableTensorType(input) ~= self.moduleType then | |||
input = self._input | |||
end | |||
if self.castout and tableTensorType(gradOutput) ~= self.moduleType then | |||
self._gradOutput = recursiveTypeCopy(self._gradOutput, gradOutput, self.moduleType) | |||
gradOutput = self._gradOutput | |||
end | |||
local gradInput = self.modules[1]:updateGradInput(input, gradOutput) | |||
if self.castin then | |||
self.gradInput = recursiveTypeCopy(self.gradInput, gradInput, tableTensorType(self.gradInput)) | |||
else | |||
self.gradInput = gradInput | |||
end | |||
return self.gradInput | |||
end | |||
function DontCast:accGradParameters(input, gradOutput, scale) | |||
if self.castin and tableTensorType(input) ~= self.moduleType then | |||
input = self._input | |||
end | |||
if self.castout and tableTensorType(gradOutput) ~= self.moduleType then | |||
gradOutput = self._gradOutput | |||
end | |||
self.modules[1]:accGradParameters(input, gradOutput, scale) | |||
end | |||
function DontCast:accUpdateGradParameters(input, gradOutput, lr) | |||
if self.castin and tableTensorType(input) ~= self.moduleType then | |||
input = self._input | |||
end | |||
if self.castout and tableTensorType(gradOutput) ~= self.moduleType then | |||
gradOutput = self._gradOutput | |||
end | |||
self.modules[1]:accUpdateGradParameters(input, gradOutput, lr) | |||
end | |||
function DontCast:type(type) | |||
if self.castout and tableTensorType(self.output) ~= type then | |||
self.output = recursiveTypeCopy(nil, self.output, type) | |||
end | |||
if self.castin and tableTensorType(self.gradInput) ~= type then | |||
self.gradInput = recursiveTypeCopy(nil, self.gradInput, type) | |||
end | |||
return self | |||
end |
@@ -1,61 +0,0 @@ | |||
local DotProduct, parent = torch.class('nn.DotProduct', 'nn.Module') | |||
function DotProduct:__init() | |||
parent.__init(self) | |||
self.gradInput = {torch.Tensor(), torch.Tensor()} | |||
end | |||
function DotProduct:updateOutput(input) | |||
local input1, input2 = input[1], input[2] | |||
if input1:dim() == 1 then | |||
-- convert non batch input to batch input | |||
input1 = input1:view(1,-1) | |||
input2 = input2:view(1,-1) | |||
end | |||
if not self.buffer then | |||
self.buffer = input1.new() | |||
end | |||
self.buffer:cmul(input1, input2) | |||
self.output:sum(self.buffer, 2) | |||
self.output:resize(input1:size(1)) | |||
return self.output | |||
end | |||
function DotProduct:updateGradInput(input, gradOutput) | |||
local v1 = input[1] | |||
local v2 = input[2] | |||
local not_batch = false | |||
if #self.gradInput ~= 2 then | |||
self.gradInput[1] = self.gradInput[1] or input[1].new() | |||
self.gradInput[2] = self.gradInput[2] or input[2].new() | |||
end | |||
if v1:dim() == 1 then | |||
v1 = v1:view(1,-1) | |||
v2 = v2:view(1,-1) | |||
not_batch = true | |||
end | |||
local gw1 = self.gradInput[1] | |||
local gw2 = self.gradInput[2] | |||
gw1:resizeAs(v1):copy(v2) | |||
gw2:resizeAs(v2):copy(v1) | |||
local go = gradOutput:view(-1,1):expandAs(v1) | |||
gw1:cmul(go) | |||
gw2:cmul(go) | |||
if not_batch then | |||
-- unbatch gradInput | |||
self.gradInput[1]:set(gw1:select(1,1)) | |||
self.gradInput[2]:set(gw2:select(1,1)) | |||
end | |||
return self.gradInput | |||
end | |||
function DotProduct:clearState() | |||
if self.buffer then self.buffer:set() end | |||
return parent.clearState(self) | |||
end |
@@ -1,70 +0,0 @@ | |||
local Dropout, Parent = torch.class('nn.Dropout', 'nn.Module') | |||
function Dropout:__init(p,v1,inplace,stochasticInference) | |||
Parent.__init(self) | |||
self.p = p or 0.5 | |||
self.train = true | |||
self.inplace = inplace | |||
self.stochastic_inference = stochasticInference or false | |||
-- version 2 scales output during training instead of evaluation | |||
self.v2 = not v1 | |||
if self.p >= 1 or self.p < 0 then | |||
error('<Dropout> illegal percentage, must be 0 <= p < 1') | |||
end | |||
self.noise = torch.Tensor() | |||
end | |||
function Dropout:updateOutput(input) | |||
if self.inplace then | |||
self.output:set(input) | |||
else | |||
self.output:resizeAs(input):copy(input) | |||
end | |||
if self.p > 0 then | |||
if self.train or self.stochastic_inference then | |||
self.noise:resizeAs(input) | |||
self.noise:bernoulli(1-self.p) | |||
if self.v2 then | |||
self.noise:div(1-self.p) | |||
end | |||
self.output:cmul(self.noise) | |||
elseif not self.v2 then | |||
self.output:mul(1-self.p) | |||
end | |||
end | |||
return self.output | |||
end | |||
function Dropout:updateGradInput(input, gradOutput) | |||
if self.inplace then | |||
self.gradInput:set(gradOutput) | |||
else | |||
self.gradInput:resizeAs(gradOutput):copy(gradOutput) | |||
end | |||
if self.train then | |||
if self.p > 0 then | |||
self.gradInput:cmul(self.noise) -- simply mask the gradients with the noise vector | |||
end | |||
else | |||
if not self.v2 and self.p > 0 then | |||
self.gradInput:mul(1-self.p) | |||
end | |||
end | |||
return self.gradInput | |||
end | |||
function Dropout:setp(p) | |||
self.p = p | |||
end | |||
function Dropout:__tostring__() | |||
return string.format('%s(%f)', torch.type(self), self.p) | |||
end | |||
function Dropout:clearState() | |||
if self.noise then | |||
self.noise:set() | |||
end | |||
return Parent.clearState(self) | |||
end |
@@ -1,45 +0,0 @@ | |||
local ELU, parent = torch.class('nn.ELU', 'nn.Module') | |||
--[[ | |||
Djork-Arné Clevert, Thomas Unterthiner, Sepp Hochreiter | |||
Fast and Accurate Deep Network Learning by Exponential Linear Units (ELUs) | |||
http://arxiv.org/pdf/1511.07289.pdf | |||
--]] | |||
function ELU:__init(alpha, inplace) | |||
parent.__init(self) | |||
self.alpha = alpha or 1 | |||
assert(type(self.alpha) == 'number') | |||
self.inplace = inplace or false | |||
assert(type(self.inplace) == 'boolean') | |||
end | |||
function ELU:updateOutput(input) | |||
local inplace = self.inplace or false | |||
input.THNN.ELU_updateOutput( | |||
input:cdata(), | |||
self.output:cdata(), | |||
self.alpha, | |||
inplace | |||
) | |||
return self.output | |||
end | |||
function ELU:updateGradInput(input, gradOutput) | |||
local inplace = self.inplace or false | |||
input.THNN.ELU_updateGradInput( | |||
input:cdata(), | |||
gradOutput:cdata(), | |||
self.gradInput:cdata(), | |||
self.output:cdata(), | |||
self.alpha, | |||
inplace | |||
) | |||
return self.gradInput | |||
end | |||
function ELU:__tostring__() | |||
return string.format('%s (alpha:%f)', torch.type(self), self.alpha) | |||
end |
@@ -1,19 +0,0 @@ | |||
local mt = { | |||
__index = function(table, key) | |||
error("nn."..key.." is only supported for Float or Double Tensors.") | |||
end | |||
} | |||
local tensors = { | |||
torch.ByteTensor, | |||
torch.CharTensor, | |||
torch.ShortTensor, | |||
torch.IntTensor, | |||
torch.LongTensor, | |||
} | |||
for _, t in ipairs(tensors) do | |||
t.nn = {} | |||
setmetatable(t.nn, mt) | |||
end |
@@ -1,197 +0,0 @@ | |||
local Euclidean, parent = torch.class('nn.Euclidean', 'nn.Module') | |||
function Euclidean:__init(inputSize,outputSize) | |||
parent.__init(self) | |||
self.weight = torch.Tensor(inputSize,outputSize) | |||
self.gradWeight = torch.Tensor(inputSize,outputSize) | |||
-- state | |||
self.gradInput:resize(inputSize) | |||
self.output:resize(outputSize) | |||
self.fastBackward = true | |||
self:reset() | |||
end | |||
function Euclidean:reset(stdv) | |||
if stdv then | |||
stdv = stdv * math.sqrt(3) | |||
else | |||
stdv = 1./math.sqrt(self.weight:size(1)) | |||
end | |||
if nn.oldSeed then | |||
for i=1,self.weight:size(2) do | |||
self.weight:select(2, i):apply(function() | |||
return torch.uniform(-stdv, stdv) | |||
end) | |||
end | |||
else | |||
self.weight:uniform(-stdv, stdv) | |||
end | |||
end | |||
local function view(res, src, ...) | |||
local args = {...} | |||
if src:isContiguous() then | |||
res:view(src, table.unpack(args)) | |||
else | |||
res:reshape(src, table.unpack(args)) | |||
end | |||
end | |||
function Euclidean:updateOutput(input) | |||
-- lazy initialize buffers | |||
self._input = self._input or input.new() | |||
self._weight = self._weight or self.weight.new() | |||
self._expand = self._expand or self.output.new() | |||
self._expand2 = self._expand2 or self.output.new() | |||
self._repeat = self._repeat or self.output.new() | |||
self._repeat2 = self._repeat2 or self.output.new() | |||
local inputSize, outputSize = self.weight:size(1), self.weight:size(2) | |||
-- y_j = || w_j - x || = || x - w_j || | |||
if input:dim() == 1 then | |||
view(self._input, input, inputSize, 1) | |||
self._expand:expandAs(self._input, self.weight) | |||
self._repeat:resizeAs(self._expand):copy(self._expand) | |||
self._repeat:add(-1, self.weight) | |||
self.output:norm(self._repeat, 2, 1) | |||
self.output:resize(outputSize) | |||
elseif input:dim() == 2 then | |||
local batchSize = input:size(1) | |||
view(self._input, input, batchSize, inputSize, 1) | |||
self._expand:expand(self._input, batchSize, inputSize, outputSize) | |||
-- make the expanded tensor contiguous (requires lots of memory) | |||
self._repeat:resizeAs(self._expand):copy(self._expand) | |||
self._weight:view(self.weight, 1, inputSize, outputSize) | |||
self._expand2:expandAs(self._weight, self._repeat) | |||
if torch.type(input) == 'torch.CudaTensor' then | |||
-- requires lots of memory, but minimizes cudaMallocs and loops | |||
self._repeat2:resizeAs(self._expand2):copy(self._expand2) | |||
self._repeat:add(-1, self._repeat2) | |||
else | |||
self._repeat:add(-1, self._expand2) | |||
end | |||
self.output:norm(self._repeat, 2, 2) | |||
self.output:resize(batchSize, outputSize) | |||
else | |||
error"1D or 2D input expected" | |||
end | |||
return self.output | |||
end | |||
function Euclidean:updateGradInput(input, gradOutput) | |||
if not self.gradInput then | |||
return | |||
end | |||
self._div = self._div or input.new() | |||
self._output = self._output or self.output.new() | |||
self._gradOutput = self._gradOutput or input.new() | |||
self._expand3 = self._expand3 or input.new() | |||
if not self.fastBackward then | |||
self:updateOutput(input) | |||
end | |||
local inputSize, outputSize = self.weight:size(1), self.weight:size(2) | |||
--[[ | |||
dy_j -2 * (w_j - x) x - w_j | |||
---- = --------------- = ------- | |||
dx 2 || w_j - x || y_j | |||
--]] | |||
-- to prevent div by zero (NaN) bugs | |||
self._output:resizeAs(self.output):copy(self.output):add(0.0000001) | |||
view(self._gradOutput, gradOutput, gradOutput:size()) | |||
self._div:cdiv(gradOutput, self._output) | |||
if input:dim() == 1 then | |||
self._div:resize(1, outputSize) | |||
self._expand3:expandAs(self._div, self.weight) | |||
if torch.type(input) == 'torch.CudaTensor' then | |||
self._repeat2:resizeAs(self._expand3):copy(self._expand3) | |||
self._repeat2:cmul(self._repeat) | |||
else | |||
self._repeat2:cmul(self._repeat, self._expand3) | |||
end | |||
self.gradInput:sum(self._repeat2, 2) | |||
self.gradInput:resizeAs(input) | |||
elseif input:dim() == 2 then | |||
local batchSize = input:size(1) | |||
self._div:resize(batchSize, 1, outputSize) | |||
self._expand3:expand(self._div, batchSize, inputSize, outputSize) | |||
if torch.type(input) == 'torch.CudaTensor' then | |||
self._repeat2:resizeAs(self._expand3):copy(self._expand3) | |||
self._repeat2:cmul(self._repeat) | |||
else | |||
self._repeat2:cmul(self._repeat, self._expand3) | |||
end | |||
self.gradInput:sum(self._repeat2, 3) | |||
self.gradInput:resizeAs(input) | |||
else | |||
error"1D or 2D input expected" | |||
end | |||
return self.gradInput | |||
end | |||
function Euclidean:accGradParameters(input, gradOutput, scale) | |||
local inputSize, outputSize = self.weight:size(1), self.weight:size(2) | |||
scale = scale or 1 | |||
--[[ | |||
dy_j 2 * (w_j - x) w_j - x | |||
---- = --------------- = ------- | |||
dw_j 2 || w_j - x || y_j | |||
--]] | |||
-- assumes a preceding call to updateGradInput | |||
if input:dim() == 1 then | |||
self.gradWeight:add(-scale, self._repeat2) | |||
elseif input:dim() == 2 then | |||
self._sum = self._sum or input.new() | |||
self._sum:sum(self._repeat2, 1) | |||
self._sum:resize(inputSize, outputSize) | |||
self.gradWeight:add(-scale, self._sum) | |||
else | |||
error"1D or 2D input expected" | |||
end | |||
end | |||
function Euclidean:type(type, tensorCache) | |||
if type then | |||
-- prevent premature memory allocations | |||
self:clearState() | |||
end | |||
return parent.type(self, type, tensorCache) | |||
end | |||
function Euclidean:clearState() | |||
nn.utils.clear(self, { | |||
'_input', | |||
'_output', | |||
'_gradOutput', | |||
'_weight', | |||
'_div', | |||
'_sum', | |||
'_expand', | |||
'_expand2', | |||
'_expand3', | |||
'_repeat', | |||
'_repeat2', | |||
}) | |||
return parent.clearState(self) | |||
end |
@@ -1,9 +0,0 @@ | |||
local Exp = torch.class('nn.Exp', 'nn.Module') | |||
function Exp:updateOutput(input) | |||
return self.output:exp(input) | |||
end | |||
function Exp:updateGradInput(input, gradOutput) | |||
return self.gradInput:cmul(self.output, gradOutput) | |||
end |