diff --git a/src/data/CCIntMap.ml b/src/data/CCIntMap.ml index af62e8a4..26fba781 100644 --- a/src/data/CCIntMap.ml +++ b/src/data/CCIntMap.ml @@ -17,12 +17,13 @@ module Bit : sig val mask : mask:t -> int -> int (* zeroes the bit, puts all lower bits to 1 *) val lt : t -> t -> bool val gt : t -> t -> bool + val equal_int : int -> t -> bool end = struct type t = int let min_int = min_int - let equal = (=) + let equal : t -> t -> bool = Pervasives.(=) let rec highest_bit_naive x m = if x=m then m @@ -33,26 +34,45 @@ end = struct let highest x = if x<0 then min_int - else if Sys.word_size > 40 && x > mask_40_ - then (* remove least significant 40 bits *) + else if Sys.word_size > 40 && x > mask_40_ then ( + (* remove least significant 40 bits *) let x' = x land (lnot (mask_40_ -1)) in highest_bit_naive x' mask_40_ - else if x> mask_20_ - then (* small shortcut: remove least significant 20 bits *) + ) else if x> mask_20_ then ( + (* small shortcut: remove least significant 20 bits *) let x' = x land (lnot (mask_20_ -1)) in highest_bit_naive x' mask_20_ - else highest_bit_naive x 1 + ) else ( + highest_bit_naive x 1 + ) - let is_0 ~bit x = x land bit = 0 - let is_1 ~bit x = x land bit = bit + let[@inline] is_0 ~bit x = x land bit = 0 + let[@inline] is_1 ~bit x = x land bit = bit - let mask ~mask x = (x lor (mask -1)) land (lnot mask) + let[@inline] mask ~mask x = (x lor (mask -1)) land (lnot mask) (* low endian: let mask_ x ~mask = x land (mask - 1) *) - let gt a b = (b != min_int) && (a = min_int || a > b) - let lt a b = gt b a + let[@inline] gt a b = (b != min_int) && (a = min_int || a > b) + let[@inline] lt a b = gt b a + let equal_int = Pervasives.(=) end +(*$inject + let highest2 x : int = + let rec aux i = + if i=0 then i + else if 1 = (x lsr i) then 1 lsl i else aux (i-1) + in + if x<0 then min_int else aux (Sys.word_size-2) +*) + +(*$QR & ~count:1_000 + Q.int (fun x -> + if Bit.equal_int (highest2 x) (Bit.highest x) then true + else QCheck.Test.fail_reportf "x=%d, highest=%d, highest2=%d@." x + (Bit.highest x :> int) (highest2 x)) + *) + type 'a t = | E (* empty *) | L of int * 'a (* leaf *) @@ -60,10 +80,11 @@ type 'a t = let empty = E -let is_prefix_ ~prefix y ~bit = prefix = Bit.mask y ~mask:bit +let[@inline] is_prefix_ ~prefix y ~bit = + prefix = Bit.mask y ~mask:bit (*$inject - let _list_uniq = CCList.sort_uniq ~cmp:(fun a b-> Pervasives.compare (fst a)(fst b)) + let _list_uniq l = CCList.sort_uniq ~cmp:(fun a b-> Pervasives.compare (fst a)(fst b)) l *) (*$Q @@ -92,7 +113,7 @@ let is_prefix_ ~prefix y ~bit = prefix = Bit.mask y ~mask:bit *) (* low endian: let branching_bit_ a _ b _ = lowest_bit_ (a lxor b) *) -let branching_bit_ a b = Bit.highest (a lxor b) +let[@inline] branching_bit_ a b = Bit.highest (a lxor b) (* TODO use hint in branching_bit_ *) @@ -103,15 +124,14 @@ let check_invariants t = | L (k, _) -> List.for_all (fun (prefix, switch, side) -> - is_prefix_ ~prefix k ~bit:switch - && - match side with + is_prefix_ ~prefix k ~bit:switch && + begin match side with | `Left -> Bit.is_0 k ~bit:switch | `Right -> Bit.is_1 k ~bit:switch - ) path + end) + path | N (prefix, switch, l, r) -> - check_keys ((prefix, switch, `Left) :: path) l - && + check_keys ((prefix, switch, `Left) :: path) l && check_keys ((prefix, switch, `Right) :: path) r in check_keys [] t @@ -126,11 +146,13 @@ let rec find_exn k t = match t with | L (k', v) when k = k' -> v | L _ -> raise Not_found | N (prefix, m, l, r) -> - if is_prefix_ ~prefix k ~bit:m - then if Bit.is_0 k ~bit:m + if is_prefix_ ~prefix k ~bit:m then ( + if Bit.is_0 k ~bit:m then find_exn k l else find_exn k r - else raise Not_found + ) else ( + raise Not_found + ) (* XXX could test with lt_unsigned_? *) @@ -161,7 +183,7 @@ let mem k t = List.for_all (fun (k,_) -> mem k m) l) *) -let mk_node_ prefix switch l r = match l, r with +let[@inline] mk_node_ prefix switch l r = match l, r with | E, o | o, E -> o | _ -> N (prefix, switch, l, r) @@ -170,8 +192,7 @@ let mk_node_ prefix switch l r = match l, r with let join_ t1 p1 t2 p2 = let switch = branching_bit_ p1 p2 in let prefix = Bit.mask p1 ~mask:switch in - if Bit.is_0 p1 ~bit:switch - then ( + if Bit.is_0 p1 ~bit:switch then ( assert (Bit.is_1 p2 ~bit:switch); mk_node_ prefix switch t1 t2 ) else ( @@ -179,7 +200,7 @@ let join_ t1 p1 t2 p2 = mk_node_ prefix switch t2 t1 ) -let singleton k v = L (k, v) +let[@inline] singleton k v = L (k, v) (* c: conflict function *) let rec insert_ c k v t = match t with @@ -189,11 +210,13 @@ let rec insert_ c k v t = match t with then L (k, c ~old:v' v) else join_ t k' (L (k, v)) k | N (prefix, switch, l, r) -> - if is_prefix_ ~prefix k ~bit:switch - then if Bit.is_0 k ~bit:switch + if is_prefix_ ~prefix k ~bit:switch then ( + if Bit.is_0 k ~bit:switch then N(prefix, switch, insert_ c k v l, r) else N(prefix, switch, l, insert_ c k v r) - else join_ (L(k,v)) k t prefix + ) else ( + join_ (L(k,v)) k t prefix + ) let add k v t = insert_ (fun ~old:_ v -> v) k v t @@ -207,11 +230,13 @@ let rec remove k t = match t with | E -> E | L (k', _) -> if k=k' then E else t | N (prefix, switch, l, r) -> - if is_prefix_ ~prefix k ~bit:switch - then if Bit.is_0 k ~bit:switch + if is_prefix_ ~prefix k ~bit:switch then ( + if Bit.is_0 k ~bit:switch then mk_node_ prefix switch (remove k l) r else mk_node_ prefix switch l (remove k r) - else t (* not present *) + ) else ( + t (* not present *) + ) (*$Q & ~count:20 Q.(list (pair int int)) (fun l -> \ @@ -240,7 +265,9 @@ let update k f t = let doubleton k1 v1 k2 v2 = add k1 v1 (singleton k2 v2) -let rec equal ~eq a b = Pervasives.(==) a b || match a, b with +let rec equal ~eq a b = + Pervasives.(==) a b || + begin match a, b with | E, E -> true | L (ka, va), L (kb, vb) -> ka = kb && eq va vb | N (pa, sa, la, ra), N (pb, sb, lb, rb) -> @@ -248,6 +275,7 @@ let rec equal ~eq a b = Pervasives.(==) a b || match a, b with | E, _ | N _, _ | L _, _ -> false + end (*$Q Q.(list (pair int bool)) ( fun l -> \ @@ -289,26 +317,29 @@ let choose t = try Some (choose_exn t) with Not_found -> None +(** {2 Whole-collection operations} *) + let rec union f t1 t2 = - if Pervasives.(==) t1 t2 then t1 - else match t1, t2 with + match t1, t2 with | E, o | o, E -> o | L (k, v), o | o, L (k, v) -> (* insert k, v into o *) insert_ (fun ~old v -> f k old v) k v o | N (p1, m1, l1, r1), N (p2, m2, l2, r2) -> - if p1 = p2 && Bit.equal m1 m2 - then mk_node_ p1 m1 (union f l1 l2) (union f r1 r2) - else if Bit.gt m1 m2 && is_prefix_ ~prefix:p1 p2 ~bit:m1 - then if Bit.is_0 p2 ~bit:m1 + if p1 = p2 && Bit.equal m1 m2 then ( + mk_node_ p1 m1 (union f l1 l2) (union f r1 r2) + ) else if Bit.gt m1 m2 && is_prefix_ ~prefix:p1 p2 ~bit:m1 then ( + if Bit.is_0 p2 ~bit:m1 then N (p1, m1, union f l1 t2, r1) else N (p1, m1, l1, union f r1 t2) - else if Bit.lt m1 m2 && is_prefix_ ~prefix:p2 p1 ~bit:m2 - then if Bit.is_0 p1 ~bit:m2 + ) else if Bit.lt m1 m2 && is_prefix_ ~prefix:p2 p1 ~bit:m2 then ( + if Bit.is_0 p1 ~bit:m2 then N (p2, m2, union f t1 l2, r2) else N (p2, m2, l2, union f t1 r2) - else join_ t1 p1 t2 p2 + ) else ( + join_ t1 p1 t2 p2 + ) (*$Q & ~small:(fun (a,b) -> List.length a + List.length b) Q.(pair (list (pair int bool)) (list (pair int bool))) (fun (l1,l2) -> \ @@ -344,9 +375,30 @@ let rec union f t1 t2 = equal ~eq:(=) (of_list l) (union (fun _ a _ -> a) (of_list l)(of_list l))) *) +(*$inject + let union_l l1 l2 = + let l2' = List.filter (fun (x,_) -> not @@ List.mem_assoc x l1) l2 in + _list_uniq (l1 @ l2') + + let inter_l l1 l2 = + let l2' = List.filter (fun (x,_) -> List.mem_assoc x l1) l2 in + _list_uniq l2' +*) + +(*$QR + Q.(pair (small_list (pair small_int unit)) (small_list (pair small_int unit))) + (fun (l1,l2) -> + union_l l1 l2 = _list_uniq @@ to_list (union (fun _ _ _ ->())(of_list l1) (of_list l2))) + *) + +(*$QR + Q.(pair (small_list (pair small_int unit)) (small_list (pair small_int unit))) + (fun (l1,l2) -> + inter_l l1 l2 = _list_uniq @@ to_list (inter (fun _ _ _ ->()) (of_list l1) (of_list l2))) + *) + let rec inter f a b = - if Pervasives.(==) a b then a - else match a, b with + match a, b with | E, _ | _, E -> E | L (k, v), o | o, L (k, v) -> @@ -356,17 +408,17 @@ let rec inter f a b = with Not_found -> E end | N (p1, m1, l1, r1), N (p2, m2, l2, r2) -> - if p1 = p2 && Bit.equal m1 m2 - then mk_node_ p1 m1 (inter f l1 l2) (inter f r1 r2) - else if Bit.gt m1 m2 && is_prefix_ ~prefix:p1 p2 ~bit:m1 - then if Bit.is_0 p2 ~bit:m1 + if p1 = p2 && Bit.equal m1 m2 then ( + mk_node_ p1 m1 (inter f l1 l2) (inter f r1 r2) + ) else if Bit.gt m1 m2 && is_prefix_ ~prefix:p1 p2 ~bit:m1 then ( + if Bit.is_0 p2 ~bit:m1 then inter f l1 b else inter f r1 b - else if Bit.lt m1 m2 && is_prefix_ ~prefix:p2 p1 ~bit:m2 - then if Bit.is_0 p1 ~bit:m2 + ) else if Bit.lt m1 m2 && is_prefix_ ~prefix:p2 p1 ~bit:m2 then ( + if Bit.is_0 p1 ~bit:m2 then inter f a l2 else inter f a r2 - else E + ) else E (*$R assert_equal ~cmp:(equal ~eq:(=)) ~printer:(CCFormat.to_string (pp CCString.pp)) @@ -465,11 +517,12 @@ let compare ~cmp a b = | Some _, None -> 1 | None, Some _ -> -1 | Some (ka, va), Some (kb, vb) -> - if ka=kb - then + if ka=kb then ( let c = cmp va vb in if c=0 then cmp_gen cmp a b else c - else compare ka kb + ) else ( + compare ka kb + ) in cmp_gen cmp (to_gen a) (to_gen b) diff --git a/src/data/CCIntMap.mli b/src/data/CCIntMap.mli index 2036ddb9..04874954 100644 --- a/src/data/CCIntMap.mli +++ b/src/data/CCIntMap.mli @@ -118,6 +118,7 @@ module Bit : sig type t = private int val min_int : t val highest : int -> t + val equal_int : int -> t -> bool end val check_invariants : _ t -> bool