fix bugs in CCFlatHashtbl

This commit is contained in:
Simon Cruanes 2016-06-20 16:10:07 +02:00
parent 4b0371d7c6
commit c5303919bd

View file

@ -101,6 +101,9 @@ module Make(X : HASHABLE) = struct
let i = min Sys.max_array_length (max i 8) in let i = min Sys.max_array_length (max i 8) in
{ arr=Array.make i Empty; size=0; } { arr=Array.make i Empty; size=0; }
(* TODO: enforce that [tbl.arr] has a power of 2 as length, then
initial_index is just a mask with (length-1)? *)
(* initial index for a value with hash [h] *) (* initial index for a value with hash [h] *)
let _initial_idx tbl h = let _initial_idx tbl h =
h mod Array.length tbl.arr h mod Array.length tbl.arr
@ -109,18 +112,15 @@ module Make(X : HASHABLE) = struct
let i' = i+1 in let i' = i+1 in
if i' = Array.length tbl.arr then 0 else i' if i' = Array.length tbl.arr then 0 else i'
let _pred tbl i =
if i = 0 then Array.length tbl.arr - 1 else i-1
(* distance to initial bucket, at index [i] with hash [h] *) (* distance to initial bucket, at index [i] with hash [h] *)
let _dib tbl h i = let _dib tbl h ~i =
let i0 = _initial_idx tbl h in let i0 = _initial_idx tbl h in
if i>=i0 if i>=i0
then i-i0 then i - i0
else i+ (Array.length tbl.arr - i0 - 1) else i + (Array.length tbl.arr - i0)
(* insert k->v in [tbl], currently at index [i] *) (* insert k->v in [tbl], currently at index [i] and distance [dib] *)
let rec _linear_probe tbl k v h_k i = let rec _linear_probe tbl k v h_k i dib =
match tbl.arr.(i) with match tbl.arr.(i) with
| Empty -> | Empty ->
(* add binding *) (* add binding *)
@ -131,14 +131,16 @@ module Make(X : HASHABLE) = struct
assert (h_k = h_k'); assert (h_k = h_k');
tbl.arr.(i) <- Key (k, v, h_k) tbl.arr.(i) <- Key (k, v, h_k)
| Key (k', v', h_k') -> | Key (k', v', h_k') ->
if _dib tbl h_k i < _dib tbl h_k' i let dib' = _dib tbl h_k' ~i in
if dib > dib'
then ( then (
(* replace *) (* replace *)
tbl.arr.(i) <- Key (k, v, h_k); tbl.arr.(i) <- Key (k, v, h_k);
_linear_probe tbl k' v' h_k' (_succ tbl i) _linear_probe tbl k' v' h_k' (_succ tbl i) (dib'+1)
) else ) else (
(* go further *) (* go further *)
_linear_probe tbl k v h_k (_succ tbl i) _linear_probe tbl k v h_k (_succ tbl i) (dib+1)
)
(* resize table: put a bigger array in it, then insert values (* resize table: put a bigger array in it, then insert values
from the old array *) from the old array *)
@ -152,64 +154,72 @@ module Make(X : HASHABLE) = struct
Array.iter Array.iter
(function (function
| Empty -> () | Empty -> ()
| Key (k, v, h_k) -> _linear_probe tbl k v h_k (_initial_idx tbl h_k) | Key (k, v, h_k) ->
) old_arr _linear_probe tbl k v h_k (_initial_idx tbl h_k) 0)
old_arr
let add tbl k v = let add tbl k v =
if _reached_max_load tbl if _reached_max_load tbl then _resize tbl;
then _resize tbl;
(* insert value *) (* insert value *)
let h_k = X.hash k in let h_k = X.hash k in
_linear_probe tbl k v h_k (_initial_idx tbl h_k) _linear_probe tbl k v h_k (_initial_idx tbl h_k) 0
(* shift back elements that have a DIB > 0 until an empty bucket is (* shift back elements that have a DIB > 0 until an empty bucket
met, or some element doesn't need shifting *) or a bucket that doesn't need shifting is met *)
let rec _backward_shift tbl i = let rec _backward_shift tbl ~prev:prev_i i =
match tbl.arr.(i) with match tbl.arr.(i) with
| Empty -> () | Empty ->
| Key (_, _, h_k) when _dib tbl h_k i = 0 -> tbl.arr.(prev_i) <- Empty;
() (* stop *) | Key (_, _, h_k) as bucket ->
| Key (_k, _v, h_k) as bucket -> let d = _dib tbl h_k ~i in
assert (_dib tbl h_k i > 0); assert (d >= 0);
if d > 0 then (
(* shift backward *) (* shift backward *)
tbl.arr.(_pred tbl i) <- bucket; tbl.arr.(prev_i) <- bucket;
tbl.arr.(i) <- Empty; _backward_shift tbl ~prev:i (_succ tbl i)
_backward_shift tbl (_succ tbl i) ) else (
tbl.arr.(prev_i) <- Empty;
)
(* linear probing for removal of [k] *) (* linear probing for removal of [k]: find the bucket containing [k],
let rec _linear_probe_remove tbl k h_k i = if any, and perform backward shift from there *)
let rec _linear_probe_remove tbl k h_k i dib =
match tbl.arr.(i) with match tbl.arr.(i) with
| Empty -> () | Empty -> ()
| Key (k', _, _) when X.equal k k' -> | Key (k', _, _) when X.equal k k' ->
tbl.arr.(i) <- Empty;
tbl.size <- tbl.size - 1; tbl.size <- tbl.size - 1;
_backward_shift tbl (_succ tbl i) (* shift all elements that follow and have a DIB > 0;
it will also erase the last shifted bucket, and erase [i] in
any case *)
_backward_shift tbl ~prev:i (_succ tbl i)
| Key (_, _, h_k') -> | Key (_, _, h_k') ->
if _dib tbl h_k' i < _dib tbl h_k i if dib > _dib tbl h_k' ~i
then () (* [k] not present, would be here otherwise *) then () (* [k] not present, would be here otherwise *)
else _linear_probe_remove tbl k h_k (_succ tbl i) else _linear_probe_remove tbl k h_k (_succ tbl i) (dib+1)
let remove tbl k = let remove tbl k =
let h_k = X.hash k in let h_k = X.hash k in
_linear_probe_remove tbl k h_k (_initial_idx tbl h_k) _linear_probe_remove tbl k h_k (_initial_idx tbl h_k) 0
let rec _get_exn tbl k h_k i dib = let rec get_exn_rec tbl k h_k i dib =
match tbl.arr.(i) with match tbl.arr.(i) with
| Empty -> raise Not_found | Empty -> raise Not_found
| Key (k', v', _) when X.equal k k' -> v' | Key (k', v', _) when X.equal k k' -> v'
| Key (_, _, h_k') -> | Key (_, _, h_k') ->
if _dib tbl h_k' i < dib if dib > _dib tbl h_k' ~i
then raise Not_found (* [k] would be here otherwise *) then raise Not_found (* [k] would be here otherwise *)
else _get_exn tbl k h_k (_succ tbl i) (dib+1) else get_exn_rec tbl k h_k (_succ tbl i) (dib+1)
let get_exn k tbl = let get_exn k tbl =
let h_k = X.hash k in let h_k = X.hash k in
let i0 = _initial_idx tbl h_k in let i0 = _initial_idx tbl h_k in
(* unroll a few steps *)
match tbl.arr.(i0) with match tbl.arr.(i0) with
| Empty -> raise Not_found | Empty -> raise Not_found
| Key (k', v, _) -> | Key (k', v, _) ->
if X.equal k k' then v if X.equal k k' then v
else let i1 = _succ tbl i0 in else
let i1 = _succ tbl i0 in
match tbl.arr.(i1) with match tbl.arr.(i1) with
| Empty -> raise Not_found | Empty -> raise Not_found
| Key (k', v, _) -> | Key (k', v, _) ->
@ -220,7 +230,7 @@ module Make(X : HASHABLE) = struct
| Empty -> raise Not_found | Empty -> raise Not_found
| Key (k', v, _) -> | Key (k', v, _) ->
if X.equal k k' then v if X.equal k k' then v
else _get_exn tbl k h_k (_succ tbl i2) 3 else get_exn_rec tbl k h_k (_succ tbl i2) 3
let get k tbl = let get k tbl =
try Some (get_exn k tbl) try Some (get_exn k tbl)
@ -245,8 +255,8 @@ module Make(X : HASHABLE) = struct
Array.fold_left Array.fold_left
(fun acc bucket -> match bucket with (fun acc bucket -> match bucket with
| Empty -> acc | Empty -> acc
| Key (k,v,_) -> (k,v)::acc | Key (k,v,_) -> (k,v)::acc)
) [] tbl.arr [] tbl.arr
let of_seq seq = let of_seq seq =
let tbl = create 16 in let tbl = create 16 in