From 421cb1332bf9da33a7f2c42ddcc7da00f3243aee Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Wed, 16 Sep 2015 14:10:57 +0200 Subject: [PATCH] new tests in `CCTrie`; bugfix in `CCTrie.below` --- src/data/CCTrie.ml | 129 +++++++++++++++++++++++++++++++++++--------- src/data/CCTrie.mli | 4 ++ 2 files changed, 108 insertions(+), 25 deletions(-) diff --git a/src/data/CCTrie.ml b/src/data/CCTrie.ml index bdebe9b8..422e7352 100644 --- a/src/data/CCTrie.ml +++ b/src/data/CCTrie.ml @@ -112,8 +112,36 @@ module type S = sig val below : key -> 'a t -> (key * 'a) sequence (** All bindings whose key is smaller or equal to the given key *) + + (**/**) + val check_invariants: _ t -> bool + (**/**) end +(*$inject + module T = MakeList(CCInt) + module S = String + + let l1 = [ [1;2], "12"; [1], "1"; [2;1], "21"; [1;2;3], "123"; [], "[]" ] + let t1 = T.of_list l1 + + let small_l l = List.fold_left (fun acc (k,v) -> List.length k+acc) 0 l + *) + +(*$T + String.of_list ["a", 1; "b", 2] |> String.size = 2 + String.of_list ["a", 1; "b", 2; "a", 3] |> String.size = 2 + String.of_list ["a", 1; "b", 2] |> String.find_exn "a" = 1 + String.of_list ["a", 1; "b", 2] |> String.find_exn "b" = 2 + String.of_list ["a", 1; "b", 2] |> String.find "c" = None + + String.of_list ["cat", 1; "catogan", 2; "foo", 3] |> String.find_exn "cat" = 1 + String.of_list ["cat", 1; "catogan", 2; "foo", 3] |> String.find_exn "catogan" = 2 + String.of_list ["cat", 1; "catogan", 2; "foo", 3] |> String.find_exn "foo" = 3 + String.of_list ["cat", 1; "catogan", 2; "foo", 3] |> String.find "cato" = None +*) + + module Make(W : WORD) = struct type char_ = W.char_ type key = W.t @@ -139,13 +167,22 @@ module Make(W : WORD) = struct | Node (None, map) when M.is_empty map -> false | _ -> true + let rec check_invariants = function + | Empty -> true + | Cons (_, t) -> check_invariants t + | Node (None, map) when M.is_empty map -> false + | Node (_, map) -> + M.for_all (fun _ v -> check_invariants v) map + let is_empty = function | Empty -> true | _ -> false let _id x = x - let _fold_seq f ~finish acc seq = + (* fold [f] on [seq] with accumulator [acc], and call [finish] + on the accumulator once [seq] is exhausted *) + let _fold_seq_and_then f ~finish acc seq = let acc = ref acc in seq (fun x -> acc := f !acc x); finish !acc @@ -258,12 +295,20 @@ module Make(W : WORD) = struct rebuild (_mk_node value' map) in let word = W.to_seq key in - _fold_seq goto ~finish (t, _id) word + _fold_seq_and_then goto ~finish (t, _id) word let add k v t = update k (fun _ -> Some v) t let remove k t = update k (fun _ -> None) t + (*$T + T.add [3] "3" t1 |> T.find_exn [3] = "3" + T.add [3] "3" t1 |> T.find_exn [1;2] = "12" + T.remove [1;2] t1 |> T.find [1;2] = None + T.remove [1;2] t1 |> T.find [1] = Some "1" + T.remove [1;2] t1 |> T.find [] = Some "[]" + *) + let find_exn k t = (* at subtree [t], and character [c] *) let goto t c = match t with @@ -278,7 +323,7 @@ module Make(W : WORD) = struct | _ -> raise Not_found in let word = W.to_seq k in - _fold_seq goto ~finish t word + _fold_seq_and_then goto ~finish t word let find k t = try Some (find_exn k t) @@ -308,6 +353,11 @@ module Make(W : WORD) = struct f acc key v ) _id t acc + (*$T + T.fold (fun acc k v -> (k,v) :: acc) [] t1 \ + |> List.sort Pervasives.compare = List.sort Pervasives.compare l1 + *) + let iter f t = _fold (fun () path y -> f (W.of_list (path [])) y) @@ -379,6 +429,17 @@ module Make(W : WORD) = struct in _mk_node v map' + (*$Q & ~small:(fun (a,b) -> List.length a + List.length b) ~count:30 + Q.(let p = list (pair printable_string small_int) in pair p p) \ + (fun (l1,l2) -> \ + let t1 = S.of_list l1 and t2 = S.of_list l2 in \ + let t = S.merge (fun a _ -> Some a) t1 t2 in \ + S.to_seq t |> Sequence.for_all \ + (fun (k,v) -> S.find k t1 = Some v || S.find k t2 = Some v) && \ + S.to_seq t1 |> Sequence.for_all (fun (k,v) -> S.find k t <> None) && \ + S.to_seq t2 |> Sequence.for_all (fun (k,v) -> S.find k t <> None)) + *) + let rec size t = match t with | Empty -> 0 | Cons (_, t') -> size t' @@ -388,6 +449,10 @@ module Make(W : WORD) = struct (fun _ t' acc -> size t' + acc) map s + (*$T + T.size t1 = List.length l1 + *) + let to_list t = fold (fun acc k v -> (k,v)::acc) [] t let of_list l = @@ -398,7 +463,7 @@ module Make(W : WORD) = struct let to_seq_values t k = iter_values k t let of_seq seq = - _fold_seq (fun acc (k,v) -> add k v acc) ~finish:_id empty seq + _fold_seq_and_then (fun acc (k,v) -> add k v acc) ~finish:_id empty seq let rec to_tree t () = let _tree_node x l () = `Node (x,l) in @@ -415,10 +480,10 @@ module Make(W : WORD) = struct (** {6 Ranges} *) - (* range above or below a threshold. + (* range above (if [above = true]) or below a threshold . [p c c'] must return [true] if [c'], in the tree, meets some criterion w.r.t [c] which is a part of the key. *) - let _half_range ~p key t k = + let _half_range ~above ~p key t k = (* at subtree [cur = Some (t,trail)] or [None], alternatives above [alternatives], and char [c] in [key]. *) let on_char (cur, alternatives) c = @@ -429,7 +494,12 @@ module Make(W : WORD) = struct if W.compare c c' = 0 then Some (t', _difflist_add trail c), alternatives else None, alternatives - | Some (Node (_, map), trail) -> + | Some (Node (o, map), trail) -> + (* if [not above], [o]'s key is below [key] so add it *) + begin match o with + | Some v when not above -> k (W.of_list (trail []), v) + | _ -> () + end; let alternatives = let seq = _seq_map map in let seq = _filter_map_seq @@ -450,8 +520,14 @@ module Make(W : WORD) = struct (* run through the current path (if any) and alternatives *) and finish (cur,alternatives) = begin match cur with - | Some (t, prefix) -> + | Some (t, prefix) when above -> + (* subtree prefixed by input key, therefore above key *) _iter_prefix ~prefix (fun key' v -> k (key', v)) t + | Some (Node (Some v, _), prefix) when not above -> + (* yield the value for key *) + assert (W.of_list (prefix []) = key); + k (key, v) + | Some _ | None -> () end; List.iter @@ -459,13 +535,30 @@ module Make(W : WORD) = struct alternatives in let word = W.to_seq key in - _fold_seq on_char ~finish (Some(t,_id), []) word + _fold_seq_and_then on_char ~finish (Some(t,_id), []) word let above key t = - _half_range ~p:(fun c c' -> W.compare c c' < 0) key t + _half_range ~above:true ~p:(fun c c' -> W.compare c c' < 0) key t let below key t = - _half_range ~p:(fun c c' -> W.compare c c' > 0) key t + _half_range ~above:false ~p:(fun c c' -> W.compare c c' > 0) key t + + (*$= & ~printer:CCPrint.(to_string (list (pair (list int) string))) + [ [1], "1"; [1;2], "12"; [1;2;3], "123"; [2;1], "21" ] \ + (T.above [1] t1 |> Sequence.sort |> Sequence.to_list) + [ [1;2], "12"; [1;2;3], "123"; [2;1], "21" ] \ + (T.above [1;1] t1 |> Sequence.sort |> Sequence.to_list) + [ [], "[]"; [1], "1"; [1;2], "12" ] \ + (T.below [1;2] t1 |> Sequence.sort |> Sequence.to_list) + [ [], "[]"; [1], "1" ] \ + (T.below [1;1] t1 |> Sequence.sort |> Sequence.to_list) + *) + + (*$Q & ~small:List.length + Q.(list (pair printable_string small_int)) (fun l -> \ + let t = S.of_list l in \ + S.check_invariants t) + *) end module type ORDERED = sig @@ -499,17 +592,3 @@ module String = Make(struct List.iter (fun c -> Buffer.add_char buf c) l; Buffer.contents buf end) - -(*$T - String.of_list ["a", 1; "b", 2] |> String.size = 2 - String.of_list ["a", 1; "b", 2; "a", 3] |> String.size = 2 - String.of_list ["a", 1; "b", 2] |> String.find_exn "a" = 1 - String.of_list ["a", 1; "b", 2] |> String.find_exn "b" = 2 - String.of_list ["a", 1; "b", 2] |> String.find "c" = None - - String.of_list ["cat", 1; "catogan", 2; "foo", 3] |> String.find_exn "cat" = 1 - String.of_list ["cat", 1; "catogan", 2; "foo", 3] |> String.find_exn "catogan" = 2 - String.of_list ["cat", 1; "catogan", 2; "foo", 3] |> String.find_exn "foo" = 3 - String.of_list ["cat", 1; "catogan", 2; "foo", 3] |> String.find "cato" = None -*) - diff --git a/src/data/CCTrie.mli b/src/data/CCTrie.mli index b7afccd7..3176e48a 100644 --- a/src/data/CCTrie.mli +++ b/src/data/CCTrie.mli @@ -112,6 +112,10 @@ module type S = sig val below : key -> 'a t -> (key * 'a) sequence (** All bindings whose key is smaller or equal to the given key *) + + (**/**) + val check_invariants: _ t -> bool + (**/**) end (** {2 Implementation} *)