mirror of
https://github.com/c-cube/ocaml-containers.git
synced 2025-12-06 03:05:28 -05:00
new tests in CCTrie; bugfix in CCTrie.below
This commit is contained in:
parent
eee7b2318a
commit
421cb1332b
2 changed files with 108 additions and 25 deletions
|
|
@ -112,8 +112,36 @@ module type S = sig
|
|||
|
||||
val below : key -> 'a t -> (key * 'a) sequence
|
||||
(** All bindings whose key is smaller or equal to the given key *)
|
||||
|
||||
(**/**)
|
||||
val check_invariants: _ t -> bool
|
||||
(**/**)
|
||||
end
|
||||
|
||||
(*$inject
|
||||
module T = MakeList(CCInt)
|
||||
module S = String
|
||||
|
||||
let l1 = [ [1;2], "12"; [1], "1"; [2;1], "21"; [1;2;3], "123"; [], "[]" ]
|
||||
let t1 = T.of_list l1
|
||||
|
||||
let small_l l = List.fold_left (fun acc (k,v) -> List.length k+acc) 0 l
|
||||
*)
|
||||
|
||||
(*$T
|
||||
String.of_list ["a", 1; "b", 2] |> String.size = 2
|
||||
String.of_list ["a", 1; "b", 2; "a", 3] |> String.size = 2
|
||||
String.of_list ["a", 1; "b", 2] |> String.find_exn "a" = 1
|
||||
String.of_list ["a", 1; "b", 2] |> String.find_exn "b" = 2
|
||||
String.of_list ["a", 1; "b", 2] |> String.find "c" = None
|
||||
|
||||
String.of_list ["cat", 1; "catogan", 2; "foo", 3] |> String.find_exn "cat" = 1
|
||||
String.of_list ["cat", 1; "catogan", 2; "foo", 3] |> String.find_exn "catogan" = 2
|
||||
String.of_list ["cat", 1; "catogan", 2; "foo", 3] |> String.find_exn "foo" = 3
|
||||
String.of_list ["cat", 1; "catogan", 2; "foo", 3] |> String.find "cato" = None
|
||||
*)
|
||||
|
||||
|
||||
module Make(W : WORD) = struct
|
||||
type char_ = W.char_
|
||||
type key = W.t
|
||||
|
|
@ -139,13 +167,22 @@ module Make(W : WORD) = struct
|
|||
| Node (None, map) when M.is_empty map -> false
|
||||
| _ -> true
|
||||
|
||||
let rec check_invariants = function
|
||||
| Empty -> true
|
||||
| Cons (_, t) -> check_invariants t
|
||||
| Node (None, map) when M.is_empty map -> false
|
||||
| Node (_, map) ->
|
||||
M.for_all (fun _ v -> check_invariants v) map
|
||||
|
||||
let is_empty = function
|
||||
| Empty -> true
|
||||
| _ -> false
|
||||
|
||||
let _id x = x
|
||||
|
||||
let _fold_seq f ~finish acc seq =
|
||||
(* fold [f] on [seq] with accumulator [acc], and call [finish]
|
||||
on the accumulator once [seq] is exhausted *)
|
||||
let _fold_seq_and_then f ~finish acc seq =
|
||||
let acc = ref acc in
|
||||
seq (fun x -> acc := f !acc x);
|
||||
finish !acc
|
||||
|
|
@ -258,12 +295,20 @@ module Make(W : WORD) = struct
|
|||
rebuild (_mk_node value' map)
|
||||
in
|
||||
let word = W.to_seq key in
|
||||
_fold_seq goto ~finish (t, _id) word
|
||||
_fold_seq_and_then goto ~finish (t, _id) word
|
||||
|
||||
let add k v t = update k (fun _ -> Some v) t
|
||||
|
||||
let remove k t = update k (fun _ -> None) t
|
||||
|
||||
(*$T
|
||||
T.add [3] "3" t1 |> T.find_exn [3] = "3"
|
||||
T.add [3] "3" t1 |> T.find_exn [1;2] = "12"
|
||||
T.remove [1;2] t1 |> T.find [1;2] = None
|
||||
T.remove [1;2] t1 |> T.find [1] = Some "1"
|
||||
T.remove [1;2] t1 |> T.find [] = Some "[]"
|
||||
*)
|
||||
|
||||
let find_exn k t =
|
||||
(* at subtree [t], and character [c] *)
|
||||
let goto t c = match t with
|
||||
|
|
@ -278,7 +323,7 @@ module Make(W : WORD) = struct
|
|||
| _ -> raise Not_found
|
||||
in
|
||||
let word = W.to_seq k in
|
||||
_fold_seq goto ~finish t word
|
||||
_fold_seq_and_then goto ~finish t word
|
||||
|
||||
let find k t =
|
||||
try Some (find_exn k t)
|
||||
|
|
@ -308,6 +353,11 @@ module Make(W : WORD) = struct
|
|||
f acc key v
|
||||
) _id t acc
|
||||
|
||||
(*$T
|
||||
T.fold (fun acc k v -> (k,v) :: acc) [] t1 \
|
||||
|> List.sort Pervasives.compare = List.sort Pervasives.compare l1
|
||||
*)
|
||||
|
||||
let iter f t =
|
||||
_fold
|
||||
(fun () path y -> f (W.of_list (path [])) y)
|
||||
|
|
@ -379,6 +429,17 @@ module Make(W : WORD) = struct
|
|||
in
|
||||
_mk_node v map'
|
||||
|
||||
(*$Q & ~small:(fun (a,b) -> List.length a + List.length b) ~count:30
|
||||
Q.(let p = list (pair printable_string small_int) in pair p p) \
|
||||
(fun (l1,l2) -> \
|
||||
let t1 = S.of_list l1 and t2 = S.of_list l2 in \
|
||||
let t = S.merge (fun a _ -> Some a) t1 t2 in \
|
||||
S.to_seq t |> Sequence.for_all \
|
||||
(fun (k,v) -> S.find k t1 = Some v || S.find k t2 = Some v) && \
|
||||
S.to_seq t1 |> Sequence.for_all (fun (k,v) -> S.find k t <> None) && \
|
||||
S.to_seq t2 |> Sequence.for_all (fun (k,v) -> S.find k t <> None))
|
||||
*)
|
||||
|
||||
let rec size t = match t with
|
||||
| Empty -> 0
|
||||
| Cons (_, t') -> size t'
|
||||
|
|
@ -388,6 +449,10 @@ module Make(W : WORD) = struct
|
|||
(fun _ t' acc -> size t' + acc)
|
||||
map s
|
||||
|
||||
(*$T
|
||||
T.size t1 = List.length l1
|
||||
*)
|
||||
|
||||
let to_list t = fold (fun acc k v -> (k,v)::acc) [] t
|
||||
|
||||
let of_list l =
|
||||
|
|
@ -398,7 +463,7 @@ module Make(W : WORD) = struct
|
|||
let to_seq_values t k = iter_values k t
|
||||
|
||||
let of_seq seq =
|
||||
_fold_seq (fun acc (k,v) -> add k v acc) ~finish:_id empty seq
|
||||
_fold_seq_and_then (fun acc (k,v) -> add k v acc) ~finish:_id empty seq
|
||||
|
||||
let rec to_tree t () =
|
||||
let _tree_node x l () = `Node (x,l) in
|
||||
|
|
@ -415,10 +480,10 @@ module Make(W : WORD) = struct
|
|||
|
||||
(** {6 Ranges} *)
|
||||
|
||||
(* range above or below a threshold.
|
||||
(* range above (if [above = true]) or below a threshold .
|
||||
[p c c'] must return [true] if [c'], in the tree, meets some criterion
|
||||
w.r.t [c] which is a part of the key. *)
|
||||
let _half_range ~p key t k =
|
||||
let _half_range ~above ~p key t k =
|
||||
(* at subtree [cur = Some (t,trail)] or [None], alternatives above
|
||||
[alternatives], and char [c] in [key]. *)
|
||||
let on_char (cur, alternatives) c =
|
||||
|
|
@ -429,7 +494,12 @@ module Make(W : WORD) = struct
|
|||
if W.compare c c' = 0
|
||||
then Some (t', _difflist_add trail c), alternatives
|
||||
else None, alternatives
|
||||
| Some (Node (_, map), trail) ->
|
||||
| Some (Node (o, map), trail) ->
|
||||
(* if [not above], [o]'s key is below [key] so add it *)
|
||||
begin match o with
|
||||
| Some v when not above -> k (W.of_list (trail []), v)
|
||||
| _ -> ()
|
||||
end;
|
||||
let alternatives =
|
||||
let seq = _seq_map map in
|
||||
let seq = _filter_map_seq
|
||||
|
|
@ -450,8 +520,14 @@ module Make(W : WORD) = struct
|
|||
(* run through the current path (if any) and alternatives *)
|
||||
and finish (cur,alternatives) =
|
||||
begin match cur with
|
||||
| Some (t, prefix) ->
|
||||
| Some (t, prefix) when above ->
|
||||
(* subtree prefixed by input key, therefore above key *)
|
||||
_iter_prefix ~prefix (fun key' v -> k (key', v)) t
|
||||
| Some (Node (Some v, _), prefix) when not above ->
|
||||
(* yield the value for key *)
|
||||
assert (W.of_list (prefix []) = key);
|
||||
k (key, v)
|
||||
| Some _
|
||||
| None -> ()
|
||||
end;
|
||||
List.iter
|
||||
|
|
@ -459,13 +535,30 @@ module Make(W : WORD) = struct
|
|||
alternatives
|
||||
in
|
||||
let word = W.to_seq key in
|
||||
_fold_seq on_char ~finish (Some(t,_id), []) word
|
||||
_fold_seq_and_then on_char ~finish (Some(t,_id), []) word
|
||||
|
||||
let above key t =
|
||||
_half_range ~p:(fun c c' -> W.compare c c' < 0) key t
|
||||
_half_range ~above:true ~p:(fun c c' -> W.compare c c' < 0) key t
|
||||
|
||||
let below key t =
|
||||
_half_range ~p:(fun c c' -> W.compare c c' > 0) key t
|
||||
_half_range ~above:false ~p:(fun c c' -> W.compare c c' > 0) key t
|
||||
|
||||
(*$= & ~printer:CCPrint.(to_string (list (pair (list int) string)))
|
||||
[ [1], "1"; [1;2], "12"; [1;2;3], "123"; [2;1], "21" ] \
|
||||
(T.above [1] t1 |> Sequence.sort |> Sequence.to_list)
|
||||
[ [1;2], "12"; [1;2;3], "123"; [2;1], "21" ] \
|
||||
(T.above [1;1] t1 |> Sequence.sort |> Sequence.to_list)
|
||||
[ [], "[]"; [1], "1"; [1;2], "12" ] \
|
||||
(T.below [1;2] t1 |> Sequence.sort |> Sequence.to_list)
|
||||
[ [], "[]"; [1], "1" ] \
|
||||
(T.below [1;1] t1 |> Sequence.sort |> Sequence.to_list)
|
||||
*)
|
||||
|
||||
(*$Q & ~small:List.length
|
||||
Q.(list (pair printable_string small_int)) (fun l -> \
|
||||
let t = S.of_list l in \
|
||||
S.check_invariants t)
|
||||
*)
|
||||
end
|
||||
|
||||
module type ORDERED = sig
|
||||
|
|
@ -499,17 +592,3 @@ module String = Make(struct
|
|||
List.iter (fun c -> Buffer.add_char buf c) l;
|
||||
Buffer.contents buf
|
||||
end)
|
||||
|
||||
(*$T
|
||||
String.of_list ["a", 1; "b", 2] |> String.size = 2
|
||||
String.of_list ["a", 1; "b", 2; "a", 3] |> String.size = 2
|
||||
String.of_list ["a", 1; "b", 2] |> String.find_exn "a" = 1
|
||||
String.of_list ["a", 1; "b", 2] |> String.find_exn "b" = 2
|
||||
String.of_list ["a", 1; "b", 2] |> String.find "c" = None
|
||||
|
||||
String.of_list ["cat", 1; "catogan", 2; "foo", 3] |> String.find_exn "cat" = 1
|
||||
String.of_list ["cat", 1; "catogan", 2; "foo", 3] |> String.find_exn "catogan" = 2
|
||||
String.of_list ["cat", 1; "catogan", 2; "foo", 3] |> String.find_exn "foo" = 3
|
||||
String.of_list ["cat", 1; "catogan", 2; "foo", 3] |> String.find "cato" = None
|
||||
*)
|
||||
|
||||
|
|
|
|||
|
|
@ -112,6 +112,10 @@ module type S = sig
|
|||
|
||||
val below : key -> 'a t -> (key * 'a) sequence
|
||||
(** All bindings whose key is smaller or equal to the given key *)
|
||||
|
||||
(**/**)
|
||||
val check_invariants: _ t -> bool
|
||||
(**/**)
|
||||
end
|
||||
|
||||
(** {2 Implementation} *)
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue