mirror of
https://github.com/c-cube/ocaml-containers.git
synced 2025-12-10 13:13:56 -05:00
519 lines
14 KiB
OCaml
519 lines
14 KiB
OCaml
(* This file is free software, part of containers. See file "license" for more details. *)
|
||
|
||
(** {1 Weight-Balanced Tree}
|
||
|
||
Most of this comes from "implementing sets efficiently in a functional language",
|
||
Stephen Adams.
|
||
|
||
The coefficients 5/2, 3/2 for balancing come from "balancing weight-balanced trees"
|
||
*)
|
||
|
||
type 'a iter = ('a -> unit) -> unit
|
||
type 'a gen = unit -> 'a option
|
||
type 'a printer = Format.formatter -> 'a -> unit
|
||
|
||
module type ORD = sig
|
||
type t
|
||
|
||
val compare : t -> t -> int
|
||
end
|
||
|
||
module type KEY = sig
|
||
include ORD
|
||
|
||
val weight : t -> int
|
||
end
|
||
|
||
(** {2 Signature} *)
|
||
|
||
module type S = sig
|
||
type key
|
||
type +'a t
|
||
|
||
val empty : 'a t
|
||
val is_empty : _ t -> bool
|
||
val singleton : key -> 'a -> 'a t
|
||
val mem : key -> _ t -> bool
|
||
val get : key -> 'a t -> 'a option
|
||
|
||
val get_exn : key -> 'a t -> 'a
|
||
(** @raise Not_found if the key is not present *)
|
||
|
||
val nth : int -> 'a t -> (key * 'a) option
|
||
(** [nth i m] returns the [i]-th [key, value] in the ascending
|
||
order. Complexity is [O(log (cardinal m))] *)
|
||
|
||
val nth_exn : int -> 'a t -> key * 'a
|
||
(** @raise Not_found if the index is invalid *)
|
||
|
||
val get_rank : key -> 'a t -> [ `At of int | `After of int | `First ]
|
||
(** [get_rank k m] looks for the rank of [k] in [m], i.e. the index
|
||
of [k] in the sorted list of bindings of [m].
|
||
[let (`At n) = get_rank k m in nth_exn n m = get m k] should hold.
|
||
@since 1.4 *)
|
||
|
||
val add : key -> 'a -> 'a t -> 'a t
|
||
val remove : key -> 'a t -> 'a t
|
||
|
||
val update : key -> ('a option -> 'a option) -> 'a t -> 'a t
|
||
(** [update k f m] calls [f (Some v)] if [get k m = Some v], [f None]
|
||
otherwise. Then, if [f] returns [Some v'] it binds [k] to [v'],
|
||
if [f] returns [None] it removes [k] *)
|
||
|
||
val cardinal : _ t -> int
|
||
val weight : _ t -> int
|
||
val fold : f:('b -> key -> 'a -> 'b) -> x:'b -> 'a t -> 'b
|
||
|
||
val mapi : f:(key -> 'a -> 'b) -> 'a t -> 'b t
|
||
(** Map values, giving both key and value.
|
||
@since 0.17
|
||
*)
|
||
|
||
val map : f:('a -> 'b) -> 'a t -> 'b t
|
||
(** Map values, giving only the value.
|
||
@since 0.17
|
||
*)
|
||
|
||
val iter : f:(key -> 'a -> unit) -> 'a t -> unit
|
||
|
||
val split : key -> 'a t -> 'a t * 'a option * 'a t
|
||
(** [split k t] returns [l, o, r] where [l] is the part of the map
|
||
with keys smaller than [k], [r] has keys bigger than [k],
|
||
and [o = Some v] if [k, v] belonged to the map *)
|
||
|
||
val merge :
|
||
f:(key -> 'a option -> 'b option -> 'c option) -> 'a t -> 'b t -> 'c t
|
||
(** Like {!Map.S.merge} *)
|
||
|
||
val extract_min : 'a t -> key * 'a * 'a t
|
||
(** [extract_min m] returns [k, v, m'] where [k,v] is the pair with the
|
||
smallest key in [m], and [m'] does not contain [k].
|
||
@raise Not_found if the map is empty *)
|
||
|
||
val extract_max : 'a t -> key * 'a * 'a t
|
||
(** [extract_max m] returns [k, v, m'] where [k,v] is the pair with the
|
||
highest key in [m], and [m'] does not contain [k].
|
||
@raise Not_found if the map is empty *)
|
||
|
||
val choose : 'a t -> (key * 'a) option
|
||
|
||
val choose_exn : 'a t -> key * 'a
|
||
(** @raise Not_found if the tree is empty *)
|
||
|
||
val random_choose : Random.State.t -> 'a t -> key * 'a
|
||
(** Randomly choose a (key,value) pair within the tree, using weights
|
||
as probability weights
|
||
@raise Not_found if the tree is empty *)
|
||
|
||
val add_list : 'a t -> (key * 'a) list -> 'a t
|
||
val of_list : (key * 'a) list -> 'a t
|
||
val to_list : 'a t -> (key * 'a) list
|
||
val add_iter : 'a t -> (key * 'a) iter -> 'a t
|
||
val of_iter : (key * 'a) iter -> 'a t
|
||
val to_iter : 'a t -> (key * 'a) iter
|
||
val add_gen : 'a t -> (key * 'a) gen -> 'a t
|
||
val of_gen : (key * 'a) gen -> 'a t
|
||
val to_gen : 'a t -> (key * 'a) gen
|
||
|
||
val pp :
|
||
?pp_start:unit printer ->
|
||
?pp_stop:unit printer ->
|
||
?pp_arrow:unit printer ->
|
||
?pp_sep:unit printer ->
|
||
key printer ->
|
||
'a printer ->
|
||
'a t printer
|
||
|
||
(**/**)
|
||
|
||
val node_ : key -> 'a -> 'a t -> 'a t -> 'a t
|
||
val balanced : _ t -> bool
|
||
|
||
(**/**)
|
||
end
|
||
|
||
module MakeFull (K : KEY) : S with type key = K.t = struct
|
||
type key = K.t
|
||
type weight = int
|
||
type +'a t = E | N of key * 'a * 'a t * 'a t * weight
|
||
|
||
let empty = E
|
||
|
||
let is_empty = function
|
||
| E -> true
|
||
| N _ -> false
|
||
|
||
let rec get_exn k m =
|
||
match m with
|
||
| E -> raise Not_found
|
||
| N (k', v, l, r, _) ->
|
||
(match K.compare k k' with
|
||
| 0 -> v
|
||
| n when n < 0 -> get_exn k l
|
||
| _ -> get_exn k r)
|
||
|
||
let get k m = try Some (get_exn k m) with Not_found -> None
|
||
|
||
let mem k m =
|
||
try
|
||
ignore (get_exn k m);
|
||
true
|
||
with Not_found -> false
|
||
|
||
let singleton k v = N (k, v, E, E, K.weight k)
|
||
|
||
let weight = function
|
||
| E -> 0
|
||
| N (_, _, _, _, w) -> w
|
||
|
||
(* balancing parameters.
|
||
|
||
We take the parameters from "Balancing weight-balanced trees", as they
|
||
are rational and efficient. *)
|
||
|
||
(* delta=5/2
|
||
delta × (weight l + 1) ≥ weight r + 1
|
||
*)
|
||
let is_balanced l r = 5 * (weight l + 1) >= 2 * (weight r + 1)
|
||
|
||
(* gamma = 3/2
|
||
weight l + 1 < gamma × (weight r + 1) *)
|
||
let is_single l r = 2 * (weight l + 1) < 3 * (weight r + 1)
|
||
|
||
(* debug function *)
|
||
let rec balanced = function
|
||
| E -> true
|
||
| N (_, _, l, r, _) ->
|
||
is_balanced l r && is_balanced r l && balanced l && balanced r
|
||
|
||
(* smart constructor *)
|
||
let mk_node_ k v l r = N (k, v, l, r, weight l + weight r + K.weight k)
|
||
|
||
let single_l k1 v1 t1 t2 =
|
||
match t2 with
|
||
| E -> assert false
|
||
| N (k2, v2, t2, t3, _) -> mk_node_ k2 v2 (mk_node_ k1 v1 t1 t2) t3
|
||
|
||
let double_l k1 v1 t1 t2 =
|
||
match t2 with
|
||
| N (k2, v2, N (k3, v3, t2, t3, _), t4, _) ->
|
||
mk_node_ k3 v3 (mk_node_ k1 v1 t1 t2) (mk_node_ k2 v2 t3 t4)
|
||
| _ -> assert false
|
||
|
||
let rotate_l k v l r =
|
||
match r with
|
||
| E -> assert false
|
||
| N (_, _, rl, rr, _) ->
|
||
if is_single rl rr then
|
||
single_l k v l r
|
||
else
|
||
double_l k v l r
|
||
|
||
(* balance towards left *)
|
||
let balance_l k v l r =
|
||
if is_balanced l r then
|
||
mk_node_ k v l r
|
||
else
|
||
rotate_l k v l r
|
||
|
||
let single_r k1 v1 t1 t2 =
|
||
match t1 with
|
||
| E -> assert false
|
||
| N (k2, v2, t11, t12, _) -> mk_node_ k2 v2 t11 (mk_node_ k1 v1 t12 t2)
|
||
|
||
let double_r k1 v1 t1 t2 =
|
||
match t1 with
|
||
| N (k2, v2, t11, N (k3, v3, t121, t122, _), _) ->
|
||
mk_node_ k3 v3 (mk_node_ k2 v2 t11 t121) (mk_node_ k1 v1 t122 t2)
|
||
| _ -> assert false
|
||
|
||
let rotate_r k v l r =
|
||
match l with
|
||
| E -> assert false
|
||
| N (_, _, ll, lr, _) ->
|
||
if is_single lr ll then
|
||
single_r k v l r
|
||
else
|
||
double_r k v l r
|
||
|
||
(* balance toward right *)
|
||
let balance_r k v l r =
|
||
if is_balanced r l then
|
||
mk_node_ k v l r
|
||
else
|
||
rotate_r k v l r
|
||
|
||
let rec add k v m =
|
||
match m with
|
||
| E -> singleton k v
|
||
| N (k', v', l, r, _) ->
|
||
(match K.compare k k' with
|
||
| 0 -> mk_node_ k v l r
|
||
| n when n < 0 -> balance_r k' v' (add k v l) r
|
||
| _ -> balance_l k' v' l (add k v r))
|
||
|
||
(* extract min binding of the tree *)
|
||
let rec extract_min m =
|
||
match m with
|
||
| E -> raise Not_found
|
||
| N (k, v, E, r, _) -> k, v, r
|
||
| N (k, v, l, r, _) ->
|
||
let k', v', l' = extract_min l in
|
||
k', v', balance_l k v l' r
|
||
|
||
(* extract max binding of the tree *)
|
||
let rec extract_max m =
|
||
match m with
|
||
| E -> raise Not_found
|
||
| N (k, v, l, E, _) -> k, v, l
|
||
| N (k, v, l, r, _) ->
|
||
let k', v', r' = extract_max r in
|
||
k', v', balance_r k v l r'
|
||
|
||
let rec remove k m =
|
||
match m with
|
||
| E -> E
|
||
| N (k', v', l, r, _) ->
|
||
(match K.compare k k' with
|
||
| 0 ->
|
||
(match l, r with
|
||
| E, E -> E
|
||
| E, o | o, E -> o
|
||
| _, _ ->
|
||
if weight l > weight r then (
|
||
(* remove max element of [l] and put it at the root,
|
||
then rebalance towards the left if needed *)
|
||
let k', v', l' = extract_max l in
|
||
balance_l k' v' l' r
|
||
) else (
|
||
(* remove min element of [r] and rebalance *)
|
||
let k', v', r' = extract_min r in
|
||
balance_r k' v' l r'
|
||
))
|
||
| n when n < 0 -> balance_l k' v' (remove k l) r
|
||
| _ -> balance_r k' v' l (remove k r))
|
||
|
||
let update k f m =
|
||
let maybe_v = get k m in
|
||
match maybe_v, f maybe_v with
|
||
| None, None -> m
|
||
| Some _, None -> remove k m
|
||
| _, Some v -> add k v m
|
||
|
||
let rec nth_exn i m =
|
||
match m with
|
||
| E -> raise Not_found
|
||
| N (k, v, l, r, w) ->
|
||
let c = i - weight l in
|
||
(match c with
|
||
| 0 -> k, v
|
||
| n when n < 0 -> nth_exn i l (* search left *)
|
||
| _ ->
|
||
(* means c< K.weight k *)
|
||
if i < w - weight r then
|
||
k, v
|
||
else
|
||
nth_exn (i + weight r - w) r)
|
||
|
||
let nth i m = try Some (nth_exn i m) with Not_found -> None
|
||
|
||
let get_rank k m =
|
||
let rec aux i k m =
|
||
match m with
|
||
| E ->
|
||
if i = 0 then
|
||
`First
|
||
else
|
||
`After i
|
||
| N (k', _, l, r, _) ->
|
||
(match K.compare k k' with
|
||
| 0 -> `At (i + weight l)
|
||
| n when n < 0 -> aux i k l
|
||
| _ -> aux (1 + weight l + i) k r)
|
||
in
|
||
aux 0 k m
|
||
|
||
let rec fold ~f ~x:acc m =
|
||
match m with
|
||
| E -> acc
|
||
| N (k, v, l, r, _) ->
|
||
let acc = fold ~f ~x:acc l in
|
||
let acc = f acc k v in
|
||
fold ~f ~x:acc r
|
||
|
||
let rec mapi ~f = function
|
||
| E -> E
|
||
| N (k, v, l, r, w) -> N (k, f k v, mapi ~f l, mapi ~f r, w)
|
||
|
||
let rec map ~f = function
|
||
| E -> E
|
||
| N (k, v, l, r, w) -> N (k, f v, map ~f l, map ~f r, w)
|
||
|
||
let rec iter ~f m =
|
||
match m with
|
||
| E -> ()
|
||
| N (k, v, l, r, _) ->
|
||
iter ~f l;
|
||
f k v;
|
||
iter ~f r
|
||
|
||
let choose_exn = function
|
||
| E -> raise Not_found
|
||
| N (k, v, _, _, _) -> k, v
|
||
|
||
let choose = function
|
||
| E -> None
|
||
| N (k, v, _, _, _) -> Some (k, v)
|
||
|
||
(* pick an index within [0.. weight m-1] and get the element with
|
||
this index *)
|
||
let random_choose st m =
|
||
let w = weight m in
|
||
if w = 0 then raise Not_found;
|
||
nth_exn (Random.State.int st w) m
|
||
|
||
(* make a node (k,v,l,r) but balances on whichever side requires it *)
|
||
let node_shallow_ k v l r =
|
||
if is_balanced l r then
|
||
if is_balanced r l then
|
||
mk_node_ k v l r
|
||
else
|
||
balance_r k v l r
|
||
else
|
||
balance_l k v l r
|
||
|
||
(* assume keys of [l] are smaller than [k] and [k] smaller than keys of [r],
|
||
but do not assume anything about weights.
|
||
returns a tree with l, r, and (k,v) *)
|
||
let rec node_ k v l r =
|
||
match l, r with
|
||
| E, E -> singleton k v
|
||
| E, o | o, E -> add k v o
|
||
| N (kl, vl, ll, lr, _), N (kr, vr, rl, rr, _) ->
|
||
let left = is_balanced l r in
|
||
if left && is_balanced r l then
|
||
mk_node_ k v l r
|
||
else if not left then
|
||
node_shallow_ kr vr (node_ k v l rl) rr
|
||
else
|
||
node_shallow_ kl vl ll (node_ k v lr r)
|
||
|
||
(* join two trees, assuming all keys of [l] are smaller than keys of [r] *)
|
||
let join_ l r =
|
||
match l, r with
|
||
| E, E -> E
|
||
| E, o | o, E -> o
|
||
| N _, N _ ->
|
||
if weight l <= weight r then (
|
||
let k, v, r' = extract_min r in
|
||
node_ k v l r'
|
||
) else (
|
||
let k, v, l' = extract_max l in
|
||
node_ k v l' r
|
||
)
|
||
|
||
(* if [o_v = Some v], behave like [mk_node k v l r]
|
||
else behave like [join_ l r] *)
|
||
let mk_node_or_join_ k o_v l r =
|
||
match o_v with
|
||
| None -> join_ l r
|
||
| Some v -> node_ k v l r
|
||
|
||
let rec split k m =
|
||
match m with
|
||
| E -> E, None, E
|
||
| N (k', v', l, r, _) ->
|
||
(match K.compare k k' with
|
||
| 0 -> l, Some v', r
|
||
| n when n < 0 ->
|
||
let ll, o, lr = split k l in
|
||
ll, o, node_ k' v' lr r
|
||
| _ ->
|
||
let rl, o, rr = split k r in
|
||
node_ k' v' l rl, o, rr)
|
||
|
||
let rec merge ~f a b =
|
||
match a, b with
|
||
| E, E -> E
|
||
| E, N (k, v, l, r, _) ->
|
||
let v' = f k None (Some v) in
|
||
mk_node_or_join_ k v' (merge ~f E l) (merge ~f E r)
|
||
| N (k, v, l, r, _), E ->
|
||
let v' = f k (Some v) None in
|
||
mk_node_or_join_ k v' (merge ~f l E) (merge ~f r E)
|
||
| N (k1, v1, l1, r1, w1), N (k2, v2, l2, r2, w2) ->
|
||
if K.compare k1 k2 = 0 then
|
||
(* easy case *)
|
||
mk_node_or_join_ k1 (f k1 (Some v1) (Some v2)) (merge ~f l1 l2)
|
||
(merge ~f r1 r2)
|
||
else if w1 <= w2 then (
|
||
(* split left tree *)
|
||
let l1', v1', r1' = split k2 a in
|
||
mk_node_or_join_ k2 (f k2 v1' (Some v2)) (merge ~f l1' l2)
|
||
(merge ~f r1' r2)
|
||
) else (
|
||
(* split right tree *)
|
||
let l2', v2', r2' = split k1 b in
|
||
mk_node_or_join_ k1 (f k1 (Some v1) v2') (merge ~f l1 l2')
|
||
(merge ~f r1 r2')
|
||
)
|
||
|
||
let cardinal m = fold ~f:(fun acc _ _ -> acc + 1) ~x:0 m
|
||
let add_list m l = List.fold_left (fun acc (k, v) -> add k v acc) m l
|
||
let of_list l = add_list empty l
|
||
let to_list m = fold ~f:(fun acc k v -> (k, v) :: acc) ~x:[] m
|
||
|
||
let add_iter m seq =
|
||
let m = ref m in
|
||
seq (fun (k, v) -> m := add k v !m);
|
||
!m
|
||
|
||
let of_iter s = add_iter empty s
|
||
let to_iter m yield = iter ~f:(fun k v -> yield (k, v)) m
|
||
|
||
let rec add_gen m g =
|
||
match g () with
|
||
| None -> m
|
||
| Some (k, v) -> add_gen (add k v m) g
|
||
|
||
let of_gen g = add_gen empty g
|
||
|
||
let to_gen m =
|
||
let st = Stack.create () in
|
||
Stack.push m st;
|
||
let rec next () =
|
||
if Stack.is_empty st then
|
||
None
|
||
else (
|
||
match Stack.pop st with
|
||
| E -> next ()
|
||
| N (k, v, l, r, _) ->
|
||
Stack.push r st;
|
||
Stack.push l st;
|
||
Some (k, v)
|
||
)
|
||
in
|
||
next
|
||
|
||
let pp ?(pp_start = fun _ () -> ()) ?(pp_stop = fun _ () -> ())
|
||
?(pp_arrow = fun fmt () -> Format.fprintf fmt "@ -> ")
|
||
?(pp_sep = fun fmt () -> Format.fprintf fmt ",@ ") pp_k pp_v fmt m =
|
||
pp_start fmt ();
|
||
let first = ref true in
|
||
iter m ~f:(fun k v ->
|
||
if !first then
|
||
first := false
|
||
else
|
||
pp_sep fmt ();
|
||
pp_k fmt k;
|
||
pp_arrow fmt ();
|
||
pp_v fmt v;
|
||
Format.pp_print_cut fmt ());
|
||
pp_stop fmt ()
|
||
end
|
||
|
||
module Make (X : ORD) = MakeFull (struct
|
||
include X
|
||
|
||
let weight _ = 1
|
||
end)
|