You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.lua 4.1KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. local dt = require "decisiontree._env"
  2. -- returns a buffer table local to a thread (no serialized)
  3. function dt.getBufferTable(name)
  4. local dt = require 'decisiontree'
  5. assert(torch.type(name) == 'string')
  6. dt.buffer = dt.buffer or {}
  7. dt.buffer[name] = dt.buffer[name] or {}
  8. return dt.buffer[name]
  9. end
  10. function dt.getSparseDummyData(nExample, nCluster, nFeature, overlap, nValid, nActive)
  11. local dt = require 'decisiontree'
  12. if torch.type(nExample) == 'table' then
  13. local opt = nExample
  14. nExample = opt.nExample
  15. nCluster = opt.nCluster
  16. nFeature = opt.nFeature
  17. overlap = opt.overlap
  18. nValid = opt.nValid
  19. nActive = opt.nActive
  20. end
  21. nExample = nExample or 100 -- training set size
  22. nCluster = nCluster or 10
  23. assert(nCluster >= 2)
  24. nFeature = math.max(2, nFeature or 10)
  25. overlap = overlap or 0
  26. nValid = nValid or nExample/10 -- validation set size
  27. nActive = nActive or math.max(2, nFeature / 2)
  28. -- sample nCluster centers
  29. local clusterCenter = torch.rand(nCluster, nFeature)
  30. local clusterLabel = torch.LongTensor(nCluster)
  31. local clusterExamples = {}
  32. for i=1,nCluster do
  33. clusterCenter[i]:add(i)
  34. clusterLabel[i] = i % 2
  35. clusterExamples[i] = {}
  36. end
  37. local sparseCenter = torch.Tensor()
  38. local shuffle = torch.LongTensor()
  39. -- build dataset in pseudo-dense format
  40. local inputs = {}
  41. local targets = torch.Tensor(nExample+nValid)
  42. for i=1,nExample+nValid do
  43. local clusterIdx = torch.random(1,nCluster)
  44. table.insert(clusterExamples[clusterIdx], i)
  45. shuffle:randperm(nFeature)
  46. local keys = torch.LongTensor(nActive):copy(shuffle:narrow(1,1,nActive))
  47. sparseCenter:index(clusterCenter[clusterIdx], 1, keys)
  48. local stdiv = i <= nExample and 100 or 1000
  49. local values = torch.randn(nActive):div(stdiv):add(sparseCenter)
  50. table.insert(inputs, torch.SparseTensor(keys, values))
  51. local label = clusterLabel[clusterIdx]
  52. if math.random() < overlap then
  53. targets[i] = label == 1 and 0 or 1
  54. else
  55. targets[i] = label
  56. end
  57. end
  58. local _ = require 'moses'
  59. local validSet = dt.DataSet(_.slice(inputs, nExample+1, nExample+nValid), targets:narrow(1,nExample+1,nValid))
  60. local trainSet = dt.DataSet(_.slice(inputs, 1, nExample), targets:narrow(1,1,nExample))
  61. return trainSet, validSet, clusterExamples, inputs, targets
  62. end
  63. function dt.getDenseDummyData(nExample, nCluster, nFeature, overlap, nValid)
  64. local dt = require 'decisiontree'
  65. if torch.type(nExample) == 'table' then
  66. local opt = nExample
  67. nExample = opt.nExample
  68. nCluster = opt.nCluster
  69. nFeature = opt.nFeature
  70. overlap = opt.overlap
  71. nValid = opt.nValid
  72. end
  73. nExample = nExample or 100 -- training set size
  74. nCluster = nCluster or 10
  75. assert(nCluster >= 2)
  76. nFeature = math.max(2, nFeature or 10)
  77. overlap = overlap or 0
  78. nValid = nValid or nExample/10 -- validation set size
  79. -- sample nCluster centers
  80. local clusterCenter = torch.rand(nCluster, nFeature)
  81. local clusterLabel = torch.LongTensor(nCluster)
  82. local clusterExamples = {}
  83. for i=1,nCluster do
  84. clusterCenter[i]:add(i)
  85. clusterLabel[i] = i % 2
  86. clusterExamples[i] = {}
  87. end
  88. -- build dataset in pseudo-dense format
  89. local inputs = torch.Tensor(nExample+nValid, nFeature)
  90. local targets = torch.Tensor(nExample+nValid)
  91. for i=1,nExample+nValid do
  92. local clusterIdx = torch.random(1,nCluster)
  93. table.insert(clusterExamples[clusterIdx], i)
  94. local stdiv = i <= nExample and 100 or 1000
  95. inputs[i]:normal():div(stdiv):add(clusterCenter[clusterIdx])
  96. local label = clusterLabel[clusterIdx]
  97. if math.random() < overlap then
  98. targets[i] = label == 1 and 0 or 1
  99. else
  100. targets[i] = label
  101. end
  102. end
  103. local _ = require 'moses'
  104. local validSet = dt.DataSet(inputs:narrow(1,nExample+1,nValid), targets:narrow(1,nExample+1,nValid))
  105. local trainSet = dt.DataSet(inputs:narrow(1,1,nExample), targets:narrow(1,1,nExample))
  106. return trainSet, validSet, clusterExamples, inputs, targets
  107. end