Argon 0.1.0
Loading...
Searching...
No Matches
lane_helpers.hpp
1#pragma once
2#include <utility>
3#include "argon/features.h"
4#include "concepts.hpp"
5#include "scalar.hpp"
6
7#ifdef __ARM_FEATURE_MVE
8#define simd mve
9#else
10#define simd neon
11#endif
12
13#ifdef ARGON_PLATFORM_SIMDE
14#define nce
15#elifdef __clang__
16#define nce [[gnu::always_inline]] constexpr
17#else
18#define nce [[gnu::always_inline]] inline
19#endif
20
21#define make_lane_helper_dword_1arg(lane_func) \
22 template <is_vector_type T, is_doubleword U> \
23 requires std::is_same_v<Scalar_t<T>, Scalar_t<U>> \
24 nce T lane_func(T a, U vec, const int lane) { \
25 constexpr int lanes = sizeof(U) / sizeof(Scalar_t<U>); \
26 if constexpr (lanes == 2) { \
27 switch (lane) { \
28 case 0: \
29 return lane_func<0>(a, vec); \
30 case 1: \
31 return lane_func<1>(a, vec); \
32 default: \
33 std::unreachable(); \
34 } \
35 } else if constexpr (lanes == 4) { \
36 switch (lane) { \
37 case 0: \
38 return lane_func<0>(a, vec); \
39 case 1: \
40 return lane_func<1>(a, vec); \
41 case 2: \
42 return lane_func<2>(a, vec); \
43 case 3: \
44 return lane_func<3>(a, vec); \
45 default: \
46 std::unreachable(); \
47 } \
48 } else if constexpr (lanes == 8) { \
49 switch (lane) { \
50 case 0: \
51 return lane_func<0>(a, vec); \
52 case 1: \
53 return lane_func<1>(a, vec); \
54 case 2: \
55 return lane_func<2>(a, vec); \
56 case 3: \
57 return lane_func<3>(a, vec); \
58 case 4: \
59 return lane_func<4>(a, vec); \
60 case 5: \
61 return lane_func<5>(a, vec); \
62 case 6: \
63 return lane_func<6>(a, vec); \
64 case 7: \
65 return lane_func<7>(a, vec); \
66 default: \
67 std::unreachable(); \
68 } \
69 } \
70 }
71
72#define make_lane_helper_dword_2arg(lane_func) \
73 template <is_vector_type T, is_doubleword U> \
74 requires std::is_same_v<Scalar_t<T>, Scalar_t<U>> \
75 nce T lane_func(T a, T b, U vec, const int lane) { \
76 constexpr int lanes = sizeof(U) / sizeof(Scalar_t<U>); \
77 if constexpr (lanes == 2) { \
78 switch (lane) { \
79 case 0: \
80 return lane_func<0>(a, b, vec); \
81 case 1: \
82 return lane_func<1>(a, b, vec); \
83 default: \
84 std::unreachable(); \
85 } \
86 } else if constexpr (lanes == 4) { \
87 switch (lane) { \
88 case 0: \
89 return lane_func<0>(a, b, vec); \
90 case 1: \
91 return lane_func<1>(a, b, vec); \
92 case 2: \
93 return lane_func<2>(a, b, vec); \
94 case 3: \
95 return lane_func<3>(a, b, vec); \
96 default: \
97 std::unreachable(); \
98 } \
99 } else if constexpr (lanes == 8) { \
100 switch (lane) { \
101 case 0: \
102 return lane_func<0>(a, b, vec); \
103 case 1: \
104 return lane_func<1>(a, b, vec); \
105 case 2: \
106 return lane_func<2>(a, b, vec); \
107 case 3: \
108 return lane_func<3>(a, b, vec); \
109 case 4: \
110 return lane_func<4>(a, b, vec); \
111 case 5: \
112 return lane_func<5>(a, b, vec); \
113 case 6: \
114 return lane_func<6>(a, b, vec); \
115 case 7: \
116 return lane_func<7>(a, b, vec); \
117 default: \
118 std::unreachable(); \
119 } \
120 } \
121 }
122
123#define make_lane_helper_qword_1arg(lane_func) \
124 template <is_vector_type T, is_vector_type U> \
125 requires std::is_same_v<Scalar_t<T>, Scalar_t<U>> \
126 nce T lane_func(T a, U vec, const int lane) { \
127 constexpr int lanes = sizeof(U) / sizeof(Scalar_t<U>); \
128 if constexpr (lanes == 2) { \
129 switch (lane) { \
130 case 0: \
131 return lane_func<0>(a, vec); \
132 case 1: \
133 return lane_func<1>(a, vec); \
134 default: \
135 std::unreachable(); \
136 } \
137 } else if constexpr (lanes == 4) { \
138 switch (lane) { \
139 case 0: \
140 return lane_func<0>(a, vec); \
141 case 1: \
142 return lane_func<1>(a, vec); \
143 case 2: \
144 return lane_func<2>(a, vec); \
145 case 3: \
146 return lane_func<3>(a, vec); \
147 default: \
148 std::unreachable(); \
149 } \
150 } else if constexpr (lanes == 8) { \
151 switch (lane) { \
152 case 0: \
153 return lane_func<0>(a, vec); \
154 case 1: \
155 return lane_func<1>(a, vec); \
156 case 2: \
157 return lane_func<2>(a, vec); \
158 case 3: \
159 return lane_func<3>(a, vec); \
160 case 4: \
161 return lane_func<4>(a, vec); \
162 case 5: \
163 return lane_func<5>(a, vec); \
164 case 6: \
165 return lane_func<6>(a, vec); \
166 case 7: \
167 return lane_func<7>(a, vec); \
168 default: \
169 std::unreachable(); \
170 } \
171 } else if constexpr (lanes == 16) { \
172 switch (lane) { \
173 case 0: \
174 return lane_func<0>(a, vec); \
175 case 1: \
176 return lane_func<1>(a, vec); \
177 case 2: \
178 return lane_func<2>(a, vec); \
179 case 3: \
180 return lane_func<3>(a, vec); \
181 case 4: \
182 return lane_func<4>(a, vec); \
183 case 5: \
184 return lane_func<5>(a, vec); \
185 case 6: \
186 return lane_func<6>(a, vec); \
187 case 7: \
188 return lane_func<7>(a, vec); \
189 case 8: \
190 return lane_func<8>(a, vec); \
191 case 9: \
192 return lane_func<9>(a, vec); \
193 case 10: \
194 return lane_func<10>(a, vec); \
195 case 11: \
196 return lane_func<11>(a, vec); \
197 case 12: \
198 return lane_func<12>(a, vec); \
199 case 13: \
200 return lane_func<13>(a, vec); \
201 case 14: \
202 return lane_func<14>(a, vec); \
203 case 15: \
204 return lane_func<15>(a, vec); \
205 default: \
206 std::unreachable(); \
207 } \
208 } \
209 }
210
211#define make_lane_helper_qword_2arg(lane_func) \
212 template <is_vector_type T, is_vector_type U> \
213 requires std::is_same_v<Scalar_t<T>, Scalar_t<U>> \
214 nce T lane_func(T a, T b, U vec, const int lane) { \
215 constexpr int lanes = sizeof(U) / sizeof(Scalar_t<U>); \
216 if constexpr (lanes == 2) { \
217 switch (lane) { \
218 case 0: \
219 return lane_func<0>(a, b, vec); \
220 case 1: \
221 return lane_func<1>(a, b, vec); \
222 default: \
223 std::unreachable(); \
224 } \
225 } else if constexpr (lanes == 4) { \
226 switch (lane) { \
227 case 0: \
228 return lane_func<0>(a, b, vec); \
229 case 1: \
230 return lane_func<1>(a, b, vec); \
231 case 2: \
232 return lane_func<2>(a, b, vec); \
233 case 3: \
234 return lane_func<3>(a, b, vec); \
235 default: \
236 std::unreachable(); \
237 } \
238 } else if constexpr (lanes == 8) { \
239 switch (lane) { \
240 case 0: \
241 return lane_func<0>(a, b, vec); \
242 case 1: \
243 return lane_func<1>(a, b, vec); \
244 case 2: \
245 return lane_func<2>(a, b, vec); \
246 case 3: \
247 return lane_func<3>(a, b, vec); \
248 case 4: \
249 return lane_func<4>(a, b, vec); \
250 case 5: \
251 return lane_func<5>(a, b, vec); \
252 case 6: \
253 return lane_func<6>(a, b, vec); \
254 case 7: \
255 return lane_func<7>(a, b, vec); \
256 default: \
257 std::unreachable(); \
258 } \
259 } else if constexpr (lanes == 16) { \
260 switch (lane) { \
261 case 0: \
262 return lane_func<0>(a, b, vec); \
263 case 1: \
264 return lane_func<1>(a, b, vec); \
265 case 2: \
266 return lane_func<2>(a, b, vec); \
267 case 3: \
268 return lane_func<3>(a, b, vec); \
269 case 4: \
270 return lane_func<4>(a, b, vec); \
271 case 5: \
272 return lane_func<5>(a, b, vec); \
273 case 6: \
274 return lane_func<6>(a, b, vec); \
275 case 7: \
276 return lane_func<7>(a, b, vec); \
277 case 8: \
278 return lane_func<8>(a, b, vec); \
279 case 9: \
280 return lane_func<9>(a, b, vec); \
281 case 10: \
282 return lane_func<10>(a, b, vec); \
283 case 11: \
284 return lane_func<11>(a, b, vec); \
285 case 12: \
286 return lane_func<12>(a, b, vec); \
287 case 13: \
288 return lane_func<13>(a, b, vec); \
289 case 14: \
290 return lane_func<14>(a, b, vec); \
291 case 15: \
292 return lane_func<15>(a, b, vec); \
293 default: \
294 std::unreachable(); \
295 } \
296 } \
297 }
298
299namespace simd {
300#ifndef ARGON_PLATFORM_MVE
301make_lane_helper_dword_1arg(multiply_lane);
302make_lane_helper_dword_1arg(multiply_long_lane);
303make_lane_helper_dword_1arg(multiply_double_saturate_long_lane);
304make_lane_helper_dword_1arg(multiply_double_saturate_high_lane);
305make_lane_helper_dword_1arg(multiply_double_round_saturate_high_lane);
306make_lane_helper_dword_2arg(multiply_add_lane);
307make_lane_helper_dword_2arg(multiply_subtract_lane);
308make_lane_helper_dword_2arg(multiply_add_long_lane);
309make_lane_helper_dword_2arg(multiply_subtract_long_lane);
310make_lane_helper_dword_2arg(multiply_double_add_saturate_long_lane);
311make_lane_helper_dword_2arg(multiply_double_subtract_saturate_long_lane);
312#endif
313
314#ifdef __aarch64__
315make_lane_helper_qword_1arg(multiply_lane);
316make_lane_helper_qword_1arg(multiply_long_lane);
317make_lane_helper_qword_1arg(multiply_double_saturate_long_lane);
318make_lane_helper_qword_1arg(multiply_double_saturate_high_lane);
319make_lane_helper_qword_1arg(multiply_double_round_saturate_high_lane);
320make_lane_helper_qword_2arg(multiply_add_lane);
321make_lane_helper_qword_2arg(multiply_subtract_lane);
322make_lane_helper_qword_2arg(multiply_add_long_lane);
323make_lane_helper_qword_2arg(multiply_subtract_long_lane);
324make_lane_helper_qword_2arg(multiply_double_add_saturate_long_lane);
325make_lane_helper_qword_2arg(multiply_double_subtract_saturate_long_lane);
326
327#endif
328} // namespace simd
329#undef simd
330#undef nce
331#undef make_lane_helper_dword_1arg
332#undef make_lane_helper_qword_1arg
333#undef make_lane_helper_dword_2arg
334#undef make_lane_helper_qword_2arg
Header file for SIMD features and platform detection.