1 | //===- SparseTensorDialect.cpp - Sparse tensor dialect implementation -----===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include <utility> |
10 | |
11 | #include "Detail/DimLvlMapParser.h" |
12 | |
13 | #include "mlir/Dialect/SparseTensor/IR/Enums.h" |
14 | #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" |
15 | #include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h" |
16 | #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" |
17 | |
18 | #include "mlir/Dialect/Arith/IR/Arith.h" |
19 | #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" |
20 | #include "mlir/Dialect/Complex/IR/Complex.h" |
21 | #include "mlir/Dialect/Utils/StaticValueUtils.h" |
22 | #include "mlir/IR/Builders.h" |
23 | #include "mlir/IR/DialectImplementation.h" |
24 | #include "mlir/IR/Matchers.h" |
25 | #include "mlir/IR/OpImplementation.h" |
26 | #include "mlir/IR/PatternMatch.h" |
27 | #include "llvm/ADT/Bitset.h" |
28 | #include "llvm/ADT/TypeSwitch.h" |
29 | #include "llvm/Support/FormatVariadic.h" |
30 | |
31 | #define GET_ATTRDEF_CLASSES |
32 | #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc" |
33 | #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrEnums.cpp.inc" |
34 | |
35 | // Forward declarations, following custom print/parsing methods are referenced |
36 | // by the generated code for SparseTensorTypes.td. |
37 | static mlir::ParseResult parseLevelRange(mlir::AsmParser &, |
38 | mlir::sparse_tensor::Level &, |
39 | mlir::sparse_tensor::Level &); |
40 | static void printLevelRange(mlir::AsmPrinter &, mlir::sparse_tensor::Level, |
41 | mlir::sparse_tensor::Level); |
42 | |
43 | #define GET_TYPEDEF_CLASSES |
44 | #include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc" |
45 | |
46 | using namespace mlir; |
47 | using namespace mlir::sparse_tensor; |
48 | |
49 | // Support hashing LevelType such that SparseTensorEncodingAttr can be hashed as |
50 | // well. |
51 | namespace mlir::sparse_tensor { |
52 | llvm::hash_code hash_value(LevelType lt) { |
53 | return llvm::hash_value(static_cast<uint64_t>(lt)); |
54 | } |
55 | } // namespace mlir::sparse_tensor |
56 | |
57 | //===----------------------------------------------------------------------===// |
58 | // Local Convenience Methods. |
59 | //===----------------------------------------------------------------------===// |
60 | |
61 | static constexpr bool acceptBitWidth(unsigned bitWidth) { |
62 | switch (bitWidth) { |
63 | case 0: |
64 | case 8: |
65 | case 16: |
66 | case 32: |
67 | case 64: |
68 | return true; |
69 | default: |
70 | return false; |
71 | } |
72 | } |
73 | |
74 | static SmallVector<Size> |
75 | getSparseFieldShape(const SparseTensorEncodingAttr enc, |
76 | std::optional<ArrayRef<int64_t>> dimShape) { |
77 | assert(enc); |
78 | // With only encoding, we can not determine the static shape for leading |
79 | // batch levels, we therefore return a dynamic shape memref instead. |
80 | SmallVector<int64_t> memrefShape(enc.getBatchLvlRank(), ShapedType::kDynamic); |
81 | if (dimShape.has_value()) { |
82 | // If the actual tensor shape is provided, we can then refine the leading |
83 | // batch dimension. |
84 | SmallVector<int64_t> lvlShape = |
85 | enc.translateShape(*dimShape, CrdTransDirectionKind::dim2lvl); |
86 | memrefShape.assign(lvlShape.begin(), |
87 | lvlShape.begin() + enc.getBatchLvlRank()); |
88 | } |
89 | // Another dynamic dimension to store the sparse level. |
90 | memrefShape.push_back(ShapedType::kDynamic); |
91 | return memrefShape; |
92 | } |
93 | |
94 | //===----------------------------------------------------------------------===// |
95 | // SparseTensorDialect StorageLayout. |
96 | //===----------------------------------------------------------------------===// |
97 | |
98 | static constexpr Level kInvalidLevel = -1u; |
99 | static constexpr Level kInvalidFieldIndex = -1u; |
100 | static constexpr FieldIndex kDataFieldStartingIdx = 0; |
101 | |
102 | void StorageLayout::foreachField( |
103 | llvm::function_ref<bool(FieldIndex, SparseTensorFieldKind, Level, |
104 | LevelType)> |
105 | callback) const { |
106 | const auto lvlTypes = enc.getLvlTypes(); |
107 | const Level lvlRank = enc.getLvlRank(); |
108 | SmallVector<COOSegment> cooSegs = enc.getCOOSegments(); |
109 | FieldIndex fieldIdx = kDataFieldStartingIdx; |
110 | |
111 | ArrayRef cooSegsRef = cooSegs; |
112 | // Per-level storage. |
113 | for (Level l = 0; l < lvlRank; /*l += 1 or l += AoSCooLen*/) { |
114 | const auto lt = lvlTypes[l]; |
115 | if (isWithPosLT(lt)) { |
116 | if (!(callback(fieldIdx++, SparseTensorFieldKind::PosMemRef, l, lt))) |
117 | return; |
118 | } |
119 | if (isWithCrdLT(lt)) { |
120 | if (!(callback(fieldIdx++, SparseTensorFieldKind::CrdMemRef, l, lt))) |
121 | return; |
122 | } |
123 | if (!cooSegsRef.empty() && cooSegsRef.front().isSegmentStart(l)) { |
124 | if (!cooSegsRef.front().isSoA) { |
125 | // AoS COO, all singletons are fused into one memrefs. Skips the entire |
126 | // COO segement. |
127 | l = cooSegsRef.front().lvlRange.second; |
128 | } else { |
129 | // SoA COO, each singleton level has one memref. |
130 | l++; |
131 | } |
132 | // Expire handled COO segment. |
133 | cooSegsRef = cooSegsRef.drop_front(); |
134 | } else { |
135 | // Non COO levels. |
136 | l++; |
137 | } |
138 | } |
139 | // The values array. |
140 | if (!(callback(fieldIdx++, SparseTensorFieldKind::ValMemRef, kInvalidLevel, |
141 | LevelFormat::Undef))) |
142 | return; |
143 | // Put metadata at the end. |
144 | if (!(callback(fieldIdx++, SparseTensorFieldKind::StorageSpec, kInvalidLevel, |
145 | LevelFormat::Undef))) |
146 | return; |
147 | } |
148 | |
149 | void sparse_tensor::foreachFieldAndTypeInSparseTensor( |
150 | SparseTensorType stt, |
151 | llvm::function_ref<bool(Type, FieldIndex, SparseTensorFieldKind, Level, |
152 | LevelType)> |
153 | callback) { |
154 | assert(stt.hasEncoding()); |
155 | |
156 | SmallVector<int64_t> memrefShape = |
157 | getSparseFieldShape(stt.getEncoding(), stt.getDimShape()); |
158 | |
159 | const Type specType = StorageSpecifierType::get(stt.getEncoding()); |
160 | // memref<[batch] x ? x pos> positions |
161 | const Type posMemType = MemRefType::get(memrefShape, stt.getPosType()); |
162 | // memref<[batch] x ? x crd> coordinates |
163 | const Type crdMemType = MemRefType::get(memrefShape, stt.getCrdType()); |
164 | // memref<[batch] x ? x eltType> values |
165 | const Type valMemType = MemRefType::get(memrefShape, stt.getElementType()); |
166 | |
167 | StorageLayout(stt).foreachField(callback: [specType, posMemType, crdMemType, valMemType, |
168 | callback](FieldIndex fieldIdx, |
169 | SparseTensorFieldKind fieldKind, |
170 | Level lvl, LevelType lt) -> bool { |
171 | switch (fieldKind) { |
172 | case SparseTensorFieldKind::StorageSpec: |
173 | return callback(specType, fieldIdx, fieldKind, lvl, lt); |
174 | case SparseTensorFieldKind::PosMemRef: |
175 | return callback(posMemType, fieldIdx, fieldKind, lvl, lt); |
176 | case SparseTensorFieldKind::CrdMemRef: |
177 | return callback(crdMemType, fieldIdx, fieldKind, lvl, lt); |
178 | case SparseTensorFieldKind::ValMemRef: |
179 | return callback(valMemType, fieldIdx, fieldKind, lvl, lt); |
180 | }; |
181 | llvm_unreachable("unrecognized field kind"); |
182 | }); |
183 | } |
184 | |
185 | unsigned StorageLayout::getNumFields() const { |
186 | unsigned numFields = 0; |
187 | foreachField(callback: [&numFields](FieldIndex, SparseTensorFieldKind, Level, |
188 | LevelType) -> bool { |
189 | numFields++; |
190 | return true; |
191 | }); |
192 | return numFields; |
193 | } |
194 | |
195 | unsigned StorageLayout::getNumDataFields() const { |
196 | unsigned numFields = 0; // one value memref |
197 | foreachField(callback: [&numFields](FieldIndex fidx, SparseTensorFieldKind, Level, |
198 | LevelType) -> bool { |
199 | if (fidx >= kDataFieldStartingIdx) |
200 | numFields++; |
201 | return true; |
202 | }); |
203 | numFields -= 1; // the last field is StorageSpecifier |
204 | assert(numFields == getNumFields() - kDataFieldStartingIdx - 1); |
205 | return numFields; |
206 | } |
207 | |
208 | std::pair<FieldIndex, unsigned> |
209 | StorageLayout::getFieldIndexAndStride(SparseTensorFieldKind kind, |
210 | std::optional<Level> lvl) const { |
211 | FieldIndex fieldIdx = kInvalidFieldIndex; |
212 | unsigned stride = 1; |
213 | if (kind == SparseTensorFieldKind::CrdMemRef) { |
214 | assert(lvl.has_value()); |
215 | const Level cooStart = enc.getAoSCOOStart(); |
216 | const Level lvlRank = enc.getLvlRank(); |
217 | if (lvl.value() >= cooStart && lvl.value() < lvlRank) { |
218 | lvl = cooStart; |
219 | stride = lvlRank - cooStart; |
220 | } |
221 | } |
222 | foreachField(callback: [lvl, kind, &fieldIdx](FieldIndex fIdx, |
223 | SparseTensorFieldKind fKind, Level fLvl, |
224 | LevelType lt) -> bool { |
225 | if ((lvl && fLvl == lvl.value() && kind == fKind) || |
226 | (kind == fKind && fKind == SparseTensorFieldKind::ValMemRef)) { |
227 | fieldIdx = fIdx; |
228 | // Returns false to break the iteration. |
229 | return false; |
230 | } |
231 | return true; |
232 | }); |
233 | assert(fieldIdx != kInvalidFieldIndex); |
234 | return std::pair<FieldIndex, unsigned>(fieldIdx, stride); |
235 | } |
236 | |
237 | //===----------------------------------------------------------------------===// |
238 | // SparseTensorDialect Attribute Methods. |
239 | //===----------------------------------------------------------------------===// |
240 | |
241 | std::optional<uint64_t> SparseTensorDimSliceAttr::getStatic(int64_t v) { |
242 | return isDynamic(v) ? std::nullopt |
243 | : std::make_optional(static_cast<uint64_t>(v)); |
244 | } |
245 | |
246 | std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticOffset() const { |
247 | return getStatic(getOffset()); |
248 | } |
249 | |
250 | std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticStride() const { |
251 | return getStatic(getStride()); |
252 | } |
253 | |
254 | std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticSize() const { |
255 | return getStatic(getSize()); |
256 | } |
257 | |
258 | bool SparseTensorDimSliceAttr::isCompletelyDynamic() const { |
259 | return isDynamic(getOffset()) && isDynamic(getStride()) && |
260 | isDynamic(getSize()); |
261 | } |
262 | |
263 | std::string SparseTensorDimSliceAttr::getStaticString(int64_t v) { |
264 | return isDynamic(v) ? "?": std::to_string(v); |
265 | } |
266 | |
267 | void SparseTensorDimSliceAttr::print(llvm::raw_ostream &os) const { |
268 | assert(getImpl() && "Uninitialized SparseTensorDimSliceAttr"); |
269 | os << '('; |
270 | os << getStaticString(getOffset()); |
271 | os << ", "; |
272 | os << getStaticString(getSize()); |
273 | os << ", "; |
274 | os << getStaticString(getStride()); |
275 | os << ')'; |
276 | } |
277 | |
278 | void SparseTensorDimSliceAttr::print(AsmPrinter &printer) const { |
279 | print(printer.getStream()); |
280 | } |
281 | |
282 | static ParseResult parseOptionalStaticSlice(int64_t &result, |
283 | AsmParser &parser) { |
284 | auto parseResult = parser.parseOptionalInteger(result); |
285 | if (parseResult.has_value()) { |
286 | if (parseResult.value().succeeded() && result < 0) { |
287 | parser.emitError( |
288 | loc: parser.getCurrentLocation(), |
289 | message: "expect positive value or ? for slice offset/size/stride"); |
290 | return failure(); |
291 | } |
292 | return parseResult.value(); |
293 | } |
294 | |
295 | // Else, and '?' which represented dynamic slice |
296 | result = SparseTensorDimSliceAttr::kDynamic; |
297 | return parser.parseQuestion(); |
298 | } |
299 | |
300 | Attribute SparseTensorDimSliceAttr::parse(AsmParser &parser, Type type) { |
301 | int64_t offset = kDynamic, size = kDynamic, stride = kDynamic; |
302 | |
303 | if (failed(parser.parseLParen()) || |
304 | failed(parseOptionalStaticSlice(offset, parser)) || |
305 | failed(parser.parseComma()) || |
306 | failed(parseOptionalStaticSlice(size, parser)) || |
307 | failed(parser.parseComma()) || |
308 | failed(parseOptionalStaticSlice(stride, parser)) || |
309 | failed(parser.parseRParen())) |
310 | return {}; |
311 | |
312 | return parser.getChecked<SparseTensorDimSliceAttr>(parser.getContext(), |
313 | offset, size, stride); |
314 | } |
315 | |
316 | LogicalResult |
317 | SparseTensorDimSliceAttr::verify(function_ref<InFlightDiagnostic()> emitError, |
318 | int64_t offset, int64_t size, int64_t stride) { |
319 | if (!isDynamic(offset) && offset < 0) |
320 | return emitError() << "expect non-negative value or ? for slice offset"; |
321 | if (!isDynamic(size) && size <= 0) |
322 | return emitError() << "expect positive value or ? for slice size"; |
323 | if (!isDynamic(stride) && stride <= 0) |
324 | return emitError() << "expect positive value or ? for slice stride"; |
325 | return success(); |
326 | } |
327 | |
328 | SparseTensorEncodingAttr |
329 | SparseTensorEncodingAttr::withDimToLvl(AffineMap dimToLvl) const { |
330 | assert(getImpl() && "Uninitialized SparseTensorEncodingAttr"); |
331 | return SparseTensorEncodingAttr::get( |
332 | getContext(), getLvlTypes(), dimToLvl, AffineMap(), getPosWidth(), |
333 | getCrdWidth(), getExplicitVal(), getImplicitVal()); |
334 | } |
335 | |
336 | SparseTensorEncodingAttr |
337 | SparseTensorEncodingAttr::withDimToLvl(SparseTensorEncodingAttr enc) const { |
338 | return withDimToLvl(enc ? enc.getDimToLvl() : AffineMap()); |
339 | } |
340 | |
341 | SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimToLvl() const { |
342 | return withDimToLvl(AffineMap()); |
343 | } |
344 | |
345 | SparseTensorEncodingAttr |
346 | SparseTensorEncodingAttr::withBitWidths(unsigned posWidth, |
347 | unsigned crdWidth) const { |
348 | assert(getImpl() && "Uninitialized SparseTensorEncodingAttr"); |
349 | return SparseTensorEncodingAttr::get( |
350 | getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), posWidth, |
351 | crdWidth, getExplicitVal(), getImplicitVal()); |
352 | } |
353 | |
354 | SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutBitWidths() const { |
355 | return withBitWidths(0, 0); |
356 | } |
357 | |
358 | SparseTensorEncodingAttr |
359 | SparseTensorEncodingAttr::withExplicitVal(Attribute explicitVal) const { |
360 | assert(getImpl() && "Uninitialized SparseTensorEncodingAttr"); |
361 | return SparseTensorEncodingAttr::get( |
362 | getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), getPosWidth(), |
363 | getCrdWidth(), explicitVal, getImplicitVal()); |
364 | } |
365 | |
366 | SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutExplicitVal() const { |
367 | return withExplicitVal(Attribute()); |
368 | } |
369 | |
370 | SparseTensorEncodingAttr |
371 | SparseTensorEncodingAttr::withImplicitVal(Attribute implicitVal) const { |
372 | assert(getImpl() && "Uninitialized SparseTensorEncodingAttr"); |
373 | return SparseTensorEncodingAttr::get( |
374 | getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), getPosWidth(), |
375 | getCrdWidth(), getExplicitVal(), implicitVal); |
376 | } |
377 | |
378 | SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutImplicitVal() const { |
379 | return withImplicitVal(Attribute()); |
380 | } |
381 | |
382 | SparseTensorEncodingAttr SparseTensorEncodingAttr::withDimSlices( |
383 | ArrayRef<SparseTensorDimSliceAttr> dimSlices) const { |
384 | return SparseTensorEncodingAttr::get( |
385 | getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), getPosWidth(), |
386 | getCrdWidth(), getExplicitVal(), getImplicitVal(), dimSlices); |
387 | } |
388 | |
389 | SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimSlices() const { |
390 | return withDimSlices(ArrayRef<SparseTensorDimSliceAttr>{}); |
391 | } |
392 | |
393 | uint64_t SparseTensorEncodingAttr::getBatchLvlRank() const { |
394 | ArrayRef<LevelType> lvlTypes = getLvlTypes(); |
395 | auto lastBatch = std::find_if(lvlTypes.rbegin(), lvlTypes.rend(), isBatchLT); |
396 | return std::distance(lastBatch, lvlTypes.rend()); |
397 | } |
398 | |
399 | bool SparseTensorEncodingAttr::isAllDense() const { |
400 | return !getImpl() || llvm::all_of(getLvlTypes(), isDenseLT); |
401 | } |
402 | |
403 | bool SparseTensorEncodingAttr::isAllOrdered() const { |
404 | return !getImpl() || llvm::all_of(getLvlTypes(), isOrderedLT); |
405 | } |
406 | |
407 | Type SparseTensorEncodingAttr::getCrdElemType() const { |
408 | if (!getImpl()) |
409 | return nullptr; |
410 | if (getCrdWidth()) |
411 | return IntegerType::get(getContext(), getCrdWidth()); |
412 | return IndexType::get(getContext()); |
413 | } |
414 | |
415 | Type SparseTensorEncodingAttr::getPosElemType() const { |
416 | if (!getImpl()) |
417 | return nullptr; |
418 | if (getPosWidth()) |
419 | return IntegerType::get(getContext(), getPosWidth()); |
420 | return IndexType::get(getContext()); |
421 | } |
422 | |
423 | MemRefType SparseTensorEncodingAttr::getCrdMemRefType( |
424 | std::optional<ArrayRef<int64_t>> dimShape) const { |
425 | SmallVector<Size> shape = getSparseFieldShape(*this, dimShape); |
426 | return MemRefType::get(shape, getCrdElemType()); |
427 | } |
428 | |
429 | MemRefType SparseTensorEncodingAttr::getPosMemRefType( |
430 | std::optional<ArrayRef<int64_t>> dimShape) const { |
431 | SmallVector<Size> shape = getSparseFieldShape(*this, dimShape); |
432 | return MemRefType::get(shape, getPosElemType()); |
433 | } |
434 | |
435 | bool SparseTensorEncodingAttr::isIdentity() const { |
436 | return !getImpl() || !getDimToLvl() || getDimToLvl().isIdentity(); |
437 | } |
438 | |
439 | bool SparseTensorEncodingAttr::isPermutation() const { |
440 | return !getImpl() || !getDimToLvl() || getDimToLvl().isPermutation(); |
441 | } |
442 | |
443 | Dimension SparseTensorEncodingAttr::getDimRank() const { |
444 | assert(getImpl() && "Uninitialized SparseTensorEncodingAttr"); |
445 | const auto dimToLvl = getDimToLvl(); |
446 | return dimToLvl ? dimToLvl.getNumDims() : getLvlRank(); |
447 | } |
448 | |
449 | Level SparseTensorEncodingAttr::getLvlRank() const { |
450 | assert(getImpl() && "Uninitialized SparseTensorEncodingAttr"); |
451 | return getLvlTypes().size(); |
452 | } |
453 | |
454 | LevelType SparseTensorEncodingAttr::getLvlType(Level l) const { |
455 | if (!getImpl()) |
456 | return LevelFormat::Batch; |
457 | assert(l < getLvlRank() && "Level is out of bounds"); |
458 | return getLvlTypes()[l]; |
459 | } |
460 | |
461 | bool SparseTensorEncodingAttr::isSlice() const { |
462 | assert(getImpl() && "Uninitialized SparseTensorEncodingAttr"); |
463 | return !getDimSlices().empty(); |
464 | } |
465 | |
466 | SparseTensorDimSliceAttr |
467 | SparseTensorEncodingAttr::getDimSlice(Dimension dim) const { |
468 | assert(isSlice() && "Is not a slice"); |
469 | const auto dimSlices = getDimSlices(); |
470 | assert(dim < dimSlices.size() && "Dimension is out of bounds"); |
471 | return dimSlices[dim]; |
472 | } |
473 | |
474 | std::optional<uint64_t> |
475 | SparseTensorEncodingAttr::getStaticDimSliceOffset(Dimension dim) const { |
476 | return getDimSlice(dim).getStaticOffset(); |
477 | } |
478 | |
479 | std::optional<uint64_t> |
480 | SparseTensorEncodingAttr::getStaticDimSliceStride(Dimension dim) const { |
481 | return getDimSlice(dim).getStaticStride(); |
482 | } |
483 | |
484 | std::optional<uint64_t> |
485 | SparseTensorEncodingAttr::getStaticLvlSliceOffset(Level lvl) const { |
486 | return getStaticDimSliceOffset(toDim(*this, lvl)); |
487 | } |
488 | |
489 | std::optional<uint64_t> |
490 | SparseTensorEncodingAttr::getStaticLvlSliceStride(Level lvl) const { |
491 | return getStaticDimSliceStride(toDim(*this, lvl)); |
492 | } |
493 | |
494 | SmallVector<int64_t> |
495 | SparseTensorEncodingAttr::translateShape(ArrayRef<int64_t> srcShape, |
496 | CrdTransDirectionKind dir) const { |
497 | if (isIdentity()) |
498 | return SmallVector<int64_t>(srcShape); |
499 | |
500 | SmallVector<int64_t> ret; |
501 | unsigned rank = |
502 | dir == CrdTransDirectionKind::dim2lvl ? getLvlRank() : getDimRank(); |
503 | ret.reserve(rank); |
504 | |
505 | if (isPermutation()) { |
506 | for (unsigned r = 0; r < rank; r++) { |
507 | unsigned trans = dir == CrdTransDirectionKind::dim2lvl ? toDim(*this, r) |
508 | : toLvl(*this, r); |
509 | ret.push_back(srcShape[trans]); |
510 | } |
511 | return ret; |
512 | } |
513 | |
514 | // Handle non-permutation maps. |
515 | AffineMap transMap = |
516 | dir == CrdTransDirectionKind::dim2lvl ? getDimToLvl() : getLvlToDim(); |
517 | |
518 | SmallVector<AffineExpr> dimRep; |
519 | dimRep.reserve(srcShape.size()); |
520 | for (int64_t sz : srcShape) { |
521 | if (!ShapedType::isDynamic(sz)) { |
522 | // Push back the max coordinate for the given dimension/level size. |
523 | dimRep.push_back(getAffineConstantExpr(sz - 1, getContext())); |
524 | } else { |
525 | // A dynamic size, use a AffineDimExpr to symbolize the value. |
526 | dimRep.push_back(getAffineDimExpr(dimRep.size(), getContext())); |
527 | } |
528 | }; |
529 | |
530 | for (AffineExpr exp : transMap.getResults()) { |
531 | // Do constant propagation on the affine map. |
532 | AffineExpr evalExp = |
533 | simplifyAffineExpr(exp.replaceDims(dimRep), srcShape.size(), 0); |
534 | // use llvm namespace here to avoid ambiguity |
535 | if (auto c = llvm::dyn_cast<AffineConstantExpr>(evalExp)) { |
536 | ret.push_back(c.getValue() + 1); |
537 | } else { |
538 | if (auto mod = llvm::dyn_cast<AffineBinaryOpExpr>(evalExp); |
539 | mod && mod.getKind() == AffineExprKind::Mod) { |
540 | // We can still infer a static bound for expressions in form |
541 | // "d % constant" since d % constant \in [0, constant). |
542 | if (auto bound = llvm::dyn_cast<AffineConstantExpr>(mod.getRHS())) { |
543 | ret.push_back(bound.getValue()); |
544 | continue; |
545 | } |
546 | } |
547 | ret.push_back(ShapedType::kDynamic); |
548 | } |
549 | } |
550 | assert(ret.size() == rank); |
551 | return ret; |
552 | } |
553 | |
554 | ValueRange |
555 | SparseTensorEncodingAttr::translateCrds(OpBuilder &builder, Location loc, |
556 | ValueRange crds, |
557 | CrdTransDirectionKind dir) const { |
558 | if (!getImpl()) |
559 | return crds; |
560 | |
561 | SmallVector<Type> retType( |
562 | dir == CrdTransDirectionKind::lvl2dim ? getDimRank() : getLvlRank(), |
563 | builder.getIndexType()); |
564 | auto transOp = builder.create<CrdTranslateOp>(loc, retType, crds, dir, *this); |
565 | return transOp.getOutCrds(); |
566 | } |
567 | |
568 | Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) { |
569 | // Open "<{" part. |
570 | if (failed(parser.parseLess())) |
571 | return {}; |
572 | if (failed(parser.parseLBrace())) |
573 | return {}; |
574 | |
575 | // Process the data from the parsed dictionary value into struct-like data. |
576 | SmallVector<LevelType> lvlTypes; |
577 | SmallVector<SparseTensorDimSliceAttr> dimSlices; |
578 | AffineMap dimToLvl = {}; |
579 | AffineMap lvlToDim = {}; |
580 | unsigned posWidth = 0; |
581 | unsigned crdWidth = 0; |
582 | Attribute explicitVal; |
583 | Attribute implicitVal; |
584 | StringRef attrName; |
585 | SmallVector<StringRef, 5> keys = {"map", "posWidth", "crdWidth", |
586 | "explicitVal", "implicitVal"}; |
587 | while (succeeded(parser.parseOptionalKeyword(&attrName))) { |
588 | // Detect admissible keyword. |
589 | auto *it = find(keys, attrName); |
590 | if (it == keys.end()) { |
591 | parser.emitError(parser.getNameLoc(), "unexpected key: ") << attrName; |
592 | return {}; |
593 | } |
594 | unsigned keyWordIndex = it - keys.begin(); |
595 | // Consume the `=` after keys |
596 | if (failed(parser.parseEqual())) |
597 | return {}; |
598 | // Dispatch on keyword. |
599 | switch (keyWordIndex) { |
600 | case 0: { // map |
601 | ir_detail::DimLvlMapParser cParser(parser); |
602 | auto res = cParser.parseDimLvlMap(); |
603 | if (failed(res)) |
604 | return {}; |
605 | const auto &dlm = *res; |
606 | |
607 | const Level lvlRank = dlm.getLvlRank(); |
608 | for (Level lvl = 0; lvl < lvlRank; lvl++) |
609 | lvlTypes.push_back(dlm.getLvlType(lvl)); |
610 | |
611 | const Dimension dimRank = dlm.getDimRank(); |
612 | for (Dimension dim = 0; dim < dimRank; dim++) |
613 | dimSlices.push_back(dlm.getDimSlice(dim)); |
614 | // NOTE: the old syntax requires an all-or-nothing approach to |
615 | // `dimSlices`; therefore, if any slice actually exists then we need |
616 | // to convert null-DSA into default/nop DSA. |
617 | const auto isDefined = [](SparseTensorDimSliceAttr slice) { |
618 | return static_cast<bool>(slice.getImpl()); |
619 | }; |
620 | if (llvm::any_of(dimSlices, isDefined)) { |
621 | const auto defaultSlice = |
622 | SparseTensorDimSliceAttr::get(parser.getContext()); |
623 | for (Dimension dim = 0; dim < dimRank; dim++) |
624 | if (!isDefined(dimSlices[dim])) |
625 | dimSlices[dim] = defaultSlice; |
626 | } else { |
627 | dimSlices.clear(); |
628 | } |
629 | |
630 | dimToLvl = dlm.getDimToLvlMap(parser.getContext()); |
631 | lvlToDim = dlm.getLvlToDimMap(parser.getContext()); |
632 | break; |
633 | } |
634 | case 1: { // posWidth |
635 | Attribute attr; |
636 | if (failed(parser.parseAttribute(attr))) |
637 | return {}; |
638 | auto intAttr = llvm::dyn_cast<IntegerAttr>(attr); |
639 | if (!intAttr) { |
640 | parser.emitError(parser.getNameLoc(), |
641 | "expected an integral position bitwidth"); |
642 | return {}; |
643 | } |
644 | posWidth = intAttr.getInt(); |
645 | break; |
646 | } |
647 | case 2: { // crdWidth |
648 | Attribute attr; |
649 | if (failed(parser.parseAttribute(attr))) |
650 | return {}; |
651 | auto intAttr = llvm::dyn_cast<IntegerAttr>(attr); |
652 | if (!intAttr) { |
653 | parser.emitError(parser.getNameLoc(), |
654 | "expected an integral index bitwidth"); |
655 | return {}; |
656 | } |
657 | crdWidth = intAttr.getInt(); |
658 | break; |
659 | } |
660 | case 3: { // explicitVal |
661 | Attribute attr; |
662 | if (failed(parser.parseAttribute(attr))) |
663 | return {}; |
664 | if (auto result = llvm::dyn_cast<FloatAttr>(attr)) { |
665 | explicitVal = result; |
666 | } else if (auto result = llvm::dyn_cast<IntegerAttr>(attr)) { |
667 | explicitVal = result; |
668 | } else if (auto result = llvm::dyn_cast<complex::NumberAttr>(attr)) { |
669 | explicitVal = result; |
670 | } else { |
671 | parser.emitError(parser.getNameLoc(), |
672 | "expected a numeric value for explicitVal"); |
673 | return {}; |
674 | } |
675 | break; |
676 | } |
677 | case 4: { // implicitVal |
678 | Attribute attr; |
679 | if (failed(parser.parseAttribute(attr))) |
680 | return {}; |
681 | if (auto result = llvm::dyn_cast<FloatAttr>(attr)) { |
682 | implicitVal = result; |
683 | } else if (auto result = llvm::dyn_cast<IntegerAttr>(attr)) { |
684 | implicitVal = result; |
685 | } else if (auto result = llvm::dyn_cast<complex::NumberAttr>(attr)) { |
686 | implicitVal = result; |
687 | } else { |
688 | parser.emitError(parser.getNameLoc(), |
689 | "expected a numeric value for implicitVal"); |
690 | return {}; |
691 | } |
692 | break; |
693 | } |
694 | } // switch |
695 | // Only last item can omit the comma. |
696 | if (parser.parseOptionalComma().failed()) |
697 | break; |
698 | } |
699 | |
700 | // Close "}>" part. |
701 | if (failed(parser.parseRBrace())) |
702 | return {}; |
703 | if (failed(parser.parseGreater())) |
704 | return {}; |
705 | |
706 | // Construct struct-like storage for attribute. |
707 | if (!lvlToDim || lvlToDim.isEmpty()) { |
708 | lvlToDim = inferLvlToDim(dimToLvl, parser.getContext()); |
709 | } |
710 | return parser.getChecked<SparseTensorEncodingAttr>( |
711 | parser.getContext(), lvlTypes, dimToLvl, lvlToDim, posWidth, crdWidth, |
712 | explicitVal, implicitVal, dimSlices); |
713 | } |
714 | |
715 | void SparseTensorEncodingAttr::print(AsmPrinter &printer) const { |
716 | auto map = static_cast<AffineMap>(getDimToLvl()); |
717 | // Empty affine map indicates identity map |
718 | if (!map) |
719 | map = AffineMap::getMultiDimIdentityMap(getLvlTypes().size(), getContext()); |
720 | printer << "<{ map = "; |
721 | printSymbols(map, printer); |
722 | printer << '('; |
723 | printDimensions(map, printer, getDimSlices()); |
724 | printer << ") -> ("; |
725 | printLevels(map, printer, getLvlTypes()); |
726 | printer << ')'; |
727 | // Print remaining members only for non-default values. |
728 | if (getPosWidth()) |
729 | printer << ", posWidth = "<< getPosWidth(); |
730 | if (getCrdWidth()) |
731 | printer << ", crdWidth = "<< getCrdWidth(); |
732 | if (getExplicitVal()) { |
733 | printer << ", explicitVal = "<< getExplicitVal(); |
734 | } |
735 | if (getImplicitVal()) |
736 | printer << ", implicitVal = "<< getImplicitVal(); |
737 | printer << " }>"; |
738 | } |
739 | |
740 | void SparseTensorEncodingAttr::printSymbols(AffineMap &map, |
741 | AsmPrinter &printer) const { |
742 | if (map.getNumSymbols() == 0) |
743 | return; |
744 | printer << '['; |
745 | for (unsigned i = 0, n = map.getNumSymbols() - 1; i < n; i++) |
746 | printer << 's' << i << ", "; |
747 | if (map.getNumSymbols() >= 1) |
748 | printer << 's' << map.getNumSymbols() - 1; |
749 | printer << ']'; |
750 | } |
751 | |
752 | void SparseTensorEncodingAttr::printDimensions( |
753 | AffineMap &map, AsmPrinter &printer, |
754 | ArrayRef<SparseTensorDimSliceAttr> dimSlices) const { |
755 | if (!dimSlices.empty()) { |
756 | for (unsigned i = 0, n = map.getNumDims() - 1; i < n; i++) |
757 | printer << 'd' << i << " : "<< dimSlices[i] << ", "; |
758 | if (map.getNumDims() >= 1) { |
759 | printer << 'd' << map.getNumDims() - 1 << " : " |
760 | << dimSlices[map.getNumDims() - 1]; |
761 | } |
762 | } else { |
763 | for (unsigned i = 0, n = map.getNumDims() - 1; i < n; i++) |
764 | printer << 'd' << i << ", "; |
765 | if (map.getNumDims() >= 1) |
766 | printer << 'd' << map.getNumDims() - 1; |
767 | } |
768 | } |
769 | |
770 | void SparseTensorEncodingAttr::printLevels(AffineMap &map, AsmPrinter &printer, |
771 | ArrayRef<LevelType> lvlTypes) const { |
772 | for (unsigned i = 0, n = map.getNumResults() - 1; i < n; i++) { |
773 | map.getResult(i).print(printer.getStream()); |
774 | printer << " : "<< toMLIRString(lvlTypes[i]) << ", "; |
775 | } |
776 | if (map.getNumResults() >= 1) { |
777 | auto lastIndex = map.getNumResults() - 1; |
778 | map.getResult(lastIndex).print(printer.getStream()); |
779 | printer << " : "<< toMLIRString(lvlTypes[lastIndex]); |
780 | } |
781 | } |
782 | |
783 | LogicalResult SparseTensorEncodingAttr::verify( |
784 | function_ref<InFlightDiagnostic()> emitError, ArrayRef<LevelType> lvlTypes, |
785 | AffineMap dimToLvl, AffineMap lvlToDim, unsigned posWidth, |
786 | unsigned crdWidth, Attribute explicitVal, Attribute implicitVal, |
787 | ArrayRef<SparseTensorDimSliceAttr> dimSlices) { |
788 | if (!acceptBitWidth(posWidth)) |
789 | return emitError() << "unexpected position bitwidth: "<< posWidth; |
790 | if (!acceptBitWidth(crdWidth)) |
791 | return emitError() << "unexpected coordinate bitwidth: "<< crdWidth; |
792 | |
793 | // Verify every COO segment. |
794 | auto *it = llvm::find_if(lvlTypes, isSingletonLT); |
795 | while (it != lvlTypes.end()) { |
796 | if (it == lvlTypes.begin() || |
797 | !(it - 1)->isa<LevelFormat::Compressed, LevelFormat::LooseCompressed>()) |
798 | return emitError() << "expected compressed or loose_compressed level " |
799 | "before singleton level"; |
800 | |
801 | auto *curCOOEnd = std::find_if_not(it, lvlTypes.end(), isSingletonLT); |
802 | if (!std::all_of(it, curCOOEnd, isSingletonLT)) |
803 | return emitError() << "expected all singleton lvlTypes " |
804 | "following a singleton level"; |
805 | // We can potentially support mixed SoA/AoS singleton levels. |
806 | if (!std::all_of(it, curCOOEnd, [it](LevelType i) { |
807 | return it->isa<LevelPropNonDefault::SoA>() == |
808 | i.isa<LevelPropNonDefault::SoA>(); |
809 | })) { |
810 | return emitError() << "expected all singleton lvlTypes stored in the " |
811 | "same memory layout (SoA vs AoS)."; |
812 | } |
813 | it = std::find_if(curCOOEnd, lvlTypes.end(), isSingletonLT); |
814 | } |
815 | |
816 | auto lastBatch = std::find_if(lvlTypes.rbegin(), lvlTypes.rend(), isBatchLT); |
817 | if (!std::all_of(lastBatch, lvlTypes.rend(), isBatchLT)) |
818 | return emitError() << "Batch lvlType can only be leading levels."; |
819 | |
820 | // SoA property can only be applied on singleton level. |
821 | auto soaLvls = llvm::make_filter_range(lvlTypes, [](LevelType lt) { |
822 | return lt.isa<LevelPropNonDefault::SoA>(); |
823 | }); |
824 | if (llvm::any_of(soaLvls, [](LevelType lt) { |
825 | return !lt.isa<LevelFormat::Singleton>(); |
826 | })) { |
827 | return emitError() << "SoA is only applicable to singleton lvlTypes."; |
828 | } |
829 | |
830 | // TODO: audit formats that actually are supported by backend. |
831 | if (auto it = llvm::find_if(lvlTypes, isNOutOfMLT); |
832 | it != std::end(lvlTypes)) { |
833 | if (it != lvlTypes.end() - 1) |
834 | return emitError() << "expected n_out_of_m to be the last level type"; |
835 | if (!std::all_of(lvlTypes.begin(), it, isDenseLT)) |
836 | return emitError() << "expected all dense lvlTypes " |
837 | "before a n_out_of_m level"; |
838 | if (dimToLvl && (dimToLvl.getNumDims() != dimToLvl.getNumResults())) { |
839 | if (!isBlockSparsity(dimToLvl)) { |
840 | return emitError() |
841 | << "expected 1xm block structure for n_out_of_m level"; |
842 | } |
843 | auto sizes = getBlockSize(dimToLvl); |
844 | unsigned coefficient = 0; |
845 | for (const auto &elem : sizes) { |
846 | if (elem != 0) { |
847 | if (elem != coefficient && coefficient != 0) { |
848 | return emitError() << "expected only one blocked level " |
849 | "with the same coefficients"; |
850 | } |
851 | coefficient = elem; |
852 | } |
853 | } |
854 | if (coefficient != getM(*it)) { |
855 | return emitError() << "expected coeffiencts of Affine expressions " |
856 | "to be equal to m of n_out_of_m level"; |
857 | } |
858 | } |
859 | } |
860 | // Before we can check that the level-rank is consistent/coherent |
861 | // across all fields, we need to define it. The source-of-truth for |
862 | // the `getLvlRank` method is the length of the level-types array, |
863 | // since it must always be provided and have full rank; therefore we |
864 | // use that same source-of-truth here. |
865 | const Level lvlRank = lvlTypes.size(); |
866 | if (lvlRank == 0) |
867 | return emitError() << "expected a non-empty array for lvlTypes"; |
868 | // We save `dimRank` here because we'll also need it to verify `dimSlices`. |
869 | const Dimension dimRank = dimToLvl ? dimToLvl.getNumDims() : lvlRank; |
870 | if (dimToLvl) { |
871 | if (dimToLvl.getNumResults() != lvlRank) |
872 | return emitError() |
873 | << "level-rank mismatch between dimToLvl and lvlTypes: " |
874 | << dimToLvl.getNumResults() << " != "<< lvlRank; |
875 | auto inferRes = inferLvlToDim(dimToLvl, dimToLvl.getContext()); |
876 | // Symbols can't be inferred but are acceptable. |
877 | if (!inferRes && dimToLvl.getNumSymbols() == 0) |
878 | return emitError() << "failed to infer lvlToDim from dimToLvl"; |
879 | if (lvlToDim && (inferRes != lvlToDim)) |
880 | return emitError() << "expected lvlToDim to be an inverse of dimToLvl"; |
881 | if (dimRank > lvlRank) |
882 | return emitError() << "unexpected dimToLvl mapping from "<< dimRank |
883 | << " to "<< lvlRank; |
884 | } |
885 | if (!dimSlices.empty()) { |
886 | if (dimSlices.size() != dimRank) |
887 | return emitError() |
888 | << "dimension-rank mismatch between dimSlices and dimToLvl: " |
889 | << dimSlices.size() << " != "<< dimRank; |
890 | // Compiler support for `dimSlices` currently requires that the two |
891 | // ranks agree. (However, it does allow `dimToLvl` to be a permutation.) |
892 | if (dimRank != lvlRank) |
893 | return emitError() |
894 | << "dimSlices expected dimension-rank to match level-rank: " |
895 | << dimRank << " != "<< lvlRank; |
896 | } |
897 | return success(); |
898 | } |
899 | |
900 | LogicalResult SparseTensorEncodingAttr::verifyEncoding( |
901 | ArrayRef<Size> dimShape, Type elementType, |
902 | function_ref<InFlightDiagnostic()> emitError) const { |
903 | // Check structural integrity. In particular, this ensures that the |
904 | // level-rank is coherent across all the fields. |
905 | if (failed(verify(emitError, getLvlTypes(), getDimToLvl(), getLvlToDim(), |
906 | getPosWidth(), getCrdWidth(), getExplicitVal(), |
907 | getImplicitVal(), getDimSlices()))) |
908 | return failure(); |
909 | // Check integrity with tensor type specifics. In particular, we |
910 | // need only check that the dimension-rank of the tensor agrees with |
911 | // the dimension-rank of the encoding. |
912 | const Dimension dimRank = dimShape.size(); |
913 | if (dimRank == 0) |
914 | return emitError() << "expected non-scalar sparse tensor"; |
915 | if (getDimRank() != dimRank) |
916 | return emitError() |
917 | << "dimension-rank mismatch between encoding and tensor shape: " |
918 | << getDimRank() << " != "<< dimRank; |
919 | if (auto expVal = getExplicitVal()) { |
920 | Type attrType = llvm::dyn_cast<TypedAttr>(expVal).getType(); |
921 | if (attrType != elementType) { |
922 | return emitError() << "explicit value type mismatch between encoding and " |
923 | << "tensor element type: "<< attrType |
924 | << " != "<< elementType; |
925 | } |
926 | } |
927 | if (auto impVal = getImplicitVal()) { |
928 | Type attrType = llvm::dyn_cast<TypedAttr>(impVal).getType(); |
929 | if (attrType != elementType) { |
930 | return emitError() << "implicit value type mismatch between encoding and " |
931 | << "tensor element type: "<< attrType |
932 | << " != "<< elementType; |
933 | } |
934 | // Currently, we only support zero as the implicit value. |
935 | auto impFVal = llvm::dyn_cast<FloatAttr>(impVal); |
936 | auto impIntVal = llvm::dyn_cast<IntegerAttr>(impVal); |
937 | auto impComplexVal = llvm::dyn_cast<complex::NumberAttr>(impVal); |
938 | if ((impFVal && impFVal.getValue().isNonZero()) || |
939 | (impIntVal && !impIntVal.getValue().isZero()) || |
940 | (impComplexVal && (impComplexVal.getImag().isNonZero() || |
941 | impComplexVal.getReal().isNonZero()))) { |
942 | return emitError() << "implicit value must be zero"; |
943 | } |
944 | } |
945 | return success(); |
946 | } |
947 | |
948 | Level mlir::sparse_tensor::SparseTensorEncodingAttr::getAoSCOOStart() const { |
949 | SmallVector<COOSegment> coo = getCOOSegments(); |
950 | assert(coo.size() == 1 || coo.empty()); |
951 | if (!coo.empty() && coo.front().isAoS()) { |
952 | return coo.front().lvlRange.first; |
953 | } |
954 | return getLvlRank(); |
955 | } |
956 | |
957 | SmallVector<COOSegment> |
958 | mlir::sparse_tensor::SparseTensorEncodingAttr::getCOOSegments() const { |
959 | SmallVector<COOSegment> ret; |
960 | if (getLvlRank() <= 1) |
961 | return ret; |
962 | |
963 | ArrayRef<LevelType> lts = getLvlTypes(); |
964 | Level l = 0; |
965 | while (l < getLvlRank()) { |
966 | auto lt = lts[l]; |
967 | if (lt.isa<LevelFormat::Compressed, LevelFormat::LooseCompressed>()) { |
968 | auto cur = lts.begin() + l; |
969 | auto end = std::find_if(cur + 1, lts.end(), [](LevelType lt) { |
970 | return !lt.isa<LevelFormat::Singleton>(); |
971 | }); |
972 | unsigned cooLen = std::distance(cur, end); |
973 | if (cooLen > 1) { |
974 | // To support mixed SoA/AoS COO, we should break the segment when the |
975 | // storage scheme changes, for now we faithfully assume that all |
976 | // consecutive singleton levels have the same storage format as verified |
977 | // STEA. |
978 | ret.push_back(COOSegment{std::make_pair(l, l + cooLen), |
979 | lts[l + 1].isa<LevelPropNonDefault::SoA>()}); |
980 | } |
981 | l += cooLen; |
982 | } else { |
983 | l++; |
984 | } |
985 | } |
986 | return ret; |
987 | } |
988 | |
989 | //===----------------------------------------------------------------------===// |
990 | // SparseTensorType Methods. |
991 | //===----------------------------------------------------------------------===// |
992 | |
993 | bool mlir::sparse_tensor::SparseTensorType::isCOOType(Level startLvl, |
994 | bool isUnique) const { |
995 | if (!hasEncoding()) |
996 | return false; |
997 | if (!isCompressedLvl(l: startLvl) && !isLooseCompressedLvl(l: startLvl)) |
998 | return false; |
999 | for (Level l = startLvl + 1; l < lvlRank; ++l) |
1000 | if (!isSingletonLvl(l)) |
1001 | return false; |
1002 | // If isUnique is true, then make sure that the last level is unique, |
1003 | // that is, when lvlRank == 1, the only compressed level is unique, |
1004 | // and when lvlRank > 1, the last singleton is unique. |
1005 | return !isUnique || isUniqueLvl(l: lvlRank - 1); |
1006 | } |
1007 | |
1008 | RankedTensorType |
1009 | mlir::sparse_tensor::SparseTensorType::getCOOType(bool ordered) const { |
1010 | SmallVector<LevelType> lvlTypes; |
1011 | lvlTypes.reserve(N: lvlRank); |
1012 | // A non-unique compressed level at beginning (unless this is |
1013 | // also the last level, then it is unique). |
1014 | lvlTypes.push_back( |
1015 | Elt: *buildLevelType(lf: LevelFormat::Compressed, ordered, unique: lvlRank == 1)); |
1016 | if (lvlRank > 1) { |
1017 | // Followed by n-2 non-unique singleton levels. |
1018 | std::fill_n(std::back_inserter(x&: lvlTypes), lvlRank - 2, |
1019 | *buildLevelType(lf: LevelFormat::Singleton, ordered, unique: false)); |
1020 | // Ends by a unique singleton level. |
1021 | lvlTypes.push_back(Elt: *buildLevelType(lf: LevelFormat::Singleton, ordered, unique: true)); |
1022 | } |
1023 | auto enc = SparseTensorEncodingAttr::get( |
1024 | getContext(), lvlTypes, getDimToLvl(), getLvlToDim(), getPosWidth(), |
1025 | getCrdWidth(), getExplicitVal(), getImplicitVal()); |
1026 | return RankedTensorType::get(getDimShape(), getElementType(), enc); |
1027 | } |
1028 | |
1029 | //===----------------------------------------------------------------------===// |
1030 | // Convenience Methods. |
1031 | //===----------------------------------------------------------------------===// |
1032 | |
1033 | SparseTensorEncodingAttr |
1034 | mlir::sparse_tensor::getSparseTensorEncoding(Type type) { |
1035 | if (auto ttp = llvm::dyn_cast<RankedTensorType>(type)) |
1036 | return llvm::dyn_cast_or_null<SparseTensorEncodingAttr>(ttp.getEncoding()); |
1037 | if (auto mdtp = llvm::dyn_cast<StorageSpecifierType>(type)) |
1038 | return mdtp.getEncoding(); |
1039 | return nullptr; |
1040 | } |
1041 | |
1042 | AffineMap mlir::sparse_tensor::inferLvlToDim(AffineMap dimToLvl, |
1043 | MLIRContext *context) { |
1044 | auto map = static_cast<AffineMap>(dimToLvl); |
1045 | AffineMap lvlToDim; |
1046 | // Return an empty lvlToDim when inference is not successful. |
1047 | if (!map || map.getNumSymbols() != 0) { |
1048 | lvlToDim = AffineMap(); |
1049 | } else if (map.isPermutation()) { |
1050 | lvlToDim = inversePermutation(map); |
1051 | } else if (isBlockSparsity(dimToLvl: map)) { |
1052 | lvlToDim = inverseBlockSparsity(dimToLvl: map, context); |
1053 | } |
1054 | return lvlToDim; |
1055 | } |
1056 | |
1057 | AffineMap mlir::sparse_tensor::inverseBlockSparsity(AffineMap dimToLvl, |
1058 | MLIRContext *context) { |
1059 | SmallVector<AffineExpr> lvlExprs; |
1060 | auto numLvls = dimToLvl.getNumResults(); |
1061 | lvlExprs.reserve(N: numLvls); |
1062 | // lvlExprComponents stores information of the floordiv and mod operations |
1063 | // applied to the same dimension, so as to build the lvlToDim map. |
1064 | std::map<unsigned, SmallVector<AffineExpr, 3>> lvlExprComponents; |
1065 | for (unsigned i = 0, n = numLvls; i < n; i++) { |
1066 | auto result = dimToLvl.getResult(idx: i); |
1067 | if (auto binOp = dyn_cast<AffineBinaryOpExpr>(Val&: result)) { |
1068 | if (result.getKind() == AffineExprKind::FloorDiv) { |
1069 | // Position of the dimension in dimToLvl. |
1070 | auto pos = dyn_cast<AffineDimExpr>(Val: binOp.getLHS()).getPosition(); |
1071 | assert(lvlExprComponents.find(pos) == lvlExprComponents.end() && |
1072 | "expected only one floordiv for each dimension"); |
1073 | SmallVector<AffineExpr, 3> components; |
1074 | // Level variable for floordiv. |
1075 | components.push_back(Elt: getAffineDimExpr(position: i, context)); |
1076 | // Multiplier. |
1077 | components.push_back(Elt: binOp.getRHS()); |
1078 | // Map key is the position of the dimension. |
1079 | lvlExprComponents[pos] = components; |
1080 | } else if (result.getKind() == AffineExprKind::Mod) { |
1081 | auto pos = dyn_cast<AffineDimExpr>(Val: binOp.getLHS()).getPosition(); |
1082 | assert(lvlExprComponents.find(pos) != lvlExprComponents.end() && |
1083 | "expected floordiv before mod"); |
1084 | // Add level variable for mod to the same vector |
1085 | // of the corresponding floordiv. |
1086 | lvlExprComponents[pos].push_back(Elt: getAffineDimExpr(position: i, context)); |
1087 | } else { |
1088 | assert(false && "expected floordiv or mod"); |
1089 | } |
1090 | } else { |
1091 | lvlExprs.push_back(Elt: getAffineDimExpr(position: i, context)); |
1092 | } |
1093 | } |
1094 | // Build lvlExprs from lvlExprComponents. |
1095 | // For example, for il = i floordiv 2 and ii = i mod 2, the components |
1096 | // would be [il, 2, ii]. It could be used to build the AffineExpr |
1097 | // i = il * 2 + ii in lvlToDim. |
1098 | for (auto &components : lvlExprComponents) { |
1099 | assert(components.second.size() == 3 && |
1100 | "expected 3 components to build lvlExprs"); |
1101 | auto mulOp = getAffineBinaryOpExpr( |
1102 | kind: AffineExprKind::Mul, lhs: components.second[0], rhs: components.second[1]); |
1103 | auto addOp = |
1104 | getAffineBinaryOpExpr(kind: AffineExprKind::Add, lhs: mulOp, rhs: components.second[2]); |
1105 | lvlExprs.push_back(Elt: addOp); |
1106 | } |
1107 | return dimToLvl.get(dimCount: dimToLvl.getNumResults(), symbolCount: 0, results: lvlExprs, context); |
1108 | } |
1109 | |
1110 | SmallVector<unsigned> mlir::sparse_tensor::getBlockSize(AffineMap dimToLvl) { |
1111 | assert(isBlockSparsity(dimToLvl) && |
1112 | "expected dimToLvl to be block sparsity for calling getBlockSize"); |
1113 | SmallVector<unsigned> blockSize; |
1114 | for (auto result : dimToLvl.getResults()) { |
1115 | if (auto binOp = dyn_cast<AffineBinaryOpExpr>(Val&: result)) { |
1116 | if (result.getKind() == AffineExprKind::Mod) { |
1117 | blockSize.push_back( |
1118 | Elt: dyn_cast<AffineConstantExpr>(Val: binOp.getRHS()).getValue()); |
1119 | } |
1120 | } else { |
1121 | blockSize.push_back(Elt: 0); |
1122 | } |
1123 | } |
1124 | return blockSize; |
1125 | } |
1126 | |
1127 | bool mlir::sparse_tensor::isBlockSparsity(AffineMap dimToLvl) { |
1128 | if (!dimToLvl) |
1129 | return false; |
1130 | std::map<unsigned, int64_t> coeffientMap; |
1131 | bool hasBlock = false; |
1132 | for (auto result : dimToLvl.getResults()) { |
1133 | if (auto binOp = dyn_cast<AffineBinaryOpExpr>(Val&: result)) { |
1134 | // Check for "dim op const". |
1135 | auto dimOp = dyn_cast<AffineDimExpr>(Val: binOp.getLHS()); |
1136 | auto conOp = dyn_cast<AffineConstantExpr>(Val: binOp.getRHS()); |
1137 | if (!dimOp || !conOp || conOp.getValue() <= 0) |
1138 | return false; |
1139 | // Inspect "dim / const" or "dim % const". |
1140 | auto pos = dimOp.getPosition(); |
1141 | if (binOp.getKind() == AffineExprKind::FloorDiv) { |
1142 | // Expect only one floordiv for each dimension. |
1143 | auto [it, inserted] = coeffientMap.try_emplace(k: pos); |
1144 | if (!inserted) |
1145 | return false; |
1146 | // Record coefficient of the floordiv. |
1147 | it->second = conOp.getValue(); |
1148 | } else if (binOp.getKind() == AffineExprKind::Mod) { |
1149 | // Expect floordiv before mod. |
1150 | auto it = coeffientMap.find(x: pos); |
1151 | if (it == coeffientMap.end()) |
1152 | return false; |
1153 | // Expect mod to have the same coefficient as floordiv. |
1154 | if (conOp.getValue() != it->second) |
1155 | return false; |
1156 | hasBlock = true; |
1157 | } else { |
1158 | return false; |
1159 | } |
1160 | } else if (auto dimOp = dyn_cast<AffineDimExpr>(Val&: result)) { |
1161 | auto pos = dimOp.getPosition(); |
1162 | // Expect dim to be unset. |
1163 | if (!coeffientMap.try_emplace(k: pos, args: 0).second) |
1164 | return false; |
1165 | } else { |
1166 | return false; |
1167 | } |
1168 | } |
1169 | return hasBlock; |
1170 | } |
1171 | |
1172 | bool mlir::sparse_tensor::hasAnyNonIdentityOperandsOrResults(Operation *op) { |
1173 | auto hasNonIdentityMap = [](Value v) { |
1174 | auto stt = tryGetSparseTensorType(v); |
1175 | return stt && !stt->isIdentity(); |
1176 | }; |
1177 | |
1178 | return llvm::any_of(Range: op->getOperands(), P: hasNonIdentityMap) || |
1179 | llvm::any_of(Range: op->getResults(), P: hasNonIdentityMap); |
1180 | } |
1181 | |
1182 | Dimension mlir::sparse_tensor::toDim(SparseTensorEncodingAttr enc, Level l) { |
1183 | if (enc) { |
1184 | assert(enc.isPermutation() && "Non permutation map not supported"); |
1185 | if (const auto dimToLvl = enc.getDimToLvl()) |
1186 | return dimToLvl.getDimPosition(l); |
1187 | } |
1188 | return l; |
1189 | } |
1190 | |
1191 | Level mlir::sparse_tensor::toLvl(SparseTensorEncodingAttr enc, Dimension d) { |
1192 | if (enc) { |
1193 | assert(enc.isPermutation() && "Non permutation map not supported"); |
1194 | if (const auto lvlToDim = enc.getLvlToDim()) |
1195 | return lvlToDim.getDimPosition(d); |
1196 | } |
1197 | return d; |
1198 | } |
1199 | |
1200 | /// We normalized sparse tensor encoding attribute by always using |
1201 | /// ordered/unique LT such that "compressed_nu_no" and "compressed_nu" (as well |
1202 | /// as other variants) lead to the same storage specifier type, and stripping |
1203 | /// irrelevant fields that do not alter the sparse tensor memory layout. |
1204 | static SparseTensorEncodingAttr |
1205 | getNormalizedEncodingForSpecifier(SparseTensorEncodingAttr enc) { |
1206 | SmallVector<LevelType> lts; |
1207 | for (auto lt : enc.getLvlTypes()) |
1208 | lts.push_back(lt.stripStorageIrrelevantProperties()); |
1209 | |
1210 | return SparseTensorEncodingAttr::get( |
1211 | enc.getContext(), lts, |
1212 | AffineMap(), // dimToLvl (irrelevant to storage specifier) |
1213 | AffineMap(), // lvlToDim (irrelevant to storage specifier) |
1214 | // Always use `index` for memSize and lvlSize instead of reusing |
1215 | // `getPosWidth` and `getCrdWidth`. It allows us to reuse the same SSA |
1216 | // value for different bitwidth, it also avoids casting between index and |
1217 | // integer (returned by DimOp) |
1218 | 0, 0, |
1219 | Attribute(), // explicitVal (irrelevant to storage specifier) |
1220 | Attribute(), // implicitVal (irrelevant to storage specifier) |
1221 | enc.getDimSlices()); |
1222 | } |
1223 | |
1224 | StorageSpecifierType |
1225 | StorageSpecifierType::get(MLIRContext *ctx, SparseTensorEncodingAttr encoding) { |
1226 | return Base::get(ctx, getNormalizedEncodingForSpecifier(encoding)); |
1227 | } |
1228 | |
1229 | StorageSpecifierType |
1230 | StorageSpecifierType::getChecked(function_ref<InFlightDiagnostic()> emitError, |
1231 | MLIRContext *ctx, |
1232 | SparseTensorEncodingAttr encoding) { |
1233 | return Base::getChecked(emitError, ctx, |
1234 | getNormalizedEncodingForSpecifier(encoding)); |
1235 | } |
1236 | |
1237 | //===----------------------------------------------------------------------===// |
1238 | // SparseTensorDialect Operations. |
1239 | //===----------------------------------------------------------------------===// |
1240 | |
1241 | static LogicalResult lvlIsInBounds(Level lvl, Value tensor) { |
1242 | return success(IsSuccess: lvl < getSparseTensorType(val: tensor).getLvlRank()); |
1243 | } |
1244 | |
1245 | static LogicalResult isMatchingWidth(Value mem, unsigned width) { |
1246 | const Type etp = getMemRefType(mem).getElementType(); |
1247 | return success(IsSuccess: width == 0 ? etp.isIndex() : etp.isInteger(width)); |
1248 | } |
1249 | |
1250 | static LogicalResult verifySparsifierGetterSetter( |
1251 | StorageSpecifierKind mdKind, std::optional<Level> lvl, |
1252 | TypedValue<StorageSpecifierType> md, Operation *op) { |
1253 | if (mdKind == StorageSpecifierKind::ValMemSize && lvl) { |
1254 | return op->emitError( |
1255 | message: "redundant level argument for querying value memory size"); |
1256 | } |
1257 | |
1258 | const auto enc = md.getType().getEncoding(); |
1259 | const Level lvlRank = enc.getLvlRank(); |
1260 | |
1261 | if (mdKind == StorageSpecifierKind::DimOffset || |
1262 | mdKind == StorageSpecifierKind::DimStride) |
1263 | if (!enc.isSlice()) |
1264 | return op->emitError(message: "requested slice data on non-slice tensor"); |
1265 | |
1266 | if (mdKind != StorageSpecifierKind::ValMemSize) { |
1267 | if (!lvl) |
1268 | return op->emitError(message: "missing level argument"); |
1269 | |
1270 | const Level l = lvl.value(); |
1271 | if (l >= lvlRank) |
1272 | return op->emitError(message: "requested level is out of bounds"); |
1273 | |
1274 | if (mdKind == StorageSpecifierKind::PosMemSize && enc.isSingletonLvl(l)) |
1275 | return op->emitError( |
1276 | message: "requested position memory size on a singleton level"); |
1277 | } |
1278 | return success(); |
1279 | } |
1280 | |
1281 | static Type getFieldElemType(SparseTensorType stt, SparseTensorFieldKind kind) { |
1282 | switch (kind) { |
1283 | case SparseTensorFieldKind::CrdMemRef: |
1284 | return stt.getCrdType(); |
1285 | case SparseTensorFieldKind::PosMemRef: |
1286 | return stt.getPosType(); |
1287 | case SparseTensorFieldKind::ValMemRef: |
1288 | return stt.getElementType(); |
1289 | case SparseTensorFieldKind::StorageSpec: |
1290 | return nullptr; |
1291 | } |
1292 | llvm_unreachable("Unrecognizable FieldKind"); |
1293 | } |
1294 | |
1295 | static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape, |
1296 | SparseTensorType stt, |
1297 | RankedTensorType valTp, |
1298 | TypeRange lvlTps) { |
1299 | if (requiresStaticShape && !stt.hasStaticDimShape()) |
1300 | return op->emitError(message: "the sparse-tensor must have static shape"); |
1301 | if (!stt.hasEncoding()) |
1302 | return op->emitError(message: "the sparse-tensor must have an encoding attribute"); |
1303 | |
1304 | // Verifies the trailing COO. |
1305 | Level cooStartLvl = stt.getAoSCOOStart(); |
1306 | if (cooStartLvl < stt.getLvlRank()) { |
1307 | // We only supports trailing COO for now, must be the last input. |
1308 | auto cooTp = llvm::cast<ShapedType>(lvlTps.back()); |
1309 | // The coordinates should be in shape of <? x rank> |
1310 | unsigned expCOORank = stt.getLvlRank() - cooStartLvl; |
1311 | if (cooTp.getRank() != 2 || expCOORank != cooTp.getShape().back()) { |
1312 | return op->emitError(message: "input/output trailing COO level-ranks don't match"); |
1313 | } |
1314 | } |
1315 | |
1316 | // Verifies that all types match. |
1317 | StorageLayout layout(stt.getEncoding()); |
1318 | if (layout.getNumDataFields() != lvlTps.size() + 1) // plus one value memref |
1319 | return op->emitError(message: "inconsistent number of fields between input/output"); |
1320 | |
1321 | unsigned idx = 0; |
1322 | bool misMatch = false; |
1323 | layout.foreachField(callback: [&idx, &misMatch, stt, valTp, |
1324 | lvlTps](FieldIndex fid, SparseTensorFieldKind fKind, |
1325 | Level lvl, LevelType lt) -> bool { |
1326 | if (fKind == SparseTensorFieldKind::StorageSpec) |
1327 | return true; |
1328 | |
1329 | Type inputTp = nullptr; |
1330 | if (fKind == SparseTensorFieldKind::ValMemRef) { |
1331 | inputTp = valTp; |
1332 | } else { |
1333 | assert(fid == idx && stt.getLvlType(lvl) == lt); |
1334 | inputTp = lvlTps[idx++]; |
1335 | } |
1336 | // The input element type and expected element type should match. |
1337 | Type inpElemTp = llvm::cast<TensorType>(Val&: inputTp).getElementType(); |
1338 | Type expElemTp = getFieldElemType(stt, kind: fKind); |
1339 | if (inpElemTp != expElemTp) { |
1340 | misMatch = true; |
1341 | return false; // to terminate the iteration |
1342 | } |
1343 | return true; |
1344 | }); |
1345 | |
1346 | if (misMatch) |
1347 | return op->emitError(message: "input/output element-types don't match"); |
1348 | return success(); |
1349 | } |
1350 | |
1351 | LogicalResult AssembleOp::verify() { |
1352 | RankedTensorType valuesTp = getValues().getType(); |
1353 | const auto lvlsTp = getLevels().getTypes(); |
1354 | const auto resTp = getSparseTensorType(getResult()); |
1355 | return verifyPackUnPack(*this, true, resTp, valuesTp, lvlsTp); |
1356 | } |
1357 | |
1358 | LogicalResult DisassembleOp::verify() { |
1359 | if (getOutValues().getType() != getRetValues().getType()) |
1360 | return emitError("output values and return value type mismatch"); |
1361 | |
1362 | for (auto [ot, rt] : llvm::zip_equal(getOutLevels(), getRetLevels())) |
1363 | if (ot.getType() != rt.getType()) |
1364 | return emitError("output levels and return levels type mismatch"); |
1365 | |
1366 | RankedTensorType valuesTp = getRetValues().getType(); |
1367 | const auto lvlsTp = getRetLevels().getTypes(); |
1368 | const auto srcTp = getSparseTensorType(getTensor()); |
1369 | return verifyPackUnPack(*this, false, srcTp, valuesTp, lvlsTp); |
1370 | } |
1371 | |
1372 | LogicalResult ConvertOp::verify() { |
1373 | RankedTensorType tp1 = getSource().getType(); |
1374 | RankedTensorType tp2 = getDest().getType(); |
1375 | if (tp1.getRank() != tp2.getRank()) |
1376 | return emitError("unexpected conversion mismatch in rank"); |
1377 | auto dstEnc = |
1378 | llvm::dyn_cast_or_null<SparseTensorEncodingAttr>(tp2.getEncoding()); |
1379 | if (dstEnc && dstEnc.isSlice()) |
1380 | return emitError("cannot convert to a sparse tensor slice"); |
1381 | |
1382 | auto shape1 = tp1.getShape(); |
1383 | auto shape2 = tp2.getShape(); |
1384 | // Accept size matches between the source and the destination type |
1385 | // (e.g. 10 vs. 10, 10 vs. ?, or ? vs. ?), but reject direct mismatches or |
1386 | // matches that would need a runtime assert (e.g. 10 vs. 20 or ? vs. 10). |
1387 | for (Dimension d = 0, dimRank = tp1.getRank(); d < dimRank; d++) |
1388 | if (shape1[d] != shape2[d] && shape2[d] != ShapedType::kDynamic) |
1389 | return emitError("unexpected conversion mismatch in dimension ") << d; |
1390 | return success(); |
1391 | } |
1392 | |
1393 | OpFoldResult ConvertOp::fold(FoldAdaptor adaptor) { |
1394 | if (getType() == getSource().getType()) |
1395 | return getSource(); |
1396 | return {}; |
1397 | } |
1398 | |
1399 | bool ConvertOp::needsExtraSort() { |
1400 | SparseTensorType srcStt = getSparseTensorType(getSource()); |
1401 | SparseTensorType dstStt = getSparseTensorType(getDest()); |
1402 | |
1403 | // We do not need an extra sort when returning unordered sparse tensors or |
1404 | // dense tensor since dense tensor support random access. |
1405 | if (dstStt.isAllDense() || !dstStt.isAllOrdered()) |
1406 | return false; |
1407 | |
1408 | if (srcStt.isAllOrdered() && dstStt.isAllOrdered() && |
1409 | srcStt.hasSameDimToLvl(dstStt)) { |
1410 | return false; |
1411 | } |
1412 | |
1413 | // Source and dest tensors are ordered in different ways. We only do direct |
1414 | // dense to sparse conversion when the dense input is defined by a sparse |
1415 | // constant. Note that we can theoretically always directly convert from dense |
1416 | // inputs by rotating dense loops but it leads to bad cache locality and hurt |
1417 | // performance. |
1418 | if (auto constOp = getSource().getDefiningOp<arith::ConstantOp>()) |
1419 | if (isa<SparseElementsAttr>(constOp.getValue())) |
1420 | return false; |
1421 | |
1422 | return true; |
1423 | } |
1424 | |
1425 | LogicalResult CrdTranslateOp::verify() { |
1426 | uint64_t inRank = getEncoder().getLvlRank(); |
1427 | uint64_t outRank = getEncoder().getDimRank(); |
1428 | |
1429 | if (getDirection() == CrdTransDirectionKind::dim2lvl) |
1430 | std::swap(inRank, outRank); |
1431 | |
1432 | if (inRank != getInCrds().size() || outRank != getOutCrds().size()) |
1433 | return emitError("Coordinate rank mismatch with encoding"); |
1434 | |
1435 | return success(); |
1436 | } |
1437 | |
1438 | LogicalResult CrdTranslateOp::fold(FoldAdaptor adaptor, |
1439 | SmallVectorImpl<OpFoldResult> &results) { |
1440 | if (getEncoder().isIdentity()) { |
1441 | results.assign(getInCrds().begin(), getInCrds().end()); |
1442 | return success(); |
1443 | } |
1444 | if (getEncoder().isPermutation()) { |
1445 | AffineMap perm = getDirection() == CrdTransDirectionKind::dim2lvl |
1446 | ? getEncoder().getDimToLvl() |
1447 | : getEncoder().getLvlToDim(); |
1448 | for (AffineExpr exp : perm.getResults()) |
1449 | results.push_back(getInCrds()[cast<AffineDimExpr>(exp).getPosition()]); |
1450 | return success(); |
1451 | } |
1452 | |
1453 | // Fuse dim2lvl/lvl2dim pairs. |
1454 | auto def = getInCrds()[0].getDefiningOp<CrdTranslateOp>(); |
1455 | bool sameDef = def && llvm::all_of(getInCrds(), [def](Value v) { |
1456 | return v.getDefiningOp() == def; |
1457 | }); |
1458 | if (!sameDef) |
1459 | return failure(); |
1460 | |
1461 | bool oppositeDir = def.getDirection() != getDirection(); |
1462 | bool sameOracle = |
1463 | def.getEncoder().getDimToLvl() == getEncoder().getDimToLvl(); |
1464 | bool sameCount = def.getNumResults() == getInCrds().size(); |
1465 | if (!oppositeDir || !sameOracle || !sameCount) |
1466 | return failure(); |
1467 | |
1468 | // The definition produces the coordinates in the same order as the input |
1469 | // coordinates. |
1470 | bool sameOrder = llvm::all_of(llvm::zip_equal(def.getOutCrds(), getInCrds()), |
1471 | [](auto valuePair) { |
1472 | auto [lhs, rhs] = valuePair; |
1473 | return lhs == rhs; |
1474 | }); |
1475 | |
1476 | if (!sameOrder) |
1477 | return failure(); |
1478 | // l1 = dim2lvl (lvl2dim l0) |
1479 | // ==> l0 |
1480 | results.append(def.getInCrds().begin(), def.getInCrds().end()); |
1481 | return success(); |
1482 | } |
1483 | |
1484 | void LvlOp::build(OpBuilder &builder, OperationState &state, Value source, |
1485 | int64_t index) { |
1486 | Value val = builder.create<arith::ConstantIndexOp>(state.location, index); |
1487 | return build(builder, state, source, val); |
1488 | } |
1489 | |
1490 | LogicalResult LvlOp::verify() { |
1491 | if (std::optional<uint64_t> lvl = getConstantLvlIndex()) { |
1492 | auto stt = getSparseTensorType(getSource()); |
1493 | if (static_cast<uint64_t>(lvl.value()) >= stt.getLvlRank()) |
1494 | return emitError( |
1495 | "Level index exceeds the rank of the input sparse tensor"); |
1496 | } |
1497 | return success(); |
1498 | } |
1499 | |
1500 | std::optional<uint64_t> LvlOp::getConstantLvlIndex() { |
1501 | return getConstantIntValue(getIndex()); |
1502 | } |
1503 | |
1504 | Speculation::Speculatability LvlOp::getSpeculatability() { |
1505 | auto constantIndex = getConstantLvlIndex(); |
1506 | if (!constantIndex) |
1507 | return Speculation::NotSpeculatable; |
1508 | |
1509 | assert(constantIndex < |
1510 | cast<RankedTensorType>(getSource().getType()).getRank()); |
1511 | return Speculation::Speculatable; |
1512 | } |
1513 | |
1514 | OpFoldResult LvlOp::fold(FoldAdaptor adaptor) { |
1515 | auto lvlIndex = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex()); |
1516 | if (!lvlIndex) |
1517 | return {}; |
1518 | |
1519 | Level lvl = lvlIndex.getAPSInt().getZExtValue(); |
1520 | auto stt = getSparseTensorType(getSource()); |
1521 | if (lvl >= stt.getLvlRank()) { |
1522 | // Follows the same convention used by tensor.dim operation. Out of bound |
1523 | // indices produce undefined behavior but are still valid IR. Don't choke on |
1524 | // them. |
1525 | return {}; |
1526 | } |
1527 | |
1528 | // Helper lambda to build an IndexAttr. |
1529 | auto getIndexAttr = [this](int64_t lvlSz) { |
1530 | return IntegerAttr::get(IndexType::get(getContext()), APInt(64, lvlSz)); |
1531 | }; |
1532 | |
1533 | SmallVector<Size> lvlShape = stt.getLvlShape(); |
1534 | if (!ShapedType::isDynamic(lvlShape[lvl])) |
1535 | return getIndexAttr(lvlShape[lvl]); |
1536 | |
1537 | return {}; |
1538 | } |
1539 | |
1540 | void ReinterpretMapOp::build(OpBuilder &odsBuilder, OperationState &odsState, |
1541 | SparseTensorEncodingAttr dstEnc, Value source) { |
1542 | auto srcStt = getSparseTensorType(source); |
1543 | SmallVector<int64_t> srcLvlShape = srcStt.getLvlShape(); |
1544 | SmallVector<int64_t> dstDimShape = |
1545 | dstEnc.translateShape(srcLvlShape, CrdTransDirectionKind::lvl2dim); |
1546 | auto dstTp = |
1547 | RankedTensorType::get(dstDimShape, srcStt.getElementType(), dstEnc); |
1548 | return build(odsBuilder, odsState, dstTp, source); |
1549 | } |
1550 | |
1551 | LogicalResult ReinterpretMapOp::verify() { |
1552 | auto srcStt = getSparseTensorType(getSource()); |
1553 | auto dstStt = getSparseTensorType(getDest()); |
1554 | ArrayRef<LevelType> srcLvlTps = srcStt.getLvlTypes(); |
1555 | ArrayRef<LevelType> dstLvlTps = dstStt.getLvlTypes(); |
1556 | |
1557 | if (srcLvlTps.size() != dstLvlTps.size()) |
1558 | return emitError("Level rank mismatch between source/dest tensors"); |
1559 | |
1560 | for (auto [srcLvlTp, dstLvlTp] : llvm::zip(srcLvlTps, dstLvlTps)) |
1561 | if (srcLvlTp != dstLvlTp) |
1562 | return emitError("Level type mismatch between source/dest tensors"); |
1563 | |
1564 | if (srcStt.getPosWidth() != dstStt.getPosWidth() || |
1565 | srcStt.getCrdWidth() != dstStt.getCrdWidth()) { |
1566 | return emitError("Crd/Pos width mismatch between source/dest tensors"); |
1567 | } |
1568 | |
1569 | if (srcStt.getElementType() != dstStt.getElementType()) |
1570 | return emitError("Element type mismatch between source/dest tensors"); |
1571 | |
1572 | SmallVector<Size> srcLvlShape = srcStt.getLvlShape(); |
1573 | SmallVector<Size> dstLvlShape = dstStt.getLvlShape(); |
1574 | for (auto [srcLvlSz, dstLvlSz] : llvm::zip(srcLvlShape, dstLvlShape)) { |
1575 | if (srcLvlSz != dstLvlSz) { |
1576 | // Should we allow one side to be dynamic size, e.g., <?x?> should be |
1577 | // compatible to <3x4>? For now, we require all the level sizes to be |
1578 | // *exactly* matched for simplicity. |
1579 | return emitError("Level size mismatch between source/dest tensors"); |
1580 | } |
1581 | } |
1582 | |
1583 | return success(); |
1584 | } |
1585 | |
1586 | OpFoldResult ReinterpretMapOp::fold(FoldAdaptor adaptor) { |
1587 | if (getSource().getType() == getDest().getType()) |
1588 | return getSource(); |
1589 | |
1590 | if (auto def = getSource().getDefiningOp<ReinterpretMapOp>()) { |
1591 | // A -> B, B -> A ==> A |
1592 | if (def.getSource().getType() == getDest().getType()) |
1593 | return def.getSource(); |
1594 | } |
1595 | return {}; |
1596 | } |
1597 | |
1598 | template <typename ToBufferOp> |
1599 | static LogicalResult inferSparseBufferType(ValueRange ops, DictionaryAttr attr, |
1600 | OpaqueProperties prop, |
1601 | RegionRange region, |
1602 | SmallVectorImpl<mlir::Type> &ret) { |
1603 | typename ToBufferOp::Adaptor adaptor(ops, attr, prop, region); |
1604 | SparseTensorType stt = getSparseTensorType(adaptor.getTensor()); |
1605 | Type elemTp = nullptr; |
1606 | bool withStride = false; |
1607 | if constexpr (std::is_same_v<ToBufferOp, ToPositionsOp>) { |
1608 | elemTp = stt.getPosType(); |
1609 | } else if constexpr (std::is_same_v<ToBufferOp, ToCoordinatesOp> || |
1610 | std::is_same_v<ToBufferOp, ToCoordinatesBufferOp>) { |
1611 | elemTp = stt.getCrdType(); |
1612 | if constexpr (std::is_same_v<ToBufferOp, ToCoordinatesOp>) |
1613 | withStride = stt.getAoSCOOStart() <= adaptor.getLevel(); |
1614 | } else if constexpr (std::is_same_v<ToBufferOp, ToValuesOp>) { |
1615 | elemTp = stt.getElementType(); |
1616 | } |
1617 | |
1618 | assert(elemTp && "unhandled operation."); |
1619 | SmallVector<int64_t> bufShape = stt.getBatchLvlShape(); |
1620 | bufShape.push_back(ShapedType::kDynamic); |
1621 | |
1622 | auto layout = withStride ? StridedLayoutAttr::StridedLayoutAttr::get( |
1623 | stt.getContext(), ShapedType::kDynamic, |
1624 | {ShapedType::kDynamic}) |
1625 | : StridedLayoutAttr(); |
1626 | ret.emplace_back(MemRefType::get(bufShape, elemTp, layout)); |
1627 | return success(); |
1628 | } |
1629 | |
1630 | LogicalResult ToPositionsOp::verify() { |
1631 | auto stt = getSparseTensorType(getTensor()); |
1632 | if (failed(lvlIsInBounds(getLevel(), getTensor()))) |
1633 | return emitError("requested level is out of bounds"); |
1634 | if (failed(isMatchingWidth(getResult(), stt.getPosWidth()))) |
1635 | return emitError("unexpected type for positions"); |
1636 | return success(); |
1637 | } |
1638 | |
1639 | LogicalResult |
1640 | ToPositionsOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc, |
1641 | ValueRange ops, DictionaryAttr attr, |
1642 | OpaqueProperties prop, RegionRange region, |
1643 | SmallVectorImpl<mlir::Type> &ret) { |
1644 | return inferSparseBufferType<ToPositionsOp>(ops, attr, prop, region, ret); |
1645 | } |
1646 | |
1647 | LogicalResult ToCoordinatesOp::verify() { |
1648 | auto stt = getSparseTensorType(getTensor()); |
1649 | if (failed(lvlIsInBounds(getLevel(), getTensor()))) |
1650 | return emitError("requested level is out of bounds"); |
1651 | if (failed(isMatchingWidth(getResult(), stt.getCrdWidth()))) |
1652 | return emitError("unexpected type for coordinates"); |
1653 | return success(); |
1654 | } |
1655 | |
1656 | LogicalResult |
1657 | ToCoordinatesOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc, |
1658 | ValueRange ops, DictionaryAttr attr, |
1659 | OpaqueProperties prop, RegionRange region, |
1660 | SmallVectorImpl<mlir::Type> &ret) { |
1661 | return inferSparseBufferType<ToCoordinatesOp>(ops, attr, prop, region, ret); |
1662 | } |
1663 | |
1664 | LogicalResult ToCoordinatesBufferOp::verify() { |
1665 | auto stt = getSparseTensorType(getTensor()); |
1666 | if (stt.getAoSCOOStart() >= stt.getLvlRank()) |
1667 | return emitError("expected sparse tensor with a COO region"); |
1668 | return success(); |
1669 | } |
1670 | |
1671 | LogicalResult ToCoordinatesBufferOp::inferReturnTypes( |
1672 | MLIRContext *ctx, std::optional<Location> loc, ValueRange ops, |
1673 | DictionaryAttr attr, OpaqueProperties prop, RegionRange region, |
1674 | SmallVectorImpl<mlir::Type> &ret) { |
1675 | return inferSparseBufferType<ToCoordinatesBufferOp>(ops, attr, prop, region, |
1676 | ret); |
1677 | } |
1678 | |
1679 | LogicalResult ToValuesOp::verify() { |
1680 | auto stt = getSparseTensorType(getTensor()); |
1681 | auto mtp = getMemRefType(getResult()); |
1682 | if (stt.getElementType() != mtp.getElementType()) |
1683 | return emitError("unexpected mismatch in element types"); |
1684 | return success(); |
1685 | } |
1686 | |
1687 | LogicalResult ToValuesOp::inferReturnTypes(MLIRContext *ctx, |
1688 | std::optional<Location> loc, |
1689 | ValueRange ops, DictionaryAttr attr, |
1690 | OpaqueProperties prop, |
1691 | RegionRange region, |
1692 | SmallVectorImpl<mlir::Type> &ret) { |
1693 | return inferSparseBufferType<ToValuesOp>(ops, attr, prop, region, ret); |
1694 | } |
1695 | |
1696 | LogicalResult ToSliceOffsetOp::verify() { |
1697 | auto rank = getSlice().getType().getRank(); |
1698 | if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0) |
1699 | return emitError("requested dimension out of bound"); |
1700 | return success(); |
1701 | } |
1702 | |
1703 | LogicalResult ToSliceStrideOp::verify() { |
1704 | auto rank = getSlice().getType().getRank(); |
1705 | if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0) |
1706 | return emitError("requested dimension out of bound"); |
1707 | return success(); |
1708 | } |
1709 | |
1710 | LogicalResult GetStorageSpecifierOp::verify() { |
1711 | return verifySparsifierGetterSetter(getSpecifierKind(), getLevel(), |
1712 | getSpecifier(), getOperation()); |
1713 | } |
1714 | |
1715 | template <typename SpecifierOp> |
1716 | static SetStorageSpecifierOp getSpecifierSetDef(SpecifierOp op) { |
1717 | return op.getSpecifier().template getDefiningOp<SetStorageSpecifierOp>(); |
1718 | } |
1719 | |
1720 | OpFoldResult GetStorageSpecifierOp::fold(FoldAdaptor adaptor) { |
1721 | const StorageSpecifierKind kind = getSpecifierKind(); |
1722 | const auto lvl = getLevel(); |
1723 | for (auto op = getSpecifierSetDef(*this); op; op = getSpecifierSetDef(op)) |
1724 | if (kind == op.getSpecifierKind() && lvl == op.getLevel()) |
1725 | return op.getValue(); |
1726 | return {}; |
1727 | } |
1728 | |
1729 | LogicalResult SetStorageSpecifierOp::verify() { |
1730 | return verifySparsifierGetterSetter(getSpecifierKind(), getLevel(), |
1731 | getSpecifier(), getOperation()); |
1732 | } |
1733 | |
1734 | template <class T> |
1735 | static LogicalResult verifyNumBlockArgs(T *op, Region ®ion, |
1736 | const char *regionName, |
1737 | TypeRange inputTypes, Type outputType) { |
1738 | unsigned numArgs = region.getNumArguments(); |
1739 | unsigned expectedNum = inputTypes.size(); |
1740 | if (numArgs != expectedNum) |
1741 | return op->emitError() << regionName << " region must have exactly " |
1742 | << expectedNum << " arguments"; |
1743 | |
1744 | for (unsigned i = 0; i < numArgs; i++) { |
1745 | Type typ = region.getArgument(i).getType(); |
1746 | if (typ != inputTypes[i]) |
1747 | return op->emitError() << regionName << " region argument "<< (i + 1) |
1748 | << " type mismatch"; |
1749 | } |
1750 | Operation *term = region.front().getTerminator(); |
1751 | YieldOp yield = dyn_cast<YieldOp>(term); |
1752 | if (!yield) |
1753 | return op->emitError() << regionName |
1754 | << " region must end with sparse_tensor.yield"; |
1755 | if (!yield.hasSingleResult() || |
1756 | yield.getSingleResult().getType() != outputType) |
1757 | return op->emitError() << regionName << " region yield type mismatch"; |
1758 | |
1759 | return success(); |
1760 | } |
1761 | |
1762 | LogicalResult BinaryOp::verify() { |
1763 | NamedAttrList attrs = (*this)->getAttrs(); |
1764 | Type leftType = getX().getType(); |
1765 | Type rightType = getY().getType(); |
1766 | Type outputType = getOutput().getType(); |
1767 | Region &overlap = getOverlapRegion(); |
1768 | Region &left = getLeftRegion(); |
1769 | Region &right = getRightRegion(); |
1770 | |
1771 | // Check correct number of block arguments and return type for each |
1772 | // non-empty region. |
1773 | if (!overlap.empty()) { |
1774 | if (failed(verifyNumBlockArgs(this, overlap, "overlap", |
1775 | TypeRange{leftType, rightType}, outputType))) |
1776 | return failure(); |
1777 | } |
1778 | if (!left.empty()) { |
1779 | if (failed(verifyNumBlockArgs(this, left, "left", TypeRange{leftType}, |
1780 | outputType))) |
1781 | return failure(); |
1782 | } else if (getLeftIdentity()) { |
1783 | if (leftType != outputType) |
1784 | return emitError("left=identity requires first argument to have the same " |
1785 | "type as the output"); |
1786 | } |
1787 | if (!right.empty()) { |
1788 | if (failed(verifyNumBlockArgs(this, right, "right", TypeRange{rightType}, |
1789 | outputType))) |
1790 | return failure(); |
1791 | } else if (getRightIdentity()) { |
1792 | if (rightType != outputType) |
1793 | return emitError("right=identity requires second argument to have the " |
1794 | "same type as the output"); |
1795 | } |
1796 | return success(); |
1797 | } |
1798 | |
1799 | LogicalResult UnaryOp::verify() { |
1800 | Type inputType = getX().getType(); |
1801 | Type outputType = getOutput().getType(); |
1802 | |
1803 | // Check correct number of block arguments and return type for each |
1804 | // non-empty region. |
1805 | Region &present = getPresentRegion(); |
1806 | if (!present.empty()) { |
1807 | if (failed(verifyNumBlockArgs(this, present, "present", |
1808 | TypeRange{inputType}, outputType))) |
1809 | return failure(); |
1810 | } |
1811 | Region &absent = getAbsentRegion(); |
1812 | if (!absent.empty()) { |
1813 | if (failed(verifyNumBlockArgs(this, absent, "absent", TypeRange{}, |
1814 | outputType))) |
1815 | return failure(); |
1816 | // Absent branch can only yield invariant values. |
1817 | Block *absentBlock = &absent.front(); |
1818 | Block *parent = getOperation()->getBlock(); |
1819 | Value absentVal = |
1820 | cast<YieldOp>(absentBlock->getTerminator()).getSingleResult(); |
1821 | if (auto arg = dyn_cast<BlockArgument>(absentVal)) { |
1822 | if (arg.getOwner() == parent) |
1823 | return emitError("absent region cannot yield linalg argument"); |
1824 | } else if (Operation *def = absentVal.getDefiningOp()) { |
1825 | if (!isa<arith::ConstantOp>(def) && |
1826 | (def->getBlock() == absentBlock || def->getBlock() == parent)) |
1827 | return emitError("absent region cannot yield locally computed value"); |
1828 | } |
1829 | } |
1830 | return success(); |
1831 | } |
1832 | |
1833 | bool ConcatenateOp::needsExtraSort() { |
1834 | SparseTensorType dstStt = getSparseTensorType(*this); |
1835 | if (dstStt.isAllDense() || !dstStt.isAllOrdered()) |
1836 | return false; |
1837 | |
1838 | bool allSameOrdered = llvm::all_of(getInputs(), [dstStt](Value op) { |
1839 | return getSparseTensorType(op).hasSameDimToLvl(dstStt); |
1840 | }); |
1841 | // TODO: When conDim != 0, as long as conDim corresponding to the first level |
1842 | // in all input/output buffers, and all input/output buffers have the same |
1843 | // dimToLvl, the tmp COO buffer is still unnecessary (e.g, concatenate |
1844 | // CSC matrices along column). |
1845 | bool directLowerable = |
1846 | allSameOrdered && getDimension() == 0 && dstStt.isIdentity(); |
1847 | return !directLowerable; |
1848 | } |
1849 | |
1850 | LogicalResult ConcatenateOp::verify() { |
1851 | const auto dstTp = getSparseTensorType(*this); |
1852 | const Dimension concatDim = getDimension(); |
1853 | const Dimension dimRank = dstTp.getDimRank(); |
1854 | |
1855 | if (getInputs().size() <= 1) |
1856 | return emitError("Need at least two tensors to concatenate."); |
1857 | |
1858 | if (concatDim >= dimRank) |
1859 | return emitError(llvm::formatv( |
1860 | "Concat-dimension is out of bounds for dimension-rank ({0} >= {1})", |
1861 | concatDim, dimRank)); |
1862 | |
1863 | for (const auto &it : llvm::enumerate(getInputs())) { |
1864 | const auto i = it.index(); |
1865 | const auto srcTp = getSparseTensorType(it.value()); |
1866 | if (srcTp.hasDynamicDimShape()) |
1867 | return emitError(llvm::formatv("Input tensor ${0} has dynamic shape", i)); |
1868 | const Dimension srcDimRank = srcTp.getDimRank(); |
1869 | if (srcDimRank != dimRank) |
1870 | return emitError( |
1871 | llvm::formatv("Input tensor ${0} has a different rank (rank={1}) " |
1872 | "from the output tensor (rank={2}).", |
1873 | i, srcDimRank, dimRank)); |
1874 | } |
1875 | |
1876 | for (Dimension d = 0; d < dimRank; d++) { |
1877 | const Size dstSh = dstTp.getDimShape()[d]; |
1878 | if (d == concatDim) { |
1879 | if (!ShapedType::isDynamic(dstSh)) { |
1880 | // If we reach here, then all inputs have static shapes. So we |
1881 | // can use `getDimShape()[d]` instead of `*getDynamicDimSize(d)` |
1882 | // to avoid redundant assertions in the loop. |
1883 | Size sumSz = 0; |
1884 | for (const auto src : getInputs()) |
1885 | sumSz += getSparseTensorType(src).getDimShape()[d]; |
1886 | // If all dimension are statically known, the sum of all the input |
1887 | // dimensions should be equal to the output dimension. |
1888 | if (sumSz != dstSh) |
1889 | return emitError( |
1890 | "The concatenation dimension of the output tensor should be the " |
1891 | "sum of all the concatenation dimensions of the input tensors."); |
1892 | } |
1893 | } else { |
1894 | Size prev = dstSh; |
1895 | for (const auto src : getInputs()) { |
1896 | const auto sh = getSparseTensorType(src).getDimShape()[d]; |
1897 | if (!ShapedType::isDynamic(prev) && sh != prev) |
1898 | return emitError("All dimensions (expect for the concatenating one) " |
1899 | "should be equal."); |
1900 | prev = sh; |
1901 | } |
1902 | } |
1903 | } |
1904 | |
1905 | return success(); |
1906 | } |
1907 | |
1908 | void PushBackOp::build(OpBuilder &builder, OperationState &result, |
1909 | Value curSize, Value inBuffer, Value value) { |
1910 | build(builder, result, curSize, inBuffer, value, Value()); |
1911 | } |
1912 | |
1913 | LogicalResult PushBackOp::verify() { |
1914 | if (Value n = getN()) { |
1915 | std::optional<int64_t> nValue = getConstantIntValue(n); |
1916 | if (nValue && nValue.value() < 1) |
1917 | return emitOpError("n must be not less than 1"); |
1918 | } |
1919 | return success(); |
1920 | } |
1921 | |
1922 | LogicalResult CompressOp::verify() { |
1923 | const auto stt = getSparseTensorType(getTensor()); |
1924 | if (stt.getLvlRank() != 1 + static_cast<Level>(getLvlCoords().size())) |
1925 | return emitOpError("incorrect number of coordinates"); |
1926 | return success(); |
1927 | } |
1928 | |
1929 | void ForeachOp::build( |
1930 | OpBuilder &builder, OperationState &result, Value tensor, |
1931 | ValueRange initArgs, AffineMapAttr order, |
1932 | function_ref<void(OpBuilder &, Location, ValueRange, Value, ValueRange)> |
1933 | bodyBuilder) { |
1934 | build(builder, result, initArgs.getTypes(), tensor, initArgs, order); |
1935 | // Builds foreach body. |
1936 | if (!bodyBuilder) |
1937 | return; |
1938 | const auto stt = getSparseTensorType(tensor); |
1939 | const Dimension dimRank = stt.getDimRank(); |
1940 | |
1941 | // Starts with `dimRank`-many coordinates. |
1942 | SmallVector<Type> blockArgTypes(dimRank, builder.getIndexType()); |
1943 | // Followed by one value. |
1944 | blockArgTypes.push_back(stt.getElementType()); |
1945 | // Followed by the reduction variables. |
1946 | blockArgTypes.append(initArgs.getTypes().begin(), initArgs.getTypes().end()); |
1947 | |
1948 | SmallVector<Location> blockArgLocs(blockArgTypes.size(), tensor.getLoc()); |
1949 | |
1950 | OpBuilder::InsertionGuard guard(builder); |
1951 | auto ®ion = *result.regions.front(); |
1952 | Block *bodyBlock = |
1953 | builder.createBlock(®ion, region.end(), blockArgTypes, blockArgLocs); |
1954 | bodyBuilder(builder, result.location, |
1955 | bodyBlock->getArguments().slice(0, dimRank), |
1956 | bodyBlock->getArguments()[dimRank], |
1957 | bodyBlock->getArguments().drop_front(dimRank + 1)); |
1958 | } |
1959 | |
1960 | LogicalResult ForeachOp::verify() { |
1961 | const auto t = getSparseTensorType(getTensor()); |
1962 | const Dimension dimRank = t.getDimRank(); |
1963 | const auto args = getBody()->getArguments(); |
1964 | |
1965 | if (getOrder().has_value() && getOrder()->getNumDims() != t.getLvlRank()) |
1966 | return emitError("Level traverse order does not match tensor's level rank"); |
1967 | |
1968 | if (dimRank + 1 + getInitArgs().size() != args.size()) |
1969 | return emitError("Unmatched number of arguments in the block"); |
1970 | |
1971 | if (getNumResults() != getInitArgs().size()) |
1972 | return emitError("Mismatch in number of init arguments and results"); |
1973 | |
1974 | if (getResultTypes() != getInitArgs().getTypes()) |
1975 | return emitError("Mismatch in types of init arguments and results"); |
1976 | |
1977 | // Cannot mark this const, because the getters aren't. |
1978 | auto yield = cast<YieldOp>(getBody()->getTerminator()); |
1979 | if (yield.getNumOperands() != getNumResults() || |
1980 | yield.getOperands().getTypes() != getResultTypes()) |
1981 | return emitError("Mismatch in types of yield values and results"); |
1982 | |
1983 | const auto iTp = IndexType::get(getContext()); |
1984 | for (Dimension d = 0; d < dimRank; d++) |
1985 | if (args[d].getType() != iTp) |
1986 | return emitError( |
1987 | llvm::formatv("Expecting Index type for argument at index {0}", d)); |
1988 | |
1989 | const auto elemTp = t.getElementType(); |
1990 | const auto valueTp = args[dimRank].getType(); |
1991 | if (elemTp != valueTp) |
1992 | return emitError( |
1993 | llvm::formatv("Unmatched element type between input tensor and " |
1994 | "block argument, expected:{0}, got: {1}", |
1995 | elemTp, valueTp)); |
1996 | return success(); |
1997 | } |
1998 | |
1999 | OpFoldResult ReorderCOOOp::fold(FoldAdaptor adaptor) { |
2000 | if (getSparseTensorEncoding(getInputCoo().getType()) == |
2001 | getSparseTensorEncoding(getResultCoo().getType())) |
2002 | return getInputCoo(); |
2003 | |
2004 | return {}; |
2005 | } |
2006 | |
2007 | LogicalResult ReorderCOOOp::verify() { |
2008 | SparseTensorType srcStt = getSparseTensorType(getInputCoo()); |
2009 | SparseTensorType dstStt = getSparseTensorType(getResultCoo()); |
2010 | |
2011 | if (!srcStt.isCOOType() || !dstStt.isCOOType()) |
2012 | return emitError("Expected COO sparse tensors only"); |
2013 | |
2014 | if (!srcStt.hasSameDimToLvl(dstStt)) |
2015 | return emitError("Unmatched dim2lvl map between input and result COO"); |
2016 | |
2017 | if (srcStt.getPosType() != dstStt.getPosType() || |
2018 | srcStt.getCrdType() != dstStt.getCrdType() || |
2019 | srcStt.getElementType() != dstStt.getElementType()) |
2020 | return emitError("Unmatched storage format between input and result COO"); |
2021 | |
2022 | return success(); |
2023 | } |
2024 | |
2025 | LogicalResult ReduceOp::verify() { |
2026 | Type inputType = getX().getType(); |
2027 | Region &formula = getRegion(); |
2028 | return verifyNumBlockArgs(this, formula, "reduce", |
2029 | TypeRange{inputType, inputType}, inputType); |
2030 | } |
2031 | |
2032 | LogicalResult SelectOp::verify() { |
2033 | Builder b(getContext()); |
2034 | Type inputType = getX().getType(); |
2035 | Type boolType = b.getI1Type(); |
2036 | Region &formula = getRegion(); |
2037 | return verifyNumBlockArgs(this, formula, "select", TypeRange{inputType}, |
2038 | boolType); |
2039 | } |
2040 | |
2041 | LogicalResult SortOp::verify() { |
2042 | AffineMap xPerm = getPermMap(); |
2043 | uint64_t nx = xPerm.getNumDims(); |
2044 | if (nx < 1) |
2045 | return emitError(llvm::formatv("Expected rank(perm_map) > 1, got {0}", nx)); |
2046 | |
2047 | if (!xPerm.isPermutation()) |
2048 | return emitError( |
2049 | llvm::formatv("Expected a permutation map, got {0}", xPerm)); |
2050 | |
2051 | // We can't check the size of the buffers when n or buffer dimensions aren't |
2052 | // compile-time constants. |
2053 | std::optional<int64_t> cn = getConstantIntValue(getN()); |
2054 | if (!cn) |
2055 | return success(); |
2056 | |
2057 | // Verify dimensions. |
2058 | const auto checkDim = [&](Value v, Size minSize, |
2059 | const char *message) -> LogicalResult { |
2060 | const Size sh = getMemRefType(v).getShape()[0]; |
2061 | if (!ShapedType::isDynamic(sh) && sh < minSize) |
2062 | return emitError( |
2063 | llvm::formatv("{0} got {1} < {2}", message, sh, minSize)); |
2064 | return success(); |
2065 | }; |
2066 | uint64_t n = cn.value(); |
2067 | uint64_t ny = 0; |
2068 | if (auto nyAttr = getNyAttr()) |
2069 | ny = nyAttr.getInt(); |
2070 | if (failed(checkDim(getXy(), n * (nx + ny), |
2071 | "Expected dimension(xy) >= n * (rank(perm_map) + ny)"))) |
2072 | return failure(); |
2073 | for (Value opnd : getYs()) |
2074 | if (failed(checkDim(opnd, n, "Expected dimension(y) >= n"))) |
2075 | return failure(); |
2076 | |
2077 | return success(); |
2078 | } |
2079 | |
2080 | //===----------------------------------------------------------------------===// |
2081 | // Sparse Tensor Iteration Operations. |
2082 | //===----------------------------------------------------------------------===// |
2083 | |
2084 | IterSpaceType IteratorType::getIterSpaceType() const { |
2085 | return IterSpaceType::get(getContext(), getEncoding(), getLoLvl(), |
2086 | getHiLvl()); |
2087 | } |
2088 | |
2089 | IteratorType IterSpaceType::getIteratorType() const { |
2090 | return IteratorType::get(getContext(), getEncoding(), getLoLvl(), getHiLvl()); |
2091 | } |
2092 | |
2093 | /// Parses a level range in the form "$lo `to` $hi" |
2094 | /// or simply "$lo" if $hi - $lo = 1 |
2095 | static ParseResult parseLevelRange(AsmParser &parser, Level &lvlLo, |
2096 | Level &lvlHi) { |
2097 | if (parser.parseInteger(result&: lvlLo)) |
2098 | return failure(); |
2099 | |
2100 | if (succeeded(Result: parser.parseOptionalKeyword(keyword: "to"))) { |
2101 | if (parser.parseInteger(result&: lvlHi)) |
2102 | return failure(); |
2103 | } else { |
2104 | lvlHi = lvlLo + 1; |
2105 | } |
2106 | |
2107 | if (lvlHi <= lvlLo) |
2108 | return parser.emitError(loc: parser.getNameLoc(), |
2109 | message: "expect larger level upper bound than lower bound"); |
2110 | |
2111 | return success(); |
2112 | } |
2113 | |
2114 | /// Parses a level range in the form "$lo `to` $hi" |
2115 | /// or simply "$lo" if $hi - $lo = 1 |
2116 | static ParseResult parseLevelRange(OpAsmParser &parser, IntegerAttr &lvlLoAttr, |
2117 | IntegerAttr &lvlHiAttr) { |
2118 | Level lvlLo, lvlHi; |
2119 | if (parseLevelRange(parser, lvlLo, lvlHi)) |
2120 | return failure(); |
2121 | |
2122 | lvlLoAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlLo); |
2123 | lvlHiAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlHi); |
2124 | return success(); |
2125 | } |
2126 | |
2127 | /// Prints a level range in the form "$lo `to` $hi" |
2128 | /// or simply "$lo" if $hi - $lo = 1 |
2129 | static void printLevelRange(AsmPrinter &p, Level lo, Level hi) { |
2130 | |
2131 | if (lo + 1 == hi) |
2132 | p << lo; |
2133 | else |
2134 | p << lo << " to "<< hi; |
2135 | } |
2136 | |
2137 | /// Prints a level range in the form "$lo `to` $hi" |
2138 | /// or simply "$lo" if $hi - $lo = 1 |
2139 | static void printLevelRange(OpAsmPrinter &p, Operation *, IntegerAttr lvlLo, |
2140 | IntegerAttr lvlHi) { |
2141 | unsigned lo = lvlLo.getValue().getZExtValue(); |
2142 | unsigned hi = lvlHi.getValue().getZExtValue(); |
2143 | printLevelRange(p, lo, hi); |
2144 | } |
2145 | |
2146 | /// Parses a list of `optional` defined list in the form of |
2147 | /// "(%val0, _, %val1, ...)", where `_` is used to annotate that the |
2148 | /// corresponding value is not defined (e.g., to represent an undefined |
2149 | /// coordinate in the sparse iteration space). |
2150 | static ParseResult parseOptionalDefinedList( |
2151 | OpAsmParser &parser, OperationState &state, I64BitSet &definedSet, |
2152 | SmallVectorImpl<OpAsmParser::Argument> &definedArgs, |
2153 | unsigned maxCnt = std::numeric_limits<unsigned>::max(), |
2154 | OpAsmParser::Delimiter delimiter = OpAsmParser::Delimiter::Paren) { |
2155 | unsigned cnt = 0; |
2156 | ParseResult crdList = |
2157 | parser.parseCommaSeparatedList(delimiter, parseElementFn: [&]() -> ParseResult { |
2158 | if (parser.parseOptionalKeyword(keyword: "_")) { |
2159 | if (parser.parseArgument(result&: definedArgs.emplace_back())) |
2160 | return failure(); |
2161 | definedSet.set(cnt); |
2162 | } |
2163 | cnt += 1; |
2164 | return success(); |
2165 | }); |
2166 | |
2167 | if (cnt > maxCnt) |
2168 | return parser.emitError(loc: parser.getNameLoc(), |
2169 | message: "parsed more value than expected."); |
2170 | |
2171 | if (failed(Result: crdList)) { |
2172 | return parser.emitError( |
2173 | loc: parser.getNameLoc(), |
2174 | message: "expecting SSA value or \"_\" for level coordinates"); |
2175 | } |
2176 | assert(definedArgs.size() == definedSet.count()); |
2177 | return success(); |
2178 | } |
2179 | |
2180 | static void printOptionalDefinedList(OpAsmPrinter &p, unsigned size, |
2181 | Block::BlockArgListType blocksArgs, |
2182 | I64BitSet definedSet) { |
2183 | if (definedSet.empty()) |
2184 | return; |
2185 | |
2186 | for (unsigned i = 0; i < size; i++) { |
2187 | if (definedSet[i]) { |
2188 | p << blocksArgs.front(); |
2189 | blocksArgs = blocksArgs.drop_front(); |
2190 | } else { |
2191 | p << "_"; |
2192 | } |
2193 | if (i != size - 1) |
2194 | p << ", "; |
2195 | } |
2196 | assert(blocksArgs.empty()); |
2197 | } |
2198 | |
2199 | static ParseResult |
2200 | parseUsedCoordList(OpAsmParser &parser, OperationState &state, |
2201 | SmallVectorImpl<OpAsmParser::Argument> &coords) { |
2202 | // Parse "at(%crd0, _, ...)" |
2203 | I64BitSet crdUsedLvlSet; |
2204 | if (succeeded(Result: parser.parseOptionalKeyword(keyword: "at")) && |
2205 | failed(Result: parseOptionalDefinedList(parser, state, definedSet&: crdUsedLvlSet, definedArgs&: coords))) |
2206 | return failure(); |
2207 | |
2208 | // Always use IndexType for the coordinate. |
2209 | for (auto &coord : coords) |
2210 | coord.type = parser.getBuilder().getIndexType(); |
2211 | |
2212 | // Set the CrdUsedLvl bitset. |
2213 | state.addAttribute("crdUsedLvls", |
2214 | parser.getBuilder().getI64IntegerAttr(crdUsedLvlSet)); |
2215 | return success(); |
2216 | } |
2217 | |
2218 | static ParseResult |
2219 | parseSparseIterateLoop(OpAsmParser &parser, OperationState &state, |
2220 | SmallVectorImpl<OpAsmParser::Argument> &iterators, |
2221 | SmallVectorImpl<OpAsmParser::Argument> &blockArgs) { |
2222 | SmallVector<OpAsmParser::UnresolvedOperand> spaces; |
2223 | SmallVector<OpAsmParser::UnresolvedOperand> initArgs; |
2224 | |
2225 | // Parse "%iters, ... in %spaces, ..." |
2226 | if (parser.parseArgumentList(result&: iterators) || parser.parseKeyword(keyword: "in") || |
2227 | parser.parseOperandList(result&: spaces)) |
2228 | return failure(); |
2229 | |
2230 | if (iterators.size() != spaces.size()) |
2231 | return parser.emitError( |
2232 | loc: parser.getNameLoc(), |
2233 | message: "mismatch in number of sparse iterators and sparse spaces"); |
2234 | |
2235 | SmallVector<OpAsmParser::Argument> coords; |
2236 | if (failed(Result: parseUsedCoordList(parser, state, coords))) |
2237 | return failure(); |
2238 | size_t numCrds = coords.size(); |
2239 | |
2240 | // Parse "iter_args(%arg = %init, ...)" |
2241 | bool hasIterArgs = succeeded(Result: parser.parseOptionalKeyword(keyword: "iter_args")); |
2242 | if (hasIterArgs) |
2243 | if (parser.parseAssignmentList(lhs&: blockArgs, rhs&: initArgs)) |
2244 | return failure(); |
2245 | |
2246 | blockArgs.append(RHS: coords); |
2247 | |
2248 | SmallVector<Type> iterSpaceTps; |
2249 | // parse ": sparse_tensor.iter_space -> ret" |
2250 | if (parser.parseColon() || parser.parseTypeList(result&: iterSpaceTps)) |
2251 | return failure(); |
2252 | if (iterSpaceTps.size() != spaces.size()) |
2253 | return parser.emitError(loc: parser.getNameLoc(), |
2254 | message: "mismatch in number of iteration space operands " |
2255 | "and iteration space types"); |
2256 | |
2257 | for (auto [it, tp] : llvm::zip_equal(t&: iterators, u&: iterSpaceTps)) { |
2258 | IterSpaceType spaceTp = llvm::dyn_cast<IterSpaceType>(tp); |
2259 | if (!spaceTp) |
2260 | return parser.emitError(loc: parser.getNameLoc(), |
2261 | message: "expected sparse_tensor.iter_space type for " |
2262 | "iteration space operands"); |
2263 | it.type = spaceTp.getIteratorType(); |
2264 | } |
2265 | |
2266 | if (hasIterArgs) |
2267 | if (parser.parseArrowTypeList(result&: state.types)) |
2268 | return failure(); |
2269 | |
2270 | // Resolves input operands. |
2271 | if (parser.resolveOperands(operands&: spaces, types&: iterSpaceTps, loc: parser.getNameLoc(), |
2272 | result&: state.operands)) |
2273 | return failure(); |
2274 | |
2275 | if (hasIterArgs) { |
2276 | // Strip off leading args that used for coordinates. |
2277 | MutableArrayRef args = MutableArrayRef(blockArgs).drop_back(N: numCrds); |
2278 | if (args.size() != initArgs.size() || args.size() != state.types.size()) { |
2279 | return parser.emitError( |
2280 | loc: parser.getNameLoc(), |
2281 | message: "mismatch in number of iteration arguments and return values"); |
2282 | } |
2283 | |
2284 | for (auto [it, init, tp] : llvm::zip_equal(t&: args, u&: initArgs, args&: state.types)) { |
2285 | it.type = tp; |
2286 | if (parser.resolveOperand(operand: init, type: tp, result&: state.operands)) |
2287 | return failure(); |
2288 | } |
2289 | } |
2290 | return success(); |
2291 | } |
2292 | |
2293 | static ParseResult |
2294 | parseSparseCoIterateLoop(OpAsmParser &parser, OperationState &state, |
2295 | SmallVectorImpl<Value> &spacesVals, |
2296 | SmallVectorImpl<OpAsmParser::Argument> &blockArgs) { |
2297 | |
2298 | // Parse "(%spaces, ...)" |
2299 | SmallVector<OpAsmParser::UnresolvedOperand> spaces; |
2300 | if (parser.parseOperandList(result&: spaces, delimiter: OpAsmParser::Delimiter::Paren)) |
2301 | return failure(); |
2302 | |
2303 | SmallVector<OpAsmParser::Argument> coords; |
2304 | if (failed(Result: parseUsedCoordList(parser, state, coords))) |
2305 | return failure(); |
2306 | size_t numCrds = coords.size(); |
2307 | |
2308 | // Parse "iter_args(%arg = %init, ...)" |
2309 | SmallVector<OpAsmParser::UnresolvedOperand> initArgs; |
2310 | bool hasIterArgs = succeeded(Result: parser.parseOptionalKeyword(keyword: "iter_args")); |
2311 | if (hasIterArgs) |
2312 | if (parser.parseAssignmentList(lhs&: blockArgs, rhs&: initArgs)) |
2313 | return failure(); |
2314 | blockArgs.append(RHS: coords); |
2315 | |
2316 | SmallVector<Type> iterSpaceTps; |
2317 | // parse ": (sparse_tensor.iter_space, ...) -> ret" |
2318 | if (parser.parseColon() || parser.parseLParen() || |
2319 | parser.parseTypeList(result&: iterSpaceTps) || parser.parseRParen()) |
2320 | return failure(); |
2321 | |
2322 | if (iterSpaceTps.size() != spaces.size()) |
2323 | return parser.emitError(loc: parser.getNameLoc(), |
2324 | message: "mismatch in number of iteration space operands " |
2325 | "and iteration space types"); |
2326 | |
2327 | if (hasIterArgs) |
2328 | if (parser.parseArrowTypeList(result&: state.types)) |
2329 | return failure(); |
2330 | |
2331 | // Resolves input sparse iteration spaces. |
2332 | if (parser.resolveOperands(operands&: spaces, types&: iterSpaceTps, loc: parser.getNameLoc(), |
2333 | result&: spacesVals)) |
2334 | return failure(); |
2335 | state.operands.append(RHS: spacesVals); |
2336 | |
2337 | if (hasIterArgs) { |
2338 | // Strip off trailing args that used for coordinates. |
2339 | MutableArrayRef args = MutableArrayRef(blockArgs).drop_back(N: numCrds); |
2340 | if (args.size() != initArgs.size() || args.size() != state.types.size()) { |
2341 | return parser.emitError( |
2342 | loc: parser.getNameLoc(), |
2343 | message: "mismatch in number of iteration arguments and return values"); |
2344 | } |
2345 | |
2346 | for (auto [it, init, tp] : llvm::zip_equal(t&: args, u&: initArgs, args&: state.types)) { |
2347 | it.type = tp; |
2348 | if (parser.resolveOperand(operand: init, type: tp, result&: state.operands)) |
2349 | return failure(); |
2350 | } |
2351 | } |
2352 | return success(); |
2353 | } |
2354 | |
2355 | LogicalResult ExtractIterSpaceOp::inferReturnTypes( |
2356 | MLIRContext *ctx, std::optional<Location> loc, ValueRange ops, |
2357 | DictionaryAttr attr, OpaqueProperties prop, RegionRange region, |
2358 | SmallVectorImpl<mlir::Type> &ret) { |
2359 | |
2360 | ExtractIterSpaceOp::Adaptor adaptor(ops, attr, prop, region); |
2361 | SparseTensorType stt = getSparseTensorType(adaptor.getTensor()); |
2362 | ret.push_back(IterSpaceType::get(ctx, stt.getEncoding(), adaptor.getLoLvl(), |
2363 | adaptor.getHiLvl())); |
2364 | return success(); |
2365 | } |
2366 | |
2367 | LogicalResult ExtractIterSpaceOp::verify() { |
2368 | if (getLoLvl() >= getHiLvl()) |
2369 | return emitOpError("expected smaller level low than level high"); |
2370 | |
2371 | TypedValue<IteratorType> pIter = getParentIter(); |
2372 | if ((pIter && getLoLvl() == 0) || (!pIter && getLoLvl() != 0)) { |
2373 | return emitOpError( |
2374 | "parent iterator should be specified iff level lower bound equals 0"); |
2375 | } |
2376 | |
2377 | if (pIter) { |
2378 | IterSpaceType spaceTp = getExtractedSpace().getType(); |
2379 | if (pIter.getType().getEncoding() != spaceTp.getEncoding()) |
2380 | return emitOpError( |
2381 | "mismatch in parent iterator encoding and iteration space encoding."); |
2382 | |
2383 | if (spaceTp.getLoLvl() != pIter.getType().getHiLvl()) |
2384 | return emitOpError("parent iterator should be used to extract an " |
2385 | "iteration space from a consecutive level."); |
2386 | } |
2387 | |
2388 | return success(); |
2389 | } |
2390 | |
2391 | LogicalResult ExtractValOp::verify() { |
2392 | auto stt = getSparseTensorType(getTensor()); |
2393 | auto itTp = getIterator().getType(); |
2394 | |
2395 | if (stt.getEncoding() != itTp.getEncoding()) |
2396 | return emitOpError("mismatch in tensor encoding and iterator encoding."); |
2397 | |
2398 | if (stt.getLvlRank() != itTp.getHiLvl()) |
2399 | return emitOpError("must use last-level iterator to extract values. "); |
2400 | |
2401 | return success(); |
2402 | } |
2403 | |
2404 | struct RemoveUnusedLvlCrds : public OpRewritePattern<IterateOp> { |
2405 | using OpRewritePattern::OpRewritePattern; |
2406 | |
2407 | LogicalResult matchAndRewrite(IterateOp iterateOp, |
2408 | PatternRewriter &rewriter) const override { |
2409 | I64BitSet newUsedLvls(0); |
2410 | llvm::BitVector toRemove(iterateOp.getBody()->getNumArguments()); |
2411 | for (unsigned i = 0, e = iterateOp.getSpaceDim(); i < e; i++) { |
2412 | if (auto crd = iterateOp.getLvlCrd(i)) { |
2413 | if (crd->getUsers().empty()) |
2414 | toRemove.set(crd->getArgNumber()); |
2415 | else |
2416 | newUsedLvls.set(i); |
2417 | } |
2418 | } |
2419 | |
2420 | // All coordinates are used. |
2421 | if (toRemove.none()) |
2422 | return failure(); |
2423 | |
2424 | rewriter.startOpModification(op: iterateOp); |
2425 | iterateOp.setCrdUsedLvls(newUsedLvls); |
2426 | iterateOp.getBody()->eraseArguments(toRemove); |
2427 | rewriter.finalizeOpModification(op: iterateOp); |
2428 | return success(); |
2429 | } |
2430 | }; |
2431 | |
2432 | void IterateOp::getCanonicalizationPatterns(mlir::RewritePatternSet &results, |
2433 | mlir::MLIRContext *context) { |
2434 | results.add<RemoveUnusedLvlCrds>(context); |
2435 | } |
2436 | |
2437 | void IterateOp::build(OpBuilder &builder, OperationState &odsState, |
2438 | Value iterSpace, ValueRange initArgs) { |
2439 | unsigned rank = llvm::cast<IterSpaceType>(iterSpace.getType()).getSpaceDim(); |
2440 | // All ones. |
2441 | I64BitSet set((1 << rank) - 1); |
2442 | return build(builder, odsState, iterSpace, initArgs, set); |
2443 | } |
2444 | |
2445 | void IterateOp::build(OpBuilder &builder, OperationState &odsState, |
2446 | Value iterSpace, ValueRange initArgs, |
2447 | I64BitSet crdUsedLvls) { |
2448 | OpBuilder::InsertionGuard guard(builder); |
2449 | |
2450 | odsState.addOperands(iterSpace); |
2451 | odsState.addOperands(initArgs); |
2452 | odsState.getOrAddProperties<Properties>().crdUsedLvls = |
2453 | builder.getIntegerAttr(builder.getIntegerType(64), crdUsedLvls); |
2454 | Region *bodyRegion = odsState.addRegion(); |
2455 | odsState.addTypes(initArgs.getTypes()); |
2456 | Block *bodyBlock = builder.createBlock(bodyRegion); |
2457 | |
2458 | // Starts with a list of user-provided loop arguments. |
2459 | for (Value v : initArgs) |
2460 | bodyBlock->addArgument(v.getType(), v.getLoc()); |
2461 | |
2462 | // Follows by a list of used coordinates. |
2463 | for (unsigned i = 0, e = crdUsedLvls.count(); i < e; i++) |
2464 | bodyBlock->addArgument(builder.getIndexType(), odsState.location); |
2465 | |
2466 | // Ends with sparse iterator |
2467 | bodyBlock->addArgument( |
2468 | llvm::cast<IterSpaceType>(iterSpace.getType()).getIteratorType(), |
2469 | odsState.location); |
2470 | } |
2471 | |
2472 | ParseResult IterateOp::parse(OpAsmParser &parser, OperationState &result) { |
2473 | OpAsmParser::Argument iterator; |
2474 | OpAsmParser::UnresolvedOperand iterSpace; |
2475 | |
2476 | SmallVector<OpAsmParser::Argument> iters, iterArgs; |
2477 | if (parseSparseIterateLoop(parser, result, iters, iterArgs)) |
2478 | return failure(); |
2479 | if (iters.size() != 1) |
2480 | return parser.emitError(parser.getNameLoc(), |
2481 | "expected only one iterator/iteration space"); |
2482 | |
2483 | iterArgs.append(iters); |
2484 | Region *body = result.addRegion(); |
2485 | if (parser.parseRegion(*body, iterArgs)) |
2486 | return failure(); |
2487 | |
2488 | IterateOp::ensureTerminator(*body, parser.getBuilder(), result.location); |
2489 | |
2490 | // Parse the optional attribute list. |
2491 | if (parser.parseOptionalAttrDict(result.attributes)) |
2492 | return failure(); |
2493 | |
2494 | return success(); |
2495 | } |
2496 | |
2497 | /// Prints the initialization list in the form of |
2498 | /// <prefix>(%inner = %outer, %inner2 = %outer2, <...>) |
2499 | /// where 'inner' values are assumed to be region arguments and 'outer' values |
2500 | /// are regular SSA values. |
2501 | static void printInitializationList(OpAsmPrinter &p, |
2502 | Block::BlockArgListType blocksArgs, |
2503 | ValueRange initializers, |
2504 | StringRef prefix = "") { |
2505 | assert(blocksArgs.size() == initializers.size() && |
2506 | "expected same length of arguments and initializers"); |
2507 | if (initializers.empty()) |
2508 | return; |
2509 | |
2510 | p << prefix << '('; |
2511 | llvm::interleaveComma(c: llvm::zip(t&: blocksArgs, u&: initializers), os&: p, each_fn: [&](auto it) { |
2512 | p << std::get<0>(it) << " = "<< std::get<1>(it); |
2513 | }); |
2514 | p << ")"; |
2515 | } |
2516 | |
2517 | template <typename SparseLoopOp> |
2518 | static LogicalResult verifySparseLoopOp(SparseLoopOp op) { |
2519 | if (op.getInitArgs().size() != op.getNumResults()) { |
2520 | return op.emitOpError( |
2521 | "mismatch in number of loop-carried values and defined values"); |
2522 | } |
2523 | if (op.getCrdUsedLvls().max() > op.getSpaceDim()) |
2524 | return op.emitOpError("required out-of-bound coordinates"); |
2525 | |
2526 | return success(); |
2527 | } |
2528 | |
2529 | LogicalResult IterateOp::verify() { return verifySparseLoopOp(*this); } |
2530 | LogicalResult CoIterateOp::verify() { return verifySparseLoopOp(*this); } |
2531 | |
2532 | void IterateOp::print(OpAsmPrinter &p) { |
2533 | p << " "<< getIterator() << " in "<< getIterSpace(); |
2534 | if (!getCrdUsedLvls().empty()) { |
2535 | p << " at("; |
2536 | printOptionalDefinedList(p, getSpaceDim(), getCrds(), getCrdUsedLvls()); |
2537 | p << ")"; |
2538 | } |
2539 | printInitializationList(p, getRegionIterArgs(), getInitArgs(), " iter_args"); |
2540 | |
2541 | p << " : "<< getIterSpace().getType() << " "; |
2542 | if (!getInitArgs().empty()) |
2543 | p.printArrowTypeList(getInitArgs().getTypes()); |
2544 | |
2545 | p << " "; |
2546 | p.printRegion(getRegion(), /*printEntryBlockArgs=*/false, |
2547 | /*printBlockTerminators=*/!getInitArgs().empty()); |
2548 | } |
2549 | |
2550 | LogicalResult IterateOp::verifyRegions() { |
2551 | if (getIterator().getType() != getIterSpace().getType().getIteratorType()) |
2552 | return emitOpError("mismatch in iterator and iteration space type"); |
2553 | if (getNumRegionIterArgs() != getNumResults()) |
2554 | return emitOpError( |
2555 | "mismatch in number of basic block args and defined values"); |
2556 | |
2557 | auto initArgs = getInitArgs(); |
2558 | auto iterArgs = getRegionIterArgs(); |
2559 | auto yieldVals = getYieldedValues(); |
2560 | auto opResults = getResults(); |
2561 | if (!llvm::all_equal({initArgs.size(), iterArgs.size(), yieldVals.size(), |
2562 | opResults.size()})) { |
2563 | return emitOpError() << "number mismatch between iter args and results."; |
2564 | } |
2565 | |
2566 | for (auto [i, init, iter, yield, ret] : |
2567 | llvm::enumerate(initArgs, iterArgs, yieldVals, opResults)) { |
2568 | if (init.getType() != ret.getType()) |
2569 | return emitOpError() << "types mismatch between "<< i |
2570 | << "th iter operand and defined value"; |
2571 | if (iter.getType() != ret.getType()) |
2572 | return emitOpError() << "types mismatch between "<< i |
2573 | << "th iter region arg and defined value"; |
2574 | if (yield.getType() != ret.getType()) |
2575 | return emitOpError() << "types mismatch between "<< i |
2576 | << "th yield value and defined value"; |
2577 | } |
2578 | |
2579 | return success(); |
2580 | } |
2581 | |
2582 | /// OpInterfaces' methods implemented by IterateOp. |
2583 | SmallVector<Region *> IterateOp::getLoopRegions() { return {&getRegion()}; } |
2584 | |
2585 | MutableArrayRef<OpOperand> IterateOp::getInitsMutable() { |
2586 | return getInitArgsMutable(); |
2587 | } |
2588 | |
2589 | Block::BlockArgListType IterateOp::getRegionIterArgs() { |
2590 | return getRegion().getArguments().take_front(getNumRegionIterArgs()); |
2591 | } |
2592 | |
2593 | std::optional<MutableArrayRef<OpOperand>> IterateOp::getYieldedValuesMutable() { |
2594 | return cast<sparse_tensor::YieldOp>( |
2595 | getRegion().getBlocks().front().getTerminator()) |
2596 | .getResultsMutable(); |
2597 | } |
2598 | |
2599 | std::optional<ResultRange> IterateOp::getLoopResults() { return getResults(); } |
2600 | |
2601 | OperandRange IterateOp::getEntrySuccessorOperands(RegionBranchPoint point) { |
2602 | return getInitArgs(); |
2603 | } |
2604 | |
2605 | void IterateOp::getSuccessorRegions(RegionBranchPoint point, |
2606 | SmallVectorImpl<RegionSuccessor> ®ions) { |
2607 | // Both the operation itself and the region may be branching into the body |
2608 | // or back into the operation itself. |
2609 | regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs())); |
2610 | // It is possible for loop not to enter the body. |
2611 | regions.push_back(RegionSuccessor(getResults())); |
2612 | } |
2613 | |
2614 | void CoIterateOp::build(OpBuilder &builder, OperationState &odsState, |
2615 | ValueRange iterSpaces, ValueRange initArgs, |
2616 | unsigned numCases) { |
2617 | unsigned rank = |
2618 | cast<IterSpaceType>(iterSpaces.front().getType()).getSpaceDim(); |
2619 | // All ones. |
2620 | I64BitSet set((1 << rank) - 1); |
2621 | // Generates all-zero case bits (they only serve as placeholders), which are |
2622 | // supposed to be overriden later. We need to preallocate all the regions as |
2623 | // mlir::Region cannot be dynamically added later after the operation is |
2624 | // created. |
2625 | SmallVector<int64_t> caseBits(numCases, 0); |
2626 | ArrayAttr cases = builder.getI64ArrayAttr(caseBits); |
2627 | return CoIterateOp::build(builder, odsState, initArgs.getTypes(), iterSpaces, |
2628 | initArgs, set, cases, |
2629 | /*caseRegionsCount=*/numCases); |
2630 | } |
2631 | |
2632 | ParseResult CoIterateOp::parse(OpAsmParser &parser, OperationState &result) { |
2633 | |
2634 | SmallVector<Value> spaces; |
2635 | // The block argument list of each regions, it is arranged in the order of |
2636 | // ([used coordinate list], [loop iterations args], [sparse iterator list]). |
2637 | SmallVector<OpAsmParser::Argument> blockArgs; |
2638 | if (parseSparseCoIterateLoop(parser, result, spaces, blockArgs)) |
2639 | return failure(); |
2640 | |
2641 | result.addAttribute("operandSegmentSizes", |
2642 | parser.getBuilder().getDenseI32ArrayAttr( |
2643 | {static_cast<int32_t>(spaces.size()), |
2644 | static_cast<int32_t>(result.types.size())})); |
2645 | |
2646 | SmallVector<Attribute> cases; |
2647 | while (succeeded(parser.parseOptionalKeyword("case"))) { |
2648 | // Parse one region per case. |
2649 | I64BitSet definedItSet; |
2650 | SmallVector<OpAsmParser::Argument> definedIts; |
2651 | if (parseOptionalDefinedList(parser, result, definedItSet, definedIts, |
2652 | spaces.size(), OpAsmParser::Delimiter::None)) |
2653 | return failure(); |
2654 | |
2655 | cases.push_back(parser.getBuilder().getI64IntegerAttr(definedItSet)); |
2656 | |
2657 | for (auto [i, definedIdx] : llvm::enumerate(definedItSet.bits())) { |
2658 | // Resolve the iterator type based on the iteration space type. |
2659 | auto spaceTp = llvm::cast<IterSpaceType>(spaces[definedIdx].getType()); |
2660 | definedIts[i].type = spaceTp.getIteratorType(); |
2661 | } |
2662 | definedIts.insert(definedIts.begin(), blockArgs.begin(), blockArgs.end()); |
2663 | Region *body = result.addRegion(); |
2664 | if (parser.parseRegion(*body, definedIts)) |
2665 | return failure(); |
2666 | |
2667 | CoIterateOp::ensureTerminator(*body, parser.getBuilder(), result.location); |
2668 | } |
2669 | |
2670 | result.addAttribute("cases", ArrayAttr::get(parser.getContext(), cases)); |
2671 | |
2672 | // Parse the optional attribute list. |
2673 | if (parser.parseOptionalAttrDict(result.attributes)) |
2674 | return failure(); |
2675 | |
2676 | return success(); |
2677 | } |
2678 | |
2679 | void CoIterateOp::print(OpAsmPrinter &p) { |
2680 | p << " ("; |
2681 | llvm::interleaveComma(getIterSpaces(), p, [&](auto s) { p << s; }); |
2682 | p << ")"; |
2683 | |
2684 | if (!getCrdUsedLvls().empty()) { |
2685 | p << " at("; |
2686 | printOptionalDefinedList(p, getSpaceDim(), getCrds(0), getCrdUsedLvls()); |
2687 | p << ")"; |
2688 | } |
2689 | |
2690 | printInitializationList(p, getRegionIterArgs(0), getInitArgs(), " iter_args"); |
2691 | |
2692 | p << " : ("<< getIterSpaces().getTypes() << ")"; |
2693 | if (!getInitArgs().empty()) |
2694 | p.printArrowTypeList(getInitArgs().getTypes()); |
2695 | |
2696 | for (unsigned idx = 0, e = getRegions().size(); idx < e; idx++) { |
2697 | p.printNewline(); |
2698 | p << "case "; |
2699 | printOptionalDefinedList(p, getIterSpaces().size(), getRegionIterators(idx), |
2700 | getRegionDefinedSpace(idx)); |
2701 | p << " "; |
2702 | p.printRegion(getRegion(idx), /*printEntryBlockArgs=*/false, |
2703 | /*printBlockTerminators=*/!getInitArgs().empty()); |
2704 | } |
2705 | } |
2706 | |
2707 | ValueRange CoIterateOp::getYieldedValues(unsigned regionIdx) { |
2708 | return cast<sparse_tensor::YieldOp>( |
2709 | getRegion(regionIdx).getBlocks().front().getTerminator()) |
2710 | .getResults(); |
2711 | } |
2712 | |
2713 | LogicalResult CoIterateOp::verifyRegions() { |
2714 | for (unsigned r = 0, e = getNumRegions(); r < e; r++) { |
2715 | if (getNumRegionIterArgs() != getNumResults()) |
2716 | return emitOpError( |
2717 | "mismatch in number of basic block args and defined values"); |
2718 | |
2719 | auto initArgs = getInitArgs(); |
2720 | auto iterArgs = getRegionIterArgs(r); |
2721 | auto yieldVals = getYieldedValues(r); |
2722 | auto opResults = getResults(); |
2723 | if (!llvm::all_equal({initArgs.size(), iterArgs.size(), yieldVals.size(), |
2724 | opResults.size()})) { |
2725 | return emitOpError() |
2726 | << "number mismatch between iter args and results on "<< r |
2727 | << "th region"; |
2728 | } |
2729 | |
2730 | for (auto [i, init, iter, yield, ret] : |
2731 | llvm::enumerate(initArgs, iterArgs, yieldVals, opResults)) { |
2732 | if (init.getType() != ret.getType()) |
2733 | return emitOpError() |
2734 | << "types mismatch between "<< i |
2735 | << "th iter operand and defined value on "<< r << "th region"; |
2736 | if (iter.getType() != ret.getType()) |
2737 | return emitOpError() << "types mismatch between "<< i |
2738 | << "th iter region arg and defined value on "<< r |
2739 | << "th region"; |
2740 | if (yield.getType() != ret.getType()) |
2741 | return emitOpError() |
2742 | << "types mismatch between "<< i |
2743 | << "th yield value and defined value on "<< r << "th region"; |
2744 | } |
2745 | } |
2746 | |
2747 | auto cases = getRegionDefinedSpaces(); |
2748 | llvm::SmallSetVector<uint64_t, 8> set(cases.begin(), cases.end()); |
2749 | if (set.size() != getNumRegions()) |
2750 | return emitOpError("contains duplicated cases."); |
2751 | |
2752 | return success(); |
2753 | } |
2754 | |
2755 | SmallVector<Region *> CoIterateOp::getSubCasesOf(unsigned regionIdx) { |
2756 | SmallVector<Region *> ret; |
2757 | I64BitSet caseBit = getRegionDefinedSpace(regionIdx); |
2758 | for (Region &r : getCaseRegions()) |
2759 | if (getRegionDefinedSpace(r.getRegionNumber()).isSubSetOf(caseBit)) |
2760 | ret.push_back(&r); |
2761 | |
2762 | return ret; |
2763 | } |
2764 | |
2765 | //===----------------------------------------------------------------------===// |
2766 | // Sparse Tensor Dialect Setups. |
2767 | //===----------------------------------------------------------------------===// |
2768 | |
2769 | /// Materialize a single constant operation from a given attribute value with |
2770 | /// the desired resultant type. |
2771 | Operation *SparseTensorDialect::materializeConstant(OpBuilder &builder, |
2772 | Attribute value, Type type, |
2773 | Location loc) { |
2774 | if (auto op = arith::ConstantOp::materialize(builder, value, type, loc)) |
2775 | return op; |
2776 | return nullptr; |
2777 | } |
2778 | |
2779 | void SparseTensorDialect::initialize() { |
2780 | addAttributes< |
2781 | #define GET_ATTRDEF_LIST |
2782 | #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc" |
2783 | >(); |
2784 | addTypes< |
2785 | #define GET_TYPEDEF_LIST |
2786 | #include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc" |
2787 | >(); |
2788 | addOperations< |
2789 | #define GET_OP_LIST |
2790 | #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc" |
2791 | >(); |
2792 | declarePromisedInterfaces< |
2793 | bufferization::BufferizableOpInterface, ConcatenateOp, ConvertOp, LoadOp, |
2794 | NewOp, NumberOfEntriesOp, AssembleOp, DisassembleOp, |
2795 | ToCoordinatesBufferOp, ToCoordinatesOp, ToPositionsOp, ToValuesOp>(); |
2796 | } |
2797 | |
2798 | #define GET_OP_CLASSES |
2799 | #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc" |
2800 | |
2801 | #include "mlir/Dialect/SparseTensor/IR/SparseTensorOpsDialect.cpp.inc" |
2802 |
Definitions
- hash_value
- acceptBitWidth
- getSparseFieldShape
- kInvalidLevel
- kInvalidFieldIndex
- kDataFieldStartingIdx
- foreachField
- foreachFieldAndTypeInSparseTensor
- getNumFields
- getNumDataFields
- getFieldIndexAndStride
- parseOptionalStaticSlice
- isCOOType
- getCOOType
- getSparseTensorEncoding
- inferLvlToDim
- inverseBlockSparsity
- getBlockSize
- isBlockSparsity
- hasAnyNonIdentityOperandsOrResults
- toDim
- toLvl
- getNormalizedEncodingForSpecifier
- lvlIsInBounds
- isMatchingWidth
- verifySparsifierGetterSetter
- getFieldElemType
- verifyPackUnPack
- inferSparseBufferType
- getSpecifierSetDef
- verifyNumBlockArgs
- parseLevelRange
- parseLevelRange
- printLevelRange
- printLevelRange
- parseOptionalDefinedList
- printOptionalDefinedList
- parseUsedCoordList
- parseSparseIterateLoop
- parseSparseCoIterateLoop
- RemoveUnusedLvlCrds
- matchAndRewrite
- printInitializationList
Improve your Profiling and Debugging skills
Find out more