1//===- RaggedArray.h - 2D array with different inner lengths ----*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Support/LLVM.h"
10#include "llvm/ADT/STLExtras.h"
11#include "llvm/ADT/SmallVector.h"
12#include <iterator>
13
14namespace mlir {
15/// A 2D array where each row may have different length. Elements of each row
16/// are stored contiguously, but rows don't have a fixed order in the storage.
17template <typename T>
18class RaggedArray {
19public:
20 /// Returns the number of rows in the 2D array.
21 size_t size() const { return slices.size(); }
22
23 /// Returns true if the are no rows in the 2D array. Note that an array with a
24 /// non-zero number of empty rows is *NOT* empty.
25 bool empty() const { return slices.empty(); }
26
27 /// Accesses `pos`-th row.
28 ArrayRef<T> operator[](size_t pos) const { return at(pos); }
29 ArrayRef<T> at(size_t pos) const {
30 if (slices[pos].first == static_cast<size_t>(-1))
31 return ArrayRef<T>();
32 return ArrayRef<T>(storage).slice(slices[pos].first, slices[pos].second);
33 }
34 MutableArrayRef<T> operator[](size_t pos) { return at(pos); }
35 MutableArrayRef<T> at(size_t pos) {
36 if (slices[pos].first == static_cast<size_t>(-1))
37 return MutableArrayRef<T>();
38 return MutableArrayRef<T>(storage).slice(slices[pos].first,
39 slices[pos].second);
40 }
41
42 /// Iterator over the rows.
43 class iterator
44 : public llvm::iterator_facade_base<
45 iterator, std::forward_iterator_tag, MutableArrayRef<T>,
46 std::ptrdiff_t, MutableArrayRef<T> *, MutableArrayRef<T>> {
47 public:
48 /// Creates the start iterator.
49 explicit iterator(RaggedArray &ragged) : ragged(ragged), pos(0) {}
50
51 /// Creates the end iterator.
52 iterator(RaggedArray &ragged, size_t pos) : ragged(ragged), pos(pos) {}
53
54 /// Dereferences the current iterator. Assumes in-bounds.
55 MutableArrayRef<T> operator*() const { return ragged[pos]; }
56
57 /// Increments the iterator.
58 iterator &operator++() {
59 if (pos < ragged.slices.size())
60 ++pos;
61 return *this;
62 }
63
64 /// Compares the two iterators. Iterators into different ragged arrays
65 /// compare not equal.
66 bool operator==(const iterator &other) const {
67 return &ragged == &other.ragged && pos == other.pos;
68 }
69
70 private:
71 RaggedArray &ragged;
72 size_t pos;
73 };
74
75 /// Constant iterator over the rows.
76 class const_iterator
77 : public llvm::iterator_facade_base<
78 const_iterator, std::forward_iterator_tag, ArrayRef<T>,
79 std::ptrdiff_t, ArrayRef<T> *, ArrayRef<T>> {
80 public:
81 /// Creates the start iterator.
82 explicit const_iterator(const RaggedArray &ragged)
83 : ragged(ragged), pos(0) {}
84
85 /// Creates the end iterator.
86 const_iterator(const RaggedArray &ragged, size_t pos)
87 : ragged(ragged), pos(pos) {}
88
89 /// Dereferences the current iterator. Assumes in-bounds.
90 ArrayRef<T> operator*() const { return ragged[pos]; }
91
92 /// Increments the iterator.
93 const_iterator &operator++() {
94 if (pos < ragged.slices.size())
95 ++pos;
96 return *this;
97 }
98
99 /// Compares the two iterators. Iterators into different ragged arrays
100 /// compare not equal.
101 bool operator==(const const_iterator &other) const {
102 return &ragged == &other.ragged && pos == other.pos;
103 }
104
105 private:
106 const RaggedArray &ragged;
107 size_t pos;
108 };
109
110 /// Iterator over rows.
111 const_iterator begin() const { return const_iterator(*this); }
112 const_iterator end() const { return const_iterator(*this, slices.size()); }
113 iterator begin() { return iterator(*this); }
114 iterator end() { return iterator(*this, slices.size()); }
115
116 /// Reserve space to store `size` rows with `nestedSize` elements each.
117 void reserve(size_t size, size_t nestedSize = 0) {
118 slices.reserve(N: size);
119 storage.reserve(size * nestedSize);
120 }
121
122 /// Appends the given range of elements as a new row to the 2D array. May
123 /// invalidate the end iterator.
124 template <typename Range>
125 void push_back(Range &&elements) {
126 slices.push_back(Elt: appendToStorage(std::forward<Range>(elements)));
127 }
128
129 /// Replaces the `pos`-th row in the 2D array with the given range of
130 /// elements. Invalidates iterators and references to `pos`-th and all
131 /// succeeding rows.
132 template <typename Range>
133 void replace(size_t pos, Range &&elements) {
134 if (slices[pos].first != static_cast<size_t>(-1)) {
135 auto from = std::next(storage.begin(), slices[pos].first);
136 auto to = std::next(from, slices[pos].second);
137 auto newFrom = storage.erase(from, to);
138 // Update the array refs after the underlying storage was shifted.
139 for (size_t i = pos + 1, e = size(); i < e; ++i) {
140 slices[i] = std::make_pair(std::distance(storage.begin(), newFrom),
141 slices[i].second);
142 std::advance(newFrom, slices[i].second);
143 }
144 }
145 slices[pos] = appendToStorage(std::forward<Range>(elements));
146 }
147
148 /// Appends `num` empty rows to the array.
149 void appendEmptyRows(size_t num) {
150 slices.resize(N: slices.size() + num, NV: std::pair<size_t, size_t>(-1, 0));
151 }
152
153 /// Removes the first subarray in-place. Invalidates iterators to all rows.
154 void removeFront() { slices.erase(CI: slices.begin()); }
155
156private:
157 /// Appends the given elements to the storage and returns an ArrayRef
158 /// pointing to them in the storage.
159 template <typename Range>
160 std::pair<size_t, size_t> appendToStorage(Range &&elements) {
161 size_t start = storage.size();
162 llvm::append_range(storage, std::forward<Range>(elements));
163 return std::make_pair(start, storage.size() - start);
164 }
165
166 /// Outer elements of the ragged array. Each entry is an (offset, length)
167 /// pair identifying a contiguous segment in the `storage` list that
168 /// contains the actual elements. This allows for elements to be stored
169 /// contiguously without nested vectors and for different segments to be set
170 /// or replaced in any order.
171 SmallVector<std::pair<size_t, size_t>> slices;
172
173 /// Dense storage for ragged array elements.
174 SmallVector<T> storage;
175};
176} // namespace mlir
177

source code of mlir/include/mlir/Dialect/Transform/Utils/RaggedArray.h