diff --git a/src/data/CCTrie.ml b/src/data/CCTrie.ml index 003283c5..0ff580c9 100644 --- a/src/data/CCTrie.ml +++ b/src/data/CCTrie.ml @@ -54,10 +54,12 @@ module type S = sig (** Fold on key/value bindings. Will use {!WORD.of_list} to rebuild keys. *) val mapi : (key -> 'a -> 'b) -> 'a t -> 'b t - (** Map values in the try. Will use {!WORD.of_list} to rebuild keys. *) + (** Map values, giving both key and value. Will use {!WORD.of_list} to rebuild keys. + @since NEXT_RELEASE *) val map : ('a -> 'b) -> 'a t -> 'b t - (** Map values in the try, not giving keys to the mapping function. *) + (** Map values, giving only the value. + @since NEXT_RELEASE *) val iter : (key -> 'a -> unit) -> 'a t -> unit (** Same as {!fold}, but for effectful functions *) @@ -91,10 +93,12 @@ module type S = sig (** {6 Ranges} *) val above : key -> 'a t -> (key * 'a) sequence - (** All bindings whose key is bigger or equal to the given key *) + (** All bindings whose key is bigger or equal to the given key, in + ascending order *) val below : key -> 'a t -> (key * 'a) sequence - (** All bindings whose key is smaller or equal to the given key *) + (** All bindings whose key is smaller or equal to the given key, + in decreasing order *) (**/**) val check_invariants: _ t -> bool @@ -175,12 +179,17 @@ module Make(W : WORD) = struct | None -> () | Some y -> k y) - let _seq_append_list l seq = + let _seq_map f seq k = seq (fun x -> k (f x)) + + let _seq_append_list_rev l seq = let l = ref l in seq (fun x -> l := x :: !l); !l - let _seq_map map k = + let _seq_append_list l seq = + List.rev_append (_seq_append_list_rev [] seq) l + + let seq_of_map map k = M.iter (fun key v -> k (key,v)) map (* return common prefix, and disjoint suffixes *) @@ -312,7 +321,11 @@ module Make(W : WORD) = struct try Some (find_exn k t) with Not_found -> None - let _difflist_add f x = fun l' -> f (x :: l') + type 'a difflist = 'a list -> 'a list + + let _difflist_add + : 'a difflist -> 'a -> 'a difflist + = fun f x -> fun l' -> f (x :: l') (* fold that also keeps the path from the root, so as to provide the list of chars that lead to a value. The path is a difference list, ie @@ -333,8 +346,8 @@ module Make(W : WORD) = struct _fold (fun acc path v -> let key = W.of_list (path []) in - f acc key v - ) _id t acc + f acc key v) + _id t acc (*$T T.fold (fun acc k v -> (k,v) :: acc) [] t1 \ @@ -503,10 +516,42 @@ module Make(W : WORD) = struct (** {6 Ranges} *) + (* stack of actions for [above] and [below] *) + type 'a alternative = + | Yield of 'a * char_ difflist + | Explore of 'a t * char_ difflist + + type direction = + | Above + | Below + + let rec explore ~dir k alt = match alt with + | Yield (v,prefix) -> k (W.of_list (prefix[]), v) + | Explore (Empty, _) -> () + | Explore (Cons (c,t), prefix) -> + explore ~dir k (Explore (t, _difflist_add prefix c)) + | Explore (Node (o,map), prefix) -> + (* if above, yield value now *) + begin match o, dir with + | Some v, Above -> k (W.of_list (prefix[]), v) + | _ -> () + end; + let seq = + seq_of_map map + |> _seq_map (fun (c,t') -> Explore (t', _difflist_add prefix c)) + in + let l' = match o, dir with + | _, Above -> _seq_append_list [] seq + | None, Below -> _seq_append_list_rev [] seq + | Some v, Below -> + _seq_append_list_rev [Yield (v, prefix)] seq + in + List.iter (explore ~dir k) l' + (* range above (if [above = true]) or below a threshold . [p c c'] must return [true] if [c'], in the tree, meets some criterion w.r.t [c] which is a part of the key. *) - let _half_range ~above ~p key t k = + let _half_range ~dir ~p key t k = (* at subtree [cur = Some (t,trail)] or [None], alternatives above [alternatives], and char [c] in [key]. *) let on_char (cur, alternatives) c = @@ -518,22 +563,30 @@ module Make(W : WORD) = struct then Some (t', _difflist_add trail c), alternatives else None, alternatives | Some (Node (o, map), trail) -> - (* if [not above], [o]'s key is below [key] so add it *) - begin match o with - | Some v when not above -> k (W.of_list (trail []), v) - | _ -> () - end; - let alternatives = - let seq = _seq_map map in - let seq = _filter_map_seq - (fun (c', t') -> if p c c' - then Some (t', _difflist_add trail c') - else None - ) seq - in - _seq_append_list alternatives seq + (* if [dir=Below], [o]'s key is below [key] and the other + alternatives in [map] *) + let alternatives = match o, dir with + | Some v, Below -> Yield (v, trail) :: alternatives + | _ -> alternatives in - begin try + let alternatives = + let seq = seq_of_map map in + let seq = _filter_map_seq + (fun (c', t') -> + if p ~cur:c ~other:c' + then Some (Explore (t', _difflist_add trail c')) + else None) + seq + in + (* ordering: + - Above: explore alternatives in increasing order + - Below: explore alternatives in decreasing order *) + match dir with + | Above -> _seq_append_list alternatives seq + | Below -> _seq_append_list_rev alternatives seq + in + begin + try let t' = M.find c map in Some (t', _difflist_add trail c), alternatives with Not_found -> @@ -542,39 +595,37 @@ module Make(W : WORD) = struct (* run through the current path (if any) and alternatives *) and finish (cur,alternatives) = - begin match cur with - | Some (t, prefix) when above -> + begin match cur, dir with + | Some (t, prefix), Above -> (* subtree prefixed by input key, therefore above key *) _iter_prefix ~prefix (fun key' v -> k (key', v)) t - | Some (Node (Some v, _), prefix) when not above -> + | Some (Node (Some v, _), prefix), Below -> (* yield the value for key *) assert (W.of_list (prefix []) = key); k (key, v) - | Some _ - | None -> () + | Some _, _ + | None, _ -> () end; - List.iter - (fun (t,prefix) -> _iter_prefix ~prefix (fun key' v -> k (key', v)) t) - alternatives + List.iter (explore ~dir k) alternatives in let word = W.to_seq key in _fold_seq_and_then on_char ~finish (Some(t,_id), []) word let above key t = - _half_range ~above:true ~p:(fun c c' -> W.compare c c' < 0) key t + _half_range ~dir:Above ~p:(fun ~cur ~other -> W.compare cur other < 0) key t let below key t = - _half_range ~above:false ~p:(fun c c' -> W.compare c c' > 0) key t + _half_range ~dir:Below ~p:(fun ~cur ~other -> W.compare cur other > 0) key t (*$= & ~printer:CCPrint.(to_string (list (pair (list int) string))) [ [1], "1"; [1;2], "12"; [1;2;3], "123"; [2;1], "21" ] \ - (T.above [1] t1 |> Sequence.sort |> Sequence.to_list) + (T.above [1] t1 |> Sequence.to_list) [ [1;2], "12"; [1;2;3], "123"; [2;1], "21" ] \ - (T.above [1;1] t1 |> Sequence.sort |> Sequence.to_list) - [ [], "[]"; [1], "1"; [1;2], "12" ] \ - (T.below [1;2] t1 |> Sequence.sort |> Sequence.to_list) - [ [], "[]"; [1], "1" ] \ - (T.below [1;1] t1 |> Sequence.sort |> Sequence.to_list) + (T.above [1;1] t1 |> Sequence.to_list) + [ [1;2], "12"; [1], "1"; [], "[]" ] \ + (T.below [1;2] t1 |> Sequence.to_list) + [ [1], "1"; [], "[]" ] \ + (T.below [1;1] t1 |> Sequence.to_list) *) (*$Q & ~count:30 @@ -583,7 +634,14 @@ module Make(W : WORD) = struct S.check_invariants t) *) - (*$Q & ~count:20 + (*$inject + let rec sorted ~rev = function + | [] | [_] -> true + | x :: ((y ::_) as tl) -> + (if rev then x >= y else x <= y) && sorted ~rev tl + *) + + (*$Q & ~count:200 Q.(list_of_size Gen.(1 -- 20) (pair printable_string small_int)) \ (fun l -> let t = String.of_list l in \ List.for_all (fun (k,_) -> \ @@ -594,6 +652,16 @@ module Make(W : WORD) = struct List.for_all (fun (k,_) -> \ String.below k t |> Sequence.for_all (fun (k',v) -> k' <= k)) \ l) + Q.(list_of_size Gen.(1 -- 20) (pair printable_string small_int)) \ + (fun l -> let t = String.of_list l in \ + List.for_all (fun (k,_) -> \ + String.above k t |> Sequence.to_list |> sorted ~rev:false) \ + l) + Q.(list_of_size Gen.(1 -- 20) (pair printable_string small_int)) \ + (fun l -> let t = String.of_list l in \ + List.for_all (fun (k,_) -> \ + String.below k t |> Sequence.to_list |> sorted ~rev:true) \ + l) *) end diff --git a/src/data/CCTrie.mli b/src/data/CCTrie.mli index 28c0cc4f..dc8567aa 100644 --- a/src/data/CCTrie.mli +++ b/src/data/CCTrie.mli @@ -97,7 +97,8 @@ module type S = sig ascending order *) val below : key -> 'a t -> (key * 'a) sequence - (** All bindings whose key is smaller or equal to the given key *) + (** All bindings whose key is smaller or equal to the given key, + in decreasing order *) (**/**) val check_invariants: _ t -> bool