Argon 0.1.0
Loading...
Searching...
No Matches
argon_full.hpp
1#pragma once
2#include <array>
3#include <numeric>
4#include <type_traits>
5#include "arm_simd.hpp"
6#include "helpers.hpp"
7#include "helpers/argon_for.hpp"
8#include "vector.hpp"
9
10#ifdef __ARM_FEATURE_MVE
11#define simd mve
12#else
13#define simd neon
14#endif
15
16#ifdef __clang__
17#define ace [[nodiscard]] [[gnu::always_inline]] constexpr
18#else
19#define ace [[nodiscard]] [[gnu::always_inline]] inline
20#endif
21
27template <typename ScalarType>
28 requires std::is_arithmetic_v<ScalarType>
29class Argon : public argon::Vector<simd::Vec128_t<ScalarType>> {
31
32 public:
33 using vector_type = simd::Vec128_t<ScalarType>;
34 using lane_type = const argon::Lane<vector_type>;
35
36 static_assert(simd::is_quadword_v<vector_type>);
37
38 static constexpr size_t bytes = 16;
39 static constexpr size_t lanes = bytes / sizeof(ScalarType);
40
41 using argon::Vector<vector_type>::Vector;
45 ace Argon(std::array<ScalarType, 4> value_list) : T{T::Load(value_list.data())} {};
47 ace Argon(ArgonHalf<ScalarType> low, ArgonHalf<ScalarType> high) : T{Combine(low, high)} {};
48
49#ifndef ARGON_PLATFORM_MVE
50 ace Argon(argon::Lane<vector_type> b) : T{b} {};
51 ace Argon(argon::ConstLane<0, vector_type> b) : T{b} {};
52#endif
53
54 template <simd::is_vector_type intrinsic_type>
55 ace Argon(argon::Lane<intrinsic_type> b) : T{b} {};
56
60 template <typename new_scalar_type>
62 return simd::reinterpret<simd::Vec128_t<new_scalar_type>>(this->vec_);
63 }
64
65#ifndef ARGON_PLATFORM_MVE
70 return simd::combine(low, high);
71 }
72
76 auto rev_half = this->Reverse64bit();
77 return Combine(rev_half.GetHigh(), rev_half.GetLow());
78 }
79
82 template <typename U>
84 std::is_same_v<U, typename argon::helpers::NextSmaller<ScalarType>::type>
86 return simd::multiply_add_long(this->vec_, b, c);
87 }
88
89 template <typename U, typename C>
91 std::is_same_v<C, simd::Vec64_t<argon::helpers::NextSmaller_t<ScalarType>>>
93 return simd::multiply_add_long(this->vec_, b, c);
94 }
95
97 template <typename U>
99 std::is_same_v<U, typename argon::helpers::NextSmaller<ScalarType>::type>
101 return simd::multiply_add_long(this->vec_, b, c);
102 }
103
105 template <typename U>
107 std::is_same_v<U, typename argon::helpers::NextSmaller<ScalarType>::type>
109 return simd::multiply_add_long(this->vec_, b, c.vec(), c.lane());
110 }
111
113 template <typename U>
115 std::is_same_v<U, typename argon::helpers::NextSmaller<ScalarType>::type>
117 return simd::multiply_subtract_long(this->vec_, b, c);
118 }
119
121 template <typename U>
123 std::is_same_v<U, typename argon::helpers::NextSmaller<ScalarType>::type>
125 return simd::multiply_subtract_long(this->vec_, b, c);
126 }
127
129 template <typename U>
131 std::is_same_v<U, typename argon::helpers::NextSmaller<ScalarType>::type>
133 return simd::multiply_subtract_long(this->vec_, b, c.vec(), c.lane());
134 }
135
138 ace auto AddNarrow(Argon<ScalarType> b) const
140 {
141 auto result = simd::add_narrow(this->vec_, b);
142 return argon::helpers::ArgonFor_t<decltype(result)>{result};
143 }
144
148 {
149 auto result = simd::add_round_narrow(this->vec_, b);
150 return argon::helpers::ArgonFor_t<decltype(result)>{result};
151 }
152
156 {
157 auto result = simd::subtract_narrow(this->vec_, b);
158 return argon::helpers::ArgonFor_t<decltype(result)>{result};
159 }
160
164 {
165 auto result = simd::subtract_round_narrow(this->vec_, b);
166 return argon::helpers::ArgonFor_t<decltype(result)>{result};
167 }
168
171 template <size_t n>
173 ace auto ShiftRightNarrow() const {
174 auto result = simd::shift_right_narrow<n>(this->vec_);
175 return argon::helpers::ArgonFor_t<decltype(result)>{result};
176 }
177
180 template <size_t n>
182 ace auto ShiftRightSaturateNarrow() const {
183 auto result = simd::shift_right_saturate_narrow<n>(this->vec_);
184 return argon::helpers::ArgonFor_t<decltype(result)>{result};
185 }
186
189 template <size_t n>
192 auto result = simd::shift_right_round_saturate_narrow<n>(this->vec_);
193 return argon::helpers::ArgonFor_t<decltype(result)>{result};
194 }
195
198 template <size_t n>
200 ace auto ShiftRightRoundNarrow() const {
201 auto result = simd::shift_right_round_narrow<n>(this->vec_);
202 return argon::helpers::ArgonFor_t<decltype(result)>{result};
203 }
204
206 ace auto Narrow() const
207 requires argon::helpers::has_smaller_v<ScalarType>
208 {
209 auto result = simd::move_narrow(this->vec_);
210 return argon::helpers::ArgonFor_t<decltype(result)>{result};
211 }
212
214 ace auto SaturateNarrow() const
215 requires argon::helpers::has_smaller_v<ScalarType>
216 {
217 auto result = simd::move_saturate_narrow(this->vec_);
218 return argon::helpers::ArgonFor_t<decltype(result)>{result};
219 }
220
222 template <typename NextSmallerType>
224 std::is_same_v<NextSmallerType, argon::helpers::NextSmaller_t<ScalarType>>
226 return neon::multiply_double_add_saturate_long(this->vec_, b, c);
227 }
228
230 ace ArgonHalf<ScalarType> GetHigh() const { return simd::get_high(this->vec_); }
232 ace ArgonHalf<ScalarType> GetLow() const { return simd::get_low(this->vec_); }
233#endif
234
237 template <typename U>
238 ace Argon<U> ConvertTo() const {
239 return simd::convert<typename simd::Vec128<U>::type>(this->vec_);
240 }
241
245 template <typename U, int fracbits>
246 requires(std::is_same_v<U, uint32_t> || std::is_same_v<U, int32_t> || std::is_same_v<U, float>)
247 ace Argon<U> ConvertTo() const {
248 if constexpr (std::is_same_v<U, float>) {
249 return simd::convert_n<fracbits>(this->vec_);
250 } else if constexpr (std::is_unsigned_v<U>) {
251 return simd::convert_n_unsigned<fracbits>(this->vec_);
252 } else if constexpr (std::is_signed_v<U>) {
253 return simd::convert_n_signed<fracbits>(this->vec_);
254 }
255 }
256
260 Argon<ScalarType> rev = this->Reverse64bit(); // rev within dword
261 return Argon{rev.GetHigh(), rev.GetLow()}; // swap dwords
262 }
263
268 template <typename CommutableOpType>
269 ScalarType Reduce(CommutableOpType op) {
270 auto rev = this->SwapDoublewords();
271 auto sum = op(*this, rev);
272 if constexpr (lanes == 16) {
273 sum = op(sum, sum.Reverse16bit());
274 }
275 if constexpr (lanes == 8 || lanes == 16) {
276 sum = op(sum, sum.Reverse32bit());
277 }
278 if constexpr (lanes == 4 || lanes == 8 || lanes == 16) {
279 sum = op(sum, sum.Reverse64bit());
280 }
281 return sum[0];
282 }
283
285 ScalarType ReduceAdd() {
286#ifdef __aarch64__
287 return simd::reduce_add(this->vec_);
288#else
289 return this->Reduce([](auto a, auto b) { return a + b; });
290#endif
291 }
292
294 ScalarType ReduceMax() {
295#ifdef __aarch64__
296 return simd::reduce_max(this->vec_);
297#else
298 return this->Reduce([](auto a, auto b) { return a.Max(b); });
299#endif
300 }
301
303 ScalarType ReduceMin() {
304#ifdef __aarch64__
305 return simd::reduce_min(this->vec_);
306#else
307 auto arr = this->to_array();
308 return std::reduce(arr.begin(), arr.end(), arr[0], [](auto a, auto b) { return std::min(a, b); });
309#endif
310 }
311
312#ifndef ARGON_PLATFORM_MVE
315#endif
316
317#if ARGON_HAS_CRYPTO && !defined(ARGON_PLATFORM_MVE)
322 ace Argon<ScalarType> AesEncrypt(Argon<ScalarType> key) const
323 requires std::is_same_v<ScalarType, uint8_t>
324 {
325 return neon::aes_encrypt(this->vec_, key.vec_);
326 }
327
332 ace Argon<ScalarType> AesDecrypt(Argon<ScalarType> key) const
333 requires std::is_same_v<ScalarType, uint8_t>
334 {
335 return neon::aes_decrypt(this->vec_, key.vec_);
336 }
337
341 ace Argon<ScalarType> AesMixColumns() const
342 requires std::is_same_v<ScalarType, uint8_t>
343 {
344 return neon::aes_mix_columns(this->vec_);
345 }
346
350 ace Argon<ScalarType> AesInverseMixColumns() const
351 requires std::is_same_v<ScalarType, uint8_t>
352 {
353 return neon::aes_inverse_mix_columns(this->vec_);
354 }
355#endif
356};
357
358template <typename... arg_types>
359 requires(sizeof...(arg_types) > 1)
360// Argon(arg_types...) -> Argon<arg_types...[0]>;
361Argon(arg_types...) -> Argon<std::tuple_element_t<0, std::tuple<arg_types...>>>;
362
363#ifndef ARGON_PLATFORM_MVE
364template <typename VectorType>
366
367template <typename VectorType>
369#endif
370
371template <typename ScalarType>
372 requires std::is_scalar_v<ScalarType>
373Argon(ScalarType) -> Argon<ScalarType>;
374
375template <typename V>
376 requires std::is_scalar_v<V>
377ace Argon<V> operator+(const V a, const Argon<V> b) {
378 return b.Add(a);
379}
380
381template <typename V>
382 requires std::is_scalar_v<V>
383ace Argon<V> operator-(const V a, const Argon<V> b) {
384 return Argon<V>{a}.Subtract(b);
385}
386
387template <typename V>
388 requires std::is_scalar_v<V>
389ace Argon<V> operator*(const V a, const Argon<V> b) {
390 return b.Multiply(a);
391}
392
393template <typename V>
394 requires std::is_scalar_v<V>
395ace Argon<V> operator/(const V a, const Argon<V> b) {
396 return Argon<V>{a}.Divide(b);
397}
398
399namespace std {
400
401template <typename T>
402struct tuple_size<Argon<T>> {
403 static constexpr size_t value = Argon<T>::lanes;
404};
405
406template <size_t Index, typename T>
407struct tuple_element<Index, Argon<T>> {
408 static_assert(Index < Argon<T>::lanes);
410};
411} // namespace std
412
413#undef ace
414#undef simd
Provides utility templates and concepts for type traits and compile-time iteration.
Definition argon_half.hpp:11
A 128-bit SIMD vector wrapping a scalar type, providing arithmetic, logical, and data-movement operat...
Definition argon_full.hpp:29
ace auto ShiftRightRoundNarrow() const
Shift right, round, and narrow.
Definition argon_full.hpp:200
ace Argon< ScalarType > MultiplyAddLong(ArgonHalf< U > b, typename ArgonHalf< U >::lane_type c)
Multiply a half-vector by a lane and add the widened result (vector × lane).
Definition argon_full.hpp:108
ace Argon< ScalarType > MultiplyAddLong(ArgonHalf< U > b, U c)
Multiply a half-vector by a scalar and add the widened result (vector × scalar).
Definition argon_full.hpp:100
ScalarType Reduce(CommutableOpType op)
Fold all lanes into a single scalar using a commutative binary operation.
Definition argon_full.hpp:269
ScalarType ReduceAdd()
Sum all lanes and return the scalar result.
Definition argon_full.hpp:285
ace Argon< ScalarType > MultiplySubtractLong(ArgonHalf< U > b, typename ArgonHalf< U >::lane_type c)
Multiply a half-vector by a lane and subtract the widened result (vector × lane).
Definition argon_full.hpp:132
ace Argon< ScalarType > Reverse() const
Reverse the order of all elements in the 128-bit vector.
Definition argon_full.hpp:259
ace Argon< ScalarType > SwapDoublewords()
Swap the upper and lower 64-bit halves of the vector.
Definition argon_full.hpp:314
ScalarType ReduceMax()
Return the maximum value across all lanes.
Definition argon_full.hpp:294
ace Argon< ScalarType > MultiplySubtractLong(ArgonHalf< U > b, ArgonHalf< U > c)
Multiply two narrower half-vectors and subtract the widened result from this vector (vector × vector)...
Definition argon_full.hpp:116
ace Argon(ArgonHalf< ScalarType > low, ArgonHalf< ScalarType > high)
Construct by combining a low and high 64-bit half-vector.
Definition argon_full.hpp:47
ace auto Narrow() const
Truncate each lane to the next-smaller element type (no saturation).
Definition argon_full.hpp:206
ace auto AddRoundNarrow(Argon< ScalarType > b) const
Add, round, and narrow: add b, round the result, and truncate to the next-smaller type.
Definition argon_full.hpp:146
ace auto ShiftRightSaturateNarrow() const
Shift right, saturate, and narrow: shift each lane right by n bits with unsigned saturation.
Definition argon_full.hpp:182
ace Argon(argon::Vector< vector_type > vec)
Construct from an underlying argon::Vector.
Definition argon_full.hpp:43
ace Argon< U > ConvertTo() const
Convert each lane to a different element type.
Definition argon_full.hpp:238
static ace Argon< next_larger > Combine(ArgonHalf< next_larger > low, ArgonHalf< next_larger > high)
Definition argon_full.hpp:69
ScalarType ReduceMin()
Return the minimum value across all lanes.
Definition argon_full.hpp:303
ace auto SaturateNarrow() const
Saturate and narrow each lane to the next-smaller element type.
Definition argon_full.hpp:214
ace Argon< new_scalar_type > As() const
Reinterpret the vector bits as a vector of a different element type.
Definition argon_full.hpp:61
ace ArgonHalf< ScalarType > GetLow() const
Return the lower 64 bits as an ArgonHalf.
Definition argon_full.hpp:232
ace Argon< ScalarType > Reverse()
Reverse the order of all elements across the full 128-bit vector.
Definition argon_full.hpp:75
ace Argon< ScalarType > MultiplyAddLong(ArgonHalf< U > b, C c)
Multiply two narrower vectors and add the widened result (vector × raw intrinsic).
Definition argon_full.hpp:92
ace ArgonHalf< ScalarType > GetHigh() const
Return the upper 64 bits as an ArgonHalf.
Definition argon_full.hpp:230
ace Argon< ScalarType > MultiplyDoubleAddSaturateLong(ArgonHalf< NextSmallerType > b, ArgonHalf< NextSmallerType > c)
Multiply, double, add, and saturate long: this + saturate(2 * b * c) widening to ScalarType.
Definition argon_full.hpp:225
ace auto ShiftRightNarrow() const
Shift each lane right by n bits and narrow the result to the next-smaller type.
Definition argon_full.hpp:173
ace Argon(std::array< ScalarType, 4 > value_list)
Construct from a four-element array (loaded as a 128-bit vector).
Definition argon_full.hpp:45
ace Argon< ScalarType > MultiplySubtractLong(ArgonHalf< U > b, U c)
Multiply a half-vector by a scalar and subtract the widened result (vector × scalar).
Definition argon_full.hpp:124
ace auto SubtractRoundNarrow(Argon< ScalarType > b) const
Subtract, round, and narrow: subtract b, round the result, and truncate to the next-smaller type.
Definition argon_full.hpp:162
ace Argon< ScalarType > MultiplyAddLong(ArgonHalf< U > b, ArgonHalf< U > c)
Multiply two narrower half-vectors and add the widened result to this vector (vector × vector).
Definition argon_full.hpp:85
ace auto AddNarrow(Argon< ScalarType > b) const
Add and narrow: add b to this vector and truncate each lane to the next-smaller type.
Definition argon_full.hpp:138
ace auto SubtractNarrow(Argon< ScalarType > b) const
Subtract and narrow: subtract b and truncate each lane to the next-smaller type.
Definition argon_full.hpp:154
ace Argon< U > ConvertTo() const
Convert each lane to a different type using a fixed-point fractional bit count.
Definition argon_full.hpp:247
ace auto ShiftRightRoundSaturateNarrow() const
Shift right, round, saturate, and narrow.
Definition argon_full.hpp:191
Represents a single lane of a SIMD vector with the lane index known at compile time.
Definition lane.hpp:46
Represents a single lane of a SIMD vector with a runtime-determined index.
Definition lane.hpp:116
Represents a SIMD vector with various operations.
Definition vector.hpp:50
constexpr VectorType vec() const
Get the underlying SIMD vector.
Definition vector.hpp:275
ace argon_type Multiply(argon_type b) const
Multiply two vectors.
Definition vector.hpp:418
static ace argon_type Load(const scalar_type *ptr)
Load a vector from a pointer.
Definition vector.hpp:863
ace argon_type Divide(argon_type b) const
Divide two vectors.
Definition vector.hpp:679
ace std::array< scalar_type, lanes > to_array()
Definition vector.hpp:282
ace argon_type Add(argon_type b) const
Add two vectors.
Definition vector.hpp:366
ace argon_type Subtract(argon_type b) const
Subtract two vectors.
Definition vector.hpp:388
typename ArgonFor< std::remove_cv_t< T > >::type ArgonFor_t
Helper alias to get the Argon type for a given vector type.
Definition argon_for.hpp:45
constexpr bool has_smaller_v
Helper template to determine if a type has a smaller corresponding type.
Definition helpers.hpp:18
Lane deconstruction feature.
Definition argon_full.hpp:399