Argon 0.1.0
Loading...
Searching...
No Matches
load_interleaved.hpp
1#pragma once
2#include <ranges>
3#include "argon/argon_full.hpp"
4#include "arm_simd/helpers/vec128.hpp"
5
6#ifdef __ARM_FEATURE_MVE
7#define simd mve
8#else
9#define simd neon
10#endif
11
12namespace argon::vectorize {
13
14template <typename ScalarType, size_t Stride = 2>
15struct load_interleaved : std::ranges::view_interface<load_interleaved<ScalarType, Stride>> {
16 using intrinsic_type = simd::Vec128_t<ScalarType>;
17 static constexpr size_t lanes = sizeof(intrinsic_type) / sizeof(ScalarType);
18 static constexpr size_t vectorizeable_size(size_t size) { return size & ~(lanes - 1); }
19
20 static_assert(Stride > 1 && Stride < 5, "Interleaving Loads can only be performed with a stride of 2, 3, or 4");
21
22 public:
23 struct LoadInterleavedIterator {
24 using iterator_category = std::bidirectional_iterator_tag;
25 using difference_type = std::ptrdiff_t;
26 using reference_type = LoadInterleavedIterator&;
27 using value_type = std::array<Argon<ScalarType>, Stride>;
28
29 LoadInterleavedIterator() = default;
30 LoadInterleavedIterator(const ScalarType* ptr) : ptr_{ptr} {}
31
32 std::array<Argon<ScalarType>, Stride> operator*() const {
33 return Argon<ScalarType>::template LoadInterleaved<Stride>(ptr_);
34 }
35
36 LoadInterleavedIterator& operator+=(int n) {
37 ptr_ += n * lanes * Stride;
38 return *this;
39 }
40
41 LoadInterleavedIterator& operator-=(int n) {
42 ptr_ -= n * lanes * Stride;
43 return *this;
44 }
45
46 LoadInterleavedIterator operator++(int) {
47 LoadInterleavedIterator tmp = *this;
48 ++(*this);
49 return tmp;
50 }
51
52 LoadInterleavedIterator& operator++() {
53 ptr_ += lanes * Stride;
54 return *this;
55 }
56
57 LoadInterleavedIterator operator--(int) {
58 LoadInterleavedIterator tmp = *this;
59 --(*this);
60 return tmp;
61 }
62
63 LoadInterleavedIterator& operator--() {
64 ptr_ -= lanes * Stride;
65 return *this;
66 }
67
68 LoadInterleavedIterator operator+(int n) const {
69 LoadInterleavedIterator it = *this;
70 it += n;
71 return it;
72 }
73
74 LoadInterleavedIterator operator-(int n) const {
75 LoadInterleavedIterator it = *this;
76 it -= n;
77 return it;
78 }
79
80 LoadInterleavedIterator& operator[](int n) const { return *(*this + n); }
81
82 difference_type operator-(const LoadInterleavedIterator& other) const { return ptr_ - other.ptr_; }
83
84 friend bool operator==(const LoadInterleavedIterator& a, const LoadInterleavedIterator& b) {
85 return a.ptr_ == b.ptr_;
86 }
87 friend bool operator==(const LoadInterleavedIterator& a, const ScalarType* ptr) { return a.ptr_ == ptr; }
88 friend bool operator!=(const LoadInterleavedIterator& a, const LoadInterleavedIterator& b) {
89 return a.ptr_ != b.ptr_;
90 }
91 friend bool operator!=(const LoadInterleavedIterator& a, const ScalarType* ptr) { return a.ptr_ != ptr; }
92 friend auto operator<=>(const LoadInterleavedIterator& a, const LoadInterleavedIterator& b) {
93 return a.ptr_ <=> b.ptr_;
94 }
95
96 friend LoadInterleavedIterator operator+(const int n, const LoadInterleavedIterator& it) { return it + n; }
97
98 private:
99 const ScalarType* ptr_;
100 };
101 static_assert(std::sized_sentinel_for<LoadInterleavedIterator, LoadInterleavedIterator>);
102 static_assert(std::bidirectional_iterator<LoadInterleavedIterator>);
103 static_assert(std::input_iterator<LoadInterleavedIterator>);
104
105 using iterator = LoadInterleavedIterator;
106 using sentinel = const ScalarType*;
107
108 iterator begin() { return start_; }
109 const ScalarType* end() { return start_ + size_; }
110 size_t size() const { return size_ / (lanes * Stride); }
111
112 template <std::ranges::contiguous_range R>
113 load_interleaved(R&& r) : start_{&*std::ranges::begin(r)}, size_{vectorizeable_size(std::ranges::size(r))} {}
114
115 private:
116 const ScalarType* start_;
117 size_t size_;
118};
119
120static_assert(std::ranges::range<load_interleaved<int32_t, 2>>);
121static_assert(std::ranges::view<load_interleaved<int32_t, 2>>);
122
123template <std::ranges::contiguous_range R, size_t stride = 2>
125
126} // namespace argon::vectorize
127#undef simd
Definition argon_full.hpp:24
Lane deconstruction feature.
Definition argon_full.hpp:302
Definition load_interleaved.hpp:15
Definition ptr.hpp:74