You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

linalg.lua 2.1KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. --[[
  2. Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. ]]--
  13. --[[[
  14. -- @module lua_ffi/linalg
  15. -- This module contains ffi interfaces to linear algebra routines
  16. --]]
  17. local ffi = require 'ffi'
  18. local exports = {}
  19. ffi.cdef [[
  20. void kad_sgemm_simple(int trans_A, int trans_B, int M, int N, int K, const float *A, const float *B, float *C);
  21. bool kad_ssyev_simple (int N, float *A, float *output);
  22. ]]
  23. local function table_to_ffi(a, m, n)
  24. local a_conv = ffi.new("float[?]", m * n)
  25. for i = 1, m or #a do
  26. for j = 1, n or #a[1] do
  27. a_conv[(i - 1) * n + (j - 1)] = a[i][j]
  28. end
  29. end
  30. return a_conv
  31. end
  32. local function ffi_to_table(a, m, n)
  33. local res = {}
  34. for i = 0, m - 1 do
  35. res[i + 1] = {}
  36. for j = 0, n - 1 do
  37. res[i + 1][j + 1] = a[i * n + j]
  38. end
  39. end
  40. return res
  41. end
  42. exports.sgemm = function(a, m, b, n, k, trans_a, trans_b)
  43. if type(a) == 'table' then
  44. -- Need to convert, slow!
  45. a = table_to_ffi(a, m, k)
  46. end
  47. if type(b) == 'table' then
  48. b = table_to_ffi(b, k, n)
  49. end
  50. local res = ffi.new("float[?]", m * n)
  51. ffi.C.kad_sgemm_simple(trans_a or 0, trans_b or 0, m, n, k, ffi.cast('const float*', a),
  52. ffi.cast('const float*', b), ffi.cast('float*', res))
  53. return res
  54. end
  55. exports.eigen = function(a, n)
  56. if type(a) == 'table' then
  57. -- Need to convert, slow!
  58. n = n or #a
  59. a = table_to_ffi(a, n, n)
  60. end
  61. local res = ffi.new("float[?]", n)
  62. if ffi.C.kad_ssyev_simple(n, ffi.cast('float*', a), res) then
  63. return res, a
  64. end
  65. return nil
  66. end
  67. exports.ffi_to_table = ffi_to_table
  68. exports.table_to_ffi = table_to_ffi
  69. return exports