You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

levenshtein_nfa.go 6.0KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292
  1. // Copyright (c) 2018 Couchbase, Inc.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package levenshtein
  15. import (
  16. "math"
  17. "sort"
  18. )
  19. /// Levenshtein Distance computed by a Levenshtein Automaton.
  20. ///
  21. /// Levenshtein automata can only compute the exact Levenshtein distance
  22. /// up to a given `max_distance`.
  23. ///
  24. /// Over this distance, the automaton will invariably
  25. /// return `Distance::AtLeast(max_distance + 1)`.
  26. type Distance interface {
  27. distance() uint8
  28. }
  29. type Exact struct {
  30. d uint8
  31. }
  32. func (e Exact) distance() uint8 {
  33. return e.d
  34. }
  35. type Atleast struct {
  36. d uint8
  37. }
  38. func (a Atleast) distance() uint8 {
  39. return a.d
  40. }
  41. func characteristicVector(query []rune, c rune) uint64 {
  42. chi := uint64(0)
  43. for i := 0; i < len(query); i++ {
  44. if query[i] == c {
  45. chi |= 1 << uint64(i)
  46. }
  47. }
  48. return chi
  49. }
  50. type NFAState struct {
  51. Offset uint32
  52. Distance uint8
  53. InTranspose bool
  54. }
  55. type NFAStates []NFAState
  56. func (ns NFAStates) Len() int {
  57. return len(ns)
  58. }
  59. func (ns NFAStates) Less(i, j int) bool {
  60. if ns[i].Offset != ns[j].Offset {
  61. return ns[i].Offset < ns[j].Offset
  62. }
  63. if ns[i].Distance != ns[j].Distance {
  64. return ns[i].Distance < ns[j].Distance
  65. }
  66. return !ns[i].InTranspose && ns[j].InTranspose
  67. }
  68. func (ns NFAStates) Swap(i, j int) {
  69. ns[i], ns[j] = ns[j], ns[i]
  70. }
  71. func (ns *NFAState) imply(other NFAState) bool {
  72. transposeImply := ns.InTranspose
  73. if !other.InTranspose {
  74. transposeImply = !other.InTranspose
  75. }
  76. deltaOffset := ns.Offset - other.Offset
  77. if ns.Offset < other.Offset {
  78. deltaOffset = other.Offset - ns.Offset
  79. }
  80. if transposeImply {
  81. return uint32(other.Distance) >= (uint32(ns.Distance) + deltaOffset)
  82. }
  83. return uint32(other.Distance) > (uint32(ns.Distance) + deltaOffset)
  84. }
  85. type MultiState struct {
  86. states []NFAState
  87. }
  88. func (ms *MultiState) States() []NFAState {
  89. return ms.states
  90. }
  91. func (ms *MultiState) Clear() {
  92. ms.states = ms.states[:0]
  93. }
  94. func newMultiState() *MultiState {
  95. return &MultiState{states: make([]NFAState, 0)}
  96. }
  97. func (ms *MultiState) normalize() uint32 {
  98. minOffset := uint32(math.MaxUint32)
  99. for _, s := range ms.states {
  100. if s.Offset < minOffset {
  101. minOffset = s.Offset
  102. }
  103. }
  104. if minOffset == uint32(math.MaxUint32) {
  105. minOffset = 0
  106. }
  107. for i := 0; i < len(ms.states); i++ {
  108. ms.states[i].Offset -= minOffset
  109. }
  110. sort.Sort(NFAStates(ms.states))
  111. return minOffset
  112. }
  113. func (ms *MultiState) addStates(nState NFAState) {
  114. for _, s := range ms.states {
  115. if s.imply(nState) {
  116. return
  117. }
  118. }
  119. i := 0
  120. for i < len(ms.states) {
  121. if nState.imply(ms.states[i]) {
  122. ms.states = append(ms.states[:i], ms.states[i+1:]...)
  123. } else {
  124. i++
  125. }
  126. }
  127. ms.states = append(ms.states, nState)
  128. }
  129. func extractBit(bitset uint64, pos uint8) bool {
  130. shift := bitset >> pos
  131. bit := shift & 1
  132. return bit == uint64(1)
  133. }
  134. func dist(left, right uint32) uint32 {
  135. if left > right {
  136. return left - right
  137. }
  138. return right - left
  139. }
  140. type LevenshteinNFA struct {
  141. mDistance uint8
  142. damerau bool
  143. }
  144. func newLevenshtein(maxD uint8, transposition bool) *LevenshteinNFA {
  145. return &LevenshteinNFA{mDistance: maxD,
  146. damerau: transposition,
  147. }
  148. }
  149. func (la *LevenshteinNFA) maxDistance() uint8 {
  150. return la.mDistance
  151. }
  152. func (la *LevenshteinNFA) msDiameter() uint8 {
  153. return 2*la.mDistance + 1
  154. }
  155. func (la *LevenshteinNFA) initialStates() *MultiState {
  156. ms := MultiState{}
  157. nfaState := NFAState{}
  158. ms.addStates(nfaState)
  159. return &ms
  160. }
  161. func (la *LevenshteinNFA) multistateDistance(ms *MultiState,
  162. queryLen uint32) Distance {
  163. minDistance := Atleast{d: la.mDistance + 1}
  164. for _, s := range ms.states {
  165. t := s.Distance + uint8(dist(queryLen, s.Offset))
  166. if t <= uint8(la.mDistance) {
  167. if minDistance.distance() > t {
  168. minDistance.d = t
  169. }
  170. }
  171. }
  172. if minDistance.distance() == la.mDistance+1 {
  173. return Atleast{d: la.mDistance + 1}
  174. }
  175. return minDistance
  176. }
  177. func (la *LevenshteinNFA) simpleTransition(state NFAState,
  178. symbol uint64, ms *MultiState) {
  179. if state.Distance < la.mDistance {
  180. // insertion
  181. ms.addStates(NFAState{Offset: state.Offset,
  182. Distance: state.Distance + 1,
  183. InTranspose: false})
  184. // substitution
  185. ms.addStates(NFAState{Offset: state.Offset + 1,
  186. Distance: state.Distance + 1,
  187. InTranspose: false})
  188. n := la.mDistance + 1 - state.Distance
  189. for d := uint8(1); d < n; d++ {
  190. if extractBit(symbol, d) {
  191. // for d > 0, as many deletion and character match
  192. ms.addStates(NFAState{Offset: state.Offset + 1 + uint32(d),
  193. Distance: state.Distance + d,
  194. InTranspose: false})
  195. }
  196. }
  197. if la.damerau && extractBit(symbol, 1) {
  198. ms.addStates(NFAState{
  199. Offset: state.Offset,
  200. Distance: state.Distance + 1,
  201. InTranspose: true})
  202. }
  203. }
  204. if extractBit(symbol, 0) {
  205. ms.addStates(NFAState{Offset: state.Offset + 1,
  206. Distance: state.Distance,
  207. InTranspose: false})
  208. }
  209. if state.InTranspose && extractBit(symbol, 0) {
  210. ms.addStates(NFAState{Offset: state.Offset + 2,
  211. Distance: state.Distance,
  212. InTranspose: false})
  213. }
  214. }
  215. func (la *LevenshteinNFA) transition(cState *MultiState,
  216. dState *MultiState, scv uint64) {
  217. dState.Clear()
  218. mask := (uint64(1) << la.msDiameter()) - uint64(1)
  219. for _, state := range cState.states {
  220. cv := (scv >> state.Offset) & mask
  221. la.simpleTransition(state, cv, dState)
  222. }
  223. sort.Sort(NFAStates(dState.states))
  224. }
  225. func (la *LevenshteinNFA) computeDistance(query, other []rune) Distance {
  226. cState := la.initialStates()
  227. nState := newMultiState()
  228. for _, i := range other {
  229. nState.Clear()
  230. chi := characteristicVector(query, i)
  231. la.transition(cState, nState, chi)
  232. cState, nState = nState, cState
  233. }
  234. return la.multistateDistance(cState, uint32(len(query)))
  235. }