diff --git a/heap.ml b/heap.ml index 61b267ae..26ba5244 100644 --- a/heap.ml +++ b/heap.ml @@ -25,43 +25,140 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. (** {1 Imperative priority queue} *) -type 'a t = ('a, unit) SplayTree.t ref +module Tree = struct + + type 'a t = 'a tree * ('a -> 'a -> int) + (** A splay tree with the given comparison function *) + and 'a tree = + | Empty + | Node of ('a tree * 'a * 'a tree) + (** A splay tree containing values of type 'a *) + + let empty ~cmp = + (Empty, cmp) + + let is_empty (tree, _) = + match tree with + | Empty -> true + | Node _ -> false + + (** Partition the tree into (elements <= pivot, elements > pivot) *) + let rec partition ~cmp pivot tree = + match tree with + | Empty -> Empty, Empty + | Node (a, x, b) -> + if cmp x pivot <= 0 + then begin + match b with + | Empty -> (tree, Empty) + | Node (b1, y, b2) -> + if cmp y pivot <= 0 + then + let small, big = partition ~cmp pivot b2 in + Node (Node (a, x, b1), y, small), big + else + let small, big = partition ~cmp pivot b1 in + Node (a, x, small), Node (big, y, b2) + end else begin + match a with + | Empty -> (Empty, tree) + | Node (a1, y, a2) -> + if cmp y pivot <= 0 + then + let small, big = partition ~cmp pivot a2 in + Node (a1, y, small), Node (big, x, b) + else + let small, big = partition ~cmp pivot a1 in + small, Node (big, y, Node (a2, x, b)) + end + + (** Insert the element in the tree *) + let insert (tree, cmp) x = + let small, big = partition ~cmp x tree in + let tree' = Node (small, x, big) in + tree', cmp + + (** Returns the top value, or raise Not_found is empty *) + let top (tree, _) = + match tree with + | Empty -> raise Not_found + | Node (_, x, _) -> x + + (** Access minimum value *) + let min (tree, _) = + let rec min tree = + match tree with + | Empty -> raise Not_found + | Node (Empty, x, _) -> x + | Node (l, _, _) -> min l + in min tree + + (** Get minimum value and remove it from the tree *) + let delete_min (tree, cmp) = + let rec delete_min tree = match tree with + | Empty -> raise Not_found + | Node (Empty, x, b) -> x, b + | Node (Node (Empty, x, b), y, c) -> + x, Node (b, y, c) (* rebalance *) + | Node (Node (a, x, b), y, c) -> + let m, a' = delete_min a in + m, Node (a', x, Node (b, y, c)) + in + let m, tree' = delete_min tree in + m, (tree', cmp) + + (** Iterate on elements *) + let iter (tree, _) f = + let rec iter tree = + match tree with + | Empty -> () + | Node (a, x, b) -> + iter a; f x; iter b + in iter tree +end + +type 'a t = 'a Tree.t ref (** The heap is a reference to a splay tree *) (** Create an empty heap *) let empty ~cmp = - ref (SplayTree.empty ~cmp) + ref (Tree.empty ~cmp) (** Insert a value in the heap *) let insert heap x = - heap := SplayTree.insert !heap x () + heap := Tree.insert !heap x (** Check whether the heap is empty *) let is_empty heap = - SplayTree.is_empty !heap + Tree.is_empty !heap (** Access the minimal value of the heap, or raises Empty *) let min (heap : 'a t) : 'a = - let elt, _ = SplayTree.min !heap in + let elt = Tree.min !heap in elt (** Discard the minimal element *) let junk heap = - let _, (), tree' = SplayTree.delete_min !heap in + let _, tree' = Tree.delete_min !heap in heap := tree' (** Remove and return the mininal value (or raise Invalid_argument) *) let pop heap = - let elt, (), tree' = SplayTree.delete_min !heap in + let elt, tree' = Tree.delete_min !heap in heap := tree'; elt (** Iterate on the elements, in an unspecified order *) let iter heap k = - SplayTree.iter !heap (fun elt _ -> k elt) + Tree.iter !heap (fun elt -> k elt) + +let size heap = + let r = ref 0 in + iter heap (fun _ -> incr r); + !r let to_seq heap = - Sequence.from_iter (fun k -> iter heap k) + fun k -> iter heap k let of_seq heap seq = - Sequence.iter (fun elt -> insert heap elt) seq + seq (fun elt -> insert heap elt) diff --git a/heap.mli b/heap.mli index b0f3d609..24db4b77 100644 --- a/heap.mli +++ b/heap.mli @@ -49,6 +49,8 @@ val pop : 'a t -> 'a val iter : 'a t -> ('a -> unit) -> unit (** Iterate on the elements, in an unspecified order *) +val size : _ t -> int + val to_seq : 'a t -> 'a Sequence.t val of_seq : 'a t -> 'a Sequence.t -> unit diff --git a/tests/test_heap.ml b/tests/test_heap.ml index 3b24fcdb..ba6b1878 100644 --- a/tests/test_heap.ml +++ b/tests/test_heap.ml @@ -5,28 +5,32 @@ open Helpers let test_empty () = let h = Heap.empty ~cmp:(fun x y -> x - y) in - OUnit.assert_bool "is_empty empty" (Heap.is_empty h) + OUnit.assert_bool "is_empty empty" (Heap.is_empty h); + Heap.insert h 42; + OUnit.assert_bool "not empty" (not (Heap.is_empty h)); + () let test_sort () = let h = Heap.empty ~cmp:(fun x y -> x - y) in (* Heap sort *) let l = [3;4;2;1;6;5;0;7;10;9;8] in Heap.of_seq h (Sequence.of_list l); + OUnit.assert_equal ~printer:string_of_int 11 (Heap.size h); let l' = Sequence.to_list (Heap.to_seq h) in - OUnit.assert_equal ~printer:print_int_list l' [0;1;2;3;4;5;6;7;8;9;10] + OUnit.assert_equal ~printer:print_int_list [0;1;2;3;4;5;6;7;8;9;10] l' let test_remove () = let h = Heap.empty ~cmp:(fun x y -> x - y) in let l = [3;4;2;1;6;5;0;7;10;9;8] in Heap.of_seq h (Sequence.of_list l); (* check pop *) - OUnit.assert_equal (Heap.pop h) 0; - OUnit.assert_equal (Heap.pop h) 1; - OUnit.assert_equal (Heap.pop h) 2; - OUnit.assert_equal (Heap.pop h) 3; + OUnit.assert_equal 0 (Heap.pop h); + OUnit.assert_equal 1 (Heap.pop h); + OUnit.assert_equal 2 (Heap.pop h); + OUnit.assert_equal 3 (Heap.pop h); (* check that elements have been removed *) let l' = Sequence.to_list (Heap.to_seq h) in - OUnit.assert_equal ~printer:print_int_list l' [4;5;6;7;8;9;10] + OUnit.assert_equal ~printer:print_int_list [4;5;6;7;8;9;10] l' let suite = "test_heaps" >:::