You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

ClassSimplexCriterion.lua 3.7KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. local ClassSimplexCriterion, parent
  2. = torch.class('nn.ClassSimplexCriterion', 'nn.MSECriterion')
  3. --[[
  4. This file implements a criterion for multi-class classification.
  5. It learns an embedding per class, where each class' embedding
  6. is a point on an (N-1)-dimensional simplex, where N is
  7. the number of classes.
  8. For example usage of this class, look at doc/criterion.md
  9. Reference: http://arxiv.org/abs/1506.08230
  10. ]]--
  11. --[[
  12. function regsplex(n):
  13. regsplex returns the coordinates of the vertices of a
  14. regular simplex centered at the origin.
  15. The Euclidean norms of the vectors specifying the vertices are
  16. all equal to 1. The input n is the dimension of the vectors;
  17. the simplex has n+1 vertices.
  18. input:
  19. n -- dimension of the vectors specifying the vertices of the simplex
  20. output:
  21. a -- tensor dimensioned (n+1,n) whose rows are
  22. vectors specifying the vertices
  23. reference:
  24. http://en.wikipedia.org/wiki/Simplex#Cartesian_coordinates_for_regular_n-dimensional_simplex_in_Rn
  25. --]]
  26. local function regsplex(n)
  27. local a = torch.zeros(n+1,n)
  28. for k = 1,n do
  29. -- determine the last nonzero entry in the vector for the k-th vertex
  30. if k==1 then a[k][k] = 1 end
  31. if k>1 then a[k][k] = math.sqrt( 1 - a[{ {k},{1,k-1} }]:norm()^2 ) end
  32. -- fill the k-th coordinates for the vectors of the remaining vertices
  33. local c = (a[k][k]^2 - 1 - 1/n) / a[k][k]
  34. a[{ {k+1,n+1},{k} }]:fill(c)
  35. end
  36. return a
  37. end
  38. function ClassSimplexCriterion:__init(nClasses)
  39. parent.__init(self)
  40. assert(nClasses and nClasses > 1 and nClasses == (nClasses -(nClasses % 1)),
  41. "Required positive integer argument nClasses > 1")
  42. self.nClasses = nClasses
  43. -- embedding the simplex in a space of dimension strictly greater than
  44. -- the minimum possible (nClasses-1) is critical for effective training.
  45. local simp = regsplex(nClasses - 1)
  46. self.simplex = torch.cat(simp,
  47. torch.zeros(simp:size(1), nClasses -simp:size(2)),
  48. 2)
  49. self._target = torch.Tensor(nClasses)
  50. end
  51. -- handle target being both 1D tensor, and
  52. -- target being 2D tensor (2D tensor means don't do anything)
  53. local function transformTarget(self, target)
  54. if torch.type(target) == 'number' then
  55. self._target:resize(self.nClasses)
  56. self._target:copy(self.simplex[target])
  57. elseif torch.isTensor(target) then
  58. assert(target:dim() == 1, '1D tensors only!')
  59. local nSamples = target:size(1)
  60. self._target:resize(nSamples, self.nClasses)
  61. for i=1,nSamples do
  62. self._target[i]:copy(self.simplex[target[i]])
  63. end
  64. end
  65. end
  66. function ClassSimplexCriterion:updateOutput(input, target)
  67. transformTarget(self, target)
  68. assert(input:nElement() == self._target:nElement())
  69. self.output_tensor = self.output_tensor or input.new(1)
  70. input.THNN.MSECriterion_updateOutput(
  71. input:cdata(),
  72. self._target:cdata(),
  73. self.output_tensor:cdata(),
  74. self.sizeAverage
  75. )
  76. self.output = self.output_tensor[1]
  77. return self.output
  78. end
  79. function ClassSimplexCriterion:updateGradInput(input, target)
  80. assert(input:nElement() == self._target:nElement())
  81. input.THNN.MSECriterion_updateGradInput(
  82. input:cdata(),
  83. self._target:cdata(),
  84. self.gradInput:cdata(),
  85. self.sizeAverage
  86. )
  87. return self.gradInput
  88. end
  89. function ClassSimplexCriterion:getPredictions(input)
  90. if input:dim() == 1 then
  91. input = input:view(1, -1)
  92. end
  93. return torch.mm(input, self.simplex:t())
  94. end
  95. function ClassSimplexCriterion:getTopPrediction(input)
  96. local prod = self:getPredictions(input)
  97. local _, maxs = prod:max(prod:nDimension())
  98. return maxs:view(-1)
  99. end