local S2D, parent = torch.class("nn.Sparse2Dense", "nn.Module") local dt = require 'decisiontree._env' function S2D:__init(features) parent.__init(self) if torch.type(features) == 'table' then assert(#features > 0) features = torch.LongTensor(features) end assert(torch.isTensor(features)) self.features = features self.featureMap = nil self.masks = {} self.mappedKeys = {} end function S2D:updateOutput(input) if not self.featureMap then self.featureMap = dt.HashMap() self.featureMap:fill(self.features) end local batched, keys, values if torch.isTensor(input[1]) then keys = {input[1]} values = {input[2]} batched = false else keys = input[1] values = input[2] batched = true end assert(#keys == #values) local masks = self.masks local mappedKeys = self.mappedKeys local nKeys = #keys local nMasks = #masks if nMasks < nKeys then for i=nMasks+1,nKeys do masks[i] = torch.ByteTensor() mappedKeys[i] = torch.LongTensor() end elseif nMasks > nKeys then for i=nKeys+1,nMasks do masks[i] = nil mappedKeys[i] = nil end end self.featureMap:get(keys, mappedKeys, masks) self.output = self.output or torch.Tensor():type(self._type) self.output.nn.S2D_computeOutput(self.output, mappedKeys, values, masks, self.features) if not batched then self.output = self.output:view(-1) end return self.output end function S2D:type(type, tensorCache) if type then local features = self.features self.features = nil parent.type(self, type, tensorCache) self.features = features return self else return parent.type(self) end end function S2D:updateGradInput(input, gradOutput) error"Not Implemented" end function S2D:reset() parent.reset(self) self.featureMap = nil end function S2D:write(file) self.featureMap = nil parent.write(self, file) end function S2D:read(file) self.featureMap = nil parent.read(self, file) end