fix bugs in CCFlatHashtbl

This commit is contained in:
Simon Cruanes 2016-06-20 16:10:07 +02:00
parent 4b0371d7c6
commit c5303919bd

View file

@ -101,6 +101,9 @@ module Make(X : HASHABLE) = struct
let i = min Sys.max_array_length (max i 8) in
{ arr=Array.make i Empty; size=0; }
(* TODO: enforce that [tbl.arr] has a power of 2 as length, then
initial_index is just a mask with (length-1)? *)
(* initial index for a value with hash [h] *)
let _initial_idx tbl h =
h mod Array.length tbl.arr
@ -109,18 +112,15 @@ module Make(X : HASHABLE) = struct
let i' = i+1 in
if i' = Array.length tbl.arr then 0 else i'
let _pred tbl i =
if i = 0 then Array.length tbl.arr - 1 else i-1
(* distance to initial bucket, at index [i] with hash [h] *)
let _dib tbl h i =
let _dib tbl h ~i =
let i0 = _initial_idx tbl h in
if i>=i0
then i-i0
else i+ (Array.length tbl.arr - i0 - 1)
then i - i0
else i + (Array.length tbl.arr - i0)
(* insert k->v in [tbl], currently at index [i] *)
let rec _linear_probe tbl k v h_k i =
(* insert k->v in [tbl], currently at index [i] and distance [dib] *)
let rec _linear_probe tbl k v h_k i dib =
match tbl.arr.(i) with
| Empty ->
(* add binding *)
@ -131,14 +131,16 @@ module Make(X : HASHABLE) = struct
assert (h_k = h_k');
tbl.arr.(i) <- Key (k, v, h_k)
| Key (k', v', h_k') ->
if _dib tbl h_k i < _dib tbl h_k' i
let dib' = _dib tbl h_k' ~i in
if dib > dib'
then (
(* replace *)
tbl.arr.(i) <- Key (k, v, h_k);
_linear_probe tbl k' v' h_k' (_succ tbl i)
) else
_linear_probe tbl k' v' h_k' (_succ tbl i) (dib'+1)
) else (
(* go further *)
_linear_probe tbl k v h_k (_succ tbl i)
_linear_probe tbl k v h_k (_succ tbl i) (dib+1)
)
(* resize table: put a bigger array in it, then insert values
from the old array *)
@ -152,65 +154,73 @@ module Make(X : HASHABLE) = struct
Array.iter
(function
| Empty -> ()
| Key (k, v, h_k) -> _linear_probe tbl k v h_k (_initial_idx tbl h_k)
) old_arr
| Key (k, v, h_k) ->
_linear_probe tbl k v h_k (_initial_idx tbl h_k) 0)
old_arr
let add tbl k v =
if _reached_max_load tbl
then _resize tbl;
if _reached_max_load tbl then _resize tbl;
(* insert value *)
let h_k = X.hash k in
_linear_probe tbl k v h_k (_initial_idx tbl h_k)
_linear_probe tbl k v h_k (_initial_idx tbl h_k) 0
(* shift back elements that have a DIB > 0 until an empty bucket is
met, or some element doesn't need shifting *)
let rec _backward_shift tbl i =
(* shift back elements that have a DIB > 0 until an empty bucket
or a bucket that doesn't need shifting is met *)
let rec _backward_shift tbl ~prev:prev_i i =
match tbl.arr.(i) with
| Empty -> ()
| Key (_, _, h_k) when _dib tbl h_k i = 0 ->
() (* stop *)
| Key (_k, _v, h_k) as bucket ->
assert (_dib tbl h_k i > 0);
(* shift backward *)
tbl.arr.(_pred tbl i) <- bucket;
tbl.arr.(i) <- Empty;
_backward_shift tbl (_succ tbl i)
| Empty ->
tbl.arr.(prev_i) <- Empty;
| Key (_, _, h_k) as bucket ->
let d = _dib tbl h_k ~i in
assert (d >= 0);
if d > 0 then (
(* shift backward *)
tbl.arr.(prev_i) <- bucket;
_backward_shift tbl ~prev:i (_succ tbl i)
) else (
tbl.arr.(prev_i) <- Empty;
)
(* linear probing for removal of [k] *)
let rec _linear_probe_remove tbl k h_k i =
(* linear probing for removal of [k]: find the bucket containing [k],
if any, and perform backward shift from there *)
let rec _linear_probe_remove tbl k h_k i dib =
match tbl.arr.(i) with
| Empty -> ()
| Key (k', _, _) when X.equal k k' ->
tbl.arr.(i) <- Empty;
tbl.size <- tbl.size - 1;
_backward_shift tbl (_succ tbl i)
(* shift all elements that follow and have a DIB > 0;
it will also erase the last shifted bucket, and erase [i] in
any case *)
_backward_shift tbl ~prev:i (_succ tbl i)
| Key (_, _, h_k') ->
if _dib tbl h_k' i < _dib tbl h_k i
then () (* [k] not present, would be here otherwise *)
else _linear_probe_remove tbl k h_k (_succ tbl i)
if dib > _dib tbl h_k' ~i
then () (* [k] not present, would be here otherwise *)
else _linear_probe_remove tbl k h_k (_succ tbl i) (dib+1)
let remove tbl k =
let h_k = X.hash k in
_linear_probe_remove tbl k h_k (_initial_idx tbl h_k)
_linear_probe_remove tbl k h_k (_initial_idx tbl h_k) 0
let rec _get_exn tbl k h_k i dib =
let rec get_exn_rec tbl k h_k i dib =
match tbl.arr.(i) with
| Empty -> raise Not_found
| Key (k', v', _) when X.equal k k' -> v'
| Key (_, _, h_k') ->
if _dib tbl h_k' i < dib
then raise Not_found (* [k] would be here otherwise *)
else _get_exn tbl k h_k (_succ tbl i) (dib+1)
if dib > _dib tbl h_k' ~i
then raise Not_found (* [k] would be here otherwise *)
else get_exn_rec tbl k h_k (_succ tbl i) (dib+1)
let get_exn k tbl =
let h_k = X.hash k in
let i0 = _initial_idx tbl h_k in
(* unroll a few steps *)
match tbl.arr.(i0) with
| Empty -> raise Not_found
| Key (k', v, _) ->
if X.equal k k' then v
else let i1 = _succ tbl i0 in
match tbl.arr.(i1) with
else
let i1 = _succ tbl i0 in
match tbl.arr.(i1) with
| Empty -> raise Not_found
| Key (k', v, _) ->
if X.equal k k' then v
@ -220,7 +230,7 @@ module Make(X : HASHABLE) = struct
| Empty -> raise Not_found
| Key (k', v, _) ->
if X.equal k k' then v
else _get_exn tbl k h_k (_succ tbl i2) 3
else get_exn_rec tbl k h_k (_succ tbl i2) 3
let get k tbl =
try Some (get_exn k tbl)
@ -245,8 +255,8 @@ module Make(X : HASHABLE) = struct
Array.fold_left
(fun acc bucket -> match bucket with
| Empty -> acc
| Key (k,v,_) -> (k,v)::acc
) [] tbl.arr
| Key (k,v,_) -> (k,v)::acc)
[] tbl.arr
let of_seq seq =
let tbl = create 16 in