Argon 0.1.0
Loading...
Searching...
No Matches
interleaved.hpp
1#pragma once
2#include <cstddef>
3#include <cstdint>
4#include <iterator>
5#include <ranges>
6#include <span>
7#include "argon.hpp"
8#include "argon/vectorize.hpp"
9#include "arm_simd/helpers/vec128.hpp"
10
11#ifdef __ARM_FEATURE_MVE
12#define simd mve
13#else
14#define simd neon
15#endif
16
17namespace argon::vectorize {
18
19template <size_t Stride, typename ScalarType>
20struct interleaved : public std::ranges::view_interface<interleaved<Stride, ScalarType>> {
21 using intrinsic_type = simd::Vec128_t<ScalarType>;
22 static constexpr size_t lanes = sizeof(intrinsic_type) / sizeof(ScalarType);
23 static constexpr size_t vectorizeable_size(size_t size) { return size & ~(lanes - 1); }
24
25 public:
26 struct Iterator {
27 using iterator_category = std::forward_iterator_tag;
28 using argon_type = Argon<ScalarType>;
29 using value_type = std::array<argon_type, Stride>;
30 using difference_type = std::ptrdiff_t;
31
32 Iterator() = default;
33 Iterator(ScalarType* ptr) : ptr{ptr}, vec{argon_type::template LoadInterleaved<Stride>(ptr)} {}
34
35 value_type& operator*() { return vec; }
36 value_type* operator->() { return &vec; }
37 const value_type& operator*() const { return vec; }
38 const value_type* operator->() const { return &vec; }
39 Iterator& operator++() {
40 store_interleaved(ptr, vec); // store before increment
41 ptr += lanes;
42 vec = argon_type::template LoadInterleaved<Stride>(ptr);
43 return *this;
44 }
45
46 Iterator operator++(int) {
47 Iterator tmp = *this;
48 ++(*this);
49 return tmp;
50 }
51
52 friend bool operator==(const Iterator& a, const Iterator& b) { return a.ptr == b.ptr; }
53 friend bool operator==(const Iterator& a, const ScalarType* ptr) { return a.ptr == ptr; }
54 friend bool operator!=(const Iterator& a, const Iterator& b) { return a.ptr != b.ptr; }
55 friend bool operator!=(const Iterator& a, const ScalarType* ptr) { return a.ptr != ptr; }
56
57 private:
58 ScalarType* ptr = nullptr;
59 value_type vec;
60 };
61 static_assert(std::input_or_output_iterator<Iterator>);
62 struct ConstIterator {
63 using iterator_category = std::forward_iterator_tag;
64 using argon_type = Argon<ScalarType>;
65 using value_type = std::array<argon_type, Stride>;
66 using difference_type = std::ptrdiff_t;
67
68 ConstIterator() = default;
69 ConstIterator(const ScalarType* ptr) : ptr{ptr}, vec{argon_type::template LoadInterleaved<Stride>(ptr)} {}
70
71 const value_type operator*() const { return vec; }
72 ConstIterator& operator++() {
73 ptr += lanes;
74 vec = argon_type::template LoadInterleaved<Stride>(ptr);
75 return *this;
76 }
77 ConstIterator operator++(int) {
78 ConstIterator tmp = *this;
79 ++(*this);
80 return tmp;
81 }
82 friend bool operator==(const ConstIterator& a, const ConstIterator& b) { return a.ptr == b.ptr; }
83 friend bool operator!=(const ConstIterator& a, const ConstIterator& b) { return a.ptr != b.ptr; }
84
85 private:
86 const ScalarType* ptr = nullptr;
87 value_type vec;
88 };
89 static_assert(std::input_iterator<ConstIterator>);
90
91 using iterator = Iterator;
92 using const_iterator = ConstIterator;
93
94 interleaved(ScalarType* start, ScalarType* end) : start_{start}, size_{vectorizeable_size(end - start)} {};
95 interleaved(ScalarType* start, const size_t size) : start_{start}, size_{vectorizeable_size(size)} {};
96 interleaved(const std::span<ScalarType> span) : start_{span.data()}, size_{vectorizeable_size(span.size())} {};
97
98 template <size_t, std::ranges::contiguous_range R>
99 interleaved(R&& r) : start_{std::ranges::begin(r)}, size_{vectorizeable_size(std::ranges::size(r))} {}
100
101 iterator begin() const { return Iterator(start_); }
102 const ScalarType* end() const { return start_ + size_; }
103 const_iterator cbegin() const { return ConstIterator(start_); }
104 const ScalarType* cend() const { return start_ + size_; }
105 size_t size() const { return size_; }
106
107 private:
108 ScalarType* start_;
109 size_t size_;
110};
111
112// template <size_t stride, std::ranges::contiguous_range R>
113// interleaved(R&& r) -> interleaved<stride, std::ranges::range_value_t<R>>;
114
115static_assert(std::ranges::range<interleaved<3, int32_t>>);
116static_assert(std::ranges::view<interleaved<3, int32_t>>);
117static_assert(std::movable<interleaved<3, int32_t>>);
118static_assert(std::ranges::viewable_range<interleaved<3, int32_t>>);
119
120} // namespace argon::vectorize
121
122#undef simd
Definition argon_full.hpp:24
Lane deconstruction feature.
Definition argon_full.hpp:302
Definition interleaved.hpp:26
Definition interleaved.hpp:20
Definition store_interleaved.hpp:15