diff --git a/levenshtein.ml b/levenshtein.ml index c22868d8..2902c64f 100644 --- a/levenshtein.ml +++ b/levenshtein.ml @@ -46,11 +46,21 @@ module NDA = struct let get_compare nda = nda.compare + let rec mem_tr ~compare tr l = match tr, l with + | _, [] -> false + | Success, Success::_ -> true + | Epsilon (i,j), Epsilon(i',j')::_ -> i=i' && j=j' + | Upon (Any,i,j), Upon(Any,i',j')::_ when i=i' && j=j' -> true + | Upon (Char c,i,j), Upon(Char c',i',j')::_ + when compare c c' = 0 && i=i' && j=j' -> true + | _, _::l' -> mem_tr ~compare tr l' + (* build NDA from the "get : int -> 'a" function *) let make ~compare ~limit ~len ~get = - let m = Array.make_matrix len limit [] in + let m = Array.make_matrix (len+1) (limit+1) [] in let add_transition i j tr = - m.(i).(j) <- tr :: m.(i).(j) + if not (mem_tr ~compare tr m.(i).(j)) + then m.(i).(j) <- tr :: m.(i).(j) in (* internal transitions *) for i = 0 to len-1 do @@ -71,9 +81,9 @@ module NDA = struct for j = 0 to limit do (* deletions at the end *) if j < limit - then add_transition (len-1) j (Upon (Any, len-1, j+1)); + then add_transition len j (Upon (Any, len, j+1)); (* win in any case *) - add_transition (len-1) j Success; + add_transition len j Success; done; { matrix=m; compare; } @@ -83,14 +93,11 @@ end (** deterministic automaton *) module DFA = struct - type 'a transition = - | Success - | Upon of 'a * int (* transition to state i *) - | Otherwise of int (* transition to state i *) - type 'a t = { compare : 'a -> 'a -> int; - mutable transitions : 'a transition list array; + mutable transitions : ('a * int) list array; + mutable is_final : bool array; + mutable otherwise : int array; (* transition by default *) mutable len : int; } @@ -98,23 +105,43 @@ module DFA = struct compare; len = 0; transitions = Array.make size []; + is_final = Array.make size false; + otherwise = Array.make size ~-1; } + let _double_array a = + let a' = Array.make (2 * Array.length a) a.(0) in + Array.blit a 0 a' 0 (Array.length a); + a' + (* add a new state *) let add_state dfa = let n = dfa.len in (* resize *) if n = Array.length dfa.transitions then begin - let a' = Array.make (2*n) [] in - Array.blit dfa.transitions 0 a' 0 n; - dfa.transitions <- a' + dfa.transitions <- _double_array dfa.transitions; + dfa.is_final <- _double_array dfa.is_final; + dfa.otherwise <- _double_array dfa.otherwise; end; dfa.len <- n + 1; n + let rec __mem_tr ~compare tr l = match tr, l with + | _, [] -> false + | (c,i), (c',i')::l' -> + (i=i' && compare c c' = 0) + || __mem_tr ~compare tr l' + (* add transition *) let add_transition dfa i tr = - dfa.transitions.(i) <- tr :: dfa.transitions.(i) + if not (__mem_tr ~compare:dfa.compare tr dfa.transitions.(i)) + then dfa.transitions.(i) <- tr :: dfa.transitions.(i) + + let add_otherwise dfa i j = + dfa.otherwise.(i) <- j + + let set_final dfa i = + dfa.is_final.(i) <- true (* set of pairs of ints: used for representing a set of states of the NDA *) module NDAStateSet = Set.Make(struct @@ -122,6 +149,17 @@ module DFA = struct let compare = Pervasives.compare end) + (* + let set_to_string s = + let b = Buffer.create 15 in + Buffer.add_char b '{'; + NDAStateSet.iter + (fun (x,y) -> Printf.bprintf b "(%d,%d)" x y) + s; + Buffer.add_char b '}'; + Buffer.contents b + *) + (* list of characters that can specifically be followed from the given set *) let chars_from_set nda set = NDAStateSet.fold @@ -144,11 +182,13 @@ module DFA = struct let set = ref set in while not (Queue.is_empty q) do let state = Queue.pop q in + (*Printf.printf "saturate epsilon: add state %d,%d\n" (fst state)(snd state);*) + set := NDAStateSet.add state !set; List.iter (fun tr' -> match tr' with - | NDA.Epsilon (i,j) -> + | NDA.Epsilon (i,j) -> if not (NDAStateSet.mem (i,j) !set) - then (set := NDAStateSet.add (i,j) !set; Queue.push (i,j) q) + then Queue.push (i,j) q | _ -> () ) (NDA.get nda state) done; @@ -159,19 +199,27 @@ module DFA = struct (* find the transition that matches the given char (if any); may raise exceptions Not_found or LeadToSuccess. *) let rec get_transition_for_char nda c transitions = - match c, transitions with - | _, NDA.Success::_ -> raise LeadToSuccess - | NDA.Char c', NDA.Upon (NDA.Char c'', i, j) :: transitions' -> - if nda.NDA.compare c' c'' = 0 + match transitions with + | NDA.Success::_ -> raise LeadToSuccess + | NDA.Upon (NDA.Char c', i, j) :: transitions' -> + if nda.NDA.compare c c' = 0 then i, j else get_transition_for_char nda c transitions' - | NDA.Any, NDA.Upon (NDA.Any, i, j) :: _ -> i, j - | _, NDA.Upon (NDA.Any, i, j) :: transitions' -> + | NDA.Upon (NDA.Any, i, j) :: transitions' -> begin try get_transition_for_char nda c transitions' with Not_found -> i, j (* only if no other transition works *) end - | _, _::transitions' -> get_transition_for_char nda c transitions' - | _, [] -> raise Not_found + | _::transitions' -> get_transition_for_char nda c transitions' + | [] -> raise Not_found + + let rec get_transitions_for_any nda acc transitions = + match transitions with + | NDA.Success::_ -> raise LeadToSuccess + | NDA.Upon (NDA.Any, i, j) :: transitions' -> + let acc = NDAStateSet.add (i,j) acc in + get_transitions_for_any nda acc transitions' + | _:: transitions' -> get_transitions_for_any nda acc transitions' + | [] -> acc (* follow transition for given NDA.char, returns a new state and a boolean indicating whether it's final *) @@ -179,10 +227,11 @@ module DFA = struct let is_final = ref false in let set' = NDAStateSet.fold (fun state acc -> - (* possible transitions *) let transitions = NDA.get nda state in + (* among possible transitions, follow the one that matches c + the most closely *) try - let state' = get_transition_for_char nda c transitions in + let state' = get_transition_for_char nda c transitions in NDAStateSet.add state' acc with | LeadToSuccess -> is_final := true; acc @@ -192,19 +241,16 @@ module DFA = struct let set' = saturate_epsilon nda set' in set', !is_final - (* only follow "Any" transitions *) - let follow_other_transition nda set = + let follow_transition_any nda set = let is_final = ref false in let set' = NDAStateSet.fold (fun state acc -> - (* possible transitions *) let transitions = NDA.get nda state in + (* among possible transitions, follow the ones that are labelled with "*" *) try - let state' = get_transition_for_char nda NDA.Any transitions in - NDAStateSet.add state' acc + get_transitions_for_any nda acc transitions with | LeadToSuccess -> is_final := true; acc - | Not_found -> acc (* state dies *) ) set NDAStateSet.empty in let set' = saturate_epsilon nda set' in @@ -217,12 +263,17 @@ module DFA = struct let chars = chars_from_set nda set in List.iter (fun c -> - let set', is_final = follow_transition nda set (NDA.Char c) in - k ~is_final (NDA.Char c) set') - chars; + (*Printf.printf "iterate_transition follows %c (at %s)\n" + (Obj.magic c) (set_to_string set);*) + let set', is_final = follow_transition nda set c in + if not (NDAStateSet.is_empty set') + then k ~is_final (NDA.Char c) set'; + ) chars; (* remaining transitions, with only "Any" *) - let set', is_final = follow_other_transition nda set in - k ~is_final NDA.Any set' + (*Printf.printf "iterate transition follows * (at %s)\n" (set_to_string set);*) + let set', is_final = follow_transition_any nda set in + if not (NDAStateSet.is_empty set') + then k ~is_final NDA.Any set' module StateSetMap = Map.Make(NDAStateSet) @@ -237,20 +288,22 @@ module DFA = struct (* traverse the NDA. Currently we're at [set] *) let rec traverse nda dfa states set = - let set = saturate_epsilon nda set in let set_i = get_state dfa states set in iterate_transition_set nda set (fun ~is_final c set' -> - let set'_i = get_state dfa states set' in + let set_i' = get_state dfa states set' in (* did we reach success? *) if is_final - then add_transition dfa set'_i Success; + then set_final dfa set_i' + (* link set -> set' *) - match c with - | NDA.Any -> - add_transition dfa set_i (Otherwise set'_i) + else match c with | NDA.Char c' -> - add_transition dfa set_i (Upon (c', set'_i)) + add_transition dfa set_i (c', set_i'); + traverse nda dfa states set' + | NDA.Any -> + add_otherwise dfa set_i set_i'; + traverse nda dfa states set' ) let of_nda nda = @@ -259,13 +312,38 @@ module DFA = struct (* map (set of NDA states) to int (state in DFA) *) let states = ref StateSetMap.empty in (* traverse the NDA to build the NFA *) - traverse nda dfa states (NDAStateSet.singleton (0,0)); + let set = NDAStateSet.singleton (0,0) in + let set = saturate_epsilon nda set in + traverse nda dfa states set; + (*StateSetMap.iter + (fun set i -> + Printf.printf "set %s --> state %d\n" (set_to_string set) i + ) !states; *) dfa let get dfa i = dfa.transitions.(i) + + let otherwise dfa i = + dfa.otherwise.(i) + + let is_final dfa i = + dfa.is_final.(i) end +let debug_print oc dfa = + Printf.fprintf oc "automaton of %d states\n" dfa.DFA.len; + for i = 0 to dfa.DFA.len-1 do + let transitions = DFA.get dfa i in + if DFA.is_final dfa i + then Printf.fprintf oc " success %d\n" i; + List.iter + (fun (c, j) -> Printf.fprintf oc " (%c) %d -> %d\n" c i j ) transitions; + let o = DFA.otherwise dfa i in + if o >= 0 + then Printf.fprintf oc " (*) %d -> %d\n" i o + done + type 'a automaton = 'a DFA.t let of_array ?(compare=Pervasives.compare) ~limit a = @@ -279,6 +357,8 @@ let of_list ?compare ~limit l = let of_string ~limit a = let compare = Char.compare in let nda = NDA.make ~compare ~limit ~len:(String.length a) ~get:(String.get a) in + (*debug_print_nda stdout nda; + flush stdout;*) let dfa = DFA.of_nda nda in dfa @@ -288,48 +368,75 @@ type match_result = exception FoundDistance of int -let rec __has_success = function - | [] -> false - | DFA.Success :: _ -> true - | _ :: l' -> __has_success l' - -let rec __find_char ~compare c l k = match l with - | [] -> () - | DFA.Upon (c', next) :: l' -> +let rec __find_char ~compare c l = match l with + | [] -> raise Not_found + | (c', next) :: l' -> if compare c c' = 0 - then k next - else __find_char ~compare c l' k - | _ :: l' -> __find_char ~compare c l' k - -let rec __find_otherwise l k = match l with - | [] -> () - | DFA.Otherwise next :: _ -> k next - | _::l' -> __find_otherwise l' k + then next + else __find_char ~compare c l' (* real matching function *) let __match ~len ~get dfa = - let rec search ~dist i state = - if i = len then raise (FoundDistance dist) + let rec search i state = + (*Printf.printf "at state %d (dist %d)\n" i dist;*) + if i = len + then DFA.is_final dfa state else begin let transitions = DFA.get dfa state in - if __has_success transitions - then raise (FoundDistance dist); (* current char *) let c = get i in - __find_char ~compare:dfa.DFA.compare c transitions - (fun next -> search ~dist (i+1) next); - __find_otherwise transitions - (fun next -> search ~dist:(dist+1) (i+1) next); + try + let next = __find_char ~compare:dfa.DFA.compare c transitions in + search (i+1) next + with Not_found -> + let o = DFA.otherwise dfa state in + if o >= 0 + then search (i+1) o + else false end in - try - search ~dist:0 0 0; - TooFar - with FoundDistance i -> - Distance i + search 0 0 let match_with dfa a = __match ~len:(Array.length a) ~get:(Array.get a) dfa let match_with_string dfa s = __match ~len:(String.length s) ~get:(String.get s) dfa + +(** {6 Index for one-to-many matching} *) + +(** Continuation list *) +type 'a klist = + [ + | `Nil + | `Cons of 'a * (unit -> 'a klist) + ] + +module Index = struct + type ('a, 'b) node = + | Empty + | Node of 'b option * ('a, 'b) assoc_list + and ('a, 'b) assoc_list = ('a * ('a, 'b) node) list + + type ('a, 'b) t = { + tree : ('a, 'b) node; + compare : 'a -> 'a -> int; + } + + let empty ?(compare=Pervasives.compare) () = { + tree = Empty; + compare; + } + + let add idx arr value = + assert false (* TODO *) + + let add_string idx arr str = + assert false (* TODO *) + + let retrieve ~limit idx arr = + assert false (* TODO *) + + let retrieve_string ~limit idx str = + assert false (* TODO *) +end diff --git a/levenshtein.mli b/levenshtein.mli index 6e126493..e4b63184 100644 --- a/levenshtein.mli +++ b/levenshtein.mli @@ -25,7 +25,11 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. *) -(** {1 Levenshtein distance} *) +(** {1 Levenshtein distance} + +We take inspiration from +http://blog.notdot.net/2010/07/Damn-Cool-Algorithms-Levenshtein-Automata +for the main algorithm and ideas. However some parts are adapted *) (** {2 Automaton} *) @@ -41,13 +45,44 @@ val of_list : ?compare:('a -> 'a -> int) -> limit:int -> 'a list -> 'a automaton val of_string : limit:int -> string -> char automaton (** Automaton for the special case of strings *) -type match_result = - | TooFar - | Distance of int +val debug_print : out_channel -> char automaton -> unit + (** Output the automaton on the given channel. Only for string automata. *) -val match_with : 'a automaton -> 'a array -> match_result +val match_with : 'a automaton -> 'a array -> bool (** [match_with a s] matches the string [s] against [a], and returns - the distance from [s] to the word represented by [a] *) + [true] if the distance from [s] to the word represented by [a] is smaller + than the limit used to build [a] *) -val match_with_string : char automaton -> string -> match_result +val match_with_string : char automaton -> string -> bool (** Specialized version of {!match_with} for strings *) + +(** {6 Index for one-to-many matching} *) + +(** Continuation list *) +type 'a klist = + [ + | `Nil + | `Cons of 'a * (unit -> 'a klist) + ] + +module Index : sig + type ('a, 'b) t + (** Index that maps 'a strings to values of type 'b. Internally it is + based on a trie. *) + + val empty : ?compare:('a -> 'a -> int) -> unit -> ('a, 'b) t + (** Empty index, possibly with a specific comparison function *) + + val add : ('a, 'b) t -> 'a array -> 'b -> ('a, 'b) t + (** Add a char array to the index. If a value was already present + for this char it is replaced. *) + + val add_string : (char, 'b) t -> string -> 'b -> (char, 'b) t + (** Add a string to a char index *) + + val retrieve : limit:int -> ('a, 'b) t -> 'a array -> 'b klist + (** Lazy list of objects associated to strings close to + the query string *) + + val retrieve_string : limit:int -> (char,'b) t -> string -> 'b klist +end