17#ifndef HIGHWAY_HWY_CONTRIB_SORT_VQSORT_INL_H_
18#define HIGHWAY_HWY_CONTRIB_SORT_VQSORT_INL_H_
22#ifndef VQSORT_SECURE_RNG
23#define VQSORT_SECURE_RNG 0
27#include "third_party/absl/random/random.h"
36#include <sanitizer/msan_interface.h>
42#if defined(HIGHWAY_HWY_CONTRIB_SORT_VQSORT_TOGGLE) == \
43 defined(HWY_TARGET_TOGGLE)
44#ifdef HIGHWAY_HWY_CONTRIB_SORT_VQSORT_TOGGLE
45#undef HIGHWAY_HWY_CONTRIB_SORT_VQSORT_TOGGLE
47#define HIGHWAY_HWY_CONTRIB_SORT_VQSORT_TOGGLE
63template <
class Traits,
typename T>
66 constexpr size_t N1 = st.LanesPerKey();
69 while (start < num_lanes) {
70 const size_t left = 2 * start + N1;
71 const size_t right = 2 * start + 2 * N1;
72 if (left >= num_lanes)
break;
73 size_t idx_larger = start;
74 const auto key_j = st.SetKey(
d, lanes + start);
75 if (
AllTrue(
d, st.Compare(
d, key_j, st.SetKey(
d, lanes + left)))) {
78 if (right < num_lanes &&
79 AllTrue(
d, st.Compare(
d, st.SetKey(
d, lanes + idx_larger),
80 st.SetKey(
d, lanes + right)))) {
83 if (idx_larger == start)
break;
84 st.Swap(lanes + start, lanes + idx_larger);
91template <
class Traits,
typename T>
93 constexpr size_t N1 = st.LanesPerKey();
95 if (num_lanes < 2 * N1)
return;
98 for (
size_t i = ((num_lanes - N1) / N1 / 2) * N1; i != (size_t)-N1; i -= N1) {
102 for (
size_t i = num_lanes - N1; i != 0; i -= N1) {
104 st.Swap(lanes + 0, lanes + i);
111#if VQSORT_ENABLED || HWY_IDE
116template <
class D,
class Traits,
typename T>
121 using V =
decltype(
Zero(
d));
128 const size_t num_pow2 =
size_t{1}
130 static_cast<uint32_t
>(num - 1)));
142 const size_t N_sn =
Lanes(CappedTag<T, Constants::kMaxCols>());
144 SortingNetwork(st, keys, N_sn);
150 for (i = 0; i +
N <= num; i +=
N) {
157 const V kPadding = st.LastValue(
d);
164 SortingNetwork(st, buf, cols);
166 for (i = 0; i +
N <= num; i +=
N) {
176template <
class D,
class Traits,
class T>
179 size_t& left,
size_t& right,
186 const size_t num = right - left;
189 const size_t num_rem =
190 (num < 2 * kUnroll *
N) ? num : (num & (kUnroll *
N - 1));
192 for (; i +
N <= num_rem; i +=
N) {
193 const Vec<D> vL =
LoadU(
d, keys + readL);
196 const auto comp = st.Compare(
d, pivot, vL);
202 const auto mask =
FirstN(
d, num_rem - i);
203 const Vec<D> vL =
LoadU(
d, keys + readL);
205 const auto comp = st.Compare(
d, pivot, vL);
212 __msan_unpoison(buf, bufR *
sizeof(T));
222 memcpy(keys + left, keys + right, bufR *
sizeof(T));
223 memcpy(keys + right, buf, bufR *
sizeof(T));
226template <
class D,
class Traits,
typename T>
227HWY_INLINE void StoreLeftRight(D
d, Traits st,
const Vec<D>
v,
229 size_t& writeL,
size_t& writeR) {
232 const auto comp = st.Compare(
d, pivot,
v);
242 const auto lr = st.CompressKeys(
v, comp);
244 const size_t num_left =
N - num_right;
257 writeR -= (
N - num_left);
262template <
class D,
class Traits,
typename T>
263HWY_INLINE void StoreLeftRight4(D
d, Traits st,
const Vec<D> v0,
264 const Vec<D> v1,
const Vec<D> v2,
265 const Vec<D> v3,
const Vec<D> pivot,
268 StoreLeftRight(
d, st, v0, pivot, keys, writeL, writeR);
269 StoreLeftRight(
d, st, v1, pivot, keys, writeL, writeR);
270 StoreLeftRight(
d, st, v2, pivot, keys, writeL, writeR);
271 StoreLeftRight(
d, st, v3, pivot, keys, writeL, writeR);
278template <
class D,
class Traits,
typename T>
280 size_t right,
const Vec<D> pivot,
282 using V =
decltype(
Zero(
d));
290 const size_t last = right;
291 const V vlast =
LoadU(
d, keys + last);
293 PartitionToMultipleOfUnroll(
d, st, keys, left, right, pivot, buf);
297 size_t writeL = left;
298 size_t writeR = right;
300 const size_t num = right - left;
307 const V vL0 =
LoadU(
d, keys + left + 0 *
N);
308 const V vL1 =
LoadU(
d, keys + left + 1 *
N);
309 const V vL2 =
LoadU(
d, keys + left + 2 *
N);
310 const V vL3 =
LoadU(
d, keys + left + 3 *
N);
312 right -= kUnroll *
N;
313 const V vR0 =
LoadU(
d, keys + right + 0 *
N);
314 const V vR1 =
LoadU(
d, keys + right + 1 *
N);
315 const V vR2 =
LoadU(
d, keys + right + 2 *
N);
316 const V vR3 =
LoadU(
d, keys + right + 3 *
N);
319 while (left != right) {
324 const size_t capacityL = left - writeL;
325 const size_t capacityR = writeR - right;
327 if (capacityR < capacityL) {
328 right -= kUnroll *
N;
329 v0 =
LoadU(
d, keys + right + 0 *
N);
330 v1 =
LoadU(
d, keys + right + 1 *
N);
331 v2 =
LoadU(
d, keys + right + 2 *
N);
332 v3 =
LoadU(
d, keys + right + 3 *
N);
335 v0 =
LoadU(
d, keys + left + 0 *
N);
336 v1 =
LoadU(
d, keys + left + 1 *
N);
337 v2 =
LoadU(
d, keys + left + 2 *
N);
338 v3 =
LoadU(
d, keys + left + 3 *
N);
343 StoreLeftRight4(
d, st, v0, v1, v2, v3, pivot, keys, writeL, writeR);
347 StoreLeftRight4(
d, st, vL0, vL1, vL2, vL3, pivot, keys, writeL, writeR);
348 StoreLeftRight4(
d, st, vR0, vR1, vR2, vR3, pivot, keys, writeL, writeR);
357 const size_t totalR = last - writeL;
358 const size_t startR = totalR <
N ? writeL + totalR -
N : writeL;
362 const auto comp = st.Compare(
d, pivot, vlast);
371template <
class Traits,
class V>
372HWY_INLINE V MedianOf3(Traits st, V v0, V v1, V v2) {
378 const auto sum =
Xor(
Xor(v0, v1), v2);
379 const auto first = st.First(
d, st.First(
d, v0, v1), v2);
380 const auto last = st.Last(
d, st.Last(
d, v0, v1), v2);
381 return Xor(
Xor(sum, first), last);
384 v1 = st.Last(
d, v0, v1);
385 v1 = st.First(
d, v1, v2);
391template <
class D,
class Traits,
typename T>
392Vec<D> RecursiveMedianOf3(D
d, Traits st, T*
HWY_RESTRICT keys,
size_t num,
395 constexpr size_t N1 = st.LanesPerKey();
397 if (num < 3 * N1)
return st.SetKey(
d, keys);
403 for (; read + 3 *
N <= num; read += 3 *
N) {
404 const auto v0 =
Load(
d, keys + read + 0 *
N);
405 const auto v1 =
Load(
d, keys + read + 1 *
N);
406 const auto v2 =
Load(
d, keys + read + 2 *
N);
407 Store(MedianOf3(st, v0, v1, v2),
d, buf + written);
412 for (; read + 3 * N1 <= num; read += 3 * N1) {
413 const auto v0 = st.SetKey(
d, keys + read + 0 * N1);
414 const auto v1 = st.SetKey(
d, keys + read + 1 * N1);
415 const auto v2 = st.SetKey(
d, keys + read + 2 * N1);
416 StoreU(MedianOf3(st, v0, v1, v2),
d, buf + written);
421 return RecursiveMedianOf3(
d, st, buf, written, keys);
425using Generator = absl::BitGen;
431 Generator(
const void* heap,
size_t num) {
436 explicit Generator(uint64_t seed) {
441 uint64_t operator()() {
442 const uint64_t b = b_;
444 const uint64_t next = a_ ^ w_;
445 a_ = (b + (b << 3)) ^ (b >> 11);
446 const uint64_t rot = (b << 24) | (b >> 40);
463HWY_INLINE size_t RandomChunkIndex(
const uint32_t num_chunks, uint32_t bits) {
464 const uint64_t chunk_index = (
static_cast<uint64_t
>(bits) * num_chunks) >> 32;
466 return static_cast<size_t>(chunk_index);
469template <
class D,
class Traits,
typename T>
471 const size_t begin,
const size_t end,
473 using V =
decltype(
Zero(
d));
480 size_t num = end - begin;
485 const size_t misalign =
486 (
reinterpret_cast<uintptr_t
>(keys) /
sizeof(T)) & (lanes_per_chunk - 1);
488 const size_t consume = lanes_per_chunk - misalign;
494 uint64_t* bits64 =
reinterpret_cast<uint64_t*
>(buf);
495 for (
size_t i = 0; i < 5; ++i) {
498 const uint32_t* bits =
reinterpret_cast<const uint32_t*
>(buf);
500 const uint32_t lpc32 =
static_cast<uint32_t
>(lanes_per_chunk);
503 const size_t num_chunks64 = num >> log2_lpc;
505 const uint32_t num_chunks =
506 static_cast<uint32_t
>(
HWY_MIN(num_chunks64, 0xFFFFFFFFull));
508 const size_t offset0 = RandomChunkIndex(num_chunks, bits[0]) << log2_lpc;
509 const size_t offset1 = RandomChunkIndex(num_chunks, bits[1]) << log2_lpc;
510 const size_t offset2 = RandomChunkIndex(num_chunks, bits[2]) << log2_lpc;
511 const size_t offset3 = RandomChunkIndex(num_chunks, bits[3]) << log2_lpc;
512 const size_t offset4 = RandomChunkIndex(num_chunks, bits[4]) << log2_lpc;
513 const size_t offset5 = RandomChunkIndex(num_chunks, bits[5]) << log2_lpc;
514 const size_t offset6 = RandomChunkIndex(num_chunks, bits[6]) << log2_lpc;
515 const size_t offset7 = RandomChunkIndex(num_chunks, bits[7]) << log2_lpc;
516 const size_t offset8 = RandomChunkIndex(num_chunks, bits[8]) << log2_lpc;
517 for (
size_t i = 0; i < lanes_per_chunk; i +=
N) {
518 const V v0 =
Load(
d, keys + offset0 + i);
519 const V v1 =
Load(
d, keys + offset1 + i);
520 const V v2 =
Load(
d, keys + offset2 + i);
521 const V medians0 = MedianOf3(st, v0, v1, v2);
522 Store(medians0,
d, buf + i);
524 const V v3 =
Load(
d, keys + offset3 + i);
525 const V v4 =
Load(
d, keys + offset4 + i);
526 const V v5 =
Load(
d, keys + offset5 + i);
527 const V medians1 = MedianOf3(st, v3, v4, v5);
528 Store(medians1,
d, buf + i + lanes_per_chunk);
530 const V v6 =
Load(
d, keys + offset6 + i);
531 const V v7 =
Load(
d, keys + offset7 + i);
532 const V v8 =
Load(
d, keys + offset8 + i);
533 const V medians2 = MedianOf3(st, v6, v7, v8);
534 Store(medians2,
d, buf + i + lanes_per_chunk * 2);
537 return RecursiveMedianOf3(
d, st, buf, 3 * lanes_per_chunk,
538 buf + 3 * lanes_per_chunk);
543template <
class D,
class Traits,
typename T>
549 first = st.LastValue(
d);
550 last = st.FirstValue(
d);
553 for (; i +
N <= num; i +=
N) {
554 const Vec<D>
v =
LoadU(
d, keys + i);
555 first = st.First(
d,
v, first);
556 last = st.Last(
d,
v, last);
560 const Vec<D>
v =
LoadU(
d, keys + num -
N);
561 first = st.First(
d,
v, first);
562 last = st.Last(
d,
v, last);
565 first = st.FirstOfLanes(
d, first, buf);
566 last = st.LastOfLanes(
d, last, buf);
569template <
class D,
class Traits,
typename T>
571 const size_t begin,
const size_t end,
const Vec<D> pivot,
572 T*
HWY_RESTRICT buf, Generator& rng,
size_t remaining_levels) {
574 const size_t num = end - begin;
583 const ptrdiff_t base_case_num =
585 const size_t bound = Partition(
d, st, keys, begin, end, pivot, buf);
587 const ptrdiff_t num_left =
588 static_cast<ptrdiff_t
>(bound) -
static_cast<ptrdiff_t
>(begin);
589 const ptrdiff_t num_right =
590 static_cast<ptrdiff_t
>(end) -
static_cast<ptrdiff_t
>(bound);
599 ScanMinMax(
d, st, keys + begin, num, buf, first, last);
604 Recurse(
d, st, keys, keys_end, begin, end, first, buf, rng,
605 remaining_levels - 1);
610 BaseCase(
d, st, keys + begin, keys_end,
static_cast<size_t>(num_left), buf);
612 const Vec<D> next_pivot = ChoosePivot(
d, st, keys, begin, bound, buf, rng);
613 Recurse(
d, st, keys, keys_end, begin, bound, next_pivot, buf, rng,
614 remaining_levels - 1);
617 BaseCase(
d, st, keys + bound, keys_end,
static_cast<size_t>(num_right),
620 const Vec<D> next_pivot = ChoosePivot(
d, st, keys, bound, end, buf, rng);
621 Recurse(
d, st, keys, keys_end, bound, end, next_pivot, buf, rng,
622 remaining_levels - 1);
627template <
class D,
class Traits,
typename T>
628bool HandleSpecialCases(D
d, Traits st, T*
HWY_RESTRICT keys,
size_t num,
636 const bool partial_128 = !
IsFull(
d) &&
N < 2 && st.Is128();
641 constexpr bool kPotentiallyHuge =
643 const bool huge_vec = kPotentiallyHuge && (2 *
N > base_case_num);
644 if (partial_128 || huge_vec) {
652 BaseCase(
d, st, keys, keys + num, num, buf);
676template <
class D,
class Traits,
typename T>
679#if VQSORT_ENABLED || HWY_IDE
680#if !HWY_HAVE_SCALABLE
686 static_assert(
sizeof(storage) <= 8192,
"Unexpectedly large, check size");
690 if (detail::HandleSpecialCases(
d, st, keys, num, buf))
return;
692#if HWY_MAX_BYTES > 64
694 if (
Lanes(
d) > 64 /
sizeof(T)) {
695 return Sort(
CappedTag<T, 64 /
sizeof(T)>(), st, keys, num, buf);
700 detail::Generator rng(keys, num);
701 const Vec<D> pivot = detail::ChoosePivot(
d, st, keys, 0, num, buf, rng);
706 detail::Recurse(
d, st, keys, keys + num, 0, num, pivot, buf, rng, max_levels);
#define HWY_MAX(a, b)
Definition: base.h:126
#define HWY_RESTRICT
Definition: base.h:61
#define HWY_NOINLINE
Definition: base.h:63
#define HWY_MIN(a, b)
Definition: base.h:125
#define HWY_INLINE
Definition: base.h:62
#define HWY_DASSERT(condition)
Definition: base.h:191
#define HWY_LIKELY(expr)
Definition: base.h:66
#define HWY_UNLIKELY(expr)
Definition: base.h:67
static void Fill24Bytes(const void *seed_heap, size_t seed_num, void *bytes)
void SiftDown(Traits st, T *HWY_RESTRICT lanes, const size_t num_lanes, size_t start)
Definition: vqsort-inl.h:64
HWY_INLINE bool AllTrue(hwy::SizeTag< 1 >, const Mask128< T > m)
Definition: wasm_128-inl.h:3578
HWY_INLINE Mask128< T, N > And(hwy::SizeTag< 1 >, const Mask128< T, N > a, const Mask128< T, N > b)
Definition: x86_128-inl.h:818
void HeapSort(Traits st, T *HWY_RESTRICT lanes, const size_t num_lanes)
Definition: vqsort-inl.h:92
HWY_INLINE Mask128< T, N > AndNot(hwy::SizeTag< 1 >, const Mask128< T, N > a, const Mask128< T, N > b)
Definition: x86_128-inl.h:855
HWY_INLINE size_t CountTrue(hwy::SizeTag< 1 >, const Mask128< T > mask)
Definition: arm_neon-inl.h:5207
constexpr bool IsFull(Simd< T, N, kPow2 >)
Definition: ops/shared-inl.h:103
HWY_INLINE Mask512< T > Not(hwy::SizeTag< 1 >, const Mask512< T > m)
Definition: x86_512-inl.h:1574
HWY_INLINE Mask128< T, N > Xor(hwy::SizeTag< 1 >, const Mask128< T, N > a, const Mask128< T, N > b)
Definition: x86_128-inl.h:929
d
Definition: rvv-inl.h:1742
HWY_API auto Eq(V a, V b) -> decltype(a==b)
Definition: arm_neon-inl.h:6301
HWY_API Mask128< T, N > FirstN(const Simd< T, N, 0 > d, size_t num)
Definition: arm_neon-inl.h:2409
typename detail::CappedTagChecker< T, kLimit >::type CappedTag
Definition: ops/shared-inl.h:172
HWY_API constexpr size_t Lanes(Simd< T, N, kPow2 >)
Definition: arm_sve-inl.h:236
HWY_API Vec128< T, N > Load(Simd< T, N, 0 > d, const T *HWY_RESTRICT p)
Definition: arm_neon-inl.h:2706
void Sort(D d, Traits st, T *HWY_RESTRICT keys, size_t num, T *HWY_RESTRICT buf)
Definition: vqsort-inl.h:677
HWY_API void StoreU(const Vec128< uint8_t > v, Full128< uint8_t >, uint8_t *HWY_RESTRICT unaligned)
Definition: arm_neon-inl.h:2725
HWY_API Vec128< uint8_t > LoadU(Full128< uint8_t >, const uint8_t *HWY_RESTRICT unaligned)
Definition: arm_neon-inl.h:2544
HWY_API Vec128< T, N > Zero(Simd< T, N, 0 > d)
Definition: arm_neon-inl.h:1011
HWY_API size_t CompressBlendedStore(Vec128< T, N > v, Mask128< T, N > m, Simd< T, N, 0 > d, T *HWY_RESTRICT unaligned)
Definition: arm_neon-inl.h:5846
typename detail::FixedTagChecker< T, kNumLanes >::type FixedTag
Definition: ops/shared-inl.h:188
HWY_API void SafeCopyN(const size_t num, D d, const T *HWY_RESTRICT from, T *HWY_RESTRICT to)
Definition: generic_ops-inl.h:103
N
Definition: rvv-inl.h:1742
HWY_API size_t CompressStore(Vec128< T, N > v, const Mask128< T, N > mask, Simd< T, N, 0 > d, T *HWY_RESTRICT unaligned)
Definition: arm_neon-inl.h:5837
HWY_API void Store(Vec128< T, N > v, Simd< T, N, 0 > d, T *HWY_RESTRICT aligned)
Definition: arm_neon-inl.h:2882
const vfloat64m1_t v
Definition: rvv-inl.h:1742
decltype(Zero(D())) Vec
Definition: generic_ops-inl.h:32
Definition: aligned_allocator.h:27
HWY_INLINE HWY_ATTR_CACHE void Prefetch(const T *p)
Definition: cache_control.h:77
HWY_API size_t Num0BitsAboveMS1Bit_Nonzero32(const uint32_t x)
Definition: base.h:709
HWY_API size_t Num0BitsBelowLS1Bit_Nonzero32(const uint32_t x)
Definition: base.h:674
constexpr size_t CeilLog2(TI x)
Definition: base.h:777
#define HWY_MAX_BYTES
Definition: set_macros-inl.h:84
#define HWY_LANES(T)
Definition: set_macros-inl.h:85
#define HWY_ALIGN
Definition: set_macros-inl.h:83
#define HWY_NAMESPACE
Definition: set_macros-inl.h:82
Definition: arm_neon-inl.h:5318
Definition: contrib/sort/shared-inl.h:28
static constexpr size_t kMaxCols
Definition: contrib/sort/shared-inl.h:34
static constexpr size_t kMaxRows
Definition: contrib/sort/shared-inl.h:43
static constexpr HWY_INLINE size_t BaseCaseNum(size_t N)
Definition: contrib/sort/shared-inl.h:45
static constexpr size_t kMaxRowsLog2
Definition: contrib/sort/shared-inl.h:42
static constexpr size_t kPartitionUnroll
Definition: contrib/sort/shared-inl.h:54
static constexpr HWY_INLINE size_t LanesPerChunk(size_t sizeof_t, size_t N)
Definition: contrib/sort/shared-inl.h:68