aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/torch/decisiontree/generic/CartTree.c
blob: eb29fcf025857bca830e87eb9bcc32533252fa76 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
#ifndef TH_GENERIC_FILE
#define TH_GENERIC_FILE "generic/CartTree.c"
#else

static int nn_(tree_fast_score)(lua_State *L) {
  THTensor *input = luaT_checkudata(L, 1, torch_Tensor);
  THTensor *score = luaT_checkudata(L, 3, torch_Tensor);
  long n_samples = THTensor_(size)(input, 0);
  long n_features = THTensor_(size)(input, 1);
  THTensor_(resize1d)(score, n_samples);
  real *input_data = THTensor_(data)(input);
  real *score_data = THTensor_(data)(score);

  lua_pushstring(L, "leftChild");
  const int left_child_string = 4;
  lua_pushstring(L, "rightChild");
  const int right_child_string = 5;
  lua_pushstring(L, "score");
  const int score_string = 6;
  lua_pushstring(L, "splitFeatureId");
  const int id_string = 7;
  lua_pushstring(L, "splitFeatureValue");
  const int value_string = 8;

  const int original_top = lua_gettop(L);
  for (long i = 0; i < n_samples; i++) {
    int node = 2;
    while (1) {
      int current_top = lua_gettop(L);
      lua_pushvalue(L, left_child_string);
      lua_rawget(L, node);
      lua_pushvalue(L, right_child_string);
      lua_rawget(L, node);
      if (lua_isnil(L, -2) && lua_isnil(L, -1)) {
        lua_pushvalue(L, score_string);
        lua_rawget(L, node);
        score_data[i] = lua_tonumber(L, -1);
        break;
      }
      if (lua_isnil(L, -2)) {
        // go to right
        node = current_top + 2;
        continue;
      }
      if (lua_isnil(L, -1)) {
        // go to left
        node = current_top + 1;
        continue;
      }
      lua_pushvalue(L, id_string);
      lua_rawget(L, node);
      lua_pushvalue(L, value_string);
      lua_rawget(L, node);
      long feature_id = lua_tointeger(L, -2);
      real feature_value = lua_tonumber(L, -1);

      real current_value = input_data[i * n_features + (feature_id-1)];
      if (current_value < feature_value) {
        // go to left
        node = current_top + 1;
      }
      else {
        // go to right
        node = current_top + 2;
      }
    }
    lua_pop(L, lua_gettop(L) - original_top);
  }

  lua_pop(L, 5);

  lua_pushvalue(L, 3);
  return 1;
}

static const struct luaL_Reg nn_(CT__) [] = {
  {"CartTreeFastScore", nn_(tree_fast_score)},
  {NULL, NULL}
};

static void nn_(CT_init)(lua_State *L)
{
  luaT_pushmetatable(L, torch_Tensor);
  luaT_registeratname(L, nn_(CT__), "nn");
  lua_pop(L,1);
}

#endif