mirror of
https://github.com/c-cube/ocaml-containers.git
synced 2025-12-08 12:15:32 -05:00
479 lines
12 KiB
OCaml
479 lines
12 KiB
OCaml
(* "Fast Mergeable Integer Maps", Okasaki & Gill.
|
|
We use big-endian trees. *)
|
|
|
|
(** Masks with exactly one bit active *)
|
|
module Bit : sig
|
|
type t = private int
|
|
|
|
val highest : int -> t
|
|
val min_int : t
|
|
val equal : t -> t -> bool
|
|
val is_0 : bit:t -> int -> bool
|
|
val is_1 : bit:t -> int -> bool
|
|
val mask : mask:t -> int -> int (* zeroes the bit, puts all lower bits to 1 *)
|
|
|
|
val lt : t -> t -> bool
|
|
val gt : t -> t -> bool
|
|
val equal_int : int -> t -> bool
|
|
end = struct
|
|
type t = int
|
|
|
|
let min_int = min_int
|
|
let equal : t -> t -> bool = Stdlib.( = )
|
|
|
|
let rec highest_bit_naive x m =
|
|
if x = m then
|
|
m
|
|
else
|
|
highest_bit_naive (x land lnot m) (2 * m)
|
|
|
|
let mask_20_ = 1 lsl 20
|
|
let mask_40_ = 1 lsl 40
|
|
|
|
let highest x =
|
|
if x < 0 then
|
|
min_int
|
|
else if Sys.word_size > 40 && x > mask_40_ then (
|
|
(* remove least significant 40 bits *)
|
|
let x' = x land lnot (mask_40_ - 1) in
|
|
highest_bit_naive x' mask_40_
|
|
) else if x > mask_20_ then (
|
|
(* small shortcut: remove least significant 20 bits *)
|
|
let x' = x land lnot (mask_20_ - 1) in
|
|
highest_bit_naive x' mask_20_
|
|
) else
|
|
highest_bit_naive x 1
|
|
|
|
let[@inline] is_0 ~bit x = x land bit = 0
|
|
let[@inline] is_1 ~bit x = x land bit = bit
|
|
let mask ~mask x = x lor (mask - 1) land lnot mask
|
|
(* small endian: let mask_ x ~mask = x land (mask - 1) *)
|
|
|
|
let gt a b = b != min_int && (a = min_int || a > b)
|
|
let lt a b = gt b a
|
|
let equal_int : int -> int -> bool = Stdlib.( = )
|
|
end
|
|
|
|
type +'a t =
|
|
| E (* empty *)
|
|
| L of int * 'a (* leaf *)
|
|
| N of int (* common prefix *) * Bit.t (* bit switch *) * 'a t * 'a t
|
|
|
|
let empty = E
|
|
|
|
let[@inline] is_empty = function
|
|
| E -> true
|
|
| _ -> false
|
|
|
|
let[@inline] is_prefix_ ~prefix y ~bit = prefix = Bit.mask y ~mask:bit
|
|
|
|
(* small endian: let branching_bit_ a _ b _ = lowest_bit_ (a lxor b) *)
|
|
let branching_bit_ a b = Bit.highest (a lxor b)
|
|
|
|
(* TODO use hint in branching_bit_ *)
|
|
|
|
let check_invariants t =
|
|
(* check that keys are prefixed by every node in their path *)
|
|
let rec check_keys path t =
|
|
match t with
|
|
| E -> true
|
|
| L (k, _) ->
|
|
List.for_all
|
|
(fun (prefix, switch, side) ->
|
|
is_prefix_ ~prefix k ~bit:switch
|
|
&&
|
|
match side with
|
|
| `Left -> Bit.is_0 k ~bit:switch
|
|
| `Right -> Bit.is_1 k ~bit:switch)
|
|
path
|
|
| N (prefix, switch, l, r) ->
|
|
check_keys ((prefix, switch, `Left) :: path) l
|
|
&& check_keys ((prefix, switch, `Right) :: path) r
|
|
in
|
|
check_keys [] t
|
|
|
|
let rec find_exn k t =
|
|
match t with
|
|
| E -> raise Not_found
|
|
| L (k', v) when k = k' -> v
|
|
| L _ -> raise Not_found
|
|
| N (prefix, m, l, r) ->
|
|
if is_prefix_ ~prefix k ~bit:m then
|
|
if Bit.is_0 k ~bit:m then
|
|
find_exn k l
|
|
else
|
|
find_exn k r
|
|
else
|
|
raise Not_found
|
|
|
|
let find k t = try Some (find_exn k t) with Not_found -> None
|
|
|
|
let mem k t =
|
|
try
|
|
ignore (find_exn k t);
|
|
true
|
|
with Not_found -> false
|
|
|
|
let mk_node_ prefix switch l r =
|
|
match l, r with
|
|
| E, o | o, E -> o
|
|
| _ -> N (prefix, switch, l, r)
|
|
|
|
(* join trees t1 and t2 with prefix p1 and p2 respectively
|
|
(p1 and p2 do not overlap) *)
|
|
let join_ t1 p1 t2 p2 =
|
|
let switch = branching_bit_ p1 p2 in
|
|
let prefix = Bit.mask p1 ~mask:switch in
|
|
if Bit.is_0 p1 ~bit:switch then (
|
|
assert (Bit.is_1 p2 ~bit:switch);
|
|
mk_node_ prefix switch t1 t2
|
|
) else (
|
|
assert (Bit.is_0 p2 ~bit:switch);
|
|
mk_node_ prefix switch t2 t1
|
|
)
|
|
|
|
let singleton k v = L (k, v)
|
|
|
|
(* c: conflict function *)
|
|
let rec insert_ c k v t =
|
|
match t with
|
|
| E -> L (k, v)
|
|
| L (k', v') ->
|
|
if k = k' then
|
|
L (k, c ~old:v' v)
|
|
else
|
|
join_ t k' (L (k, v)) k
|
|
| N (prefix, switch, l, r) ->
|
|
if is_prefix_ ~prefix k ~bit:switch then
|
|
if Bit.is_0 k ~bit:switch then
|
|
N (prefix, switch, insert_ c k v l, r)
|
|
else
|
|
N (prefix, switch, l, insert_ c k v r)
|
|
else
|
|
join_ (L (k, v)) k t prefix
|
|
|
|
let add k v t = insert_ (fun ~old:_ v -> v) k v t
|
|
|
|
let rec remove k t =
|
|
match t with
|
|
| E -> E
|
|
| L (k', _) ->
|
|
if k = k' then
|
|
E
|
|
else
|
|
t
|
|
| N (prefix, switch, l, r) ->
|
|
if is_prefix_ ~prefix k ~bit:switch then
|
|
if Bit.is_0 k ~bit:switch then
|
|
mk_node_ prefix switch (remove k l) r
|
|
else
|
|
mk_node_ prefix switch l (remove k r)
|
|
else
|
|
t
|
|
(* not present *)
|
|
|
|
let update k f t =
|
|
try
|
|
let v = find_exn k t in
|
|
match f (Some v) with
|
|
| None -> remove k t
|
|
| Some v' -> add k v' t
|
|
with Not_found ->
|
|
(match f None with
|
|
| None -> t
|
|
| Some v -> add k v t)
|
|
|
|
let doubleton k1 v1 k2 v2 = add k1 v1 (singleton k2 v2)
|
|
|
|
let rec equal ~eq a b =
|
|
Stdlib.( == ) a b
|
|
||
|
|
match a, b with
|
|
| E, E -> true
|
|
| L (ka, va), L (kb, vb) -> ka = kb && eq va vb
|
|
| N (pa, sa, la, ra), N (pb, sb, lb, rb) ->
|
|
pa = pb && Bit.equal sa sb && equal ~eq la lb && equal ~eq ra rb
|
|
| E, _ | N _, _ | L _, _ -> false
|
|
|
|
let rec iter f t =
|
|
match t with
|
|
| E -> ()
|
|
| L (k, v) -> f k v
|
|
| N (_, _, l, r) ->
|
|
iter f l;
|
|
iter f r
|
|
|
|
let rec fold f t acc =
|
|
match t with
|
|
| E -> acc
|
|
| L (k, v) -> f k v acc
|
|
| N (_, _, l, r) ->
|
|
let acc = fold f l acc in
|
|
fold f r acc
|
|
|
|
let cardinal t = fold (fun _ _ n -> n + 1) t 0
|
|
|
|
let rec mapi f t =
|
|
match t with
|
|
| E -> E
|
|
| L (k, v) -> L (k, f k v)
|
|
| N (p, s, l, r) -> N (p, s, mapi f l, mapi f r)
|
|
|
|
let rec map f t =
|
|
match t with
|
|
| E -> E
|
|
| L (k, v) -> L (k, f v)
|
|
| N (p, s, l, r) -> N (p, s, map f l, map f r)
|
|
|
|
let rec choose_exn = function
|
|
| E -> raise Not_found
|
|
| L (k, v) -> k, v
|
|
| N (_, _, l, _) -> choose_exn l
|
|
|
|
let choose t = try Some (choose_exn t) with Not_found -> None
|
|
|
|
(** {2 Whole-collection operations} *)
|
|
|
|
let rec union f t1 t2 =
|
|
match t1, t2 with
|
|
| E, o | o, E -> o
|
|
| L (k, v1), o2 ->
|
|
insert_ (fun ~old v -> f k v old) k v1 o2 (* insert k, v into o *)
|
|
| o1, L (k, v2) ->
|
|
insert_ (fun ~old v -> f k old v) k v2 o1 (* insert k, v into o *)
|
|
| N (p1, m1, l1, r1), N (p2, m2, l2, r2) ->
|
|
if p1 = p2 && Bit.equal m1 m2 then
|
|
mk_node_ p1 m1 (union f l1 l2) (union f r1 r2)
|
|
else if Bit.gt m1 m2 && is_prefix_ ~prefix:p1 p2 ~bit:m1 then
|
|
if Bit.is_0 p2 ~bit:m1 then
|
|
N (p1, m1, union f l1 t2, r1)
|
|
else
|
|
N (p1, m1, l1, union f r1 t2)
|
|
else if Bit.lt m1 m2 && is_prefix_ ~prefix:p2 p1 ~bit:m2 then
|
|
if Bit.is_0 p1 ~bit:m2 then
|
|
N (p2, m2, union f t1 l2, r2)
|
|
else
|
|
N (p2, m2, l2, union f t1 r2)
|
|
else
|
|
join_ t1 p1 t2 p2
|
|
|
|
let rec inter f a b =
|
|
match a, b with
|
|
| E, _ | _, E -> E
|
|
| L (k, v1), o2 ->
|
|
(try
|
|
let v2' = find_exn k o2 in
|
|
L (k, f k v1 v2')
|
|
with Not_found -> E)
|
|
| o1, L (k, v2) ->
|
|
(try
|
|
let v1' = find_exn k o1 in
|
|
L (k, f k v1' v2)
|
|
with Not_found -> E)
|
|
| N (p1, m1, l1, r1), N (p2, m2, l2, r2) ->
|
|
if p1 = p2 && Bit.equal m1 m2 then
|
|
mk_node_ p1 m1 (inter f l1 l2) (inter f r1 r2)
|
|
else if Bit.gt m1 m2 && is_prefix_ ~prefix:p1 p2 ~bit:m1 then
|
|
if Bit.is_0 p2 ~bit:m1 then
|
|
inter f l1 b
|
|
else
|
|
inter f r1 b
|
|
else if Bit.lt m1 m2 && is_prefix_ ~prefix:p2 p1 ~bit:m2 then
|
|
if Bit.is_0 p1 ~bit:m2 then
|
|
inter f a l2
|
|
else
|
|
inter f a r2
|
|
else
|
|
E
|
|
|
|
let rec disjoint_union_ t1 t2 : _ t =
|
|
match t1, t2 with
|
|
| E, o | o, E -> o
|
|
| L (k, v), o | o, L (k, v) -> insert_ (fun ~old:_ _ -> assert false) k v o
|
|
| N (p1, m1, l1, r1), N (p2, m2, l2, r2) ->
|
|
if p1 = p2 && Bit.equal m1 m2 then
|
|
mk_node_ p1 m1 (disjoint_union_ l1 l2) (disjoint_union_ r1 r2)
|
|
else if Bit.gt m1 m2 && is_prefix_ ~prefix:p1 p2 ~bit:m1 then
|
|
if Bit.is_0 p2 ~bit:m1 then
|
|
mk_node_ p1 m1 (disjoint_union_ l1 t2) r1
|
|
else
|
|
mk_node_ p1 m1 l1 (disjoint_union_ r1 t2)
|
|
else if Bit.lt m1 m2 && is_prefix_ ~prefix:p2 p1 ~bit:m2 then
|
|
if Bit.is_0 p1 ~bit:m2 then
|
|
mk_node_ p2 m2 (disjoint_union_ t1 l2) r2
|
|
else
|
|
mk_node_ p2 m2 l2 (disjoint_union_ t1 r2)
|
|
else
|
|
join_ t1 p1 t2 p2
|
|
|
|
let rec filter f m =
|
|
match m with
|
|
| E -> E
|
|
| L (k, v) ->
|
|
if f k v then
|
|
m
|
|
else
|
|
E
|
|
| N (_, _, l, r) -> disjoint_union_ (filter f l) (filter f r)
|
|
|
|
let rec filter_map f m =
|
|
match m with
|
|
| E -> E
|
|
| L (k, v) ->
|
|
(match f k v with
|
|
| None -> E
|
|
| Some v' -> L (k, v'))
|
|
| N (_, _, l, r) -> disjoint_union_ (filter_map f l) (filter_map f r)
|
|
|
|
let rec merge ~f t1 t2 : _ t =
|
|
let merge1 t = filter_map (fun k v -> f k (`Left v)) t
|
|
and merge2 t = filter_map (fun k v -> f k (`Right v)) t
|
|
and add_some k opt m =
|
|
match opt with
|
|
| None -> m
|
|
| Some v -> insert_ (fun ~old:_ _ -> assert false) k v m
|
|
in
|
|
match t1, t2 with
|
|
| E, o -> merge2 o
|
|
| o, E -> merge1 o
|
|
| L (k, v), o ->
|
|
let others = merge2 (remove k o) in
|
|
add_some k
|
|
(try f k (`Both (v, find_exn k o)) with Not_found -> f k (`Left v))
|
|
others
|
|
| o, L (k, v) ->
|
|
let others = merge1 (remove k o) in
|
|
add_some k
|
|
(try f k (`Both (find_exn k o, v)) with Not_found -> f k (`Right v))
|
|
others
|
|
| N (p1, m1, l1, r1), N (p2, m2, l2, r2) ->
|
|
if p1 = p2 && Bit.equal m1 m2 then
|
|
mk_node_ p1 m1 (merge ~f l1 l2) (merge ~f r1 r2)
|
|
else if Bit.gt m1 m2 && is_prefix_ ~prefix:p1 p2 ~bit:m1 then
|
|
if Bit.is_0 p2 ~bit:m1 then
|
|
mk_node_ p1 m1 (merge ~f l1 t2) (merge1 r1)
|
|
else
|
|
mk_node_ p1 m1 (merge1 l1) (merge ~f r1 t2)
|
|
else if Bit.lt m1 m2 && is_prefix_ ~prefix:p2 p1 ~bit:m2 then
|
|
if Bit.is_0 p1 ~bit:m2 then
|
|
mk_node_ p2 m2 (merge ~f t1 l2) (merge2 r2)
|
|
else
|
|
mk_node_ p2 m2 (merge2 l2) (merge ~f t1 r2)
|
|
else
|
|
join_ (merge1 t1) p1 (merge2 t2) p2
|
|
|
|
(** {2 Conversions} *)
|
|
|
|
type 'a iter = ('a -> unit) -> unit
|
|
type 'a gen = unit -> 'a option
|
|
|
|
let add_list t l = List.fold_left (fun t (k, v) -> add k v t) t l
|
|
let of_list l = add_list empty l
|
|
let to_list t = fold (fun k v l -> (k, v) :: l) t []
|
|
|
|
let add_iter t iter =
|
|
let t = ref t in
|
|
iter (fun (k, v) -> t := add k v !t);
|
|
!t
|
|
|
|
let of_iter iter = add_iter empty iter
|
|
let to_iter t yield = iter (fun k v -> yield (k, v)) t
|
|
let keys t yield = iter (fun k _ -> yield k) t
|
|
let values t yield = iter (fun _ v -> yield v) t
|
|
|
|
let rec add_gen m g =
|
|
match g () with
|
|
| None -> m
|
|
| Some (k, v) -> add_gen (add k v m) g
|
|
|
|
let of_gen g = add_gen empty g
|
|
|
|
let to_gen m =
|
|
let st = Stack.create () in
|
|
Stack.push m st;
|
|
let rec next () =
|
|
if Stack.is_empty st then
|
|
None
|
|
else
|
|
explore (Stack.pop st)
|
|
and explore n =
|
|
match n with
|
|
| E -> next () (* backtrack *)
|
|
| L (k, v) -> Some (k, v)
|
|
| N (_, _, l, r) ->
|
|
Stack.push r st;
|
|
explore l
|
|
in
|
|
next
|
|
|
|
(* E < L < N; arbitrary order for switches *)
|
|
let compare ~cmp a b =
|
|
let rec cmp_gen cmp a b =
|
|
match a (), b () with
|
|
| None, None -> 0
|
|
| Some _, None -> 1
|
|
| None, Some _ -> -1
|
|
| Some (ka, va), Some (kb, vb) ->
|
|
if ka = kb then (
|
|
let c = cmp va vb in
|
|
if c = 0 then
|
|
cmp_gen cmp a b
|
|
else
|
|
c
|
|
) else
|
|
compare ka kb
|
|
in
|
|
cmp_gen cmp (to_gen a) (to_gen b)
|
|
|
|
let rec add_seq m l =
|
|
match l () with
|
|
| Seq.Nil -> m
|
|
| Seq.Cons ((k, v), tl) -> add_seq (add k v m) tl
|
|
|
|
let of_seq l = add_seq empty l
|
|
|
|
let to_seq m =
|
|
(* [st]: stack of alternatives *)
|
|
let rec explore st m () =
|
|
match m with
|
|
| E -> next st ()
|
|
| L (k, v) -> Seq.Cons ((k, v), next st)
|
|
| N (_, _, l, r) -> explore (r :: st) l ()
|
|
and next st () =
|
|
match st with
|
|
| [] -> Seq.Nil
|
|
| x :: st' -> explore st' x ()
|
|
in
|
|
next [ m ]
|
|
|
|
type 'a tree = unit -> [ `Nil | `Node of 'a * 'a tree list ]
|
|
|
|
let rec as_tree t () =
|
|
match t with
|
|
| E -> `Nil
|
|
| L (k, v) -> `Node (`Leaf (k, v), [])
|
|
| N (prefix, switch, l, r) ->
|
|
`Node (`Node (prefix, (switch :> int)), [ as_tree l; as_tree r ])
|
|
|
|
(** {2 IO} *)
|
|
|
|
type 'a printer = Format.formatter -> 'a -> unit
|
|
|
|
let pp pp_x out m =
|
|
Format.fprintf out "@[<hov2>intmap {@,";
|
|
let first = ref true in
|
|
iter
|
|
(fun k v ->
|
|
if !first then
|
|
first := false
|
|
else
|
|
Format.pp_print_string out ", ";
|
|
Format.fprintf out "%d -> " k;
|
|
pp_x out v;
|
|
Format.pp_print_cut out ())
|
|
m;
|
|
Format.fprintf out "}@]"
|
|
|
|
(* Some thorough tests from Jan Midtgaar
|
|
https://github.com/jmid/qc-ptrees
|
|
*)
|