functor interface to Levenshtein automaton/index

This commit is contained in:
Simon Cruanes 2014-03-05 16:49:10 +01:00
parent c5473857f8
commit b6310ae17d
2 changed files with 576 additions and 549 deletions

View file

@ -24,49 +24,123 @@ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*) *)
(** {1 Levenshtein distance} *) (** {1 Levenshtein distance} *)
module NDA = struct module type STRING = sig
type 'a char = type char_
type t
val of_list : char_ list -> t
val get : t -> int -> char_
val length : t -> int
val compare_char : char_ -> char_ -> int
end
(** Continuation list *)
type 'a klist =
[
| `Nil
| `Cons of 'a * (unit -> 'a klist)
]
let rec klist_to_list = function
| `Nil -> []
| `Cons (x,k) -> x :: klist_to_list (k ())
module type S = sig
type char_
type string_
(** {6 Automaton} *)
type automaton
(** Levenshtein automaton *)
val of_string : limit:int -> string_ -> automaton
(** Build an automaton from an array, with a maximal distance [limit] *)
val of_list : limit:int -> char_ list -> automaton
(** Build an automaton from a list, with a maximal distance [limit] *)
val debug_print : (out_channel -> char_ -> unit) ->
out_channel -> automaton -> unit
(** Output the automaton on the given channel. *)
val match_with : automaton -> string_ -> bool
(** [match_with a s] matches the string [s] against [a], and returns
[true] if the distance from [s] to the word represented by [a] is smaller
than the limit used to build [a] *)
(** {6 Index for one-to-many matching} *)
module Index : sig
type 'b t
(** Index that maps strings to values of type 'b. Internally it is
based on a trie. *)
val empty : 'b t
(** Empty index *)
val is_empty : _ t -> bool
val add : 'b t -> string_ -> 'b -> 'b t
(** Add a char array to the index. If a value was already present
for this array it is replaced. *)
val remove : 'b t -> string_ -> 'b -> 'b t
(** Remove a string from the index. *)
val retrieve : limit:int -> 'b t -> string_ -> 'b klist
(** Lazy list of objects associated to strings close to the query string *)
val of_list : (string_ * 'b) list -> 'b t
val to_list : 'b t -> (string_ * 'b) list
(* TODO sequence/iteration functions *)
end
end
module Make(Str : STRING) = struct
type string_ = Str.t
type char_ = Str.char_
module NDA = struct
type char =
| Any | Any
| Char of 'a | Char of char_
type 'a transition = type transition =
| Success | Success
| Upon of 'a char * int * int | Upon of char * int * int
| Epsilon of int * int | Epsilon of int * int
(* non deterministic automaton *) (* non deterministic automaton *)
type 'a t = { type t = transition list array array
compare : 'a -> 'a -> int;
matrix : 'a transition list array array;
}
let length nda = Array.length nda.matrix let length nda = Array.length nda
let get_compare nda = nda.compare let rec mem_tr tr l = match tr, l with
let rec mem_tr ~compare tr l = match tr, l with
| _, [] -> false | _, [] -> false
| Success, Success::_ -> true | Success, Success::_ -> true
| Epsilon (i,j), Epsilon(i',j')::_ -> i=i' && j=j' | Epsilon (i,j), Epsilon(i',j')::_ -> i=i' && j=j'
| Upon (Any,i,j), Upon(Any,i',j')::_ when i=i' && j=j' -> true | Upon (Any,i,j), Upon(Any,i',j')::_ when i=i' && j=j' -> true
| Upon (Char c,i,j), Upon(Char c',i',j')::_ | Upon (Char c,i,j), Upon(Char c',i',j')::_
when compare c c' = 0 && i=i' && j=j' -> true when Str.compare_char c c' = 0 && i=i' && j=j' -> true
| _, _::l' -> mem_tr ~compare tr l' | _, _::l' -> mem_tr tr l'
(* build NDA from the "get : int -> 'a" function *) (* build NDA from the string *)
let make ~compare ~limit ~len ~get = let make ~limit s =
let m = Array.make_matrix (len+1) (limit+1) [] in let len = Str.length s in
let m = Array.make_matrix (len +1) (limit+1) [] in
let add_transition i j tr = let add_transition i j tr =
if not (mem_tr ~compare tr m.(i).(j)) if not (mem_tr tr m.(i).(j))
then m.(i).(j) <- tr :: m.(i).(j) then m.(i).(j) <- tr :: m.(i).(j)
in in
(* internal transitions *) (* internal transitions *)
for i = 0 to len-1 do for i = 0 to len-1 do
for j = 0 to limit do for j = 0 to limit do
(* correct char *) (* correct char *)
add_transition i j (Upon (Char (get i), i+1, j)); add_transition i j (Upon (Char (Str.get s i), i+1, j));
(* other transitions *) (* other transitions *)
if j < limit then begin if j < limit then begin
(* substitution *) (* substitution *)
@ -85,29 +159,27 @@ module NDA = struct
(* win in any case *) (* win in any case *)
add_transition len j Success; add_transition len j Success;
done; done;
{ matrix=m; compare; } m
let get nda (i,j) = let get nda (i,j) =
nda.matrix.(i).(j) nda.(i).(j)
let is_final nda (i,j) = let is_final nda (i,j) =
List.exists List.exists
(function Success -> true | _ -> false) (function Success -> true | _ -> false)
(get nda (i,j)) (get nda (i,j))
end end
(** deterministic automaton *) (** deterministic automaton *)
module DFA = struct module DFA = struct
type 'a t = { type t = {
compare : 'a -> 'a -> int; mutable transitions : (char_ * int) list array;
mutable transitions : ('a * int) list array;
mutable is_final : bool array; mutable is_final : bool array;
mutable otherwise : int array; (* transition by default *) mutable otherwise : int array; (* transition by default *)
mutable len : int; mutable len : int;
} }
let create ~compare size = { let create size = {
compare;
len = 0; len = 0;
transitions = Array.make size []; transitions = Array.make size [];
is_final = Array.make size false; is_final = Array.make size false;
@ -131,15 +203,15 @@ module DFA = struct
dfa.len <- n + 1; dfa.len <- n + 1;
n n
let rec __mem_tr ~compare tr l = match tr, l with let rec __mem_tr tr l = match tr, l with
| _, [] -> false | _, [] -> false
| (c,i), (c',i')::l' -> | (c,i), (c',i')::l' ->
(i=i' && compare c c' = 0) (i=i' && compare c c' = 0)
|| __mem_tr ~compare tr l' || __mem_tr tr l'
(* add transition *) (* add transition *)
let add_transition dfa i tr = let add_transition dfa i tr =
if not (__mem_tr ~compare:dfa.compare tr dfa.transitions.(i)) if not (__mem_tr tr dfa.transitions.(i))
then dfa.transitions.(i) <- tr :: dfa.transitions.(i) then dfa.transitions.(i) <- tr :: dfa.transitions.(i)
let add_otherwise dfa i j = let add_otherwise dfa i j =
@ -171,7 +243,7 @@ module DFA = struct
List.fold_left List.fold_left
(fun acc tr -> match tr with (fun acc tr -> match tr with
| NDA.Upon (NDA.Char c, _, _) -> | NDA.Upon (NDA.Char c, _, _) ->
if List.exists (fun c' -> nda.NDA.compare c c' = 0) acc if List.exists (fun c' -> Str.compare_char c c' = 0) acc
then acc then acc
else c :: acc (* new char! *) else c :: acc (* new char! *)
| _ -> acc | _ -> acc
@ -201,7 +273,7 @@ module DFA = struct
may raise exceptions Not_found or LeadToSuccess. *) may raise exceptions Not_found or LeadToSuccess. *)
let rec get_transition_for_char nda c acc transitions = let rec get_transition_for_char nda c acc transitions =
match transitions with match transitions with
| NDA.Upon (NDA.Char c', i, j) :: transitions' when nda.NDA.compare c c' = 0 -> | NDA.Upon (NDA.Char c', i, j) :: transitions' when Str.compare_char c c' = 0 ->
(* follow same char *) (* follow same char *)
let acc = NDAStateSet.add (i, j) acc in let acc = NDAStateSet.add (i, j) acc in
get_transition_for_char nda c acc transitions' get_transition_for_char nda c acc transitions'
@ -303,8 +375,7 @@ module DFA = struct
) )
let of_nda nda = let of_nda nda =
let compare = NDA.get_compare nda in let dfa = create (NDA.length nda) in
let dfa = create ~compare (NDA.length nda) in
(* map (set of NDA states) to int (state in DFA) *) (* map (set of NDA states) to int (state in DFA) *)
let states = ref StateSetMap.empty in let states = ref StateSetMap.empty in
(* traverse the NDA to build the NFA *) (* traverse the NDA to build the NFA *)
@ -325,73 +396,65 @@ module DFA = struct
let is_final dfa i = let is_final dfa i =
dfa.is_final.(i) dfa.is_final.(i)
end end
let debug_print oc dfa = let debug_print pp_char oc dfa =
Printf.fprintf oc "automaton of %d states\n" dfa.DFA.len; Printf.fprintf oc "automaton of %d states\n" dfa.DFA.len;
for i = 0 to dfa.DFA.len-1 do for i = 0 to dfa.DFA.len-1 do
let transitions = DFA.get dfa i in let transitions = DFA.get dfa i in
if DFA.is_final dfa i if DFA.is_final dfa i
then Printf.fprintf oc " success %d\n" i; then Printf.fprintf oc " success %d\n" i;
List.iter List.iter
(fun (c, j) -> Printf.fprintf oc " %d --%c--> %d\n" i c j ) transitions; (fun (c, j) -> Printf.fprintf oc " %d --%a--> %d\n" i pp_char c j ) transitions;
let o = DFA.otherwise dfa i in let o = DFA.otherwise dfa i in
if o >= 0 if o >= 0
then Printf.fprintf oc " %d --*--> %d\n" i o then Printf.fprintf oc " %d --*--> %d\n" i o
done done
type 'a automaton = 'a DFA.t type automaton = DFA.t
let of_array ?(compare=Pervasives.compare) ~limit a = let of_string ~limit s =
let nda = NDA.make ~compare ~limit ~len:(Array.length a) ~get:(Array.get a) in let nda = NDA.make ~limit s in
let dfa = DFA.of_nda nda in let dfa = DFA.of_nda nda in
dfa dfa
let of_list ?compare ~limit l = let of_list ~limit l =
of_array ?compare ~limit (Array.of_list l) of_string ~limit (Str.of_list l)
let of_string ~limit a = type match_result =
let compare = Char.compare in
let nda = NDA.make ~compare ~limit ~len:(String.length a) ~get:(String.get a) in
(*debug_print_nda stdout nda;
flush stdout;*)
let dfa = DFA.of_nda nda in
dfa
type match_result =
| TooFar | TooFar
| Distance of int | Distance of int
exception FoundDistance of int exception FoundDistance of int
let rec __find_char ~compare c l = match l with let rec __find_char c l = match l with
| [] -> raise Not_found | [] -> raise Not_found
| (c', next) :: l' -> | (c', next) :: l' ->
if compare c c' = 0 if compare c c' = 0
then next then next
else __find_char ~compare c l' else __find_char c l'
(* transition for [c] in state [i] of [dfa]; (* transition for [c] in state [i] of [dfa];
@raise Not_found if no transition matches *) @raise Not_found if no transition matches *)
let __transition dfa i c = let __transition dfa i c =
let transitions = DFA.get dfa i in let transitions = DFA.get dfa i in
try try
__find_char ~compare:dfa.DFA.compare c transitions __find_char c transitions
with Not_found -> with Not_found ->
let o = DFA.otherwise dfa i in let o = DFA.otherwise dfa i in
if o >= 0 if o >= 0
then o then o
else raise Not_found else raise Not_found
(* real matching function *) let match_with dfa a =
let __match ~len ~get dfa = let len = Str.length a in
let rec search i state = let rec search i state =
(*Printf.printf "at state %d (dist %d)\n" i dist;*) (*Printf.printf "at state %d (dist %d)\n" i dist;*)
if i = len if i = len
then DFA.is_final dfa state then DFA.is_final dfa state
else begin else begin
(* current char *) (* current char *)
let c = get i in let c = Str.get a i in
try try
let next = __transition dfa state c in let next = __transition dfa state c in
search (i+1) next search (i+1) next
@ -400,29 +463,15 @@ let __match ~len ~get dfa =
in in
search 0 0 search 0 0
let match_with dfa a = (** {6 Index for one-to-many matching} *)
__match ~len:(Array.length a) ~get:(Array.get a) dfa
let match_with_string dfa s = module Index = struct
__match ~len:(String.length s) ~get:(String.get s) dfa type key = char_
(** {6 Index for one-to-many matching} *) module M = Map.Make(struct
type t = key
(** Continuation list *) let compare = Str.compare_char
type 'a klist = end)
[
| `Nil
| `Cons of 'a * (unit -> 'a klist)
]
let rec klist_to_list = function
| `Nil -> []
| `Cons (x,k) -> x :: klist_to_list (k ())
module Index(X : Map.OrderedType) = struct
type key = X.t
module M = Map.Make(X)
type 'b t = type 'b t =
| Node of 'b option * 'b t M.t | Node of 'b option * 'b t M.t
@ -439,7 +488,8 @@ module Index(X : Map.OrderedType) = struct
the continuation k takes the leaf, and returns a leaf option the continuation k takes the leaf, and returns a leaf option
that replaces the old leaf. that replaces the old leaf.
This function returns the new trie. *) This function returns the new trie. *)
let goto_leaf ~len ~get node k = let goto_leaf s node k =
let len = Str.length s in
(* insert the value in given [node], assuming the current index (* insert the value in given [node], assuming the current index
in [arr] is [i]. [k] is given the resulting tree. *) in [arr] is [i]. [k] is given the resulting tree. *)
let rec goto node i rebuild = match node with let rec goto node i rebuild = match node with
@ -447,7 +497,7 @@ module Index(X : Map.OrderedType) = struct
let node' = k node in let node' = k node in
rebuild node' rebuild node'
| Node (opt, m) -> | Node (opt, m) ->
let c = get i in let c = Str.get s i in
let t' = let t' =
try M.find c m try M.find c m
with Not_found -> empty with Not_found -> empty
@ -460,18 +510,19 @@ module Index(X : Map.OrderedType) = struct
in in
goto node 0 (fun t -> t) goto node 0 (fun t -> t)
let __add ~len ~get trie value = let add trie s value =
goto_leaf ~len ~get trie goto_leaf s trie
(function (function
| Node (_, m) -> Node (Some value, m)) | Node (_, m) -> Node (Some value, m))
let __remove ~len ~get trie value = let remove trie s value =
goto_leaf ~len ~get trie goto_leaf s trie
(function (function
| Node (_, m) -> Node (None, m)) | Node (_, m) -> Node (None, m))
(* traverse the automaton and the idx, yielding a klist of values *) (* traverse the automaton and the idx, yielding a klist of values *)
let __retrieve dfa idx = let retrieve ~limit idx s =
let dfa = of_string ~limit s in
(* traverse at index i in automaton, with (* traverse at index i in automaton, with
[fk] the failure continuation *) [fk] the failure continuation *)
let rec traverse node i ~(fk:unit->'a klist) = let rec traverse node i ~(fk:unit->'a klist) =
@ -495,66 +546,40 @@ module Index(X : Map.OrderedType) = struct
in in
traverse idx 0 ~fk:(fun () -> `Nil) traverse idx 0 ~fk:(fun () -> `Nil)
let add idx arr value =
__add ~len:(Array.length arr) ~get:(Array.get arr) idx value
let remove idx arr value =
__remove ~len:(Array.length arr) ~get:(Array.get arr) idx value
let retrieve ~limit idx arr =
let automaton = of_array ~compare:X.compare ~limit arr in
__retrieve automaton idx
let of_list l = let of_list l =
List.fold_left List.fold_left
(fun acc (arr,v) -> add acc arr v) (fun acc (arr,v) -> add acc arr v)
empty l empty l
let __to_list ~of_list idx = let to_list idx =
let rec explore acc trail node = match node with let rec explore acc trail node = match node with
| Node (opt, m) -> | Node (opt, m) ->
(* first, yield current value, if any *) (* first, yield current value, if any *)
let acc = match opt with let acc = match opt with
| None -> acc | None -> acc
| Some v -> (of_list (List.rev trail), v) :: acc | Some v -> (Str.of_list (List.rev trail), v) :: acc
in in
M.fold M.fold
(fun c node' acc -> explore acc (c::trail) node') (fun c node' acc -> explore acc (c::trail) node')
m acc m acc
in in
explore [] [] idx explore [] [] idx
end
let to_list idx =
__to_list ~of_list:Array.of_list idx
end end
module StrIndex = struct include Make(struct
include Index(Char) type t = string
type char_ = char
let add_string idx str value = let compare_char = Char.compare
__add ~len:(String.length str) ~get:(String.get str) idx value let length = String.length
let get = String.get
let remove_string idx str value =
__remove ~len:(String.length str) ~get:(String.get str) idx value
let retrieve_string ~limit idx str =
let automaton = of_string ~limit str in
__retrieve automaton idx
let of_str_list l =
List.fold_left
(fun acc (str,v) -> add_string acc str v)
empty l
let to_str_list idx =
(* clumsy conversion [char list -> string] *)
let of_list l = let of_list l =
let s = String.make (List.length l) ' ' in let s = String.make (List.length l) ' ' in
List.iteri (fun i c -> s.[i] <- c) l; List.iteri (fun i c -> s.[i] <- c) l;
s s
in end)
__to_list ~of_list idx
end let debug_print = debug_print output_char
(* (*
open Batteries;; open Batteries;;

View file

@ -31,34 +31,19 @@ We take inspiration from
http://blog.notdot.net/2010/07/Damn-Cool-Algorithms-Levenshtein-Automata http://blog.notdot.net/2010/07/Damn-Cool-Algorithms-Levenshtein-Automata
for the main algorithm and ideas. However some parts are adapted *) for the main algorithm and ideas. However some parts are adapted *)
(** {2 Automaton} *) (** {2 Abstraction over Strings} *)
type 'a automaton module type STRING = sig
(** Levenshtein automaton for characters of type 'a *) type char_
type t
val of_array : ?compare:('a -> 'a -> int) -> limit:int -> 'a array -> 'a automaton val of_list : char_ list -> t
(** Build an automaton from an array, with a maximal distance [limit] *) val get : t -> int -> char_
val length : t -> int
val compare_char : char_ -> char_ -> int
end
val of_list : ?compare:('a -> 'a -> int) -> limit:int -> 'a list -> 'a automaton (** {2 Continuation list} *)
(** Build an automaton from a list, with a maximal distance [limit] *)
val of_string : limit:int -> string -> char automaton
(** Automaton for the special case of strings *)
val debug_print : out_channel -> char automaton -> unit
(** Output the automaton on the given channel. Only for string automata. *)
val match_with : 'a automaton -> 'a array -> bool
(** [match_with a s] matches the string [s] against [a], and returns
[true] if the distance from [s] to the word represented by [a] is smaller
than the limit used to build [a] *)
val match_with_string : char automaton -> string -> bool
(** Specialized version of {!match_with} for strings *)
(** {6 Index for one-to-many matching} *)
(** Continuation list *)
type 'a klist = type 'a klist =
[ [
| `Nil | `Nil
@ -68,11 +53,37 @@ type 'a klist =
val klist_to_list : 'a klist -> 'a list val klist_to_list : 'a klist -> 'a list
(** Helper. *) (** Helper. *)
module Index(X : Map.OrderedType) : sig (** {2 Signature} *)
type key = X.t
module type S = sig
type char_
type string_
(** {6 Automaton} *)
type automaton
(** Levenshtein automaton *)
val of_string : limit:int -> string_ -> automaton
(** Build an automaton from an array, with a maximal distance [limit] *)
val of_list : limit:int -> char_ list -> automaton
(** Build an automaton from a list, with a maximal distance [limit] *)
val debug_print : (out_channel -> char_ -> unit) ->
out_channel -> automaton -> unit
(** Output the automaton on the given channel. *)
val match_with : automaton -> string_ -> bool
(** [match_with a s] matches the string [s] against [a], and returns
[true] if the distance from [s] to the word represented by [a] is smaller
than the limit used to build [a] *)
(** {6 Index for one-to-many matching} *)
module Index : sig
type 'b t type 'b t
(** Index that maps [key] strings to values of type 'b. Internally it is (** Index that maps strings to values of type 'b. Internally it is
based on a trie. *) based on a trie. *)
val empty : 'b t val empty : 'b t
@ -80,37 +91,28 @@ module Index(X : Map.OrderedType) : sig
val is_empty : _ t -> bool val is_empty : _ t -> bool
val add : 'b t -> key array -> 'b -> 'b t val add : 'b t -> string_ -> 'b -> 'b t
(** Add a char array to the index. If a value was already present (** Add a char array to the index. If a value was already present
for this array it is replaced. *) for this array it is replaced. *)
val remove : 'b t -> key array -> 'b -> 'b t val remove : 'b t -> string_ -> 'b -> 'b t
(** Remove a char array from the index. *) (** Remove a string from the index. *)
val retrieve : limit:int -> 'b t -> key array -> 'b klist val retrieve : limit:int -> 'b t -> string_ -> 'b klist
(** Lazy list of objects associated to strings close to (** Lazy list of objects associated to strings close to the query string *)
the query string *)
val of_list : (key array * 'b) list -> 'b t val of_list : (string_ * 'b) list -> 'b t
val to_list : 'b t -> (key array * 'b) list val to_list : 'b t -> (string_ * 'b) list
(* TODO sequence/iteration functions *) (* TODO sequence/iteration functions *)
end
end end
(** Specific case for strings *) module Make(Str : STRING) : S
module StrIndex : sig with type string_ = Str.t
include module type of Index(Char) and type char_ = Str.char_
val add_string : 'b t -> string -> 'b -> 'b t include S with type char_ = char and type string_ = string
(** Add a string to a char index *)
val remove_string : 'b t -> string -> 'b -> 'b t val debug_print : out_channel -> automaton -> unit
(** Remove a string from a char index *)
val retrieve_string : limit:int -> 'b t -> string -> 'b klist
val of_str_list : (string * 'b) list -> 'b t
val to_str_list : 'b t -> (string * 'b) list
end