You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

math.lua 3.2KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. local dt = require "decisiontree._env"
  2. local PSEUDOCOUNT = 1.0
  3. local MIN_LOGISTIC = 1E-8
  4. local MAX_LOGISTIC = 1.0 - MIN_LOGISTIC
  5. -- Create counts of possible results (last column of each row is the result)
  6. function dt.uniquecounts(counts, inputset, nclass)
  7. counts = counts or inputset.input.new()
  8. nclass = nclass or inputset.target:max()
  9. counts:resize(nclass):zero()
  10. inputset.target:apply(function(c) counts[c] = counts[c] + 1 end)
  11. return counts
  12. end
  13. -- Entropy is the sum of -p(x)log(p(x)) across all the different possible results
  14. local counts, logprobs
  15. function dt.entropy(inputset, nclass)
  16. local dt = require 'decisiontree'
  17. counts = dt.uniquecounts(counts, inputset, nclass)
  18. -- convert counts to categorical probabilities
  19. counts:add(0.0000001) -- prevent NaN
  20. counts:div(counts:sum())
  21. logprobs = logprobs or counts.new()
  22. logprobs:resize(counts:size())
  23. logprobs:log(counts):div(math.log(2)) -- log2(x)
  24. counts:cmul(logprobs)
  25. return -counts:sum()
  26. end
  27. -- Compute and return the probability of positive label.
  28. function dt.probabilityPositive(nPositive, nTotal)
  29. return (nPositive + PSEUDOCOUNT) / (nTotal + 2.0 * PSEUDOCOUNT);
  30. end
  31. -- Ref. https://en.wikipedia.org/wiki/Logit
  32. -- Calculates logit of the probability.
  33. -- Logit represents the log-odds. Probabilities transformed to logit 'space' can be combined linearly.
  34. function dt.logit(p)
  35. assert(p >= 0.0 and p <= 1.0, "Expecting probability for arg 1")
  36. local truncatedP = math.max(MIN_LOGISTIC, math.min(MAX_LOGISTIC, p))
  37. return math.log(truncatedP / (1.0 - truncatedP))
  38. end
  39. function dt.logistic(x)
  40. return (x >= 0) and (1 / (1 + math.exp(-x))) or (1 - 1 / (1 + math.exp(x)))
  41. end
  42. function dt.computeGradientBoostLoss(gradient, hessian)
  43. return -gradient * gradient / hessian
  44. end
  45. function dt.computeNewtonScore(gradient, hessian)
  46. return -0.5 * gradient / hessian;
  47. end
  48. -- Calculates the logit score for a Node in a Decision Tree based on the probability of a positive label.
  49. -- params: number of positive examples and total number of examples.
  50. function dt.calculateLogitScore(nPositive, nTotal)
  51. local dt = require 'decisiontree'
  52. return dt.logit(dt.probabilityPositive(nPositive, nTotal))
  53. end
  54. -- Compute and return the Gini impurity score based on an input contingency table.
  55. function dt.computeGini(leftCount, positiveLeftCount, rightCount, positiveRightCount)
  56. assert(torch.type(leftCount) == 'number', 'Expecting total number examples falling into leftBranch.')
  57. assert(torch.type(positiveLeftCount) == 'number', 'Expecting total number of positive examples falling into left branch.')
  58. assert(torch.type(rightCount) == 'number', 'Expecting total number of examples falling into the right branch.')
  59. assert(torch.type(positiveRightCount) == 'number', 'Expecting total number of positive examples falling into the right branch.')
  60. local total = leftCount + rightCount
  61. local pPositiveLeft = leftCount == 0 and 0 or (positiveLeftCount / leftCount)
  62. local leftGini = pPositiveLeft * (1.0 - pPositiveLeft)
  63. local pPositiveRight = rightCount == 0 and 0 or (positiveRightCount / rightCount)
  64. local rightGini = pPositiveRight * (1.0 - pPositiveRight)
  65. return (leftCount * leftGini + rightCount * rightGini) / total
  66. end