19#if defined(HIGHWAY_HWY_CONTRIB_DOT_DOT_INL_H_) == \
20 defined(HWY_TARGET_TOGGLE)
21#ifdef HIGHWAY_HWY_CONTRIB_DOT_DOT_INL_H_
22#undef HIGHWAY_HWY_CONTRIB_DOT_DOT_INL_H_
24#define HIGHWAY_HWY_CONTRIB_DOT_DOT_INL_H_
51 template <
int kAssumptions,
class D,
typename T = TFromD<D>,
52 HWY_IF_NOT_LANE_SIZE_D(D, 2)>
55 const size_t num_elements) {
56 static_assert(IsFloat<T>(),
"MulAdd requires float type");
57 using V =
decltype(
Zero(
d));
62 constexpr bool kIsAtLeastOneVector =
64 constexpr bool kIsMultipleOfVector =
66 constexpr bool kIsPaddedToVector = (kAssumptions &
kPaddedToVector) != 0;
69 if (!kIsAtLeastOneVector && !kIsMultipleOfVector && !kIsPaddedToVector &&
74 for (; i + 2 <= num_elements; i += 2) {
75 sum0 += pa[i + 0] * pb[i + 0];
76 sum1 += pa[i + 1] * pb[i + 1];
78 if (i < num_elements) {
79 sum1 += pa[i] * pb[i];
95 for (; i + 4 *
N <= num_elements; ) {
96 const auto a0 =
LoadU(
d, pa + i);
97 const auto b0 =
LoadU(
d, pb + i);
99 sum0 =
MulAdd(a0, b0, sum0);
100 const auto a1 =
LoadU(
d, pa + i);
101 const auto b1 =
LoadU(
d, pb + i);
103 sum1 =
MulAdd(a1, b1, sum1);
104 const auto a2 =
LoadU(
d, pa + i);
105 const auto b2 =
LoadU(
d, pb + i);
107 sum2 =
MulAdd(a2, b2, sum2);
108 const auto a3 =
LoadU(
d, pa + i);
109 const auto b3 =
LoadU(
d, pb + i);
111 sum3 =
MulAdd(a3, b3, sum3);
115 for (; i +
N <= num_elements; i +=
N) {
116 const auto a =
LoadU(
d, pa + i);
117 const auto b =
LoadU(
d, pb + i);
118 sum0 =
MulAdd(a, b, sum0);
121 if (!kIsMultipleOfVector) {
122 const size_t remaining = num_elements - i;
123 if (remaining != 0) {
124 if (kIsPaddedToVector) {
125 const auto mask =
FirstN(
d, remaining);
126 const auto a =
LoadU(
d, pa + i);
127 const auto b =
LoadU(
d, pb + i);
135 const auto skip =
FirstN(
d,
N - remaining);
136 const auto a =
LoadU(
d, pa + i);
137 const auto b =
LoadU(
d, pb + i);
144 sum0 =
Add(sum0, sum1);
145 sum2 =
Add(sum2, sum3);
146 sum0 =
Add(sum0, sum2);
152 template <
int kAssumptions,
class D>
156 const size_t num_elements) {
160 using V =
decltype(
Zero(df32));
164 constexpr bool kIsAtLeastOneVector =
166 constexpr bool kIsMultipleOfVector =
168 constexpr bool kIsPaddedToVector = (kAssumptions &
kPaddedToVector) != 0;
171 if (!kIsAtLeastOneVector && !kIsMultipleOfVector && !kIsPaddedToVector &&
175 for (; i + 2 <= num_elements; i += 2) {
179 if (i < num_elements) {
193 for (; i + 2 *
N <= num_elements; ) {
194 const auto a0 =
LoadU(
d, pa + i);
195 const auto b0 =
LoadU(
d, pb + i);
198 const auto a1 =
LoadU(
d, pa + i);
199 const auto b1 =
LoadU(
d, pb + i);
205 if (i +
N <= num_elements) {
206 const auto a0 =
LoadU(
d, pa + i);
207 const auto b0 =
LoadU(
d, pb + i);
212 if (!kIsMultipleOfVector) {
213 const size_t remaining = num_elements - i;
214 if (remaining != 0) {
215 if (kIsPaddedToVector) {
216 const auto mask =
FirstN(du16, remaining);
217 const auto va =
LoadU(
d, pa + i);
218 const auto vb =
LoadU(
d, pb + i);
229 const auto skip =
FirstN(du16,
N - remaining);
230 const auto va =
LoadU(
d, pa + i);
231 const auto vb =
LoadU(
d, pb + i);
240 sum0 =
Add(sum0, sum1);
241 sum2 =
Add(sum2, sum3);
242 sum0 =
Add(sum0, sum2);
#define HWY_RESTRICT
Definition: base.h:61
#define HWY_INLINE
Definition: base.h:62
#define HWY_DASSERT(condition)
Definition: base.h:191
#define HWY_UNLIKELY(expr)
Definition: base.h:67
d
Definition: rvv-inl.h:1742
HWY_API Mask128< T, N > FirstN(const Simd< T, N, 0 > d, size_t num)
Definition: arm_neon-inl.h:2409
HWY_API Vec128< float, N > MulAdd(const Vec128< float, N > mul, const Vec128< float, N > x, const Vec128< float, N > add)
Definition: arm_neon-inl.h:1784
HWY_API Vec128< T, N > SumOfLanes(Simd< T, N, 0 >, const Vec128< T, N > v)
Definition: arm_neon-inl.h:4932
Rebind< MakeUnsigned< TFromD< D > >, D > RebindToUnsigned
Definition: ops/shared-inl.h:200
HWY_API Vec128< float, N > ReorderWidenMulAccumulate(Simd< float, N, 0 > df32, Vec128< bfloat16_t, 2 *N > a, Vec128< bfloat16_t, 2 *N > b, const Vec128< float, N > sum0, Vec128< float, N > &sum1)
Definition: arm_neon-inl.h:4203
HWY_API Vec128< T, N > IfThenElseZero(const Mask128< T, N > mask, const Vec128< T, N > yes)
Definition: arm_neon-inl.h:2212
HWY_API V Add(V a, V b)
Definition: arm_neon-inl.h:6274
HWY_API constexpr size_t Lanes(Simd< T, N, kPow2 >)
Definition: arm_sve-inl.h:236
HWY_API Vec128< uint8_t > LoadU(Full128< uint8_t >, const uint8_t *HWY_RESTRICT unaligned)
Definition: arm_neon-inl.h:2544
HWY_API Vec128< T, N > BitCast(Simd< T, N, 0 > d, Vec128< FromT, N *sizeof(T)/sizeof(FromT)> v)
Definition: arm_neon-inl.h:988
HWY_API Vec128< T, N > Zero(Simd< T, N, 0 > d)
Definition: arm_neon-inl.h:1011
HWY_API Vec128< T, N > IfThenZeroElse(const Mask128< T, N > mask, const Vec128< T, N > no)
Definition: arm_neon-inl.h:2219
HWY_API TFromV< V > GetLane(const V v)
Definition: arm_neon-inl.h:1061
typename D::template Repartition< T > Repartition
Definition: ops/shared-inl.h:206
N
Definition: rvv-inl.h:1742
Definition: aligned_allocator.h:27
HWY_API float F32FromBF16(bfloat16_t bf)
Definition: base.h:831
#define HWY_NAMESPACE
Definition: set_macros-inl.h:82
static HWY_INLINE T Compute(const D d, const T *const HWY_RESTRICT pa, const T *const HWY_RESTRICT pb, const size_t num_elements)
Definition: dot-inl.h:53
static HWY_INLINE float Compute(const D d, const bfloat16_t *const HWY_RESTRICT pa, const bfloat16_t *const HWY_RESTRICT pb, const size_t num_elements)
Definition: dot-inl.h:153
Assumptions
Definition: dot-inl.h:37
@ kMultipleOfVector
Definition: dot-inl.h:43
@ kPaddedToVector
Definition: dot-inl.h:46
@ kAtLeastOneVector
Definition: dot-inl.h:39