diff --git a/splayTree.ml b/splayTree.ml index 6777ed85..f520e56a 100644 --- a/splayTree.ml +++ b/splayTree.ml @@ -43,23 +43,98 @@ let is_empty (tree, _) = | Empty -> true | Node _ -> false -let rec bigger ~cmp pivot tree = +(** Partition the tree into (elements <= pivot, elements > pivot) *) +let rec partition ~cmp pivot tree = match tree with - | Empty -> Empty + | Empty -> Empty, Empty | Node (a, x, x_val, b) -> if cmp x pivot <= 0 - then bigger ~cmp pivot b - else match a with - | Empty -> Node (Empty, x, x_val, b) + then begin + match b with + | Empty -> (tree, Empty) + | Node (b1, y, y_val, b2) -> + if cmp y pivot <= 0 + then + let small, big = partition ~cmp pivot b2 in + Node (Node (a, x, x_val, b1), y, y_val, small), big + else + let small, big = partition ~cmp pivot b1 in + Node (a, x, x_val, small), Node (big, y, y_val, b2) + end else begin + match a with + | Empty -> (Empty, tree) | Node (a1, y, y_val, a2) -> if cmp y pivot <= 0 - then Node (bigger ~cmp pivot a2, x, x_val, b) - else Node (bigger ~cmp pivot a1, y, y_val, Node (a2, x, x_val, b)) - -let rec smaller ~cmp pivot tree = + then + let small, big = partition ~cmp pivot a2 in + Node (a1, y, y_val, small), Node (big, x, x_val, b) + else + let small, big = partition ~cmp pivot a1 in + small, Node (big, y, y_val, Node (a2, x, x_val, b)) + end (** Insert the pair (key -> value) in the tree *) let insert (tree, cmp) k v = - let tree' = Node (smaller ~cmp k tree, k, v, bigger ~cmp k tree) in + let small, big = partition ~cmp k tree in + let tree' = Node (small, k, v, big) in tree', cmp +let remove (tree, cmp) k = failwith "not implemented" + +let replace (tree, cmp) k = failwith "not implemented" + +(** Returns the top value, or raise Not_found is empty *) +let top (tree, _) = + match tree with + | Empty -> raise Not_found + | Node (_, k, v, _) -> k, v + +(** Access minimum value *) +let min (tree, _) = + let rec min tree = + match tree with + | Empty -> raise Not_found + | Node (Empty, k, v, _) -> k, v + | Node (l, _, _, _) -> min l + in min tree + +(** Get minimum value and remove it from the tree *) +let delete_min (tree, cmp) = + let rec delete_min tree = match tree with + | Empty -> raise Not_found + | Node (Empty, x, x_val, b) -> x, x_val, b + | Node (Node (Empty, x, x_val, b), y, y_val, c) -> + x, x_val, Node (b, y, y_val, c) (* rebalance *) + | Node (Node (a, x, x_val, b), y, y_val, c) -> + let m, m_val, a' = delete_min a in + m, m_val, Node (a', x, x_val, Node (b, y, y_val, c)) + in + let m, m_val, tree' = delete_min tree in + m, m_val, (tree', cmp) + +(** Find the value for the given key (or raise Not_found). + It also returns the splayed tree *) +let find (tree, cmp) k = + failwith "not implemented" + +let find_fold (tree, cmp) k f acc = + acc (* TODO *) + +(** Iterate on elements *) +let iter (tree, _) f = + let rec iter tree = + match tree with + | Empty -> () + | Node (a, x, x_val, b) -> + iter a; + f x x_val; + iter b + in iter tree + +(** Number of elements (linear) *) +let size t = + let r = ref 0 in + iter t (fun _ _ -> incr r); + !r + +let get_cmp (_, cmp) = cmp diff --git a/splayTree.mli b/splayTree.mli index 08adcc44..dab3b18e 100644 --- a/splayTree.mli +++ b/splayTree.mli @@ -31,7 +31,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. type ('a, 'b) t (** A functional splay tree *) -val empty : cmp:('a -> 'a -> bool) -> ('a, 'b) t +val empty : cmp:('a -> 'a -> int) -> ('a, 'b) t (** Empty splay tree using the given comparison function *) val is_empty : (_, _) t -> bool @@ -40,26 +40,34 @@ val is_empty : (_, _) t -> bool val insert : ('a, 'b) t -> 'a -> 'b -> ('a, 'b) t (** Insert the pair (key -> value) in the tree *) +val remove : ('a, 'b) t -> 'a -> ('a, 'b) t + (** Remove an element by its key, returns the splayed tree *) + val replace : ('a, 'b) t -> 'a -> 'b -> ('a, 'b) t (** Insert the pair (key -> value) into the tree, replacing the previous binding (if any). It replaces at most one binding. *) -val remove : ('a, 'b) t -> 'a -> ('a, 'b) t - (** Remove an element by its key, returns the splayed tree *) - -val top : ('a, b') t -> 'a * 'b +val top : ('a, 'b) t -> 'a * 'b (** Returns the top value, or raise Not_found is empty *) -val min : ('a, 'b) t -> 'a * 'b * ('a, b') t +val min : ('a, 'b) t -> 'a * 'b (** Access minimum value *) +val delete_min : ('a, 'b) t -> 'a * 'b * ('a, 'b) t + (** Get minimum value and remove it from the tree *) + val find : ('a, 'b) t -> 'a -> 'b * ('a, 'b) t (** Find the value for the given key (or raise Not_found). It also returns the splayed tree *) +val find_fold : ('a, 'b) t -> 'a -> ('c -> 'b -> 'c) -> 'c -> 'c + (** Fold on all values associated with the given key *) + +val iter : ('a, 'b) t -> ('a -> 'b -> unit) -> unit + (** Iterate on elements *) + val size : (_, _) t -> int (** Number of elements (linear) *) -val iter : ('a -> 'b -> unit) -> ('a, 'b) t -> unit - (** Iterate on elements *) +val get_cmp : ('a, _) t -> ('a -> 'a -> int)