tune CCarray.sort_generic

This commit is contained in:
Simon Cruanes 2015-10-24 00:13:02 +02:00
parent 3b1922671e
commit 2608fc90bb

View file

@ -663,13 +663,13 @@ module SortGeneric(A : MONO_ARRAY) = struct
let seed_ = [|123456|] let seed_ = [|123456|]
type state = { type state = {
rand: Rand.t; (* random state *)
cmp: A.elt -> A.elt -> int;
mutable l: int; (* left pointer *) mutable l: int; (* left pointer *)
mutable g: int; (* right pointer *) mutable g: int; (* right pointer *)
mutable k: int; mutable k: int;
} }
let rand_idx_ rand i j = i + Rand.int rand (j-i)
let swap_ a i j = let swap_ a i j =
if i=j then () if i=j then ()
else ( else (
@ -678,55 +678,54 @@ module SortGeneric(A : MONO_ARRAY) = struct
A.set a j tmp A.set a j tmp
) )
let rec insert_ ~cmp a i k = let sort ~cmp a =
let rec insert_ a i k =
if k<i then () if k<i then ()
else if cmp (A.get a k) (A.get a (k+1)) > 0 then ( else if cmp (A.get a k) (A.get a (k+1)) > 0 then (
swap_ a k (k+1); swap_ a k (k+1);
insert_ ~cmp a i (k-1) insert_ a i (k-1)
) )
in
(* recursive part of insertion sort *) (* recursive part of insertion sort *)
let rec sort_insertion_rec ~cmp a i j k = let rec sort_insertion_rec a i j k =
if k<j then ( if k<j then (
insert_ ~cmp a i (k-1); insert_ a i (k-1);
sort_insertion_rec ~cmp a i j (k+1) sort_insertion_rec a i j (k+1)
) )
in
(* insertion sort, for small slices *) (* insertion sort, for small slices *)
let sort_insertion ~cmp a i j = let sort_insertion a i j =
if j-i > 1 then sort_insertion_rec ~cmp a i j (i+1) if j-i > 1 then sort_insertion_rec a i j (i+1)
in
let rand_idx_ ~st i j = let rand = Rand.make seed_ in
i + Rand.int st.rand (j-i)
(* sort slice. (* sort slice.
There is a chance that the two pivots are equal, but it's unlikely. *) There is a chance that the two pivots are equal, but it's unlikely. *)
let rec sort_slice_ ~st a i j = let rec sort_slice_ ~st a i j =
if j-i>16 then ( if j-i>10 then (
st.l <- i; st.l <- i;
st.g <- j-1; st.g <- j-1;
st.k <- i; st.k <- i;
(* choose pivots *) (* choose pivots *)
let p = A.get a (rand_idx_ ~st i j) in let p = A.get a (rand_idx_ rand i j) in
let q = A.get a (rand_idx_ ~st i j) in let q = A.get a (rand_idx_ rand i j) in
(* invariant: st.p <= st.q, swap them otherwise *) (* invariant: st.p <= st.q, swap them otherwise *)
let p, q = if st.cmp p q > 0 then q, p else p, q in let p, q = if cmp p q > 0 then q, p else p, q in
while st.k <= st.g do while st.k <= st.g do
let cur = A.get a st.k in let cur = A.get a st.k in
if st.cmp cur p < 0 then ( if cmp cur p < 0 then (
(* insert in leftmost band *) (* insert in leftmost band *)
if st.k <> st.l then swap_ a st.k st.l; if st.k <> st.l then swap_ a st.k st.l;
st.l <- st.l + 1 st.l <- st.l + 1
) else if st.cmp cur q > 0 then ( ) else if cmp cur q > 0 then (
(* insert in rightmost band *) (* insert in rightmost band *)
while st.k < st.g && st.cmp (A.get a st.g) q > 0 do while st.k < st.g && cmp (A.get a st.g) q > 0 do
st.g <- st.g - 1 st.g <- st.g - 1
done; done;
swap_ a st.k st.g; swap_ a st.k st.g;
st.g <- st.g - 1; st.g <- st.g - 1;
(* the element swapped from the right might be in the first situation. (* the element swapped from the right might be in the first situation.
that is, < p (we know it's <= q already) *) that is, < p (we know it's <= q already) *)
if st.cmp (A.get a st.k) p < 0 then ( if cmp (A.get a st.k) p < 0 then (
if st.k <> st.l then swap_ a st.k st.l; if st.k <> st.l then swap_ a st.k st.l;
st.l <- st.l + 1 st.l <- st.l + 1
) )
@ -734,18 +733,14 @@ module SortGeneric(A : MONO_ARRAY) = struct
st.k <- st.k + 1 st.k <- st.k + 1
done; done;
(* save values before recursing *) (* save values before recursing *)
let l = st.l and g = st.g and sort_middle = st.cmp p q < 0 in let l = st.l and g = st.g and sort_middle = cmp p q < 0 in
sort_slice_ ~st a i l; sort_slice_ ~st a i l;
if sort_middle then sort_slice_ ~st a l (g+1); if sort_middle then sort_slice_ ~st a l (g+1);
sort_slice_ ~st a (g+1) j; sort_slice_ ~st a (g+1) j;
) else sort_insertion ~cmp:st.cmp a i j ) else sort_insertion a i j
in
let sort ~cmp a =
if A.length a > 0 then ( if A.length a > 0 then (
let st = { let st = { l=0; g=A.length a; k=0; } in
rand=Rand.make seed_; cmp;
l=0; g=A.length a; k=0;
} in
sort_slice_ ~st a 0 (A.length a) sort_slice_ ~st a 0 (A.length a)
) )
end end