Grok 10.0.3
vqsort-inl.h
Go to the documentation of this file.
1// Copyright 2021 Google LLC
2// SPDX-License-Identifier: Apache-2.0
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16// Normal include guard for target-independent parts
17#ifndef HIGHWAY_HWY_CONTRIB_SORT_VQSORT_INL_H_
18#define HIGHWAY_HWY_CONTRIB_SORT_VQSORT_INL_H_
19
20// Makes it harder for adversaries to predict our sampling locations, at the
21// cost of 1-2% increased runtime.
22#ifndef VQSORT_SECURE_RNG
23#define VQSORT_SECURE_RNG 0
24#endif
25
26#if VQSORT_SECURE_RNG
27#include "third_party/absl/random/random.h"
28#endif
29
30#include <string.h> // memcpy
31
32#include "hwy/cache_control.h" // Prefetch
33#include "hwy/contrib/sort/vqsort.h" // Fill24Bytes
34
35#if HWY_IS_MSAN
36#include <sanitizer/msan_interface.h>
37#endif
38
39#endif // HIGHWAY_HWY_CONTRIB_SORT_VQSORT_INL_H_
40
41// Per-target
42#if defined(HIGHWAY_HWY_CONTRIB_SORT_VQSORT_TOGGLE) == \
43 defined(HWY_TARGET_TOGGLE)
44#ifdef HIGHWAY_HWY_CONTRIB_SORT_VQSORT_TOGGLE
45#undef HIGHWAY_HWY_CONTRIB_SORT_VQSORT_TOGGLE
46#else
47#define HIGHWAY_HWY_CONTRIB_SORT_VQSORT_TOGGLE
48#endif
49
52#include "hwy/highway.h"
53
55namespace hwy {
56namespace HWY_NAMESPACE {
57namespace detail {
58
60
61// ------------------------------ HeapSort
62
63template <class Traits, typename T>
64void SiftDown(Traits st, T* HWY_RESTRICT lanes, const size_t num_lanes,
65 size_t start) {
66 constexpr size_t N1 = st.LanesPerKey();
67 const FixedTag<T, N1> d;
68
69 while (start < num_lanes) {
70 const size_t left = 2 * start + N1;
71 const size_t right = 2 * start + 2 * N1;
72 if (left >= num_lanes) break;
73 size_t idx_larger = start;
74 const auto key_j = st.SetKey(d, lanes + start);
75 if (AllTrue(d, st.Compare(d, key_j, st.SetKey(d, lanes + left)))) {
76 idx_larger = left;
77 }
78 if (right < num_lanes &&
79 AllTrue(d, st.Compare(d, st.SetKey(d, lanes + idx_larger),
80 st.SetKey(d, lanes + right)))) {
81 idx_larger = right;
82 }
83 if (idx_larger == start) break;
84 st.Swap(lanes + start, lanes + idx_larger);
85 start = idx_larger;
86 }
87}
88
89// Heapsort: O(1) space, O(N*logN) worst-case comparisons.
90// Based on LLVM sanitizer_common.h, licensed under Apache-2.0.
91template <class Traits, typename T>
92void HeapSort(Traits st, T* HWY_RESTRICT lanes, const size_t num_lanes) {
93 constexpr size_t N1 = st.LanesPerKey();
94
95 if (num_lanes < 2 * N1) return;
96
97 // Build heap.
98 for (size_t i = ((num_lanes - N1) / N1 / 2) * N1; i != (size_t)-N1; i -= N1) {
99 SiftDown(st, lanes, num_lanes, i);
100 }
101
102 for (size_t i = num_lanes - N1; i != 0; i -= N1) {
103 // Swap root with last
104 st.Swap(lanes + 0, lanes + i);
105
106 // Sift down the new root.
107 SiftDown(st, lanes, i, 0);
108 }
109}
110
111#if VQSORT_ENABLED || HWY_IDE
112
113// ------------------------------ BaseCase
114
115// Sorts `keys` within the range [0, num) via sorting network.
116template <class D, class Traits, typename T>
117HWY_NOINLINE void BaseCase(D d, Traits st, T* HWY_RESTRICT keys,
118 T* HWY_RESTRICT keys_end, size_t num,
119 T* HWY_RESTRICT buf) {
120 const size_t N = Lanes(d);
121 using V = decltype(Zero(d));
122
123 // _Nonzero32 requires num - 1 != 0.
124 if (HWY_UNLIKELY(num <= 1)) return;
125
126 // Reshape into a matrix with kMaxRows rows, and columns limited by the
127 // 1D `num`, which is upper-bounded by the vector width (see BaseCaseNum).
128 const size_t num_pow2 = size_t{1}
130 static_cast<uint32_t>(num - 1)));
131 HWY_DASSERT(num <= num_pow2 && num_pow2 <= Constants::BaseCaseNum(N));
132 const size_t cols =
133 HWY_MAX(st.LanesPerKey(), num_pow2 >> Constants::kMaxRowsLog2);
134 HWY_DASSERT(cols <= N);
135
136 // We can avoid padding and load/store directly to `keys` after checking the
137 // original input array has enough space. Except at the right border, it's OK
138 // to sort more than the current sub-array. Even if we sort across a previous
139 // partition point, we know that keys will not migrate across it. However, we
140 // must use the maximum size of the sorting network, because the StoreU of its
141 // last vector would otherwise write invalid data starting at kMaxRows * cols.
142 const size_t N_sn = Lanes(CappedTag<T, Constants::kMaxCols>());
143 if (HWY_LIKELY(keys + N_sn * Constants::kMaxRows <= keys_end)) {
144 SortingNetwork(st, keys, N_sn);
145 return;
146 }
147
148 // Copy `keys` to `buf`.
149 size_t i;
150 for (i = 0; i + N <= num; i += N) {
151 Store(LoadU(d, keys + i), d, buf + i);
152 }
153 SafeCopyN(num - i, d, keys + i, buf + i);
154 i = num;
155
156 // Fill with padding - last in sort order, not copied to keys.
157 const V kPadding = st.LastValue(d);
158 // Initialize an extra vector because SortingNetwork loads full vectors,
159 // which may exceed cols*kMaxRows.
160 for (; i < (cols * Constants::kMaxRows + N); i += N) {
161 StoreU(kPadding, d, buf + i);
162 }
163
164 SortingNetwork(st, buf, cols);
165
166 for (i = 0; i + N <= num; i += N) {
167 StoreU(Load(d, buf + i), d, keys + i);
168 }
169 SafeCopyN(num - i, d, buf + i, keys + i);
170}
171
172// ------------------------------ Partition
173
174// Consumes from `left` until a multiple of kUnroll*N remains.
175// Temporarily stores the right side into `buf`, then moves behind `right`.
176template <class D, class Traits, class T>
177HWY_NOINLINE void PartitionToMultipleOfUnroll(D d, Traits st,
178 T* HWY_RESTRICT keys,
179 size_t& left, size_t& right,
180 const Vec<D> pivot,
181 T* HWY_RESTRICT buf) {
182 constexpr size_t kUnroll = Constants::kPartitionUnroll;
183 const size_t N = Lanes(d);
184 size_t readL = left;
185 size_t bufR = 0;
186 const size_t num = right - left;
187 // Partition requires both a multiple of kUnroll*N and at least
188 // 2*kUnroll*N for the initial loads. If less, consume all here.
189 const size_t num_rem =
190 (num < 2 * kUnroll * N) ? num : (num & (kUnroll * N - 1));
191 size_t i = 0;
192 for (; i + N <= num_rem; i += N) {
193 const Vec<D> vL = LoadU(d, keys + readL);
194 readL += N;
195
196 const auto comp = st.Compare(d, pivot, vL);
197 left += CompressBlendedStore(vL, Not(comp), d, keys + left);
198 bufR += CompressStore(vL, comp, d, buf + bufR);
199 }
200 // Last iteration: only use valid lanes.
201 if (HWY_LIKELY(i != num_rem)) {
202 const auto mask = FirstN(d, num_rem - i);
203 const Vec<D> vL = LoadU(d, keys + readL);
204
205 const auto comp = st.Compare(d, pivot, vL);
206 left += CompressBlendedStore(vL, AndNot(comp, mask), d, keys + left);
207 bufR += CompressStore(vL, And(comp, mask), d, buf + bufR);
208 }
209
210 // MSAN seems not to understand CompressStore. buf[0, bufR) are valid.
211#if HWY_IS_MSAN
212 __msan_unpoison(buf, bufR * sizeof(T));
213#endif
214
215 // Everything we loaded was put into buf, or behind the new `left`, after
216 // which there is space for bufR items. First move items from `right` to
217 // `left` to free up space, then copy `buf` into the vacated `right`.
218 // A loop with masked loads from `buf` is insufficient - we would also need to
219 // mask from `right`. Combining a loop with memcpy for the remainders is
220 // slower than just memcpy, so we use that for simplicity.
221 right -= bufR;
222 memcpy(keys + left, keys + right, bufR * sizeof(T));
223 memcpy(keys + right, buf, bufR * sizeof(T));
224}
225
226template <class D, class Traits, typename T>
227HWY_INLINE void StoreLeftRight(D d, Traits st, const Vec<D> v,
228 const Vec<D> pivot, T* HWY_RESTRICT keys,
229 size_t& writeL, size_t& writeR) {
230 const size_t N = Lanes(d);
231
232 const auto comp = st.Compare(d, pivot, v);
233
235 (HWY_MAX_BYTES == 16 && st.Is128())) {
236 // Non-native Compress (e.g. AVX2): we are able to partition a vector using
237 // a single Compress+two StoreU instead of two Compress[Blended]Store. The
238 // latter are more expensive. Because we store entire vectors, the contents
239 // between the updated writeL and writeR are ignored and will be overwritten
240 // by subsequent calls. This works because writeL and writeR are at least
241 // two vectors apart.
242 const auto lr = st.CompressKeys(v, comp);
243 const size_t num_right = CountTrue(d, comp);
244 const size_t num_left = N - num_right;
245 StoreU(lr, d, keys + writeL);
246 writeL += num_left;
247 // Now write the right-side elements (if any), such that the previous writeR
248 // is one past the end of the newly written right elements, then advance.
249 StoreU(lr, d, keys + writeR - N);
250 writeR -= num_right;
251 } else {
252 // Native Compress[Store] (e.g. AVX3), which only keep the left or right
253 // side, not both, hence we require two calls.
254 const size_t num_left = CompressStore(v, Not(comp), d, keys + writeL);
255 writeL += num_left;
256
257 writeR -= (N - num_left);
258 (void)CompressBlendedStore(v, comp, d, keys + writeR);
259 }
260}
261
262template <class D, class Traits, typename T>
263HWY_INLINE void StoreLeftRight4(D d, Traits st, const Vec<D> v0,
264 const Vec<D> v1, const Vec<D> v2,
265 const Vec<D> v3, const Vec<D> pivot,
266 T* HWY_RESTRICT keys, size_t& writeL,
267 size_t& writeR) {
268 StoreLeftRight(d, st, v0, pivot, keys, writeL, writeR);
269 StoreLeftRight(d, st, v1, pivot, keys, writeL, writeR);
270 StoreLeftRight(d, st, v2, pivot, keys, writeL, writeR);
271 StoreLeftRight(d, st, v3, pivot, keys, writeL, writeR);
272}
273
274// Moves "<= pivot" keys to the front, and others to the back. pivot is
275// broadcasted. Time-critical!
276//
277// Aligned loads do not seem to be worthwhile (not bottlenecked by load ports).
278template <class D, class Traits, typename T>
279HWY_NOINLINE size_t Partition(D d, Traits st, T* HWY_RESTRICT keys, size_t left,
280 size_t right, const Vec<D> pivot,
281 T* HWY_RESTRICT buf) {
282 using V = decltype(Zero(d));
283 const size_t N = Lanes(d);
284
285 // StoreLeftRight will CompressBlendedStore ending at `writeR`. Unless all
286 // lanes happen to be in the right-side partition, this will overrun `keys`,
287 // which triggers asan errors. Avoid by special-casing the last vector.
288 HWY_DASSERT(right - left > 2 * N); // ensured by HandleSpecialCases
289 right -= N;
290 const size_t last = right;
291 const V vlast = LoadU(d, keys + last);
292
293 PartitionToMultipleOfUnroll(d, st, keys, left, right, pivot, buf);
294 constexpr size_t kUnroll = Constants::kPartitionUnroll;
295
296 // Invariant: [left, writeL) and [writeR, right) are already partitioned.
297 size_t writeL = left;
298 size_t writeR = right;
299
300 const size_t num = right - left;
301 // Cannot load if there were fewer than 2 * kUnroll * N.
302 if (HWY_LIKELY(num != 0)) {
303 HWY_DASSERT(num >= 2 * kUnroll * N);
304 HWY_DASSERT((num & (kUnroll * N - 1)) == 0);
305
306 // Make space for writing in-place by reading from left and right.
307 const V vL0 = LoadU(d, keys + left + 0 * N);
308 const V vL1 = LoadU(d, keys + left + 1 * N);
309 const V vL2 = LoadU(d, keys + left + 2 * N);
310 const V vL3 = LoadU(d, keys + left + 3 * N);
311 left += kUnroll * N;
312 right -= kUnroll * N;
313 const V vR0 = LoadU(d, keys + right + 0 * N);
314 const V vR1 = LoadU(d, keys + right + 1 * N);
315 const V vR2 = LoadU(d, keys + right + 2 * N);
316 const V vR3 = LoadU(d, keys + right + 3 * N);
317
318 // The left/right updates may consume all inputs, so check before the loop.
319 while (left != right) {
320 V v0, v1, v2, v3;
321
322 // Free up capacity for writing by loading from the side that has less.
323 // Data-dependent but branching is faster than forcing branch-free.
324 const size_t capacityL = left - writeL;
325 const size_t capacityR = writeR - right;
326 HWY_DASSERT(capacityL <= num && capacityR <= num); // >= 0
327 if (capacityR < capacityL) {
328 right -= kUnroll * N;
329 v0 = LoadU(d, keys + right + 0 * N);
330 v1 = LoadU(d, keys + right + 1 * N);
331 v2 = LoadU(d, keys + right + 2 * N);
332 v3 = LoadU(d, keys + right + 3 * N);
333 hwy::Prefetch(keys + right - 3 * kUnroll * N);
334 } else {
335 v0 = LoadU(d, keys + left + 0 * N);
336 v1 = LoadU(d, keys + left + 1 * N);
337 v2 = LoadU(d, keys + left + 2 * N);
338 v3 = LoadU(d, keys + left + 3 * N);
339 left += kUnroll * N;
340 hwy::Prefetch(keys + left + 3 * kUnroll * N);
341 }
342
343 StoreLeftRight4(d, st, v0, v1, v2, v3, pivot, keys, writeL, writeR);
344 }
345
346 // Now finish writing the initial left/right to the middle.
347 StoreLeftRight4(d, st, vL0, vL1, vL2, vL3, pivot, keys, writeL, writeR);
348 StoreLeftRight4(d, st, vR0, vR1, vR2, vR3, pivot, keys, writeL, writeR);
349 }
350
351 // We have partitioned [left, right) such that writeL is the boundary.
352 HWY_DASSERT(writeL == writeR);
353 // Make space for inserting vlast: move up to N of the first right-side keys
354 // into the unused space starting at last. If we have fewer, ensure they are
355 // the last items in that vector by subtracting from the *load* address,
356 // which is safe because we have at least two vectors (checked above).
357 const size_t totalR = last - writeL;
358 const size_t startR = totalR < N ? writeL + totalR - N : writeL;
359 StoreU(LoadU(d, keys + startR), d, keys + last);
360
361 // Partition vlast: write L, then R, into the single-vector gap at writeL.
362 const auto comp = st.Compare(d, pivot, vlast);
363 writeL += CompressBlendedStore(vlast, Not(comp), d, keys + writeL);
364 (void)CompressBlendedStore(vlast, comp, d, keys + writeL);
365
366 return writeL;
367}
368
369// ------------------------------ Pivot
370
371template <class Traits, class V>
372HWY_INLINE V MedianOf3(Traits st, V v0, V v1, V v2) {
373 const DFromV<V> d;
374 // Slightly faster for 128-bit, apparently because not serially dependent.
375 if (st.Is128()) {
376 // Median = XOR-sum 'minus' the first and last. Calling First twice is
377 // slightly faster than Compare + 2 IfThenElse or even IfThenElse + XOR.
378 const auto sum = Xor(Xor(v0, v1), v2);
379 const auto first = st.First(d, st.First(d, v0, v1), v2);
380 const auto last = st.Last(d, st.Last(d, v0, v1), v2);
381 return Xor(Xor(sum, first), last);
382 }
383 st.Sort2(d, v0, v2);
384 v1 = st.Last(d, v0, v1);
385 v1 = st.First(d, v1, v2);
386 return v1;
387}
388
389// Replaces triplets with their median and recurses until less than 3 keys
390// remain. Ignores leftover values (non-whole triplets)!
391template <class D, class Traits, typename T>
392Vec<D> RecursiveMedianOf3(D d, Traits st, T* HWY_RESTRICT keys, size_t num,
393 T* HWY_RESTRICT buf) {
394 const size_t N = Lanes(d);
395 constexpr size_t N1 = st.LanesPerKey();
396
397 if (num < 3 * N1) return st.SetKey(d, keys);
398
399 size_t read = 0;
400 size_t written = 0;
401
402 // Triplets of vectors
403 for (; read + 3 * N <= num; read += 3 * N) {
404 const auto v0 = Load(d, keys + read + 0 * N);
405 const auto v1 = Load(d, keys + read + 1 * N);
406 const auto v2 = Load(d, keys + read + 2 * N);
407 Store(MedianOf3(st, v0, v1, v2), d, buf + written);
408 written += N;
409 }
410
411 // Triplets of keys
412 for (; read + 3 * N1 <= num; read += 3 * N1) {
413 const auto v0 = st.SetKey(d, keys + read + 0 * N1);
414 const auto v1 = st.SetKey(d, keys + read + 1 * N1);
415 const auto v2 = st.SetKey(d, keys + read + 2 * N1);
416 StoreU(MedianOf3(st, v0, v1, v2), d, buf + written);
417 written += N1;
418 }
419
420 // Tail recursion; swap buffers
421 return RecursiveMedianOf3(d, st, buf, written, keys);
422}
423
424#if VQSORT_SECURE_RNG
425using Generator = absl::BitGen;
426#else
427// Based on https://github.com/numpy/numpy/issues/16313#issuecomment-641897028
428#pragma pack(push, 1)
429class Generator {
430 public:
431 Generator(const void* heap, size_t num) {
432 Sorter::Fill24Bytes(heap, num, &a_);
433 k_ = 1; // stream index: must be odd
434 }
435
436 explicit Generator(uint64_t seed) {
437 a_ = b_ = w_ = seed;
438 k_ = 1;
439 }
440
441 uint64_t operator()() {
442 const uint64_t b = b_;
443 w_ += k_;
444 const uint64_t next = a_ ^ w_;
445 a_ = (b + (b << 3)) ^ (b >> 11);
446 const uint64_t rot = (b << 24) | (b >> 40);
447 b_ = rot + next;
448 return next;
449 }
450
451 private:
452 uint64_t a_;
453 uint64_t b_;
454 uint64_t w_;
455 uint64_t k_; // increment
456};
457#pragma pack(pop)
458
459#endif // !VQSORT_SECURE_RNG
460
461// Returns slightly biased random index of a chunk in [0, num_chunks).
462// See https://www.pcg-random.org/posts/bounded-rands.html.
463HWY_INLINE size_t RandomChunkIndex(const uint32_t num_chunks, uint32_t bits) {
464 const uint64_t chunk_index = (static_cast<uint64_t>(bits) * num_chunks) >> 32;
465 HWY_DASSERT(chunk_index < num_chunks);
466 return static_cast<size_t>(chunk_index);
467}
468
469template <class D, class Traits, typename T>
470HWY_NOINLINE Vec<D> ChoosePivot(D d, Traits st, T* HWY_RESTRICT keys,
471 const size_t begin, const size_t end,
472 T* HWY_RESTRICT buf, Generator& rng) {
473 using V = decltype(Zero(d));
474 const size_t N = Lanes(d);
475
476 // Power of two
477 const size_t lanes_per_chunk = Constants::LanesPerChunk(sizeof(T), N);
478
479 keys += begin;
480 size_t num = end - begin;
481
482 // Align start of keys to chunks. We always have at least 2 chunks because the
483 // base case would have handled anything up to 16 vectors, i.e. >= 4 chunks.
484 HWY_DASSERT(num >= 2 * lanes_per_chunk);
485 const size_t misalign =
486 (reinterpret_cast<uintptr_t>(keys) / sizeof(T)) & (lanes_per_chunk - 1);
487 if (misalign != 0) {
488 const size_t consume = lanes_per_chunk - misalign;
489 keys += consume;
490 num -= consume;
491 }
492
493 // Generate enough random bits for 9 uint32
494 uint64_t* bits64 = reinterpret_cast<uint64_t*>(buf);
495 for (size_t i = 0; i < 5; ++i) {
496 bits64[i] = rng();
497 }
498 const uint32_t* bits = reinterpret_cast<const uint32_t*>(buf);
499
500 const uint32_t lpc32 = static_cast<uint32_t>(lanes_per_chunk);
501 // Avoid division
502 const size_t log2_lpc = Num0BitsBelowLS1Bit_Nonzero32(lpc32);
503 const size_t num_chunks64 = num >> log2_lpc;
504 // Clamp to uint32 for RandomChunkIndex
505 const uint32_t num_chunks =
506 static_cast<uint32_t>(HWY_MIN(num_chunks64, 0xFFFFFFFFull));
507
508 const size_t offset0 = RandomChunkIndex(num_chunks, bits[0]) << log2_lpc;
509 const size_t offset1 = RandomChunkIndex(num_chunks, bits[1]) << log2_lpc;
510 const size_t offset2 = RandomChunkIndex(num_chunks, bits[2]) << log2_lpc;
511 const size_t offset3 = RandomChunkIndex(num_chunks, bits[3]) << log2_lpc;
512 const size_t offset4 = RandomChunkIndex(num_chunks, bits[4]) << log2_lpc;
513 const size_t offset5 = RandomChunkIndex(num_chunks, bits[5]) << log2_lpc;
514 const size_t offset6 = RandomChunkIndex(num_chunks, bits[6]) << log2_lpc;
515 const size_t offset7 = RandomChunkIndex(num_chunks, bits[7]) << log2_lpc;
516 const size_t offset8 = RandomChunkIndex(num_chunks, bits[8]) << log2_lpc;
517 for (size_t i = 0; i < lanes_per_chunk; i += N) {
518 const V v0 = Load(d, keys + offset0 + i);
519 const V v1 = Load(d, keys + offset1 + i);
520 const V v2 = Load(d, keys + offset2 + i);
521 const V medians0 = MedianOf3(st, v0, v1, v2);
522 Store(medians0, d, buf + i);
523
524 const V v3 = Load(d, keys + offset3 + i);
525 const V v4 = Load(d, keys + offset4 + i);
526 const V v5 = Load(d, keys + offset5 + i);
527 const V medians1 = MedianOf3(st, v3, v4, v5);
528 Store(medians1, d, buf + i + lanes_per_chunk);
529
530 const V v6 = Load(d, keys + offset6 + i);
531 const V v7 = Load(d, keys + offset7 + i);
532 const V v8 = Load(d, keys + offset8 + i);
533 const V medians2 = MedianOf3(st, v6, v7, v8);
534 Store(medians2, d, buf + i + lanes_per_chunk * 2);
535 }
536
537 return RecursiveMedianOf3(d, st, buf, 3 * lanes_per_chunk,
538 buf + 3 * lanes_per_chunk);
539}
540
541// Compute exact min/max to detect all-equal partitions. Only called after a
542// degenerate Partition (none in the right partition).
543template <class D, class Traits, typename T>
544HWY_NOINLINE void ScanMinMax(D d, Traits st, const T* HWY_RESTRICT keys,
545 size_t num, T* HWY_RESTRICT buf, Vec<D>& first,
546 Vec<D>& last) {
547 const size_t N = Lanes(d);
548
549 first = st.LastValue(d);
550 last = st.FirstValue(d);
551
552 size_t i = 0;
553 for (; i + N <= num; i += N) {
554 const Vec<D> v = LoadU(d, keys + i);
555 first = st.First(d, v, first);
556 last = st.Last(d, v, last);
557 }
558 if (HWY_LIKELY(i != num)) {
559 HWY_DASSERT(num >= N); // See HandleSpecialCases
560 const Vec<D> v = LoadU(d, keys + num - N);
561 first = st.First(d, v, first);
562 last = st.Last(d, v, last);
563 }
564
565 first = st.FirstOfLanes(d, first, buf);
566 last = st.LastOfLanes(d, last, buf);
567}
568
569template <class D, class Traits, typename T>
570void Recurse(D d, Traits st, T* HWY_RESTRICT keys, T* HWY_RESTRICT keys_end,
571 const size_t begin, const size_t end, const Vec<D> pivot,
572 T* HWY_RESTRICT buf, Generator& rng, size_t remaining_levels) {
573 HWY_DASSERT(begin + 1 < end);
574 const size_t num = end - begin; // >= 2
575
576 // Too many degenerate partitions. This is extremely unlikely to happen
577 // because we select pivots from large (though still O(1)) samples.
578 if (HWY_UNLIKELY(remaining_levels == 0)) {
579 HeapSort(st, keys + begin, num); // Slow but N*logN.
580 return;
581 }
582
583 const ptrdiff_t base_case_num =
584 static_cast<ptrdiff_t>(Constants::BaseCaseNum(Lanes(d)));
585 const size_t bound = Partition(d, st, keys, begin, end, pivot, buf);
586
587 const ptrdiff_t num_left =
588 static_cast<ptrdiff_t>(bound) - static_cast<ptrdiff_t>(begin);
589 const ptrdiff_t num_right =
590 static_cast<ptrdiff_t>(end) - static_cast<ptrdiff_t>(bound);
591
592 // Check for degenerate partitions (i.e. Partition did not move any keys):
593 if (HWY_UNLIKELY(num_right == 0)) {
594 // Because the pivot is one of the keys, it must have been equal to the
595 // first or last key in sort order. Scan for the actual min/max:
596 // passing the current pivot as the new bound is insufficient because one of
597 // the partitions might not actually include that key.
598 Vec<D> first, last;
599 ScanMinMax(d, st, keys + begin, num, buf, first, last);
600 if (AllTrue(d, Eq(first, last))) return;
601
602 // Separate recursion to make sure that we don't pick `last` as the
603 // pivot - that would again lead to a degenerate partition.
604 Recurse(d, st, keys, keys_end, begin, end, first, buf, rng,
605 remaining_levels - 1);
606 return;
607 }
608
609 if (HWY_UNLIKELY(num_left <= base_case_num)) {
610 BaseCase(d, st, keys + begin, keys_end, static_cast<size_t>(num_left), buf);
611 } else {
612 const Vec<D> next_pivot = ChoosePivot(d, st, keys, begin, bound, buf, rng);
613 Recurse(d, st, keys, keys_end, begin, bound, next_pivot, buf, rng,
614 remaining_levels - 1);
615 }
616 if (HWY_UNLIKELY(num_right <= base_case_num)) {
617 BaseCase(d, st, keys + bound, keys_end, static_cast<size_t>(num_right),
618 buf);
619 } else {
620 const Vec<D> next_pivot = ChoosePivot(d, st, keys, bound, end, buf, rng);
621 Recurse(d, st, keys, keys_end, bound, end, next_pivot, buf, rng,
622 remaining_levels - 1);
623 }
624}
625
626// Returns true if sorting is finished.
627template <class D, class Traits, typename T>
628bool HandleSpecialCases(D d, Traits st, T* HWY_RESTRICT keys, size_t num,
629 T* HWY_RESTRICT buf) {
630 const size_t N = Lanes(d);
631 const size_t base_case_num = Constants::BaseCaseNum(N);
632
633 // 128-bit keys require vectors with at least two u64 lanes, which is always
634 // the case unless `d` requests partial vectors (e.g. fraction = 1/2) AND the
635 // hardware vector width is less than 128bit / fraction.
636 const bool partial_128 = !IsFull(d) && N < 2 && st.Is128();
637 // Partition assumes its input is at least two vectors. If vectors are huge,
638 // base_case_num may actually be smaller. If so, which is only possible on
639 // RVV, pass a capped or partial d (LMUL < 1). Use HWY_MAX_BYTES instead of
640 // HWY_LANES to account for the largest possible LMUL.
641 constexpr bool kPotentiallyHuge =
643 const bool huge_vec = kPotentiallyHuge && (2 * N > base_case_num);
644 if (partial_128 || huge_vec) {
645 // PERFORMANCE WARNING: falling back to HeapSort.
646 HeapSort(st, keys, num);
647 return true;
648 }
649
650 // Small arrays: use sorting network, no need for other checks.
651 if (HWY_UNLIKELY(num <= base_case_num)) {
652 BaseCase(d, st, keys, keys + num, num, buf);
653 return true;
654 }
655
656 // We could also check for already sorted/reverse/equal, but that's probably
657 // counterproductive if vqsort is used as a base case.
658
659 return false; // not finished sorting
660}
661
662#endif // VQSORT_ENABLED
663} // namespace detail
664
665// Sorts `keys[0..num-1]` according to the order defined by `st.Compare`.
666// In-place i.e. O(1) additional storage. Worst-case N*logN comparisons.
667// Non-stable (order of equal keys may change), except for the common case where
668// the upper bits of T are the key, and the lower bits are a sequential or at
669// least unique ID.
670// There is no upper limit on `num`, but note that pivots may be chosen by
671// sampling only from the first 256 GiB.
672//
673// `d` is typically SortTag<T> (chooses between full and partial vectors).
674// `st` is SharedTraits<Traits*<Order*>>. This abstraction layer bridges
675// differences in sort order and single-lane vs 128-bit keys.
676template <class D, class Traits, typename T>
677void Sort(D d, Traits st, T* HWY_RESTRICT keys, size_t num,
678 T* HWY_RESTRICT buf) {
679#if VQSORT_ENABLED || HWY_IDE
680#if !HWY_HAVE_SCALABLE
681 // On targets with fixed-size vectors, avoid _using_ the allocated memory.
682 // We avoid (potentially expensive for small input sizes) allocations on
683 // platforms where no targets are scalable. For 512-bit vectors, this fits on
684 // the stack (several KiB).
685 HWY_ALIGN T storage[SortConstants::BufNum<T>(HWY_LANES(T))] = {};
686 static_assert(sizeof(storage) <= 8192, "Unexpectedly large, check size");
687 buf = storage;
688#endif // !HWY_HAVE_SCALABLE
689
690 if (detail::HandleSpecialCases(d, st, keys, num, buf)) return;
691
692#if HWY_MAX_BYTES > 64
693 // sorting_networks-inl and traits assume no more than 512 bit vectors.
694 if (Lanes(d) > 64 / sizeof(T)) {
695 return Sort(CappedTag<T, 64 / sizeof(T)>(), st, keys, num, buf);
696 }
697#endif // HWY_MAX_BYTES > 64
698
699 // Pulled out of the recursion so we can special-case degenerate partitions.
700 detail::Generator rng(keys, num);
701 const Vec<D> pivot = detail::ChoosePivot(d, st, keys, 0, num, buf, rng);
702
703 // Introspection: switch to worst-case N*logN heapsort after this many.
704 const size_t max_levels = 2 * hwy::CeilLog2(num) + 4;
705
706 detail::Recurse(d, st, keys, keys + num, 0, num, pivot, buf, rng, max_levels);
707#else
708 (void)d;
709 (void)buf;
710 // PERFORMANCE WARNING: vqsort is not enabled for the non-SIMD target
711 return detail::HeapSort(st, keys, num);
712#endif // VQSORT_ENABLED
713}
714
715// NOLINTNEXTLINE(google-readability-namespace-comments)
716} // namespace HWY_NAMESPACE
717} // namespace hwy
719
720#endif // HIGHWAY_HWY_CONTRIB_SORT_VQSORT_TOGGLE
#define HWY_MAX(a, b)
Definition: base.h:126
#define HWY_RESTRICT
Definition: base.h:61
#define HWY_NOINLINE
Definition: base.h:63
#define HWY_MIN(a, b)
Definition: base.h:125
#define HWY_INLINE
Definition: base.h:62
#define HWY_DASSERT(condition)
Definition: base.h:191
#define HWY_LIKELY(expr)
Definition: base.h:66
#define HWY_UNLIKELY(expr)
Definition: base.h:67
static void Fill24Bytes(const void *seed_heap, size_t seed_num, void *bytes)
void SiftDown(Traits st, T *HWY_RESTRICT lanes, const size_t num_lanes, size_t start)
Definition: vqsort-inl.h:64
HWY_INLINE bool AllTrue(hwy::SizeTag< 1 >, const Mask128< T > m)
Definition: wasm_128-inl.h:3578
HWY_INLINE Mask128< T, N > And(hwy::SizeTag< 1 >, const Mask128< T, N > a, const Mask128< T, N > b)
Definition: x86_128-inl.h:818
void HeapSort(Traits st, T *HWY_RESTRICT lanes, const size_t num_lanes)
Definition: vqsort-inl.h:92
HWY_INLINE Mask128< T, N > AndNot(hwy::SizeTag< 1 >, const Mask128< T, N > a, const Mask128< T, N > b)
Definition: x86_128-inl.h:855
HWY_INLINE size_t CountTrue(hwy::SizeTag< 1 >, const Mask128< T > mask)
Definition: arm_neon-inl.h:5207
constexpr bool IsFull(Simd< T, N, kPow2 >)
Definition: ops/shared-inl.h:103
HWY_INLINE Mask512< T > Not(hwy::SizeTag< 1 >, const Mask512< T > m)
Definition: x86_512-inl.h:1574
HWY_INLINE Mask128< T, N > Xor(hwy::SizeTag< 1 >, const Mask128< T, N > a, const Mask128< T, N > b)
Definition: x86_128-inl.h:929
d
Definition: rvv-inl.h:1742
HWY_API auto Eq(V a, V b) -> decltype(a==b)
Definition: arm_neon-inl.h:6301
HWY_API Mask128< T, N > FirstN(const Simd< T, N, 0 > d, size_t num)
Definition: arm_neon-inl.h:2409
typename detail::CappedTagChecker< T, kLimit >::type CappedTag
Definition: ops/shared-inl.h:172
HWY_API constexpr size_t Lanes(Simd< T, N, kPow2 >)
Definition: arm_sve-inl.h:236
HWY_API Vec128< T, N > Load(Simd< T, N, 0 > d, const T *HWY_RESTRICT p)
Definition: arm_neon-inl.h:2706
void Sort(D d, Traits st, T *HWY_RESTRICT keys, size_t num, T *HWY_RESTRICT buf)
Definition: vqsort-inl.h:677
HWY_API void StoreU(const Vec128< uint8_t > v, Full128< uint8_t >, uint8_t *HWY_RESTRICT unaligned)
Definition: arm_neon-inl.h:2725
HWY_API Vec128< uint8_t > LoadU(Full128< uint8_t >, const uint8_t *HWY_RESTRICT unaligned)
Definition: arm_neon-inl.h:2544
HWY_API Vec128< T, N > Zero(Simd< T, N, 0 > d)
Definition: arm_neon-inl.h:1011
HWY_API size_t CompressBlendedStore(Vec128< T, N > v, Mask128< T, N > m, Simd< T, N, 0 > d, T *HWY_RESTRICT unaligned)
Definition: arm_neon-inl.h:5846
typename detail::FixedTagChecker< T, kNumLanes >::type FixedTag
Definition: ops/shared-inl.h:188
HWY_API void SafeCopyN(const size_t num, D d, const T *HWY_RESTRICT from, T *HWY_RESTRICT to)
Definition: generic_ops-inl.h:103
N
Definition: rvv-inl.h:1742
HWY_API size_t CompressStore(Vec128< T, N > v, const Mask128< T, N > mask, Simd< T, N, 0 > d, T *HWY_RESTRICT unaligned)
Definition: arm_neon-inl.h:5837
HWY_API void Store(Vec128< T, N > v, Simd< T, N, 0 > d, T *HWY_RESTRICT aligned)
Definition: arm_neon-inl.h:2882
const vfloat64m1_t v
Definition: rvv-inl.h:1742
decltype(Zero(D())) Vec
Definition: generic_ops-inl.h:32
Definition: aligned_allocator.h:27
HWY_INLINE HWY_ATTR_CACHE void Prefetch(const T *p)
Definition: cache_control.h:77
HWY_API size_t Num0BitsAboveMS1Bit_Nonzero32(const uint32_t x)
Definition: base.h:709
HWY_API size_t Num0BitsBelowLS1Bit_Nonzero32(const uint32_t x)
Definition: base.h:674
constexpr size_t CeilLog2(TI x)
Definition: base.h:777
#define HWY_MAX_BYTES
Definition: set_macros-inl.h:84
#define HWY_LANES(T)
Definition: set_macros-inl.h:85
#define HWY_ALIGN
Definition: set_macros-inl.h:83
#define HWY_NAMESPACE
Definition: set_macros-inl.h:82
Definition: arm_neon-inl.h:5318
Definition: contrib/sort/shared-inl.h:28
static constexpr size_t kMaxCols
Definition: contrib/sort/shared-inl.h:34
static constexpr size_t kMaxRows
Definition: contrib/sort/shared-inl.h:43
static constexpr HWY_INLINE size_t BaseCaseNum(size_t N)
Definition: contrib/sort/shared-inl.h:45
static constexpr size_t kMaxRowsLog2
Definition: contrib/sort/shared-inl.h:42
static constexpr size_t kPartitionUnroll
Definition: contrib/sort/shared-inl.h:54
static constexpr HWY_INLINE size_t LanesPerChunk(size_t sizeof_t, size_t N)
Definition: contrib/sort/shared-inl.h:68
HWY_AFTER_NAMESPACE()
HWY_BEFORE_NAMESPACE()