diff --git a/enum.ml b/enum.ml index 8a72da94..8c80f019 100644 --- a/enum.ml +++ b/enum.ml @@ -309,10 +309,106 @@ let merge enum = x in next +(** {3 Mutable heap (taken from heap.ml to avoid dependencies)} *) +module Heap = struct + type 'a t = { + mutable tree : 'a tree; + cmp : 'a -> 'a -> int; + } (** A splay tree heap with the given comparison function *) + and 'a tree = + | Empty + | Node of ('a tree * 'a * 'a tree) + (** A splay tree containing values of type 'a *) + + let empty ~cmp = { + tree = Empty; + cmp; + } + + let is_empty h = + match h.tree with + | Empty -> true + | Node _ -> false + + (** Partition the tree into (elements <= pivot, elements > pivot) *) + let rec partition ~cmp pivot tree = + match tree with + | Empty -> Empty, Empty + | Node (a, x, b) -> + if cmp x pivot <= 0 + then begin + match b with + | Empty -> (tree, Empty) + | Node (b1, y, b2) -> + if cmp y pivot <= 0 + then + let small, big = partition ~cmp pivot b2 in + Node (Node (a, x, b1), y, small), big + else + let small, big = partition ~cmp pivot b1 in + Node (a, x, small), Node (big, y, b2) + end else begin + match a with + | Empty -> (Empty, tree) + | Node (a1, y, a2) -> + if cmp y pivot <= 0 + then + let small, big = partition ~cmp pivot a2 in + Node (a1, y, small), Node (big, x, b) + else + let small, big = partition ~cmp pivot a1 in + small, Node (big, y, Node (a2, x, b)) + end + + (** Insert the element in the tree *) + let insert h x = + let small, big = partition ~cmp:h.cmp x h.tree in + let tree' = Node (small, x, big) in + h.tree <- tree' + + (** Get minimum value and remove it from the tree *) + let pop h = + let rec delete_min tree = match tree with + | Empty -> raise Not_found + | Node (Empty, x, b) -> x, b + | Node (Node (Empty, x, b), y, c) -> + x, Node (b, y, c) (* rebalance *) + | Node (Node (a, x, b), y, c) -> + let m, a' = delete_min a in + m, Node (a', x, Node (b, y, c)) + in + let m, tree' = delete_min h.tree in + h.tree <- tree'; + m +end + (** Assuming subsequences are sorted in increasing order, merge them into an increasing sequence *) let merge_sorted ?(cmp=compare) enum = - failwith "not implemented" + fun () -> + (* make a heap of (value, generator) *) + let cmp (v1,_) (v2,_) = cmp v1 v2 in + let heap = Heap.empty ~cmp in + (* add initial values *) + iter + (fun enum' -> + let gen = enum' () in + try + let x = gen () in + Heap.insert heap (x, gen) + with EOG -> ()) + enum; + fun () -> + if Heap.is_empty heap then raise EOG + else begin + let x, gen = Heap.pop heap in + try + let y = gen () in + Heap.insert heap (y, gen); (* insert next value *) + x + with EOG -> + x (* gen is empty *) + end (** {3 Mutable double-linked list, similar to {! Deque.t} *) module MList = struct diff --git a/enum.mli b/enum.mli index db87070e..96832752 100644 --- a/enum.mli +++ b/enum.mli @@ -143,6 +143,15 @@ val merge : 'a t t -> 'a t (** Pick elements fairly in each sub-enum. The given enum must be finite (not its elements, though). *) +(** {3 Mutable heap (taken from heap.ml to avoid dependencies)} *) +module Heap : sig + type 'a t (** A heap containing values of type 'a *) + val empty : cmp:('a -> 'a -> int) -> 'a t + val insert : 'a t -> 'a -> unit + val is_empty : 'a t -> bool + val pop : 'a t -> 'a +end + val merge_sorted : ?cmp:('a -> 'a -> int) -> 'a t t -> 'a t (** Assuming subsequences are sorted in increasing order, merge them into an increasing sequence *) diff --git a/tests/test_enum.ml b/tests/test_enum.ml index 064359f1..463cf055 100644 --- a/tests/test_enum.ml +++ b/tests/test_enum.ml @@ -86,6 +86,12 @@ let test_big_rr () = OUnit.assert_equal [333;333;333] l'; () +let test_merge_sorted () = + Enum.of_list [Enum.of_list [1;3;5]; Enum.of_list [0;1;1;3;4;6;10]; Enum.of_list [2;2;11]] + |> Enum.merge_sorted ?cmp:None + |> Enum.to_list + |> OUnit.assert_equal ~printer:Helpers.print_int_list [0;1;1;1;2;2;3;3;4;5;6;10;11] + let test_interleave () = let e1 = Enum.of_list [1;3;5;7;9] in let e2 = Enum.of_list [2;4;6;8;10] in @@ -117,6 +123,7 @@ let suite = "test_persistent" >:: test_persistent; "test_round_robin" >:: test_round_robin; "test_big_rr" >:: test_big_rr; + "test_merge_sorted" >:: test_merge_sorted; "test_interleave" >:: test_interleave; "test_intersperse" >:: test_intersperse; "test_product" >:: test_product;