diff --git a/src/core/CCArray.ml b/src/core/CCArray.ml index 66d8d1dd..17b351d5 100644 --- a/src/core/CCArray.ml +++ b/src/core/CCArray.ml @@ -663,13 +663,13 @@ module SortGeneric(A : MONO_ARRAY) = struct let seed_ = [|123456|] type state = { - rand: Rand.t; (* random state *) - cmp: A.elt -> A.elt -> int; mutable l: int; (* left pointer *) mutable g: int; (* right pointer *) mutable k: int; } + let rand_idx_ rand i j = i + Rand.int rand (j-i) + let swap_ a i j = if i=j then () else ( @@ -678,74 +678,69 @@ module SortGeneric(A : MONO_ARRAY) = struct A.set a j tmp ) - let rec insert_ ~cmp a i k = - if k 0 then ( - swap_ a k (k+1); - insert_ ~cmp a i (k-1) - ) - - (* recursive part of insertion sort *) - let rec sort_insertion_rec ~cmp a i j k = - if k 1 then sort_insertion_rec ~cmp a i j (i+1) - - let rand_idx_ ~st i j = - i + Rand.int st.rand (j-i) - - (* sort slice. - There is a chance that the two pivots are equal, but it's unlikely. *) - let rec sort_slice_ ~st a i j = - if j-i>16 then ( - st.l <- i; - st.g <- j-1; - st.k <- i; - (* choose pivots *) - let p = A.get a (rand_idx_ ~st i j) in - let q = A.get a (rand_idx_ ~st i j) in - (* invariant: st.p <= st.q, swap them otherwise *) - let p, q = if st.cmp p q > 0 then q, p else p, q in - while st.k <= st.g do - let cur = A.get a st.k in - if st.cmp cur p < 0 then ( - (* insert in leftmost band *) - if st.k <> st.l then swap_ a st.k st.l; - st.l <- st.l + 1 - ) else if st.cmp cur q > 0 then ( - (* insert in rightmost band *) - while st.k < st.g && st.cmp (A.get a st.g) q > 0 do - st.g <- st.g - 1 - done; - swap_ a st.k st.g; - st.g <- st.g - 1; - (* the element swapped from the right might be in the first situation. - that is, < p (we know it's <= q already) *) - if st.cmp (A.get a st.k) p < 0 then ( + let sort ~cmp a = + let rec insert_ a i k = + if k 0 then ( + swap_ a k (k+1); + insert_ a i (k-1) + ) + in + (* recursive part of insertion sort *) + let rec sort_insertion_rec a i j k = + if k 1 then sort_insertion_rec a i j (i+1) + in + let rand = Rand.make seed_ in + (* sort slice. + There is a chance that the two pivots are equal, but it's unlikely. *) + let rec sort_slice_ ~st a i j = + if j-i>10 then ( + st.l <- i; + st.g <- j-1; + st.k <- i; + (* choose pivots *) + let p = A.get a (rand_idx_ rand i j) in + let q = A.get a (rand_idx_ rand i j) in + (* invariant: st.p <= st.q, swap them otherwise *) + let p, q = if cmp p q > 0 then q, p else p, q in + while st.k <= st.g do + let cur = A.get a st.k in + if cmp cur p < 0 then ( + (* insert in leftmost band *) if st.k <> st.l then swap_ a st.k st.l; st.l <- st.l + 1 - ) - ); - st.k <- st.k + 1 - done; - (* save values before recursing *) - let l = st.l and g = st.g and sort_middle = st.cmp p q < 0 in - sort_slice_ ~st a i l; - if sort_middle then sort_slice_ ~st a l (g+1); - sort_slice_ ~st a (g+1) j; - ) else sort_insertion ~cmp:st.cmp a i j - - let sort ~cmp a = + ) else if cmp cur q > 0 then ( + (* insert in rightmost band *) + while st.k < st.g && cmp (A.get a st.g) q > 0 do + st.g <- st.g - 1 + done; + swap_ a st.k st.g; + st.g <- st.g - 1; + (* the element swapped from the right might be in the first situation. + that is, < p (we know it's <= q already) *) + if cmp (A.get a st.k) p < 0 then ( + if st.k <> st.l then swap_ a st.k st.l; + st.l <- st.l + 1 + ) + ); + st.k <- st.k + 1 + done; + (* save values before recursing *) + let l = st.l and g = st.g and sort_middle = cmp p q < 0 in + sort_slice_ ~st a i l; + if sort_middle then sort_slice_ ~st a l (g+1); + sort_slice_ ~st a (g+1) j; + ) else sort_insertion a i j + in if A.length a > 0 then ( - let st = { - rand=Rand.make seed_; cmp; - l=0; g=A.length a; k=0; - } in + let st = { l=0; g=A.length a; k=0; } in sort_slice_ ~st a 0 (A.length a) ) end